diff --git a/core/constants/node.go b/core/constants/node.go index 464c517c..1736a413 100644 --- a/core/constants/node.go +++ b/core/constants/node.go @@ -1,8 +1,6 @@ package constants const ( - NodeStatusUnregistered = "u" - NodeStatusRegistered = "r" - NodeStatusOnline = "on" - NodeStatusOffline = "off" + NodeStatusOnline = "on" + NodeStatusOffline = "off" ) diff --git a/core/constants/task.go b/core/constants/task.go index aaa71c11..22a10662 100644 --- a/core/constants/task.go +++ b/core/constants/task.go @@ -2,6 +2,7 @@ package constants const ( TaskStatusPending = "pending" + TaskStatusAssigned = "assigned" TaskStatusRunning = "running" TaskStatusFinished = "finished" TaskStatusError = "error" diff --git a/core/controllers/spider_v2.go b/core/controllers/spider_v2.go index 9a8e6ca7..2cc2ae26 100644 --- a/core/controllers/spider_v2.go +++ b/core/controllers/spider_v2.go @@ -653,7 +653,7 @@ func PostSpiderRun(c *gin.Context) { opts.UserId = u.GetId() } - adminSvc, err := admin.GetSpiderAdminServiceV2() + adminSvc, err := admin.GetSpiderAdminService() if err != nil { HandleErrorInternalServerError(c, err) return diff --git a/core/controllers/task_v2.go b/core/controllers/task_v2.go index 829ad385..2fb15252 100644 --- a/core/controllers/task_v2.go +++ b/core/controllers/task_v2.go @@ -305,7 +305,7 @@ func PostTaskRun(c *gin.Context) { } // run - adminSvc, err := admin.GetSpiderAdminServiceV2() + adminSvc, err := admin.GetSpiderAdminService() if err != nil { HandleErrorInternalServerError(c, err) return @@ -350,7 +350,7 @@ func PostTaskRestart(c *gin.Context) { } // run - adminSvc, err := admin.GetSpiderAdminServiceV2() + adminSvc, err := admin.GetSpiderAdminService() if err != nil { HandleErrorInternalServerError(c, err) return @@ -399,7 +399,7 @@ func PostTaskCancel(c *gin.Context) { u := GetUserFromContext(c) // cancel - schedulerSvc, err := scheduler.GetTaskSchedulerServiceV2() + schedulerSvc, err := scheduler.GetTaskSchedulerService() if err != nil { HandleErrorInternalServerError(c, err) return diff --git a/core/grpc/client/client.go b/core/grpc/client/client.go index 9d3c9165..94d6d939 100644 --- a/core/grpc/client/client.go +++ b/core/grpc/client/client.go @@ -141,6 +141,8 @@ func (c *GrpcClient) connect() (err error) { return err } + // connect + c.conn.Connect() log.Infof("[GrpcClient] grpc client connected to %s", address) return nil diff --git a/core/grpc/server/node_service_server.go b/core/grpc/server/node_service_server.go index 3c504266..661aef6f 100644 --- a/core/grpc/server/node_service_server.go +++ b/core/grpc/server/node_service_server.go @@ -16,7 +16,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" - "io" "sync" "time" ) @@ -30,8 +29,7 @@ type NodeServiceServer struct { cfgSvc interfaces.NodeConfigService // internals - server *GrpcServer - subs map[primitive.ObjectID]grpc.NodeService_SubscribeServer + subs map[primitive.ObjectID]grpc.NodeService_SubscribeServer } // Register from handler/worker to master @@ -46,7 +44,7 @@ func (svr NodeServiceServer) Register(_ context.Context, req *grpc.NodeServiceRe node, err = service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil) if err == nil { // register existing - node.Status = constants.NodeStatusRegistered + node.Status = constants.NodeStatusOnline node.Active = true node.ActiveAt = time.Now() err = service.NewModelServiceV2[models.NodeV2]().ReplaceById(node.Id, *node) @@ -59,7 +57,7 @@ func (svr NodeServiceServer) Register(_ context.Context, req *grpc.NodeServiceRe node = &models.NodeV2{ Key: req.NodeKey, Name: req.NodeName, - Status: constants.NodeStatusRegistered, + Status: constants.NodeStatusOnline, Active: true, ActiveAt: time.Now(), Enabled: true, @@ -94,11 +92,6 @@ func (svr NodeServiceServer) SendHeartbeat(_ context.Context, req *grpc.NodeServ } oldStatus := node.Status - // validate status - if node.Status == constants.NodeStatusUnregistered { - return HandleError(errors.ErrorNodeUnregistered) - } - // update status node.Status = constants.NodeStatusOnline node.Active = true @@ -136,26 +129,14 @@ func (svr NodeServiceServer) Subscribe(request *grpc.NodeServiceSubscribeRequest // TODO: send notification - // create a new goroutine to receive messages from the stream to listen for EOF (end of stream) - go func() { - for { - select { - case <-stream.Context().Done(): - nodeServiceMutex.Lock() - delete(svr.subs, node.Id) - nodeServiceMutex.Unlock() - return - default: - err := stream.RecvMsg(nil) - if err == io.EOF { - nodeServiceMutex.Lock() - delete(svr.subs, node.Id) - nodeServiceMutex.Unlock() - return - } - } - } - }() + // wait for stream to close + <-stream.Context().Done() + + // unsubscribe + nodeServiceMutex.Lock() + delete(svr.subs, node.Id) + nodeServiceMutex.Unlock() + log.Infof("[NodeServiceServer] master unsubscribed from node[%s]", request.NodeKey) return nil } @@ -167,16 +148,18 @@ func (svr NodeServiceServer) GetSubscribeStream(nodeId primitive.ObjectID) (stre return stream, ok } -var nodeSvrV2 *NodeServiceServer -var nodeSvrV2Once = new(sync.Once) +var nodeSvr *NodeServiceServer +var nodeSvrOnce = new(sync.Once) func NewNodeServiceServer() (res *NodeServiceServer, err error) { - if nodeSvrV2 != nil { - return nodeSvrV2, nil + if nodeSvr != nil { + return nodeSvr, nil } - nodeSvrV2Once.Do(func() { - nodeSvrV2 = &NodeServiceServer{} - nodeSvrV2.cfgSvc = nodeconfig.GetNodeConfigService() + nodeSvrOnce.Do(func() { + nodeSvr = &NodeServiceServer{ + subs: make(map[primitive.ObjectID]grpc.NodeService_SubscribeServer), + } + nodeSvr.cfgSvc = nodeconfig.GetNodeConfigService() if err != nil { log.Errorf("[NodeServiceServer] error: %s", err.Error()) } @@ -184,5 +167,5 @@ func NewNodeServiceServer() (res *NodeServiceServer, err error) { if err != nil { return nil, err } - return nodeSvrV2, nil + return nodeSvr, nil } diff --git a/core/grpc/server/task_service_server.go b/core/grpc/server/task_service_server.go index f2230786..7800c6ad 100644 --- a/core/grpc/server/task_service_server.go +++ b/core/grpc/server/task_service_server.go @@ -54,26 +54,14 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st svr.subs[taskId] = stream taskServiceMutex.Unlock() - // create a new goroutine to receive messages from the stream to listen for EOF (end of stream) - go func() { - for { - select { - case <-stream.Context().Done(): - taskServiceMutex.Lock() - delete(svr.subs, taskId) - taskServiceMutex.Unlock() - return - default: - err := stream.RecvMsg(nil) - if err == io.EOF { - taskServiceMutex.Lock() - delete(svr.subs, taskId) - taskServiceMutex.Unlock() - return - } - } - } - }() + // wait for stream to close + <-stream.Context().Done() + + // remove stream + taskServiceMutex.Lock() + delete(svr.subs, taskId) + taskServiceMutex.Unlock() + log.Infof("[TaskServiceServer] task stream closed: %s", taskId.Hex()) return nil } @@ -129,33 +117,47 @@ func (svr TaskServiceServer) FetchTask(ctx context.Context, request *grpc.TaskSe var tid primitive.ObjectID opts := &mongo.FindOptions{ Sort: bson.D{ - {"p", 1}, + {"priority", 1}, {"_id", 1}, }, Limit: 1, } if err := mongo.RunTransactionWithContext(ctx, func(sc mongo2.SessionContext) (err error) { - // get task queue item assigned to this node - tid, err = svr.getTaskQueueItemIdAndDequeue(bson.M{"nid": n.Id}, opts, n.Id) - if err != nil { + // fetch task for the given node + t, err := service.NewModelServiceV2[models2.TaskV2]().GetOne(bson.M{ + "node_id": n.Id, + "status": constants.TaskStatusPending, + }, opts) + if err == nil { + tid = t.Id + t.Status = constants.TaskStatusAssigned + return svr.saveTask(t) + } else if !errors.Is(err, mongo2.ErrNoDocuments) { + log.Errorf("error fetching task for node[%s]: %v", nodeKey, err) return err } - if !tid.IsZero() { - return nil - } - // get task queue item assigned to any node (random mode) - tid, err = svr.getTaskQueueItemIdAndDequeue(bson.M{"nid": nil}, opts, n.Id) - if !tid.IsZero() { - return nil - } - if err != nil { + // fetch task for any node + t, err = service.NewModelServiceV2[models2.TaskV2]().GetOne(bson.M{ + "node_id": primitive.NilObjectID, + "status": constants.TaskStatusPending, + }, opts) + if err == nil { + tid = t.Id + t.NodeId = n.Id + t.Status = constants.TaskStatusAssigned + return svr.saveTask(t) + } else if !errors.Is(err, mongo2.ErrNoDocuments) { + log.Errorf("error fetching task for any node: %v", err) return err } + + // no task found return nil }); err != nil { return nil, err } + return &grpc.TaskServiceFetchTaskResponse{TaskId: tid.Hex()}, nil } @@ -284,32 +286,16 @@ func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *gr return svr.statsSvc.InsertLogs(taskId, logs...) } -func (svr TaskServiceServer) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.FindOptions, nid primitive.ObjectID) (tid primitive.ObjectID, err error) { - tq, err := service.NewModelServiceV2[models2.TaskQueueItemV2]().GetOne(query, opts) - if err != nil { - if errors.Is(err, mongo2.ErrNoDocuments) { - return tid, nil - } - return tid, trace.TraceError(err) - } - t, err := service.NewModelServiceV2[models2.TaskV2]().GetById(tq.Id) - if err == nil { - t.NodeId = nid - err = service.NewModelServiceV2[models2.TaskV2]().ReplaceById(t.Id, *t) - if err != nil { - return tid, trace.TraceError(err) - } - } - err = service.NewModelServiceV2[models2.TaskQueueItemV2]().DeleteById(tq.Id) - if err != nil { - return tid, trace.TraceError(err) - } - return tq.Id, nil +func (svr TaskServiceServer) saveTask(t *models2.TaskV2) (err error) { + t.SetUpdated(t.CreatedBy) + return service.NewModelServiceV2[models2.TaskV2]().ReplaceById(t.Id, *t) } func NewTaskServiceServer() (res *TaskServiceServer, err error) { // task server - svr := &TaskServiceServer{} + svr := &TaskServiceServer{ + subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer), + } svr.cfgSvc = nodeconfig.GetNodeConfigService() diff --git a/core/models/common/init.go b/core/models/common/init_index.go similarity index 98% rename from core/models/common/init.go rename to core/models/common/init_index.go index 69fa4348..c3872a42 100644 --- a/core/models/common/init.go +++ b/core/models/common/init_index.go @@ -45,6 +45,7 @@ func InitIndexes() { {Keys: bson.M{"parent_id": 1}}, {Keys: bson.M{"has_sub": 1}}, {Keys: bson.M{"created_ts": -1}, Options: (&options.IndexOptions{}).SetExpireAfterSeconds(60 * 60 * 24 * 30)}, + {Keys: bson.M{"node_id": 1, "status": 1}}, }) // task stats diff --git a/core/models/models/v2/base_v2.go b/core/models/models/v2/base_v2.go index 92cfed63..851f261b 100644 --- a/core/models/models/v2/base_v2.go +++ b/core/models/models/v2/base_v2.go @@ -91,7 +91,6 @@ func GetModelInstances() []any { *new(SettingV2), *new(SpiderV2), *new(SpiderStatV2), - *new(TaskQueueItemV2), *new(TaskStatV2), *new(TaskV2), *new(TokenV2), diff --git a/core/models/models/v2/task_queue_item_v2.go b/core/models/models/v2/task_queue_item_v2.go deleted file mode 100644 index f222aef3..00000000 --- a/core/models/models/v2/task_queue_item_v2.go +++ /dev/null @@ -1,12 +0,0 @@ -package models - -import ( - "go.mongodb.org/mongo-driver/bson/primitive" -) - -type TaskQueueItemV2 struct { - any `collection:"task_queue"` - BaseModelV2[TaskQueueItemV2] `bson:",inline"` - Priority int `json:"p" bson:"p"` - NodeId primitive.ObjectID `json:"nid,omitempty" bson:"nid,omitempty"` -} diff --git a/core/models/models/v2/task_v2.go b/core/models/models/v2/task_v2.go index 849cdf40..ad7fb05b 100644 --- a/core/models/models/v2/task_v2.go +++ b/core/models/models/v2/task_v2.go @@ -18,11 +18,8 @@ type TaskV2 struct { Type string `json:"type" bson:"type"` Mode string `json:"mode" bson:"mode"` NodeIds []primitive.ObjectID `json:"node_ids" bson:"node_ids"` - ParentId primitive.ObjectID `json:"parent_id" bson:"parent_id"` Priority int `json:"priority" bson:"priority"` Stat *TaskStatV2 `json:"stat,omitempty" bson:"-"` - HasSub bool `json:"has_sub" json:"has_sub"` - SubTasks []TaskV2 `json:"sub_tasks,omitempty" bson:"-"` Spider *SpiderV2 `json:"spider,omitempty" bson:"-"` Schedule *ScheduleV2 `json:"schedule,omitempty" bson:"-"` Node *NodeV2 `json:"node,omitempty" bson:"-"` diff --git a/core/node/service/master_service.go b/core/node/service/master_service.go index fa6736a5..02cf0d52 100644 --- a/core/node/service/master_service.go +++ b/core/node/service/master_service.go @@ -31,7 +31,7 @@ type MasterService struct { // dependencies cfgSvc interfaces.NodeConfigService server *server.GrpcServer - schedulerSvc *scheduler.ServiceV2 + schedulerSvc *scheduler.Service handlerSvc *handler.Service scheduleSvc *schedule.ServiceV2 systemSvc *system.ServiceV2 @@ -339,7 +339,7 @@ func newMasterServiceV2() (res *MasterService, err error) { } // scheduler service - svc.schedulerSvc, err = scheduler.GetTaskSchedulerServiceV2() + svc.schedulerSvc, err = scheduler.GetTaskSchedulerService() if err != nil { return nil, err } diff --git a/core/node/service/worker_service.go b/core/node/service/worker_service.go index 3d26c892..712c6456 100644 --- a/core/node/service/worker_service.go +++ b/core/node/service/worker_service.go @@ -2,8 +2,11 @@ package service import ( "context" - "errors" + "sync" + "time" + "github.com/apex/log" + "github.com/cenkalti/backoff/v4" "github.com/crawlab-team/crawlab/core/config" "github.com/crawlab-team/crawlab/core/grpc/client" "github.com/crawlab-team/crawlab/core/interfaces" @@ -15,9 +18,6 @@ import ( "github.com/crawlab-team/crawlab/grpc" "github.com/crawlab-team/crawlab/trace" "go.mongodb.org/mongo-driver/bson" - "io" - "sync" - "time" ) type WorkerService struct { @@ -49,13 +49,13 @@ func (svc *WorkerService) Start() { } // register to master - svc.Register() + svc.register() // subscribe - svc.Subscribe() + go svc.subscribe() // start sending heartbeat to master - go svc.ReportStatus() + go svc.reportStatus() // start handler go svc.handlerSvc.Start() @@ -78,7 +78,7 @@ func (svc *WorkerService) Stop() { log.Infof("worker[%s] service has stopped", svc.cfgSvc.GetNodeKey()) } -func (svc *WorkerService) Register() { +func (svc *WorkerService) register() { ctx, cancel := svc.client.Context() defer cancel() _, err := svc.client.NodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{ @@ -87,17 +87,19 @@ func (svc *WorkerService) Register() { MaxRunners: int32(svc.cfgSvc.GetMaxRunners()), }) if err != nil { + log.Fatalf("failed to register worker[%s] to master: %v", svc.cfgSvc.GetNodeKey(), err) panic(err) } svc.n, err = client2.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": svc.GetConfigService().GetNodeKey()}, nil) if err != nil { + log.Fatalf("failed to get node: %v", err) panic(err) } log.Infof("worker[%s] registered to master. id: %s", svc.GetConfigService().GetNodeKey(), svc.n.Id.Hex()) return } -func (svc *WorkerService) ReportStatus() { +func (svc *WorkerService) reportStatus() { ticker := time.NewTicker(svc.heartbeatInterval) for { // return if client is closed @@ -106,8 +108,8 @@ func (svc *WorkerService) ReportStatus() { return } - // report status - svc.reportStatus() + // send heartbeat + svc.sendHeartbeat() // sleep <-ticker.C @@ -126,49 +128,65 @@ func (svc *WorkerService) SetConfigPath(path string) { svc.cfgPath = path } -func (svc *WorkerService) GetAddress() (address interfaces.Address) { - return svc.address -} +func (svc *WorkerService) subscribe() { + // Configure exponential backoff + b := backoff.NewExponentialBackOff() + b.InitialInterval = 1 * time.Second + b.MaxInterval = 1 * time.Minute + b.MaxElapsedTime = 10 * time.Minute + b.Multiplier = 2.0 -func (svc *WorkerService) SetAddress(address interfaces.Address) { - svc.address = address -} - -func (svc *WorkerService) SetHeartbeatInterval(duration time.Duration) { - svc.heartbeatInterval = duration -} - -func (svc *WorkerService) Subscribe() { - stream, err := svc.client.NodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{ - NodeKey: svc.cfgSvc.GetNodeKey(), - }) - if err != nil { - log.Errorf("failed to subscribe to master: %v", err) - return - } for { if svc.stopped { return } - select { - default: - msg, err := stream.Recv() + + // Use backoff for connection attempts + operation := func() error { + stream, err := svc.client.NodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{ + NodeKey: svc.cfgSvc.GetNodeKey(), + }) if err != nil { - if errors.Is(err, io.EOF) { - return - } - log.Errorf("failed to receive message from master: %v", err) - continue + log.Errorf("failed to subscribe to master: %v", err) + return err } - switch msg.Code { - case grpc.NodeServiceSubscribeCode_PING: - // do nothing + + // Handle messages + for { + if svc.stopped { + return nil + } + + msg, err := stream.Recv() + if err != nil { + if svc.client.IsClosed() { + log.Errorf("connection to master is closed: %v", err) + return err + } + log.Errorf("failed to receive message from master: %v", err) + return err + } + + switch msg.Code { + case grpc.NodeServiceSubscribeCode_PING: + // do nothing + } } } + + // Execute with backoff + err := backoff.Retry(operation, b) + if err != nil { + log.Errorf("subscription failed after max retries: %v", err) + return + } + + // Wait before attempting to reconnect + time.Sleep(time.Second) } } -func (svc *WorkerService) reportStatus() { +func (svc *WorkerService) sendHeartbeat() { ctx, cancel := context.WithTimeout(context.Background(), svc.heartbeatInterval) defer cancel() _, err := svc.client.NodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{ diff --git a/core/schedule/service_v2.go b/core/schedule/service_v2.go index a1a244ac..0b102574 100644 --- a/core/schedule/service_v2.go +++ b/core/schedule/service_v2.go @@ -20,7 +20,7 @@ type ServiceV2 struct { // dependencies interfaces.WithConfigPath modelSvc *service.ModelServiceV2[models2.ScheduleV2] - adminSvc *admin.ServiceV2 + adminSvc *admin.Service // settings variables loc *time.Location @@ -246,7 +246,7 @@ func NewScheduleServiceV2() (svc2 *ServiceV2, err error) { skip: false, updateInterval: 1 * time.Minute, } - svc.adminSvc, err = admin.GetSpiderAdminServiceV2() + svc.adminSvc, err = admin.GetSpiderAdminService() if err != nil { return nil, err } diff --git a/core/spider/admin/service_v2.go b/core/spider/admin/service.go similarity index 55% rename from core/spider/admin/service_v2.go rename to core/spider/admin/service.go index ee4046dc..b09a9bd5 100644 --- a/core/spider/admin/service_v2.go +++ b/core/spider/admin/service.go @@ -11,24 +11,22 @@ import ( "github.com/crawlab-team/crawlab/core/node/config" "github.com/crawlab-team/crawlab/core/task/scheduler" "github.com/crawlab-team/crawlab/trace" - "github.com/robfig/cron/v3" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "sync" ) -type ServiceV2 struct { +type Service struct { // dependencies nodeCfgSvc interfaces.NodeConfigService - schedulerSvc *scheduler.ServiceV2 - cron *cron.Cron + schedulerSvc *scheduler.Service syncLock bool // settings cfgPath string } -func (svc *ServiceV2) Schedule(id primitive.ObjectID, opts *interfaces.SpiderRunOptions) (taskIds []primitive.ObjectID, err error) { +func (svc *Service) Schedule(id primitive.ObjectID, opts *interfaces.SpiderRunOptions) (taskIds []primitive.ObjectID, err error) { // spider s, err := service.NewModelServiceV2[models2.SpiderV2]().GetById(id) if err != nil { @@ -39,53 +37,58 @@ func (svc *ServiceV2) Schedule(id primitive.ObjectID, opts *interfaces.SpiderRun return svc.scheduleTasks(s, opts) } -func (svc *ServiceV2) scheduleTasks(s *models2.SpiderV2, opts *interfaces.SpiderRunOptions) (taskIds []primitive.ObjectID, err error) { - // main task - t := &models2.TaskV2{ - SpiderId: s.Id, - Mode: opts.Mode, - NodeIds: opts.NodeIds, - Cmd: opts.Cmd, - Param: opts.Param, - ScheduleId: opts.ScheduleId, - Priority: opts.Priority, - } - t.SetId(primitive.NewObjectID()) - - // normalize - if t.Mode == "" { - t.Mode = s.Mode - } - if t.NodeIds == nil { - t.NodeIds = s.NodeIds - } - if t.Cmd == "" { - t.Cmd = s.Cmd - } - if t.Param == "" { - t.Param = s.Param - } - if t.Priority == 0 { - t.Priority = s.Priority - } - +func (svc *Service) scheduleTasks(s *models2.SpiderV2, opts *interfaces.SpiderRunOptions) (taskIds []primitive.ObjectID, err error) { + // get node ids nodeIds, err := svc.getNodeIds(opts) if err != nil { return nil, err } - if len(nodeIds) > 0 { - t.NodeId = nodeIds[0] + + // iterate node ids + for _, nodeId := range nodeIds { + // task + t := &models2.TaskV2{ + SpiderId: s.Id, + NodeId: nodeId, + NodeIds: opts.NodeIds, + Mode: opts.Mode, + Cmd: opts.Cmd, + Param: opts.Param, + ScheduleId: opts.ScheduleId, + Priority: opts.Priority, + } + + // normalize + if t.Mode == "" { + t.Mode = s.Mode + } + if t.NodeIds == nil { + t.NodeIds = s.NodeIds + } + if t.Cmd == "" { + t.Cmd = s.Cmd + } + if t.Param == "" { + t.Param = s.Param + } + if t.Priority == 0 { + t.Priority = s.Priority + } + + // enqueue task + t, err = svc.schedulerSvc.Enqueue(t, opts.UserId) + if err != nil { + return nil, err + } + + // append task id + taskIds = append(taskIds, t.Id) } - t2, err := svc.schedulerSvc.Enqueue(t, opts.UserId) - if err != nil { - return nil, err - } - taskIds = append(taskIds, t2.Id) return taskIds, nil } -func (svc *ServiceV2) getNodeIds(opts *interfaces.SpiderRunOptions) (nodeIds []primitive.ObjectID, err error) { +func (svc *Service) getNodeIds(opts *interfaces.SpiderRunOptions) (nodeIds []primitive.ObjectID, err error) { if opts.Mode == constants.RunTypeAllNodes { query := bson.M{ "active": true, @@ -101,11 +104,15 @@ func (svc *ServiceV2) getNodeIds(opts *interfaces.SpiderRunOptions) (nodeIds []p } } else if opts.Mode == constants.RunTypeSelectedNodes { nodeIds = opts.NodeIds + } else if opts.Mode == constants.RunTypeRandom { + nodeIds = []primitive.ObjectID{primitive.NilObjectID} + } else { + return nil, errors.New("invalid run mode") } return nodeIds, nil } -func (svc *ServiceV2) isMultiTask(opts *interfaces.SpiderRunOptions) (res bool) { +func (svc *Service) isMultiTask(opts *interfaces.SpiderRunOptions) (res bool) { if opts.Mode == constants.RunTypeAllNodes { query := bson.M{ "active": true, @@ -127,19 +134,16 @@ func (svc *ServiceV2) isMultiTask(opts *interfaces.SpiderRunOptions) (res bool) } } -func newSpiderAdminServiceV2() (svc2 *ServiceV2, err error) { - svc := &ServiceV2{ +func newSpiderAdminService() (svc2 *Service, err error) { + svc := &Service{ nodeCfgSvc: config.GetNodeConfigService(), cfgPath: config2.GetConfigPath(), } - svc.schedulerSvc, err = scheduler.GetTaskSchedulerServiceV2() + svc.schedulerSvc, err = scheduler.GetTaskSchedulerService() if err != nil { return nil, err } - // cron - svc.cron = cron.New() - // validate node type if !svc.nodeCfgSvc.IsMaster() { return nil, errors.New("only master node can run spider admin service") @@ -148,21 +152,21 @@ func newSpiderAdminServiceV2() (svc2 *ServiceV2, err error) { return svc, nil } -var svcV2 *ServiceV2 -var svcV2Once = new(sync.Once) +var svc *Service +var svcOnce = new(sync.Once) -func GetSpiderAdminServiceV2() (svc2 *ServiceV2, err error) { - if svcV2 != nil { - return svcV2, nil +func GetSpiderAdminService() (svc2 *Service, err error) { + if svc != nil { + return svc, nil } - svcV2Once.Do(func() { - svcV2, err = newSpiderAdminServiceV2() + svcOnce.Do(func() { + svc, err = newSpiderAdminService() if err != nil { - log2.Errorf("[GetSpiderAdminServiceV2] error: %v", err) + log2.Errorf("[GetSpiderAdminService] error: %v", err) } }) if err != nil { return nil, err } - return svcV2, nil + return svc, nil } diff --git a/core/task/handler/runner_v2.go b/core/task/handler/runner.go similarity index 92% rename from core/task/handler/runner_v2.go rename to core/task/handler/runner.go index 09f74d90..34fe048b 100644 --- a/core/task/handler/runner_v2.go +++ b/core/task/handler/runner.go @@ -33,7 +33,7 @@ import ( "time" ) -type RunnerV2 struct { +type Runner struct { // dependencies svc *Service // task handler service fsSvc interfaces.FsService // task fs service @@ -48,7 +48,7 @@ type RunnerV2 struct { tid primitive.ObjectID // task id t *models.TaskV2 // task model.Task s *models.SpiderV2 // spider model.Spider - ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2 + ch chan constants.TaskSignal // channel to communicate between Service and Runner err error // standard process error cwd string // working directory c *client2.GrpcClient // grpc client @@ -60,7 +60,7 @@ type RunnerV2 struct { logBatchSize int } -func (r *RunnerV2) Init() (err error) { +func (r *Runner) Init() (err error) { // update task if err := r.updateTask("", nil); err != nil { return err @@ -82,7 +82,7 @@ func (r *RunnerV2) Init() (err error) { return nil } -func (r *RunnerV2) Run() (err error) { +func (r *Runner) Run() (err error) { // log task started log.Infof("task[%s] started", r.tid.Hex()) @@ -165,7 +165,7 @@ func (r *RunnerV2) Run() (err error) { return err } -func (r *RunnerV2) Cancel(force bool) (err error) { +func (r *Runner) Cancel(force bool) (err error) { // kill process opts := &sys_exec.KillProcessOptions{ Timeout: r.svc.GetCancelTimeout(), @@ -191,15 +191,15 @@ func (r *RunnerV2) Cancel(force bool) (err error) { } } -func (r *RunnerV2) SetSubscribeTimeout(timeout time.Duration) { +func (r *Runner) SetSubscribeTimeout(timeout time.Duration) { r.subscribeTimeout = timeout } -func (r *RunnerV2) GetTaskId() (id primitive.ObjectID) { +func (r *Runner) GetTaskId() (id primitive.ObjectID) { return r.tid } -func (r *RunnerV2) configureCmd() (err error) { +func (r *Runner) configureCmd() (err error) { var cmdStr string // customized spider @@ -230,7 +230,7 @@ func (r *RunnerV2) configureCmd() (err error) { return nil } -func (r *RunnerV2) configureLogging() { +func (r *Runner) configureLogging() { // set stdout reader stdout, _ := r.cmd.StdoutPipe() r.scannerStdout = bufio.NewReaderSize(stdout, r.bufferSize) @@ -240,7 +240,7 @@ func (r *RunnerV2) configureLogging() { r.scannerStderr = bufio.NewReaderSize(stderr, r.bufferSize) } -func (r *RunnerV2) startLogging() { +func (r *Runner) startLogging() { // start reading stdout go r.startLoggingReaderStdout() @@ -248,7 +248,7 @@ func (r *RunnerV2) startLogging() { go r.startLoggingReaderStderr() } -func (r *RunnerV2) startLoggingReaderStdout() { +func (r *Runner) startLoggingReaderStdout() { for { line, err := r.scannerStdout.ReadString(byte('\n')) if err != nil { @@ -259,7 +259,7 @@ func (r *RunnerV2) startLoggingReaderStdout() { } } -func (r *RunnerV2) startLoggingReaderStderr() { +func (r *Runner) startLoggingReaderStderr() { for { line, err := r.scannerStderr.ReadString(byte('\n')) if err != nil { @@ -270,7 +270,7 @@ func (r *RunnerV2) startLoggingReaderStderr() { } } -func (r *RunnerV2) startHealthCheck() { +func (r *Runner) startHealthCheck() { if r.cmd.ProcessState == nil || r.cmd.ProcessState.Exited() { return } @@ -285,7 +285,7 @@ func (r *RunnerV2) startHealthCheck() { } } -func (r *RunnerV2) configureEnv() { +func (r *Runner) configureEnv() { // 默认把Node.js的全局node_modules加入环境变量 envPath := os.Getenv("PATH") nodePath := "/usr/lib/node_modules" @@ -316,7 +316,7 @@ func (r *RunnerV2) configureEnv() { } } -func (r *RunnerV2) syncFiles() (err error) { +func (r *Runner) syncFiles() (err error) { var id string var workingDir string if r.s.GitId.IsZero() { @@ -425,7 +425,7 @@ func (r *RunnerV2) syncFiles() (err error) { return err } -func (r *RunnerV2) downloadFile(url string, filePath string, fileInfo *entity.FsFileInfo) error { +func (r *Runner) downloadFile(url string, filePath string, fileInfo *entity.FsFileInfo) error { // get file response resp, err := http.Get(url) if err != nil { @@ -465,8 +465,8 @@ func (r *RunnerV2) downloadFile(url string, filePath string, fileInfo *entity.Fs } // wait for process to finish and send task signal (constants.TaskSignal) -// to task runner's channel (RunnerV2.ch) according to exit code -func (r *RunnerV2) wait() { +// to task runner's channel (Runner.ch) according to exit code +func (r *Runner) wait() { // wait for process to finish if err := r.cmd.Wait(); err != nil { var exitError *exec.ExitError @@ -492,8 +492,8 @@ func (r *RunnerV2) wait() { r.ch <- constants.TaskSignalFinish } -// updateTask update and get updated info of task (RunnerV2.t) -func (r *RunnerV2) updateTask(status string, e error) (err error) { +// updateTask update and get updated info of task (Runner.t) +func (r *Runner) updateTask(status string, e error) (err error) { if r.t != nil && status != "" { // update task status r.t.Status = status @@ -529,7 +529,7 @@ func (r *RunnerV2) updateTask(status string, e error) (err error) { return nil } -func (r *RunnerV2) initConnection() (err error) { +func (r *Runner) initConnection() (err error) { r.conn, err = r.c.TaskClient.Connect(context.Background()) if err != nil { return trace.TraceError(err) @@ -537,7 +537,7 @@ func (r *RunnerV2) initConnection() (err error) { return nil } -func (r *RunnerV2) writeLogLines(lines []string) { +func (r *Runner) writeLogLines(lines []string) { linesBytes, err := json.Marshal(lines) if err != nil { log.Errorf("Error marshaling log lines: %v", err) @@ -554,7 +554,7 @@ func (r *RunnerV2) writeLogLines(lines []string) { } } -func (r *RunnerV2) _updateTaskStat(status string) { +func (r *Runner) _updateTaskStat(status string) { ts, err := client.NewModelServiceV2[models.TaskStatV2]().GetById(r.tid) if err != nil { trace.PrintError(err) @@ -590,7 +590,7 @@ func (r *RunnerV2) _updateTaskStat(status string) { } } -func (r *RunnerV2) sendNotification() { +func (r *Runner) sendNotification() { req := &grpc.TaskServiceSendNotificationRequest{ NodeKey: r.svc.GetNodeConfigService().GetNodeKey(), TaskId: r.tid.Hex(), @@ -603,7 +603,7 @@ func (r *RunnerV2) sendNotification() { } } -func (r *RunnerV2) _updateSpiderStat(status string) { +func (r *Runner) _updateSpiderStat(status string) { // task stat ts, err := client.NewModelServiceV2[models.TaskStatV2]().GetById(r.tid) if err != nil { @@ -657,7 +657,7 @@ func (r *RunnerV2) _updateSpiderStat(status string) { } } -func (r *RunnerV2) configureCwd() { +func (r *Runner) configureCwd() { workspacePath := viper.GetString("workspace") if r.s.GitId.IsZero() { // not git @@ -668,14 +668,14 @@ func (r *RunnerV2) configureCwd() { } } -func NewTaskRunnerV2(id primitive.ObjectID, svc *Service) (r2 *RunnerV2, err error) { +func NewTaskRunnerV2(id primitive.ObjectID, svc *Service) (r2 *Runner, err error) { // validate options if id.IsZero() { return nil, constants.ErrInvalidOptions } // runner - r := &RunnerV2{ + r := &Runner{ subscribeTimeout: 30 * time.Second, bufferSize: 1024 * 1024, svc: svc, diff --git a/core/task/handler/service.go b/core/task/handler/service.go index e02c7c3a..001fbc71 100644 --- a/core/task/handler/service.go +++ b/core/task/handler/service.go @@ -51,8 +51,8 @@ func (svc *Service) Start() { } } - go svc.ReportStatus() - go svc.FetchAndRunTasks() + go svc.reportStatus() + go svc.fetchAndRunTasks() } func (svc *Service) Stop() { @@ -67,7 +67,7 @@ func (svc *Service) Cancel(taskId primitive.ObjectID, force bool) (err error) { return svc.cancelTask(taskId, force) } -func (svc *Service) FetchAndRunTasks() { +func (svc *Service) fetchAndRunTasks() { ticker := time.NewTicker(svc.fetchInterval) for { if svc.stopped { @@ -119,7 +119,7 @@ func (svc *Service) FetchAndRunTasks() { } } -func (svc *Service) ReportStatus() { +func (svc *Service) reportStatus() { ticker := time.NewTicker(svc.reportInterval) for { if svc.stopped { @@ -128,9 +128,9 @@ func (svc *Service) ReportStatus() { select { case <-ticker.C: - // report handler status - if err := svc.reportStatus(); err != nil { - trace.PrintError(err) + // update node status + if err := svc.updateNodeStatus(); err != nil { + log.Errorf("failed to report status: %v", err) } } } @@ -178,6 +178,19 @@ func (svc *Service) GetTaskById(id primitive.ObjectID) (t *models2.TaskV2, err e return t, nil } +func (svc *Service) UpdateTask(t *models2.TaskV2) (err error) { + t.SetUpdated(t.CreatedBy) + if svc.cfgSvc.IsMaster() { + err = service.NewModelServiceV2[models2.TaskV2]().ReplaceById(t.Id, *t) + } else { + err = client.NewModelServiceV2[models2.TaskV2]().ReplaceById(t.Id, *t) + } + if err != nil { + return err + } + return nil +} + func (svc *Service) GetSpiderById(id primitive.ObjectID) (s *models2.SpiderV2, err error) { if svc.cfgSvc.IsMaster() { s, err = service.NewModelServiceV2[models2.SpiderV2]().GetById(id) @@ -194,12 +207,14 @@ func (svc *Service) GetSpiderById(id primitive.ObjectID) (s *models2.SpiderV2, e func (svc *Service) getRunnerCount() (count int) { n, err := svc.GetCurrentNode() if err != nil { - trace.PrintError(err) + log.Errorf("failed to get current node: %v", err) return } query := bson.M{ "node_id": n.Id, - "status": constants.TaskStatusRunning, + "status": bson.M{ + "$in": []string{constants.TaskStatusAssigned, constants.TaskStatusRunning}, + }, } if svc.cfgSvc.IsMaster() { count, err = service.NewModelServiceV2[models2.TaskV2]().Count(query) @@ -242,7 +257,7 @@ func (svc *Service) deleteRunner(taskId primitive.ObjectID) { svc.runners.Delete(taskId) } -func (svc *Service) reportStatus() (err error) { +func (svc *Service) updateNodeStatus() (err error) { // current node n, err := svc.GetCurrentNode() if err != nil { @@ -398,8 +413,19 @@ func (svc *Service) handleCancel(msg *grpc.TaskServiceSubscribeResponse, taskId log.Errorf("task[%s] failed to cancel: %v", taskId.Hex(), err) return } - log.Infof("task[%s] cancelled", taskId.Hex()) + + // set task status as "cancelled" + t, err := svc.GetTaskById(taskId) + if err != nil { + log.Errorf("task[%s] failed to get task: %v", taskId.Hex(), err) + return + } + t.Status = constants.TaskStatusCancelled + err = svc.UpdateTask(t) + if err != nil { + log.Errorf("task[%s] failed to update task: %v", taskId.Hex(), err) + } } func (svc *Service) cancelTask(taskId primitive.ObjectID, force bool) (err error) { diff --git a/core/task/scheduler/service_v2.go b/core/task/scheduler/service.go similarity index 75% rename from core/task/scheduler/service_v2.go rename to core/task/scheduler/service.go index a97760d9..c739c870 100644 --- a/core/task/scheduler/service_v2.go +++ b/core/task/scheduler/service.go @@ -21,7 +21,7 @@ import ( "time" ) -type ServiceV2 struct { +type Service struct { // dependencies nodeCfgSvc interfaces.NodeConfigService svr *server.GrpcServer @@ -31,13 +31,13 @@ type ServiceV2 struct { interval time.Duration } -func (svc *ServiceV2) Start() { +func (svc *Service) Start() { go svc.initTaskStatus() go svc.cleanupTasks() utils.DefaultWait() } -func (svc *ServiceV2) Enqueue(t *models2.TaskV2, by primitive.ObjectID) (t2 *models2.TaskV2, err error) { +func (svc *Service) Enqueue(t *models2.TaskV2, by primitive.ObjectID) (t2 *models2.TaskV2, err error) { // set task status t.Status = constants.TaskStatusPending t.SetCreated(by) @@ -45,32 +45,17 @@ func (svc *ServiceV2) Enqueue(t *models2.TaskV2, by primitive.ObjectID) (t2 *mod // add task taskModelSvc := service.NewModelServiceV2[models2.TaskV2]() - id, err := taskModelSvc.InsertOne(*t) + t.Id, err = taskModelSvc.InsertOne(*t) if err != nil { return nil, err } - // task queue item - tq := models2.TaskQueueItemV2{ - Priority: t.Priority, - NodeId: t.NodeId, - } - tq.SetId(id) - tq.SetCreated(by) - tq.SetUpdated(by) - // task stat ts := models2.TaskStatV2{} - ts.SetId(id) + ts.SetId(t.Id) ts.SetCreated(by) ts.SetUpdated(by) - // enqueue task - _, err = service.NewModelServiceV2[models2.TaskQueueItemV2]().InsertOne(tq) - if err != nil { - return nil, trace.TraceError(err) - } - // add task stat _, err = service.NewModelServiceV2[models2.TaskStatV2]().InsertOne(ts) if err != nil { @@ -81,7 +66,7 @@ func (svc *ServiceV2) Enqueue(t *models2.TaskV2, by primitive.ObjectID) (t2 *mod return t, nil } -func (svc *ServiceV2) Cancel(id, by primitive.ObjectID, force bool) (err error) { +func (svc *Service) Cancel(id, by primitive.ObjectID, force bool) (err error) { // task t, err := service.NewModelServiceV2[models2.TaskV2]().GetById(id) if err != nil { @@ -92,14 +77,10 @@ func (svc *ServiceV2) Cancel(id, by primitive.ObjectID, force bool) (err error) // initial status initialStatus := t.Status - // set status of pending tasks as "cancelled" and remove from task item queue + // set status of pending tasks as "cancelled" if initialStatus == constants.TaskStatusPending { - // remove from task item queue - if err := service.NewModelServiceV2[models2.TaskQueueItemV2]().DeleteById(t.Id); err != nil { - log.Errorf("failed to delete task queue item: %s", t.Id.Hex()) - return err - } - return nil + t.Status = constants.TaskStatusCancelled + return svc.SaveTask(t, by) } // whether task is running on master node @@ -120,7 +101,7 @@ func (svc *ServiceV2) Cancel(id, by primitive.ObjectID, force bool) (err error) } } -func (svc *ServiceV2) cancelOnMaster(t *models2.TaskV2, by primitive.ObjectID, force bool) (err error) { +func (svc *Service) cancelOnMaster(t *models2.TaskV2, by primitive.ObjectID, force bool) (err error) { if err := svc.handlerSvc.Cancel(t.Id, force); err != nil { log.Errorf("failed to cancel task on master: %s", t.Id.Hex()) return err @@ -131,7 +112,7 @@ func (svc *ServiceV2) cancelOnMaster(t *models2.TaskV2, by primitive.ObjectID, f return svc.SaveTask(t, by) } -func (svc *ServiceV2) cancelOnWorker(t *models2.TaskV2, by primitive.ObjectID, force bool) (err error) { +func (svc *Service) cancelOnWorker(t *models2.TaskV2, by primitive.ObjectID, force bool) (err error) { // get subscribe stream stream, ok := svc.svr.TaskSvr.GetSubscribeStream(t.Id) if !ok { @@ -156,11 +137,11 @@ func (svc *ServiceV2) cancelOnWorker(t *models2.TaskV2, by primitive.ObjectID, f return nil } -func (svc *ServiceV2) SetInterval(interval time.Duration) { +func (svc *Service) SetInterval(interval time.Duration) { svc.interval = interval } -func (svc *ServiceV2) SaveTask(t *models2.TaskV2, by primitive.ObjectID) (err error) { +func (svc *Service) SaveTask(t *models2.TaskV2, by primitive.ObjectID) (err error) { if t.Id.IsZero() { t.SetCreated(by) t.SetUpdated(by) @@ -173,12 +154,13 @@ func (svc *ServiceV2) SaveTask(t *models2.TaskV2, by primitive.ObjectID) (err er } // initTaskStatus initialize task status of existing tasks -func (svc *ServiceV2) initTaskStatus() { +func (svc *Service) initTaskStatus() { // set status of running tasks as TaskStatusAbnormal runningTasks, err := service.NewModelServiceV2[models2.TaskV2]().GetMany(bson.M{ "status": bson.M{ "$in": []string{ constants.TaskStatusPending, + constants.TaskStatusAssigned, constants.TaskStatusRunning, }, }, @@ -187,7 +169,8 @@ func (svc *ServiceV2) initTaskStatus() { if errors2.Is(err, mongo2.ErrNoDocuments) { return } - trace.PrintError(err) + log.Errorf("failed to get running tasks: %v", err) + return } for _, t := range runningTasks { go func(t *models2.TaskV2) { @@ -197,12 +180,9 @@ func (svc *ServiceV2) initTaskStatus() { } }(&t) } - if err := service.NewModelServiceV2[models2.TaskQueueItemV2]().DeleteMany(nil); err != nil { - return - } } -func (svc *ServiceV2) isMasterNode(t *models2.TaskV2) (ok bool, err error) { +func (svc *Service) isMasterNode(t *models2.TaskV2) (ok bool, err error) { if t.NodeId.IsZero() { return false, trace.TraceError(errors.ErrorTaskNoNodeId) } @@ -216,7 +196,7 @@ func (svc *ServiceV2) isMasterNode(t *models2.TaskV2) (ok bool, err error) { return n.IsMaster, nil } -func (svc *ServiceV2) cleanupTasks() { +func (svc *Service) cleanupTasks() { for { // task stats over 30 days ago taskStats, err := service.NewModelServiceV2[models2.TaskStatV2]().GetMany(bson.M{ @@ -255,9 +235,9 @@ func (svc *ServiceV2) cleanupTasks() { } } -func NewTaskSchedulerServiceV2() (svc2 *ServiceV2, err error) { +func NewTaskSchedulerService() (svc2 *Service, err error) { // service - svc := &ServiceV2{ + svc := &Service{ interval: 5 * time.Second, } svc.nodeCfgSvc = nodeconfig.GetNodeConfigService() @@ -275,15 +255,15 @@ func NewTaskSchedulerServiceV2() (svc2 *ServiceV2, err error) { return svc, nil } -var svcV2 *ServiceV2 +var svc *Service -func GetTaskSchedulerServiceV2() (svr *ServiceV2, err error) { - if svcV2 != nil { - return svcV2, nil +func GetTaskSchedulerService() (svr *Service, err error) { + if svc != nil { + return svc, nil } - svcV2, err = NewTaskSchedulerServiceV2() + svc, err = NewTaskSchedulerService() if err != nil { return nil, err } - return svcV2, nil + return svc, nil }