refactor: code cleanup

This commit is contained in:
Marvin Zhang
2024-10-18 15:03:32 +08:00
parent 354327ffb1
commit 1b852fb96a
171 changed files with 131 additions and 12522 deletions

View File

@@ -1,40 +0,0 @@
package handler
import (
"github.com/crawlab-team/crawlab/core/interfaces"
"time"
)
type Option func(svc interfaces.TaskHandlerService)
func WithConfigPath(path string) Option {
return func(svc interfaces.TaskHandlerService) {
svc.SetConfigPath(path)
}
}
func WithExitWatchDuration(duration time.Duration) Option {
return func(svc interfaces.TaskHandlerService) {
svc.SetExitWatchDuration(duration)
}
}
func WithReportInterval(interval time.Duration) Option {
return func(svc interfaces.TaskHandlerService) {
svc.SetReportInterval(interval)
}
}
func WithCancelTimeout(timeout time.Duration) Option {
return func(svc interfaces.TaskHandlerService) {
svc.SetCancelTimeout(timeout)
}
}
type RunnerOption func(r interfaces.TaskRunner)
func WithSubscribeTimeout(timeout time.Duration) RunnerOption {
return func(r interfaces.TaskRunner) {
r.SetSubscribeTimeout(timeout)
}
}

View File

@@ -1,103 +0,0 @@
package handler
import (
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/google/uuid"
"github.com/spf13/viper"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockRunner struct {
mock.Mock
Runner
}
func (m *MockRunner) downloadFile(url string, filePath string) error {
args := m.Called(url, filePath)
return args.Error(0)
}
func newMockRunner() *MockRunner {
r := &MockRunner{}
r.s = &models.Spider{}
return r
}
func TestSyncFiles_SuccessWithDummyFiles(t *testing.T) {
// Create a test server that responds with a list of files
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/scan") {
w.Write([]byte(`{"file1.txt":{"path": "file1.txt", "hash": "hash1"}, "file2.txt":{"path": "file2.txt", "hash": "hash2"}}`))
return
}
if strings.HasSuffix(r.URL.Path, "/download") {
w.Write([]byte("file content"))
return
}
}))
defer ts.Close()
// Create a mock runner
r := newMockRunner()
r.On("downloadFile", mock.Anything, mock.Anything).Return(nil)
// Set the master URL to the test server URL
viper.Set("api.endpoint", ts.URL)
localPath := filepath.Join(os.TempDir(), uuid.New().String())
os.MkdirAll(filepath.Join(localPath, r.s.GetId().Hex()), os.ModePerm)
defer os.RemoveAll(localPath)
viper.Set("workspace", localPath)
// Call the method under test
err := r.syncFiles()
assert.NoError(t, err)
// Assert that the files were downloaded
assert.FileExists(t, filepath.Join(localPath, r.s.GetId().Hex(), "file1.txt"))
assert.FileExists(t, filepath.Join(localPath, r.s.GetId().Hex(), "file2.txt"))
}
func TestSyncFiles_DeletesFilesNotOnMaster(t *testing.T) {
// Create a test server that responds with an empty list of files
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/scan") {
w.Write([]byte(`{}`))
return
}
}))
defer ts.Close()
// Create a mock runner
r := newMockRunner()
r.On("downloadFile", mock.Anything, mock.Anything).Return(nil)
// Set the master URL to the test server URL
viper.Set("api.endpoint", ts.URL)
localPath := filepath.Join(os.TempDir(), uuid.New().String())
os.MkdirAll(filepath.Join(localPath, r.s.GetId().Hex()), os.ModePerm)
defer os.RemoveAll(localPath)
viper.Set("workspace", localPath)
// Create a dummy file that should be deleted
dummyFilePath := filepath.Join(localPath, r.s.GetId().Hex(), "dummy.txt")
dummyFile, _ := os.Create(dummyFilePath)
dummyFile.Close()
// Call the method under test
err := r.syncFiles()
assert.NoError(t, err)
// Assert that the dummy file was deleted
assert.NoFileExists(t, dummyFilePath)
}

View File

