refactor: code cleanup

This commit is contained in:
Marvin Zhang
2024-10-31 12:59:58 +08:00
parent 1c67ad2889
commit 53afb0064e
97 changed files with 388 additions and 1561 deletions

View File

@@ -35,8 +35,8 @@ import (
type RunnerV2 struct {
// dependencies
svc *ServiceV2 // task handler service
fsSvc interfaces.FsServiceV2 // task fs service
svc *Service // task handler service
fsSvc interfaces.FsService // task fs service
// settings
subscribeTimeout time.Duration
@@ -51,7 +51,7 @@ type RunnerV2 struct {
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
c *client2.GrpcClient // grpc client
conn grpc.TaskService_ConnectClient // grpc task service stream client
// log internals
@@ -668,7 +668,7 @@ func (r *RunnerV2) configureCwd() {
}
}
func NewTaskRunnerV2(id primitive.ObjectID, svc *ServiceV2) (r2 *RunnerV2, err error) {
func NewTaskRunnerV2(id primitive.ObjectID, svc *Service) (r2 *RunnerV2, err error) {
// validate options
if id.IsZero() {
return nil, constants.ErrInvalidOptions
@@ -700,7 +700,7 @@ func NewTaskRunnerV2(id primitive.ObjectID, svc *ServiceV2) (r2 *RunnerV2, err e
r.fsSvc = fs.NewFsServiceV2(filepath.Join(viper.GetString("workspace"), r.s.Id.Hex()))
// grpc client
r.c = client2.GetGrpcClientV2()
r.c = client2.GetGrpcClient()
// initialize task runner
if err := r.Init(); err != nil {

View File

@@ -22,10 +22,10 @@ import (
"time"
)
type ServiceV2 struct {
type Service struct {
// dependencies
cfgSvc interfaces.NodeConfigService
c *grpcclient.GrpcClientV2 // grpc client
c *grpcclient.GrpcClient // grpc client
// settings
//maxRunners int
@@ -42,7 +42,7 @@ type ServiceV2 struct {
syncLocks sync.Map // files sync locks map of task runners
}
func (svc *ServiceV2) Start() {
func (svc *Service) Start() {
// Initialize gRPC if not started
if !svc.c.IsStarted() {
err := svc.c.Start()
@@ -55,26 +55,19 @@ func (svc *ServiceV2) Start() {
go svc.FetchAndRunTasks()
}
func (svc *ServiceV2) Stop() {
func (svc *Service) Stop() {
svc.stopped = true
}
func (svc *ServiceV2) Run(taskId primitive.ObjectID) (err error) {
func (svc *Service) Run(taskId primitive.ObjectID) (err error) {
return svc.runTask(taskId)
}
func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error) {
r, err := svc.getRunner(taskId)
if err != nil {
return err
}
if err := r.Cancel(force); err != nil {
return err
}
return nil
func (svc *Service) Cancel(taskId primitive.ObjectID, force bool) (err error) {
return svc.cancelTask(taskId, force)
}
func (svc *ServiceV2) FetchAndRunTasks() {
func (svc *Service) FetchAndRunTasks() {
ticker := time.NewTicker(svc.fetchInterval)
for {
if svc.stopped {
@@ -126,7 +119,7 @@ func (svc *ServiceV2) FetchAndRunTasks() {
}
}
func (svc *ServiceV2) ReportStatus() {
func (svc *Service) ReportStatus() {
ticker := time.NewTicker(svc.reportInterval)
for {
if svc.stopped {
@@ -143,19 +136,19 @@ func (svc *ServiceV2) ReportStatus() {
}
}
func (svc *ServiceV2) GetExitWatchDuration() (duration time.Duration) {
func (svc *Service) GetExitWatchDuration() (duration time.Duration) {
return svc.exitWatchDuration
}
func (svc *ServiceV2) GetCancelTimeout() (timeout time.Duration) {
func (svc *Service) GetCancelTimeout() (timeout time.Duration) {
return svc.cancelTimeout
}
func (svc *ServiceV2) GetNodeConfigService() (cfgSvc interfaces.NodeConfigService) {
func (svc *Service) GetNodeConfigService() (cfgSvc interfaces.NodeConfigService) {
return svc.cfgSvc
}
func (svc *ServiceV2) GetCurrentNode() (n *models2.NodeV2, err error) {
func (svc *Service) GetCurrentNode() (n *models2.NodeV2, err error) {
// node key
nodeKey := svc.cfgSvc.GetNodeKey()
@@ -172,7 +165,7 @@ func (svc *ServiceV2) GetCurrentNode() (n *models2.NodeV2, err error) {
return n, nil
}
func (svc *ServiceV2) GetTaskById(id primitive.ObjectID) (t *models2.TaskV2, err error) {
func (svc *Service) GetTaskById(id primitive.ObjectID) (t *models2.TaskV2, err error) {
if svc.cfgSvc.IsMaster() {
t, err = service.NewModelServiceV2[models2.TaskV2]().GetById(id)
} else {
@@ -185,7 +178,7 @@ func (svc *ServiceV2) GetTaskById(id primitive.ObjectID) (t *models2.TaskV2, err
return t, nil
}
func (svc *ServiceV2) GetSpiderById(id primitive.ObjectID) (s *models2.SpiderV2, err error) {
func (svc *Service) GetSpiderById(id primitive.ObjectID) (s *models2.SpiderV2, err error) {
if svc.cfgSvc.IsMaster() {
s, err = service.NewModelServiceV2[models2.SpiderV2]().GetById(id)
} else {
@@ -198,7 +191,7 @@ func (svc *ServiceV2) GetSpiderById(id primitive.ObjectID) (s *models2.SpiderV2,
return s, nil
}
func (svc *ServiceV2) getRunnerCount() (count int) {
func (svc *Service) getRunnerCount() (count int) {
n, err := svc.GetCurrentNode()
if err != nil {
trace.PrintError(err)
@@ -224,7 +217,7 @@ func (svc *ServiceV2) getRunnerCount() (count int) {
return count
}
func (svc *ServiceV2) getRunner(taskId primitive.ObjectID) (r interfaces.TaskRunner, err error) {
func (svc *Service) getRunner(taskId primitive.ObjectID) (r interfaces.TaskRunner, err error) {
log.Debugf("[TaskHandlerService] getRunner: taskId[%v]", taskId)
v, ok := svc.runners.Load(taskId)
if !ok {
@@ -239,17 +232,17 @@ func (svc *ServiceV2) getRunner(taskId primitive.ObjectID) (r interfaces.TaskRun
return r, nil
}
func (svc *ServiceV2) addRunner(taskId primitive.ObjectID, r interfaces.TaskRunner) {
func (svc *Service) addRunner(taskId primitive.ObjectID, r interfaces.TaskRunner) {
log.Debugf("[TaskHandlerService] addRunner: taskId[%v]", taskId)
svc.runners.Store(taskId, r)
}
func (svc *ServiceV2) deleteRunner(taskId primitive.ObjectID) {
func (svc *Service) deleteRunner(taskId primitive.ObjectID) {
log.Debugf("[TaskHandlerService] deleteRunner: taskId[%v]", taskId)
svc.runners.Delete(taskId)
}
func (svc *ServiceV2) reportStatus() (err error) {
func (svc *Service) reportStatus() (err error) {
// current node
n, err := svc.GetCurrentNode()
if err != nil {
@@ -276,10 +269,12 @@ func (svc *ServiceV2) reportStatus() (err error) {
return nil
}
func (svc *ServiceV2) fetchTask() (tid primitive.ObjectID, err error) {
func (svc *Service) fetchTask() (tid primitive.ObjectID, err error) {
ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout)
defer cancel()
res, err := svc.c.TaskClient.FetchTask(ctx, svc.c.NewRequest(nil))
res, err := svc.c.TaskClient.FetchTask(ctx, &grpc.TaskServiceFetchTaskRequest{
NodeKey: svc.cfgSvc.GetNodeKey(),
})
if err != nil {
return primitive.NilObjectID, fmt.Errorf("fetchTask task error: %v", err)
}
@@ -291,7 +286,7 @@ func (svc *ServiceV2) fetchTask() (tid primitive.ObjectID, err error) {
return tid, nil
}
func (svc *ServiceV2) runTask(taskId primitive.ObjectID) (err error) {
func (svc *Service) runTask(taskId primitive.ObjectID) (err error) {
// attempt to get runner from pool
_, ok := svc.runners.Load(taskId)
if ok {
@@ -347,7 +342,7 @@ func (svc *ServiceV2) runTask(taskId primitive.ObjectID) (err error) {
return nil
}
func (svc *ServiceV2) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req := &grpc.TaskServiceSubscribeRequest{
@@ -361,13 +356,13 @@ func (svc *ServiceV2) subscribeTask(taskId primitive.ObjectID) (stream grpc.Task
return stream, nil
}
func (svc *ServiceV2) handleStreamMessages(id primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
func (svc *Service) handleStreamMessages(taskId primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
for {
select {
case <-stopCh:
err := stream.CloseSend()
if err != nil {
log.Errorf("task[%s] failed to close stream: %v", id.Hex(), err)
log.Errorf("task[%s] failed to close stream: %v", taskId.Hex(), err)
return
}
return
@@ -375,28 +370,52 @@ func (svc *ServiceV2) handleStreamMessages(id primitive.ObjectID, stream grpc.Ta
msg, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
log.Infof("task[%s] received EOF, stream closed", taskId.Hex())
return
}
log.Errorf("task[%s] stream error: %v", id.Hex(), err)
log.Errorf("task[%s] stream error: %v", taskId.Hex(), err)
continue
}
switch msg.Code {
case grpc.TaskServiceSubscribeCode_CANCEL:
log.Infof("task[%s] received cancel signal", id.Hex())
go func() {
if err := svc.Cancel(id, true); err != nil {
log.Errorf("task[%s] failed to cancel: %v", id.Hex(), err)
}
log.Infof("task[%s] cancelled", id.Hex())
}()
log.Infof("task[%s] received cancel signal", taskId.Hex())
go svc.handleCancel(msg, taskId)
}
}
}
}
func newTaskHandlerServiceV2() (svc2 *ServiceV2, err error) {
func (svc *Service) handleCancel(msg *grpc.TaskServiceSubscribeResponse, taskId primitive.ObjectID) {
// validate task id
if msg.TaskId != taskId.Hex() {
log.Errorf("task[%s] received cancel signal for another task[%s]", taskId.Hex(), msg.TaskId)
return
}
// cancel task
err := svc.cancelTask(taskId, msg.Force)
if err != nil {
log.Errorf("task[%s] failed to cancel: %v", taskId.Hex(), err)
return
}
log.Infof("task[%s] cancelled", taskId.Hex())
}
func (svc *Service) cancelTask(taskId primitive.ObjectID, force bool) (err error) {
r, err := svc.getRunner(taskId)
if err != nil {
return err
}
if err := r.Cancel(force); err != nil {
return err
}
return nil
}
func newTaskHandlerService() (svc2 *Service, err error) {
// service
svc := &ServiceV2{
svc := &Service{
exitWatchDuration: 60 * time.Second,
fetchInterval: 1 * time.Second,
fetchTimeout: 15 * time.Second,
@@ -410,22 +429,22 @@ func newTaskHandlerServiceV2() (svc2 *ServiceV2, err error) {
svc.cfgSvc = nodeconfig.GetNodeConfigService()
// grpc client
svc.c = grpcclient.GetGrpcClientV2()
svc.c = grpcclient.GetGrpcClient()
log.Debugf("[NewTaskHandlerService] svc[cfgPath: %s]", svc.cfgSvc.GetConfigPath())
return svc, nil
}
var _serviceV2 *ServiceV2
var _serviceV2 *Service
var _serviceV2Once = new(sync.Once)
func GetTaskHandlerServiceV2() (svr *ServiceV2, err error) {
func GetTaskHandlerService() (svr *Service, err error) {
if _serviceV2 != nil {
return _serviceV2, nil
}
_serviceV2Once.Do(func() {
_serviceV2, err = newTaskHandlerServiceV2()
_serviceV2, err = newTaskHandlerService()
if err != nil {
log.Errorf("failed to create task handler service: %v", err)
}

View File

@@ -2,6 +2,7 @@ package scheduler
import (
errors2 "errors"
"fmt"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/errors"
@@ -24,7 +25,7 @@ type ServiceV2 struct {
// dependencies
nodeCfgSvc interfaces.NodeConfigService
svr *server.GrpcServer
handlerSvc *handler.ServiceV2
handlerSvc *handler.Service
// settings
interval time.Duration
@@ -80,25 +81,23 @@ func (svc *ServiceV2) Enqueue(t *models2.TaskV2, by primitive.ObjectID) (t2 *mod
return t, nil
}
func (svc *ServiceV2) Cancel(id primitive.ObjectID, by primitive.ObjectID, force bool) (err error) {
func (svc *ServiceV2) Cancel(id, by primitive.ObjectID, force bool) (err error) {
// task
t, err := service.NewModelServiceV2[models2.TaskV2]().GetById(id)
if err != nil {
return trace.TraceError(err)
log.Errorf("task not found: %s", id.Hex())
return err
}
// initial status
initialStatus := t.Status
// set task status as "cancelled"
t.Status = constants.TaskStatusCancelled
_ = svc.SaveTask(t, by)
// set status of pending tasks as "cancelled" and remove from task item queue
if initialStatus == constants.TaskStatusPending {
// remove from task item queue
if err := service.NewModelServiceV2[models2.TaskQueueItemV2]().DeleteById(t.Id); err != nil {
return trace.TraceError(err)
log.Errorf("failed to delete task queue item: %s", t.Id.Hex())
return err
}
return nil
}
@@ -106,34 +105,57 @@ func (svc *ServiceV2) Cancel(id primitive.ObjectID, by primitive.ObjectID, force
// whether task is running on master node
isMasterTask, err := svc.isMasterNode(t)
if err != nil {
// when error, force status being set as "cancelled"
t.Status = constants.TaskStatusCancelled
err := fmt.Errorf("failed to check if task is running on master node: %s", t.Id.Hex())
t.Status = constants.TaskStatusAbnormal
t.Error = err.Error()
return svc.SaveTask(t, by)
}
// node
n, err := service.NewModelServiceV2[models2.NodeV2]().GetById(t.NodeId)
if err != nil {
return trace.TraceError(err)
}
if isMasterTask {
// cancel task on master
if err := svc.handlerSvc.Cancel(id, force); err != nil {
return trace.TraceError(err)
}
// cancel success
return nil
return svc.cancelOnMaster(t, by, force)
} else {
// send to cancel task on worker nodes
if err := svc.svr.SendStreamMessageWithData("node:"+n.Key, grpc.StreamMessageCode_CANCEL_TASK, t); err != nil {
return trace.TraceError(err)
}
// cancel success
return nil
return svc.cancelOnWorker(t, by, force)
}
}
func (svc *ServiceV2) cancelOnMaster(t *models2.TaskV2, by primitive.ObjectID, force bool) (err error) {
if err := svc.handlerSvc.Cancel(t.Id, force); err != nil {
log.Errorf("failed to cancel task on master: %s", t.Id.Hex())
return err
}
// set task status as "cancelled"
t.Status = constants.TaskStatusCancelled
return svc.SaveTask(t, by)
}
func (svc *ServiceV2) cancelOnWorker(t *models2.TaskV2, by primitive.ObjectID, force bool) (err error) {
// get subscribe stream
stream, ok := svc.svr.TaskSvr.GetSubscribeStream(t.Id)
if !ok {
err := fmt.Errorf("stream not found for task: %s", t.Id.Hex())
log.Errorf(err.Error())
t.Status = constants.TaskStatusAbnormal
t.Error = err.Error()
return svc.SaveTask(t, by)
}
// send cancel request
err = stream.Send(&grpc.TaskServiceSubscribeResponse{
Code: grpc.TaskServiceSubscribeCode_CANCEL,
TaskId: t.Id.Hex(),
Force: force,
})
if err != nil {
log.Errorf("failed to send cancel request to worker: %s", t.Id.Hex())
return err
}
return nil
}
func (svc *ServiceV2) SetInterval(interval time.Duration) {
svc.interval = interval
}
@@ -244,7 +266,7 @@ func NewTaskSchedulerServiceV2() (svc2 *ServiceV2, err error) {
log.Errorf("failed to get grpc server: %v", err)
return nil, err
}
svc.handlerSvc, err = handler.GetTaskHandlerServiceV2()
svc.handlerSvc, err = handler.GetTaskHandlerService()
if err != nil {
log.Errorf("failed to get task handler service: %v", err)
return nil, err

View File

@@ -28,7 +28,7 @@ type databaseServiceItem struct {
time time.Time
}
type ServiceV2 struct {
type Service struct {
// dependencies
nodeCfgSvc interfaces.NodeConfigService
@@ -39,12 +39,12 @@ type ServiceV2 struct {
logDriver log.Driver
}
func (svc *ServiceV2) Init() (err error) {
func (svc *Service) Init() (err error) {
go svc.cleanup()
return nil
}
func (svc *ServiceV2) InsertData(taskId primitive.ObjectID, records ...map[string]interface{}) (err error) {
func (svc *Service) InsertData(taskId primitive.ObjectID, records ...map[string]interface{}) (err error) {
count := 0
item, err := svc.getDatabaseServiceItem(taskId)
@@ -80,11 +80,11 @@ func (svc *ServiceV2) InsertData(taskId primitive.ObjectID, records ...map[strin
return nil
}
func (svc *ServiceV2) InsertLogs(id primitive.ObjectID, logs ...string) (err error) {
func (svc *Service) InsertLogs(id primitive.ObjectID, logs ...string) (err error) {
return svc.logDriver.WriteLines(id.Hex(), logs)
}
func (svc *ServiceV2) getDatabaseServiceItem(taskId primitive.ObjectID) (item *databaseServiceItem, err error) {
func (svc *Service) getDatabaseServiceItem(taskId primitive.ObjectID) (item *databaseServiceItem, err error) {
// atomic operation
svc.mu.Lock()
defer svc.mu.Unlock()
@@ -136,7 +136,7 @@ func (svc *ServiceV2) getDatabaseServiceItem(taskId primitive.ObjectID) (item *d
return item, nil
}
func (svc *ServiceV2) updateTaskStats(id primitive.ObjectID, resultCount int) {
func (svc *Service) updateTaskStats(id primitive.ObjectID, resultCount int) {
err := service.NewModelServiceV2[models2.TaskStatV2]().UpdateById(id, bson.M{
"$inc": bson.M{
"result_count": resultCount,
@@ -147,7 +147,7 @@ func (svc *ServiceV2) updateTaskStats(id primitive.ObjectID, resultCount int) {
}
}
func (svc *ServiceV2) cleanup() {
func (svc *Service) cleanup() {
for {
// atomic operation
svc.mu.Lock()
@@ -164,7 +164,7 @@ func (svc *ServiceV2) cleanup() {
}
}
func (svc *ServiceV2) normalizeRecord(item *databaseServiceItem, record map[string]interface{}) (res map[string]interface{}) {
func (svc *Service) normalizeRecord(item *databaseServiceItem, record map[string]interface{}) (res map[string]interface{}) {
res = record
// set task id
@@ -176,9 +176,9 @@ func (svc *ServiceV2) normalizeRecord(item *databaseServiceItem, record map[stri
return res
}
func NewTaskStatsServiceV2() (svc2 *ServiceV2, err error) {
func NewTaskStatsServiceV2() (svc2 *Service, err error) {
// service
svc := &ServiceV2{
svc := &Service{
mu: sync.Mutex{},
databaseServiceItems: map[string]*databaseServiceItem{},
databaseServiceTll: 10 * time.Minute,
@@ -195,9 +195,9 @@ func NewTaskStatsServiceV2() (svc2 *ServiceV2, err error) {
return svc, nil
}
var _serviceV2 *ServiceV2
var _serviceV2 *Service
func GetTaskStatsServiceV2() (svr *ServiceV2, err error) {
func GetTaskStatsServiceV2() (svr *Service, err error) {
if _serviceV2 != nil {
return _serviceV2, nil
}