@@ -10,13 +10,12 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
fs2 "github.com/crawlab-team/crawlab/core/fs"
"github.com/crawlab-team/crawlab/core/fs"
client2 "github.com/crawlab-team/crawlab/core/grpc/client"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/client"
"github.com/crawlab-team/crawlab/core/models/models"
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
service2 "github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/models/models/v2"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/sys_exec"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/grpc"
@@ -45,17 +44,16 @@ type RunnerV2 struct {
bufferSize int
// internals
cmd *exec.Cmd // process command instance
pid int // process id
tid primitive.ObjectID // task id
t *models2.TaskV2 // task model.Task
s *models2.SpiderV2 // spider model.Spider
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
envs []models.Env // environment variables
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
sub grpc.TaskService_SubscribeClient // grpc task service stream client
cmd *exec.Cmd // process command instance
pid int // process id
tid primitive.ObjectID // task id
t *models.TaskV2 // task model.Task
s *models.SpiderV2 // spider model.Spider
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
sub grpc.TaskService_SubscribeClient // grpc task service stream client
// log internals
scannerStdout *bufio.Reader
@@ -316,7 +314,7 @@ func (r *RunnerV2) configureEnv() {
}
// global environment variables
envs, err := client.NewModelServiceV2[models2.EnvironmentV2]().GetMany(nil, nil)
envs, err := client.NewModelServiceV2[models.EnvironmentV2]().GetMany(nil, nil)
if err != nil {
trace.PrintError(err)
return
@@ -511,12 +509,12 @@ func (r *RunnerV2) updateTask(status string, e error) (err error) {
r.t.Error = e.Error()
}
if r.svc.GetNodeConfigService().IsMaster() {
err = service2.NewModelServiceV2[models2.TaskV2]().ReplaceById(r.t.Id, *r.t)
err = service.NewModelServiceV2[models.TaskV2]().ReplaceById(r.t.Id, *r.t)
if err != nil {
return err
}
} else {
err = client.NewModelServiceV2[models2.TaskV2]().ReplaceById(r.t.Id, *r.t)
err = client.NewModelServiceV2[models.TaskV2]().ReplaceById(r.t.Id, *r.t)
if err != nil {
return err
}
@@ -567,7 +565,7 @@ func (r *RunnerV2) writeLogLines(lines []string) {
}
func (r *RunnerV2) _updateTaskStat(status string) {
ts, err := client.NewModelServiceV2[models2.TaskStatV2]().GetById(r.tid)
ts, err := client.NewModelServiceV2[models.TaskStatV2]().GetById(r.tid)
if err != nil {
trace.PrintError(err)
return
@@ -588,13 +586,13 @@ func (r *RunnerV2) _updateTaskStat(status string) {
ts.TotalDuration = ts.EndTs.Sub(ts.BaseModelV2.CreatedAt).Milliseconds()
}
if r.svc.GetNodeConfigService().IsMaster() {
err = service2.NewModelServiceV2[models2.TaskStatV2]().ReplaceById(ts.Id, *ts)
err = service.NewModelServiceV2[models.TaskStatV2]().ReplaceById(ts.Id, *ts)
if err != nil {
trace.PrintError(err)
return
}
} else {
err = client.NewModelServiceV2[models2.TaskStatV2]().ReplaceById(ts.Id, *ts)
err = client.NewModelServiceV2[models.TaskStatV2]().ReplaceById(ts.Id, *ts)
if err != nil {
trace.PrintError(err)
return
@@ -617,7 +615,7 @@ func (r *RunnerV2) sendNotification() {
func (r *RunnerV2) _updateSpiderStat(status string) {
// task stat
ts, err := client.NewModelServiceV2[models2.TaskStatV2]().GetById(r.tid)
ts, err := client.NewModelServiceV2[models.TaskStatV2]().GetById(r.tid)
if err != nil {
trace.PrintError(err)
return
@@ -655,13 +653,13 @@ func (r *RunnerV2) _updateSpiderStat(status string) {
// perform update
if r.svc.GetNodeConfigService().IsMaster() {
err = service2.NewModelServiceV2[models2.SpiderStatV2]().UpdateById(r.s.Id, update)
err = service.NewModelServiceV2[models.SpiderStatV2]().UpdateById(r.s.Id, update)
if err != nil {
trace.PrintError(err)
return
}
} else {
err = client.NewModelServiceV2[models2.SpiderStatV2]().UpdateById(r.s.Id, update)
err = client.NewModelServiceV2[models.SpiderStatV2]().UpdateById(r.s.Id, update)
if err != nil {
trace.PrintError(err)
return
@@ -709,7 +707,7 @@ func NewTaskRunnerV2(id primitive.ObjectID, svc *ServiceV2) (r2 *RunnerV2, err e
}
// task fs service
r.fsSvc = fs2.NewFsServiceV2(filepath.Join(viper.GetString("workspace"), r.s.Id.Hex()))
r.fsSvc = fs.NewFsServiceV2(filepath.Join(viper.GetString("workspace"), r.s.Id.Hex()))
// grpc client
r.c = client2.GetGrpcClientV2()