diff --git a/core/entity/notification_variable.go b/core/entity/notification_variable.go index ad00c02f..631bcc95 100644 --- a/core/entity/notification_variable.go +++ b/core/entity/notification_variable.go @@ -1,6 +1,12 @@ package entity +import "fmt" + type NotificationVariable struct { Category string `json:"category"` Name string `json:"name"` } + +func (v *NotificationVariable) GetKey() string { + return fmt.Sprintf("${%s:%s}", v.Category, v.Name) +} diff --git a/core/grpc/server/task_server_v2.go b/core/grpc/server/task_server_v2.go index 51ccc198..9cc28ed9 100644 --- a/core/grpc/server/task_server_v2.go +++ b/core/grpc/server/task_server_v2.go @@ -109,29 +109,57 @@ func (svr TaskServerV2) Fetch(ctx context.Context, request *grpc.Request) (respo return HandleSuccessWithData(tid) } -func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.Request) (response *grpc.Response, err error) { - // task - var t = new(models2.TaskV2) - if err := json.Unmarshal(request.Data, t); err != nil { - return nil, trace.TraceError(err) - } - t, err = service.NewModelServiceV2[models2.TaskV2]().GetById(t.Id) +func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.TaskServiceSendNotificationRequest) (response *grpc.Response, err error) { + // task id + taskId, err := primitive.ObjectIDFromHex(request.TaskId) if err != nil { + log.Errorf("invalid task id: %s", request.TaskId) return nil, trace.TraceError(err) } - // serialize task data - td, err := json.Marshal(t) + // arguments + var args []any + + // task + task, err := service.NewModelServiceV2[models2.TaskV2]().GetById(taskId) + if err != nil { + log.Errorf("task not found: %s", request.TaskId) + return nil, trace.TraceError(err) + } + args = append(args, task) + + // task stat + taskStat, err := service.NewModelServiceV2[models2.TaskStatV2]().GetById(task.Id) + if err != nil { + log.Errorf("task stat not found for task: %s", request.TaskId) + return nil, trace.TraceError(err) + } + args = append(args, taskStat) + + // spider + spider, err := service.NewModelServiceV2[models2.SpiderV2]().GetById(task.SpiderId) + if err != nil { + log.Errorf("spider not found for task: %s", request.TaskId) + return nil, trace.TraceError(err) + } + args = append(args, spider) + + // node + node, err := service.NewModelServiceV2[models2.NodeV2]().GetById(task.NodeId) if err != nil { return nil, trace.TraceError(err) } - var e bson.M - if err := json.Unmarshal(td, &e); err != nil { - return nil, trace.TraceError(err) - } - ts, err := service.NewModelServiceV2[models2.TaskStatV2]().GetById(t.Id) - if err != nil { - return nil, trace.TraceError(err) + args = append(args, node) + + // schedule + var schedule *models2.ScheduleV2 + if !task.ScheduleId.IsZero() { + schedule, err = service.NewModelServiceV2[models2.ScheduleV2]().GetById(task.ScheduleId) + if err != nil { + log.Errorf("schedule not found for task: %s", request.TaskId) + return nil, trace.TraceError(err) + } + args = append(args, schedule) } // settings @@ -156,17 +184,17 @@ func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.Reques // send notification switch trigger { case constants.NotificationTriggerTaskFinish: - if t.Status != constants.TaskStatusPending && t.Status != constants.TaskStatusRunning { - _ = svc.Send(&s, e) + if task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning { + go svc.Send(&s, args...) } case constants.NotificationTriggerTaskError: - if t.Status == constants.TaskStatusError || t.Status == constants.TaskStatusAbnormal { - _ = svc.Send(&s, e) + if task.Status == constants.TaskStatusError || task.Status == constants.TaskStatusAbnormal { + go svc.Send(&s, args...) } case constants.NotificationTriggerTaskEmptyResults: - if t.Status != constants.TaskStatusPending && t.Status != constants.TaskStatusRunning { - if ts.ResultCount == 0 { - _ = svc.Send(&s, e) + if task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning { + if taskStat.ResultCount == 0 { + go svc.Send(&s, args...) } } } diff --git a/core/models/models/v2/notification_setting_v2.go b/core/models/models/v2/notification_setting_v2.go index b56ed3a7..7164a505 100644 --- a/core/models/models/v2/notification_setting_v2.go +++ b/core/models/models/v2/notification_setting_v2.go @@ -14,6 +14,7 @@ type NotificationSettingV2 struct { TemplateMarkdown string `json:"template_markdown,omitempty" bson:"template_markdown,omitempty"` TemplateRichText string `json:"template_rich_text,omitempty" bson:"template_rich_text,omitempty"` TemplateRichTextJson string `json:"template_rich_text_json,omitempty" bson:"template_rich_text_json,omitempty"` + TemplateTheme string `json:"template_theme,omitempty" bson:"template_theme,omitempty"` TaskTrigger string `json:"task_trigger" bson:"task_trigger"` TriggerTarget string `json:"trigger_target" bson:"trigger_target"` Trigger string `json:"trigger" bson:"trigger"` diff --git a/core/notification/entity.go b/core/notification/entity.go new file mode 100644 index 00000000..40b69eec --- /dev/null +++ b/core/notification/entity.go @@ -0,0 +1,11 @@ +package notification + +import "github.com/crawlab-team/crawlab/core/models/models/v2" + +type VariableDataTask struct { + Task *models.TaskV2 `json:"task"` + TaskStat *models.TaskStatV2 `json:"task_stat"` + Spider *models.SpiderV2 `json:"spider"` + Node *models.NodeV2 `json:"node"` + Schedule *models.ScheduleV2 `json:"schedule"` +} diff --git a/core/notification/mail.go b/core/notification/mail.go index c2f63289..289381aa 100644 --- a/core/notification/mail.go +++ b/core/notification/mail.go @@ -2,30 +2,19 @@ package notification import ( "errors" + "github.com/PuerkitoBio/goquery" "github.com/apex/log" "github.com/crawlab-team/crawlab/core/models/models/v2" - "github.com/matcornic/hermes/v2" + "github.com/crawlab-team/crawlab/trace" "gopkg.in/gomail.v2" "net/mail" + "regexp" "runtime/debug" "strconv" "strings" ) func SendMail(m *models.NotificationSettingMail, to, cc, title, content string) error { - // theme - theme := new(MailThemeFlat) - - // hermes instance - h := hermes.Hermes{ - Theme: theme, - Product: hermes.Product{ - Logo: "", - Name: "Crawlab", - Copyright: "© 2024 Crawlab-Team", - }, - } - // config port, _ := strconv.Atoi(m.Port) password := m.Password @@ -44,38 +33,19 @@ func SendMail(m *models.NotificationSettingMail, to, cc, title, content string) Subject: title, } - // add style - content += theme.GetStyle() - - // markdown - markdown := hermes.Markdown(content + GetFooter()) - - // email instance - email := hermes.Email{ - Body: hermes.Body{ - Signature: "Happy Crawling ☺", - FreeMarkdown: markdown, - }, + // convert html to text + text := content + if isHtml(text) { + text = convertHtmlToText(text) } - // generate html - html, err := h.GenerateHTML(email) - if err != nil { - log.Errorf(err.Error()) - debug.PrintStack() - return err - } - - // generate text - text, err := h.GeneratePlainText(email) - if err != nil { - log.Errorf(err.Error()) - debug.PrintStack() - return err + // apply theme + if isHtml(content) { + content = GetTheme() + content } // send the email - if err := send(smtpConfig, options, html, text); err != nil { + if err := send(smtpConfig, options, content, text); err != nil { log.Errorf(err.Error()) debug.PrintStack() return err @@ -84,6 +54,21 @@ func SendMail(m *models.NotificationSettingMail, to, cc, title, content string) return nil } +func isHtml(content string) bool { + regex := regexp.MustCompile(`(?i)<\s*(html|head|body|div|span|p|a|img|table|tr|td|th|tbody|thead|tfoot|ul|ol|li|dl|dt|dd|form|input|textarea|button|select|option|optgroup|fieldset|legend|label|iframe|embed|object|param|video|audio|source|canvas|svg|math|style|link|script|meta|base|title|br|hr|b|strong|i|em|u|s|strike|del|ins|mark|small|sub|sup|big|pre|code|q|blockquote|abbr|address|bdo|cite|dfn|kbd|var|samp|ruby|rt|rp|time|progress|meter|output|area|map)`) + return regex.MatchString(content) +} + +func convertHtmlToText(content string) string { + doc, err := goquery.NewDocumentFromReader(strings.NewReader(content)) + if err != nil { + log.Errorf("failed to convert html to text: %v", err) + trace.PrintError(err) + return "" + } + return doc.Text() +} + type smtpAuthentication struct { Server string Port int @@ -170,9 +155,3 @@ func getRecipientList(value string) (values []string) { } return values } - -func GetFooter() string { - return ` -[Github](https://github.com/crawlab-team/crawlab) | [Documentation](http://docs.crawlab.cn) | [Docker](https://hub.docker.com/r/tikazyq/crawlab) -` -} diff --git a/core/notification/service_v2.go b/core/notification/service_v2.go index ed91f0bf..44118333 100644 --- a/core/notification/service_v2.go +++ b/core/notification/service_v2.go @@ -1,12 +1,14 @@ package notification import ( + "fmt" "github.com/apex/log" "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/entity" "github.com/crawlab-team/crawlab/core/models/models/v2" "github.com/crawlab-team/crawlab/core/models/service" "regexp" + "strings" "sync" ) @@ -166,82 +168,121 @@ Please find the task data as below. return nil } -func (svc *ServiceV2) Send(s *models.NotificationSettingV2, args ...any) (err error) { +func (svc *ServiceV2) Send(s *models.NotificationSettingV2, args ...any) { content := svc.getContent(s, args...) switch s.Type { case TypeMail: - return svc.SendMail(s, content) + svc.SendMail(s, content) case TypeMobile: - return svc.SendMobile(s, content) + svc.SendMobile(s, content) } - return nil } -func (svc *ServiceV2) SendMail(s *models.NotificationSettingV2, content string) (err error) { +func (svc *ServiceV2) SendMail(s *models.NotificationSettingV2, content string) { // TODO: parse to/cc/bcc // send mail - if err := SendMail(&s.Mail, s.Mail.To, s.Mail.Cc, s.Title, content); err != nil { - return err + err := SendMail(&s.Mail, s.Mail.To, s.Mail.Cc, s.Title, content) + if err != nil { + log.Errorf("[NotificationServiceV2] send mail error: %v", err) } - - return nil } -func (svc *ServiceV2) SendMobile(s *models.NotificationSettingV2, content string) (err error) { - // send - if err := SendMobileNotification(s.Mobile.Webhook, s.Title, content); err != nil { - return err +func (svc *ServiceV2) SendMobile(s *models.NotificationSettingV2, content string) { + err := SendMobileNotification(s.Mobile.Webhook, s.Title, content) + if err != nil { + log.Errorf("[NotificationServiceV2] send mobile notification error: %v", err) } - - return nil } func (svc *ServiceV2) getContent(s *models.NotificationSettingV2, args ...any) (content string) { switch s.TriggerTarget { case constants.NotificationTriggerTargetTask: - //task := new(models.TaskV2) - //taskStat := new(models.TaskStatV2) - //spider := new(models.SpiderV2) - //node := new(models.NodeV2) - //for _, arg := range args { - // switch arg.(type) { - // case models.TaskV2: - // task = arg.(*models.TaskV2) - // case models.TaskStatV2: - // taskStat = arg.(*models.TaskStatV2) - // case models.SpiderV2: - // spider = arg.(*models.SpiderV2) - // case models.NodeV2: - // node = arg.(*models.NodeV2) - // } - //} + vd := svc.getTaskVariableData(args...) switch s.TemplateMode { case constants.NotificationTemplateModeMarkdown: - // TODO: implement + variables := svc.parseTemplateVariables(s.TemplateMarkdown) + return svc.getTaskContent(s.TemplateMarkdown, variables, vd) case constants.NotificationTemplateModeRichText: - //s.TemplateRichText + variables := svc.parseTemplateVariables(s.TemplateRichText) + return svc.getTaskContent(s.TemplateRichText, variables, vd) } case constants.NotificationTriggerTargetNode: + // TODO: implement } return content } -func (svc *ServiceV2) parseTemplateVariables(s *models.NotificationSettingV2) (variables []entity.NotificationVariable) { +func (svc *ServiceV2) getTaskContent(template string, variables []entity.NotificationVariable, vd VariableDataTask) (content string) { + content = template + for _, v := range variables { + switch v.Category { + case "task": + switch v.Name { + case "id": + content = strings.ReplaceAll(content, v.GetKey(), vd.Task.Id.Hex()) + case "status": + content = strings.ReplaceAll(content, v.GetKey(), vd.Task.Status) + case "priority": + content = strings.ReplaceAll(content, v.GetKey(), fmt.Sprintf("%d", vd.Task.Priority)) + case "mode": + content = strings.ReplaceAll(content, v.GetKey(), vd.Task.Mode) + case "cmd": + content = strings.ReplaceAll(content, v.GetKey(), vd.Task.Cmd) + case "param": + content = strings.ReplaceAll(content, v.GetKey(), vd.Task.Param) + } + } + } + return content +} + +func (svc *ServiceV2) getTaskVariableData(args ...any) (vd VariableDataTask) { + for _, arg := range args { + switch arg.(type) { + case *models.TaskV2: + vd.Task = arg.(*models.TaskV2) + case *models.TaskStatV2: + vd.TaskStat = arg.(*models.TaskStatV2) + case *models.SpiderV2: + vd.Spider = arg.(*models.SpiderV2) + case *models.NodeV2: + vd.Node = arg.(*models.NodeV2) + case *models.ScheduleV2: + vd.Schedule = arg.(*models.ScheduleV2) + } + } + return vd +} + +func (svc *ServiceV2) parseTemplateVariables(template string) (variables []entity.NotificationVariable) { + // regex pattern regex := regexp.MustCompile("\\$\\{(\\w+):(\\w+)}") // find all matches - matches := regex.FindAllStringSubmatch(s.Template, -1) + matches := regex.FindAllStringSubmatch(template, -1) + + // variables map + variablesMap := make(map[string]entity.NotificationVariable) // iterate over matches for _, match := range matches { - variables = append(variables, entity.NotificationVariable{ + variable := entity.NotificationVariable{ Category: match[1], Name: match[2], - }) + } + key := fmt.Sprintf("%s:%s", variable.Category, variable.Name) + if _, ok := variablesMap[key]; !ok { + variablesMap[key] = variable + } + } + + // convert map to slice + for _, variable := range variablesMap { + variables = append(variables, variable) } return variables diff --git a/core/notification/service_v2_test.go b/core/notification/service_v2_test.go index 11c9ca7d..a2ca0b28 100644 --- a/core/notification/service_v2_test.go +++ b/core/notification/service_v2_test.go @@ -2,7 +2,6 @@ package notification import ( "github.com/crawlab-team/crawlab/core/entity" - "github.com/crawlab-team/crawlab/core/models/models/v2" "testing" "github.com/stretchr/testify/assert" @@ -16,11 +15,11 @@ func TestParseTemplateVariables_WithValidTemplate_ReturnsVariables(t *testing.T) {Category: "task", Name: "id"}, {Category: "task", Name: "status"}, } - setting := models.NotificationSettingV2{Template: template} - variables := svc.parseTemplateVariables(&setting) + variables := svc.parseTemplateVariables(template) - assert.Equal(t, expected, variables) + // contains all expected variables + assert.ElementsMatch(t, expected, variables) } func TestParseTemplateVariables_WithRepeatedVariables_ReturnsUniqueVariables(t *testing.T) { @@ -31,9 +30,9 @@ func TestParseTemplateVariables_WithRepeatedVariables_ReturnsUniqueVariables(t * {Category: "task", Name: "id"}, {Category: "task", Name: "status"}, } - setting := models.NotificationSettingV2{Template: template} - variables := svc.parseTemplateVariables(&setting) + variables := svc.parseTemplateVariables(template) - assert.Equal(t, expected, variables) + // contains all expected variables + assert.ElementsMatch(t, expected, variables) } diff --git a/core/notification/theme.go b/core/notification/theme.go new file mode 100644 index 00000000..f93fb510 --- /dev/null +++ b/core/notification/theme.go @@ -0,0 +1,265 @@ +package notification + +const defaultTheme = `` + +func GetTheme() string { + return defaultTheme +} diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go deleted file mode 100644 index 32c7e5bf..00000000 --- a/core/task/handler/runner.go +++ /dev/null @@ -1,691 +0,0 @@ -package handler - -import ( - "bufio" - "context" - "encoding/json" - "fmt" - "github.com/apex/log" - "github.com/cenkalti/backoff/v4" - "github.com/crawlab-team/crawlab/core/constants" - "github.com/crawlab-team/crawlab/core/container" - "github.com/crawlab-team/crawlab/core/entity" - "github.com/crawlab-team/crawlab/core/errors" - fs2 "github.com/crawlab-team/crawlab/core/fs" - "github.com/crawlab-team/crawlab/core/interfaces" - "github.com/crawlab-team/crawlab/core/models/client" - "github.com/crawlab-team/crawlab/core/models/delegate" - "github.com/crawlab-team/crawlab/core/models/models" - "github.com/crawlab-team/crawlab/core/sys_exec" - "github.com/crawlab-team/crawlab/core/utils" - "github.com/crawlab-team/crawlab/db/mongo" - grpc "github.com/crawlab-team/crawlab/grpc" - "github.com/crawlab-team/crawlab/trace" - "github.com/shirou/gopsutil/process" - "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "io" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "time" -) - -type Runner struct { - // dependencies - svc interfaces.TaskHandlerService // task handler service - fsSvc interfaces.FsServiceV2 // task fs service - hookSvc interfaces.TaskHookService // task hook service - - // settings - subscribeTimeout time.Duration - bufferSize int - - // internals - cmd *exec.Cmd // process command instance - pid int // process id - tid primitive.ObjectID // task id - t interfaces.Task // task model.Task - s interfaces.Spider // spider model.Spider - ch chan constants.TaskSignal // channel to communicate between Service and Runner - err error // standard process error - envs []models.Env // environment variables - cwd string // working directory - c interfaces.GrpcClient // grpc client - sub grpc.TaskService_SubscribeClient // grpc task service stream client - - // log internals - scannerStdout *bufio.Reader - scannerStderr *bufio.Reader - logBatchSize int -} - -func (r *Runner) Init() (err error) { - // update task - if err := r.updateTask("", nil); err != nil { - return err - } - - // start grpc client - if !r.c.IsStarted() { - r.c.Start() - } - - // working directory - workspacePath := viper.GetString("workspace") - r.cwd = filepath.Join(workspacePath, r.s.GetId().Hex()) - - // sync files from master - if !utils.IsMaster() { - if err := r.syncFiles(); err != nil { - return err - } - } - - // grpc task service stream client - if err := r.initSub(); err != nil { - return err - } - - // pre actions - if r.hookSvc != nil { - if err := r.hookSvc.PreActions(r.t, r.s, r.fsSvc, r.svc); err != nil { - return err - } - } - - return nil -} - -func (r *Runner) Run() (err error) { - // log task started - log.Infof("task[%s] started", r.tid.Hex()) - - // configure cmd - r.configureCmd() - - // configure environment variables - r.configureEnv() - - // configure logging - r.configureLogging() - - // start process - if err := r.cmd.Start(); err != nil { - return r.updateTask(constants.TaskStatusError, err) - } - - // start logging - go r.startLogging() - - // process id - if r.cmd.Process == nil { - return r.updateTask(constants.TaskStatusError, constants.ErrNotExists) - } - r.pid = r.cmd.Process.Pid - r.t.SetPid(r.pid) - - // update task status (processing) - if err := r.updateTask(constants.TaskStatusRunning, nil); err != nil { - return err - } - - // wait for process to finish - go r.wait() - - // start health check - go r.startHealthCheck() - - // declare task status - status := "" - - // wait for signal - signal := <-r.ch - switch signal { - case constants.TaskSignalFinish: - err = nil - status = constants.TaskStatusFinished - case constants.TaskSignalCancel: - err = constants.ErrTaskCancelled - status = constants.TaskStatusCancelled - case constants.TaskSignalError: - err = r.err - status = constants.TaskStatusError - case constants.TaskSignalLost: - err = constants.ErrTaskLost - status = constants.TaskStatusError - default: - err = constants.ErrInvalidSignal - status = constants.TaskStatusError - } - - // update task status - if err := r.updateTask(status, err); err != nil { - return err - } - - // post actions - if r.hookSvc != nil { - if err := r.hookSvc.PostActions(r.t, r.s, r.fsSvc, r.svc); err != nil { - return err - } - } - - return err -} - -func (r *Runner) Cancel() (err error) { - // kill process - opts := &sys_exec.KillProcessOptions{ - Timeout: r.svc.GetCancelTimeout(), - Force: true, - } - if err := sys_exec.KillProcess(r.cmd, opts); err != nil { - return err - } - - // make sure the process does not exist - op := func() error { - if exists, _ := process.PidExists(int32(r.pid)); exists { - return errors.ErrorTaskProcessStillExists - } - return nil - } - ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetExitWatchDuration()) - defer cancel() - b := backoff.WithContext(backoff.NewConstantBackOff(1*time.Second), ctx) - if err := backoff.Retry(op, b); err != nil { - return trace.TraceError(errors.ErrorTaskUnableToCancel) - } - - return nil -} - -// CleanUp clean up task runner -func (r *Runner) CleanUp() (err error) { - return nil -} - -func (r *Runner) SetSubscribeTimeout(timeout time.Duration) { - r.subscribeTimeout = timeout -} - -func (r *Runner) GetTaskId() (id primitive.ObjectID) { - return r.tid -} - -func (r *Runner) configureCmd() { - var cmdStr string - - // customized spider - if r.t.GetCmd() == "" { - cmdStr = r.s.GetCmd() - } else { - cmdStr = r.t.GetCmd() - } - - // parameters - if r.t.GetParam() != "" { - cmdStr += " " + r.t.GetParam() - } else if r.s.GetParam() != "" { - cmdStr += " " + r.s.GetParam() - } - - // get cmd instance - r.cmd, _ = sys_exec.BuildCmd(cmdStr) - - // set working directory - r.cmd.Dir = r.cwd - - // configure pgid to allow killing sub processes - //sys_exec.SetPgid(r.cmd) -} - -func (r *Runner) configureLogging() { - // set stdout reader - stdout, _ := r.cmd.StdoutPipe() - r.scannerStdout = bufio.NewReaderSize(stdout, r.bufferSize) - - // set stderr reader - stderr, _ := r.cmd.StderrPipe() - r.scannerStderr = bufio.NewReaderSize(stderr, r.bufferSize) -} - -func (r *Runner) startLogging() { - // start reading stdout - go r.startLoggingReaderStdout() - - // start reading stderr - go r.startLoggingReaderStderr() -} - -func (r *Runner) startLoggingReaderStdout() { - for { - line, err := r.scannerStdout.ReadString(byte('\n')) - if err != nil { - break - } - line = strings.TrimSuffix(line, "\n") - r.writeLogLines([]string{line}) - } -} - -func (r *Runner) startLoggingReaderStderr() { - for { - line, err := r.scannerStderr.ReadString(byte('\n')) - if err != nil { - break - } - line = strings.TrimSuffix(line, "\n") - r.writeLogLines([]string{line}) - } -} - -func (r *Runner) startHealthCheck() { - if r.cmd.ProcessState == nil || r.cmd.ProcessState.Exited() { - return - } - for { - exists, _ := process.PidExists(int32(r.pid)) - if !exists { - // process lost - r.ch <- constants.TaskSignalLost - return - } - time.Sleep(1 * time.Second) - } -} - -func (r *Runner) configureEnv() { - // 默认把Node.js的全局node_modules加入环境变量 - envPath := os.Getenv("PATH") - nodePath := "/usr/lib/node_modules" - if !strings.Contains(envPath, nodePath) { - _ = os.Setenv("PATH", nodePath+":"+envPath) - } - _ = os.Setenv("NODE_PATH", nodePath) - - // default envs - r.cmd.Env = append(os.Environ(), "CRAWLAB_TASK_ID="+r.tid.Hex()) - if viper.GetString("grpc.address") != "" { - r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_ADDRESS="+viper.GetString("grpc.address")) - } - if viper.GetString("grpc.authKey") != "" { - r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_AUTH_KEY="+viper.GetString("grpc.authKey")) - } else { - r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_AUTH_KEY="+constants.DefaultGrpcAuthKey) - } - - // global environment variables - envs, err := r.svc.GetModelEnvironmentService().GetEnvironmentList(nil, nil) - if err != nil { - trace.PrintError(err) - return - } - for _, env := range envs { - r.cmd.Env = append(r.cmd.Env, env.GetKey()+"="+env.GetValue()) - } -} - -func (r *Runner) syncFiles() (err error) { - masterURL := fmt.Sprintf("%s/sync/%s", viper.GetString("api.endpoint"), r.s.GetId().Hex()) - workspacePath := viper.GetString("workspace") - workerDir := filepath.Join(workspacePath, r.s.GetId().Hex()) - - // get file list from master - resp, err := http.Get(masterURL + "/scan") - if err != nil { - fmt.Println("Error getting file list from master:", err) - return trace.TraceError(err) - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Println("Error reading response body:", err) - return trace.TraceError(err) - } - var masterFiles map[string]entity.FsFileInfo - err = json.Unmarshal(body, &masterFiles) - if err != nil { - fmt.Println("Error unmarshaling JSON:", err) - return trace.TraceError(err) - } - - // create a map for master files - masterFilesMap := make(map[string]entity.FsFileInfo) - for _, file := range masterFiles { - masterFilesMap[file.Path] = file - } - - // create worker directory if not exists - if _, err := os.Stat(workerDir); os.IsNotExist(err) { - if err := os.MkdirAll(workerDir, os.ModePerm); err != nil { - fmt.Println("Error creating worker directory:", err) - return trace.TraceError(err) - } - } - - // get file list from worker - workerFiles, err := utils.ScanDirectory(workerDir) - if err != nil { - fmt.Println("Error scanning worker directory:", err) - return trace.TraceError(err) - } - - // set up wait group and error channel - var wg sync.WaitGroup - errCh := make(chan error, 1) - - // delete files that are deleted on master node - for path, workerFile := range workerFiles { - if _, exists := masterFilesMap[path]; !exists { - fmt.Println("Deleting file:", path) - err := os.Remove(workerFile.FullPath) - if err != nil { - fmt.Println("Error deleting file:", err) - } - } - } - - // download files that are new or modified on master node - for path, masterFile := range masterFilesMap { - workerFile, exists := workerFiles[path] - if !exists || masterFile.Hash != workerFile.Hash { - wg.Add(1) - go func(path string, masterFile entity.FsFileInfo) { - defer wg.Done() - logrus.Infof("File needs to be synchronized: %s", path) - err := r.downloadFile(masterURL+"/download?path="+path, filepath.Join(workerDir, path)) - if err != nil { - logrus.Errorf("Error downloading file: %v", err) - select { - case errCh <- err: - default: - } - } - }(path, masterFile) - } - } - - wg.Wait() - close(errCh) - if err := <-errCh; err != nil { - return err - } - - return nil -} - -func (r *Runner) downloadFile(url string, filePath string) error { - resp, err := http.Get(url) - if err != nil { - return err - } - defer resp.Body.Close() - - out, err := os.Create(filePath) - if err != nil { - return err - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - return err -} - -// wait for process to finish and send task signal (constants.TaskSignal) -// to task runner's channel (Runner.ch) according to exit code -func (r *Runner) wait() { - // wait for process to finish - if err := r.cmd.Wait(); err != nil { - exitError, ok := err.(*exec.ExitError) - if !ok { - r.ch <- constants.TaskSignalError - return - } - exitCode := exitError.ExitCode() - if exitCode == -1 { - // cancel error - r.ch <- constants.TaskSignalCancel - return - } - - // standard error - r.err = err - r.ch <- constants.TaskSignalError - return - } - - // success - r.ch <- constants.TaskSignalFinish -} - -// updateTask update and get updated info of task (Runner.t) -func (r *Runner) updateTask(status string, e error) (err error) { - if r.t != nil && status != "" { - // update task status - r.t.SetStatus(status) - if e != nil { - r.t.SetError(e.Error()) - } - if r.svc.GetNodeConfigService().IsMaster() { - if err := delegate.NewModelDelegate(r.t).Save(); err != nil { - return err - } - } else { - if err := client.NewModelDelegate(r.t, client.WithDelegateConfigPath(r.svc.GetConfigPath())).Save(); err != nil { - return err - } - } - - // send notification - go r.sendNotification() - - // update stats - go func() { - r._updateTaskStat(status) - r._updateSpiderStat(status) - }() - } - - // get task - r.t, err = r.svc.GetTaskById(r.tid) - if err != nil { - return err - } - - return nil -} - -func (r *Runner) initSub() (err error) { - r.sub, err = r.c.GetTaskClient().Subscribe(context.Background()) - if err != nil { - return trace.TraceError(err) - } - return nil -} - -func (r *Runner) writeLogLines(lines []string) { - data, err := json.Marshal(&entity.StreamMessageTaskData{ - TaskId: r.tid, - Logs: lines, - }) - if err != nil { - trace.PrintError(err) - return - } - msg := &grpc.StreamMessage{ - Code: grpc.StreamMessageCode_INSERT_LOGS, - Data: data, - } - if err := r.sub.Send(msg); err != nil { - trace.PrintError(err) - return - } -} - -func (r *Runner) _updateTaskStat(status string) { - ts, err := r.svc.GetModelTaskStatService().GetTaskStatById(r.tid) - if err != nil { - trace.PrintError(err) - return - } - switch status { - case constants.TaskStatusPending: - // do nothing - case constants.TaskStatusRunning: - ts.SetStartTs(time.Now()) - ts.SetWaitDuration(ts.GetStartTs().Sub(ts.GetCreateTs()).Milliseconds()) - case constants.TaskStatusFinished, constants.TaskStatusError, constants.TaskStatusCancelled: - ts.SetEndTs(time.Now()) - ts.SetRuntimeDuration(ts.GetEndTs().Sub(ts.GetStartTs()).Milliseconds()) - ts.SetTotalDuration(ts.GetEndTs().Sub(ts.GetCreateTs()).Milliseconds()) - } - if r.svc.GetNodeConfigService().IsMaster() { - if err := delegate.NewModelDelegate(ts).Save(); err != nil { - trace.PrintError(err) - return - } - } else { - if err := client.NewModelDelegate(ts, client.WithDelegateConfigPath(r.svc.GetConfigPath())).Save(); err != nil { - trace.PrintError(err) - return - } - } -} - -func (r *Runner) sendNotification() { - data, err := json.Marshal(r.t) - if err != nil { - trace.PrintError(err) - return - } - req := &grpc.Request{ - NodeKey: r.svc.GetNodeConfigService().GetNodeKey(), - Data: data, - } - _, err = r.c.GetTaskClient().SendNotification(context.Background(), req) - if err != nil { - trace.PrintError(err) - return - } -} - -func (r *Runner) _updateSpiderStat(status string) { - // task stat - ts, err := r.svc.GetModelTaskStatService().GetTaskStatById(r.tid) - if err != nil { - trace.PrintError(err) - return - } - - // update - var update bson.M - switch status { - case constants.TaskStatusPending, constants.TaskStatusRunning: - update = bson.M{ - "$set": bson.M{ - "last_task_id": r.tid, // last task id - }, - "$inc": bson.M{ - "tasks": 1, // task count - "wait_duration": ts.GetWaitDuration(), // wait duration - }, - } - case constants.TaskStatusFinished, constants.TaskStatusError, constants.TaskStatusCancelled: - update = bson.M{ - "$inc": bson.M{ - "results": ts.GetResultCount(), // results - "runtime_duration": ts.GetRuntimeDuration() / 1000, // runtime duration - "total_duration": ts.GetTotalDuration() / 1000, // total duration - }, - } - default: - trace.PrintError(errors.ErrorTaskInvalidType) - return - } - - // perform update - if r.svc.GetNodeConfigService().IsMaster() { - if err := mongo.GetMongoCol(interfaces.ModelColNameSpiderStat).UpdateId(r.s.GetId(), update); err != nil { - trace.PrintError(err) - return - } - } else { - modelSvc, err := client.NewBaseServiceDelegate( - client.WithBaseServiceModelId(interfaces.ModelIdSpiderStat), - client.WithBaseServiceConfigPath(r.svc.GetConfigPath()), - ) - if err != nil { - trace.PrintError(err) - return - } - if err := modelSvc.UpdateById(r.s.GetId(), update); err != nil { - trace.PrintError(err) - return - } - } - -} - -func NewTaskRunner(id primitive.ObjectID, svc interfaces.TaskHandlerService, opts ...RunnerOption) (r2 interfaces.TaskRunner, err error) { - // validate options - if id.IsZero() { - return nil, constants.ErrInvalidOptions - } - - // runner - r := &Runner{ - subscribeTimeout: 30 * time.Second, - bufferSize: 1024 * 1024, - svc: svc, - tid: id, - ch: make(chan constants.TaskSignal), - logBatchSize: 20, - } - - // apply options - for _, opt := range opts { - opt(r) - } - - // task - r.t, err = svc.GetTaskById(id) - if err != nil { - return nil, err - } - - // spider - r.s, err = svc.GetSpiderById(r.t.GetSpiderId()) - if err != nil { - return nil, err - } - - // task fs service - r.fsSvc = fs2.NewFsServiceV2(filepath.Join(viper.GetString("workspace"), r.s.GetId().Hex())) - - // dependency injection - if err := container.GetContainer().Invoke(func( - c interfaces.GrpcClient, - ) { - r.c = c - }); err != nil { - return nil, trace.TraceError(err) - } - - _ = container.GetContainer().Invoke(func(hookSvc interfaces.TaskHookService) { - r.hookSvc = hookSvc - }) - - // initialize task runner - if err := r.Init(); err != nil { - return r, err - } - - return r, nil -} diff --git a/core/task/handler/runner_v2.go b/core/task/handler/runner_v2.go index 7fec5e18..b4baac24 100644 --- a/core/task/handler/runner_v2.go +++ b/core/task/handler/runner_v2.go @@ -605,17 +605,13 @@ func (r *RunnerV2) _updateTaskStat(status string) { } func (r *RunnerV2) sendNotification() { - data, err := json.Marshal(r.t) - if err != nil { - trace.PrintError(err) - return - } - req := &grpc.Request{ + req := &grpc.TaskServiceSendNotificationRequest{ NodeKey: r.svc.GetNodeConfigService().GetNodeKey(), - Data: data, + TaskId: r.tid.Hex(), } - _, err = r.c.TaskClient.SendNotification(context.Background(), req) + _, err := r.c.TaskClient.SendNotification(context.Background(), req) if err != nil { + log.Errorf("Error sending notification: %v", err) trace.PrintError(err) return } diff --git a/core/task/handler/service.go b/core/task/handler/service.go deleted file mode 100644 index 69a99884..00000000 --- a/core/task/handler/service.go +++ /dev/null @@ -1,506 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "errors" - "github.com/apex/log" - "github.com/crawlab-team/crawlab/core/constants" - "github.com/crawlab-team/crawlab/core/container" - errors2 "github.com/crawlab-team/crawlab/core/errors" - "github.com/crawlab-team/crawlab/core/interfaces" - "github.com/crawlab-team/crawlab/core/models/client" - "github.com/crawlab-team/crawlab/core/models/delegate" - "github.com/crawlab-team/crawlab/core/models/service" - "github.com/crawlab-team/crawlab/core/task" - "github.com/crawlab-team/crawlab/trace" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" - "sync" - "time" -) - -type Service struct { - // dependencies - interfaces.TaskBaseService - cfgSvc interfaces.NodeConfigService - modelSvc service.ModelService - clientModelSvc interfaces.GrpcClientModelService - clientModelNodeSvc interfaces.GrpcClientModelNodeService - clientModelSpiderSvc interfaces.GrpcClientModelSpiderService - clientModelTaskSvc interfaces.GrpcClientModelTaskService - clientModelTaskStatSvc interfaces.GrpcClientModelTaskStatService - clientModelEnvironmentSvc interfaces.GrpcClientModelEnvironmentService - c interfaces.GrpcClient // grpc client - - // settings - //maxRunners int - exitWatchDuration time.Duration - reportInterval time.Duration - fetchInterval time.Duration - fetchTimeout time.Duration - cancelTimeout time.Duration - - // internals variables - stopped bool - mu sync.Mutex - runners sync.Map // pool of task runners started - syncLocks sync.Map // files sync locks map of task runners -} - -func (svc *Service) Start() { - // Initialize gRPC if not started - if !svc.c.IsStarted() { - svc.c.Start() - } - - go svc.ReportStatus() - go svc.Fetch() -} - -func (svc *Service) Run(taskId primitive.ObjectID) (err error) { - return svc.run(taskId) -} - -func (svc *Service) Reset() { - svc.mu.Lock() - defer svc.mu.Unlock() -} - -func (svc *Service) Cancel(taskId primitive.ObjectID) (err error) { - r, err := svc.getRunner(taskId) - if err != nil { - return err - } - if err := r.Cancel(); err != nil { - return err - } - return nil -} - -func (svc *Service) Fetch() { - for { - // wait - time.Sleep(svc.fetchInterval) - - // current node - n, err := svc.GetCurrentNode() - if err != nil { - continue - } - - // skip if node is not active or enabled - if !n.GetActive() || !n.GetEnabled() { - continue - } - - // validate if there are available runners - if svc.getRunnerCount() >= n.GetMaxRunners() { - continue - } - - // stop - if svc.stopped { - return - } - - // fetch task - tid, err := svc.fetch() - if err != nil { - trace.PrintError(err) - continue - } - - // skip if no task id - if tid.IsZero() { - continue - } - - // run task - if err := svc.run(tid); err != nil { - trace.PrintError(err) - t, err := svc.GetTaskById(tid) - if err == nil && t.GetStatus() != constants.TaskStatusCancelled { - t.SetError(err.Error()) - _ = svc.SaveTask(t, constants.TaskStatusError) - continue - } - continue - } - } -} - -func (svc *Service) ReportStatus() { - for { - if svc.stopped { - return - } - - // report handler status - if err := svc.reportStatus(); err != nil { - trace.PrintError(err) - } - - // wait - time.Sleep(svc.reportInterval) - } -} - -func (svc *Service) IsSyncLocked(path string) (ok bool) { - _, ok = svc.syncLocks.Load(path) - return ok -} - -func (svc *Service) LockSync(path string) { - svc.syncLocks.Store(path, true) -} - -func (svc *Service) UnlockSync(path string) { - svc.syncLocks.Delete(path) -} - -//func (svc *Service) GetMaxRunners() (maxRunners int) { -// return svc.maxRunners -//} -// -//func (svc *Service) SetMaxRunners(maxRunners int) { -// svc.maxRunners = maxRunners -//} - -func (svc *Service) GetExitWatchDuration() (duration time.Duration) { - return svc.exitWatchDuration -} - -func (svc *Service) SetExitWatchDuration(duration time.Duration) { - svc.exitWatchDuration = duration -} - -func (svc *Service) GetFetchInterval() (interval time.Duration) { - return svc.fetchInterval -} - -func (svc *Service) SetFetchInterval(interval time.Duration) { - svc.fetchInterval = interval -} - -func (svc *Service) GetReportInterval() (interval time.Duration) { - return svc.reportInterval -} - -func (svc *Service) SetReportInterval(interval time.Duration) { - svc.reportInterval = interval -} - -func (svc *Service) GetCancelTimeout() (timeout time.Duration) { - return svc.cancelTimeout -} - -func (svc *Service) SetCancelTimeout(timeout time.Duration) { - svc.cancelTimeout = timeout -} - -func (svc *Service) GetModelService() (modelSvc interfaces.GrpcClientModelService) { - return svc.clientModelSvc -} - -func (svc *Service) GetModelSpiderService() (modelSpiderSvc interfaces.GrpcClientModelSpiderService) { - return svc.clientModelSpiderSvc -} - -func (svc *Service) GetModelTaskService() (modelTaskSvc interfaces.GrpcClientModelTaskService) { - return svc.clientModelTaskSvc -} - -func (svc *Service) GetModelTaskStatService() (modelTaskSvc interfaces.GrpcClientModelTaskStatService) { - return svc.clientModelTaskStatSvc -} - -func (svc *Service) GetModelEnvironmentService() (modelTaskSvc interfaces.GrpcClientModelEnvironmentService) { - return svc.clientModelEnvironmentSvc -} - -func (svc *Service) GetNodeConfigService() (cfgSvc interfaces.NodeConfigService) { - return svc.cfgSvc -} - -func (svc *Service) GetCurrentNode() (n interfaces.Node, err error) { - // node key - nodeKey := svc.cfgSvc.GetNodeKey() - - // current node - if svc.cfgSvc.IsMaster() { - n, err = svc.modelSvc.GetNodeByKey(nodeKey, nil) - } else { - n, err = svc.clientModelNodeSvc.GetNodeByKey(nodeKey) - } - if err != nil { - return nil, err - } - - return n, nil -} - -func (svc *Service) GetTaskById(id primitive.ObjectID) (t interfaces.Task, err error) { - if svc.cfgSvc.IsMaster() { - t, err = svc.modelSvc.GetTaskById(id) - } else { - t, err = svc.clientModelTaskSvc.GetTaskById(id) - } - if err != nil { - return nil, err - } - - return t, nil -} - -func (svc *Service) GetSpiderById(id primitive.ObjectID) (s interfaces.Spider, err error) { - if svc.cfgSvc.IsMaster() { - s, err = svc.modelSvc.GetSpiderById(id) - } else { - s, err = svc.clientModelSpiderSvc.GetSpiderById(id) - } - if err != nil { - return nil, err - } - - return s, nil -} - -func (svc *Service) getRunners() (runners []*Runner) { - svc.mu.Lock() - defer svc.mu.Unlock() - svc.runners.Range(func(key, value interface{}) bool { - r := value.(Runner) - runners = append(runners, &r) - return true - }) - return runners -} - -func (svc *Service) getRunnerCount() (count int) { - n, err := svc.GetCurrentNode() - if err != nil { - trace.PrintError(err) - return - } - query := bson.M{ - "node_id": n.GetId(), - "status": constants.TaskStatusRunning, - } - if svc.cfgSvc.IsMaster() { - count, err = svc.modelSvc.GetBaseService(interfaces.ModelIdTask).Count(query) - if err != nil { - trace.PrintError(err) - return - } - } else { - count, err = svc.clientModelTaskSvc.Count(query) - if err != nil { - trace.PrintError(err) - return - } - } - return count -} - -func (svc *Service) getRunner(taskId primitive.ObjectID) (r interfaces.TaskRunner, err error) { - log.Debugf("[TaskHandlerService] getRunner: taskId[%v]", taskId) - v, ok := svc.runners.Load(taskId) - if !ok { - return nil, trace.TraceError(errors2.ErrorTaskNotExists) - } - switch v.(type) { - case interfaces.TaskRunner: - r = v.(interfaces.TaskRunner) - default: - return nil, trace.TraceError(errors2.ErrorModelInvalidType) - } - return r, nil -} - -func (svc *Service) addRunner(taskId primitive.ObjectID, r interfaces.TaskRunner) { - log.Debugf("[TaskHandlerService] addRunner: taskId[%v]", taskId) - svc.runners.Store(taskId, r) -} - -func (svc *Service) deleteRunner(taskId primitive.ObjectID) { - log.Debugf("[TaskHandlerService] deleteRunner: taskId[%v]", taskId) - svc.runners.Delete(taskId) -} - -func (svc *Service) saveTask(t interfaces.Task, status string) (err error) { - // normalize status - if status == "" { - status = constants.TaskStatusPending - } - - // set task status - t.SetStatus(status) - - // attempt to get task from database - _, err = svc.clientModelTaskSvc.GetTaskById(t.GetId()) - if err != nil { - // if task does not exist, add to database - if err == mongo.ErrNoDocuments { - if err := client.NewModelDelegate(t, client.WithDelegateConfigPath(svc.GetConfigPath())).Add(); err != nil { - return err - } - return nil - } else { - return err - } - } else { - // otherwise, update - if err := client.NewModelDelegate(t, client.WithDelegateConfigPath(svc.GetConfigPath())).Save(); err != nil { - return err - } - return nil - } -} - -func (svc *Service) reportStatus() (err error) { - // current node - n, err := svc.GetCurrentNode() - if err != nil { - return err - } - - // available runners of handler - ar := n.GetMaxRunners() - svc.getRunnerCount() - - // set available runners - n.SetAvailableRunners(ar) - - // save node - if svc.cfgSvc.IsMaster() { - err = delegate.NewModelDelegate(n).Save() - } else { - err = client.NewModelDelegate(n, client.WithDelegateConfigPath(svc.GetConfigPath())).Save() - } - if err != nil { - return err - } - - return nil -} - -func (svc *Service) fetch() (tid primitive.ObjectID, err error) { - ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout) - defer cancel() - res, err := svc.c.GetTaskClient().Fetch(ctx, svc.c.NewRequest(nil)) - if err != nil { - return tid, trace.TraceError(err) - } - if err := json.Unmarshal(res.Data, &tid); err != nil { - return tid, trace.TraceError(err) - } - return tid, nil -} - -func (svc *Service) run(taskId primitive.ObjectID) (err error) { - // attempt to get runner from pool - _, ok := svc.runners.Load(taskId) - if ok { - return trace.TraceError(errors2.ErrorTaskAlreadyExists) - } - - // create a new task runner - r, err := NewTaskRunner(taskId, svc) - if err != nil { - return trace.TraceError(err) - } - - // add runner to pool - svc.addRunner(taskId, r) - - // create a goroutine to run task - go func() { - // delete runner from pool - defer svc.deleteRunner(r.GetTaskId()) - defer func(r interfaces.TaskRunner) { - err := r.CleanUp() - if err != nil { - log.Errorf("task[%s] clean up error: %v", r.GetTaskId().Hex(), err) - } - }(r) - // run task process (blocking) - // error or finish after task runner ends - if err := r.Run(); err != nil { - switch { - case errors.Is(err, constants.ErrTaskError): - log.Errorf("task[%s] finished with error: %v", r.GetTaskId().Hex(), err) - case errors.Is(err, constants.ErrTaskCancelled): - log.Errorf("task[%s] cancelled", r.GetTaskId().Hex()) - default: - log.Errorf("task[%s] finished with unknown error: %v", r.GetTaskId().Hex(), err) - } - } - log.Infof("task[%s] finished", r.GetTaskId().Hex()) - }() - - return nil -} - -func NewTaskHandlerService() (svc2 interfaces.TaskHandlerService, err error) { - // base service - baseSvc, err := task.NewBaseService() - if err != nil { - return nil, trace.TraceError(err) - } - - // service - svc := &Service{ - TaskBaseService: baseSvc, - exitWatchDuration: 60 * time.Second, - fetchInterval: 1 * time.Second, - fetchTimeout: 15 * time.Second, - reportInterval: 5 * time.Second, - cancelTimeout: 5 * time.Second, - mu: sync.Mutex{}, - runners: sync.Map{}, - syncLocks: sync.Map{}, - } - - // dependency injection - if err := container.GetContainer().Invoke(func( - cfgSvc interfaces.NodeConfigService, - modelSvc service.ModelService, - clientModelSvc interfaces.GrpcClientModelService, - clientModelNodeSvc interfaces.GrpcClientModelNodeService, - clientModelSpiderSvc interfaces.GrpcClientModelSpiderService, - clientModelTaskSvc interfaces.GrpcClientModelTaskService, - clientModelTaskStatSvc interfaces.GrpcClientModelTaskStatService, - clientModelEnvironmentSvc interfaces.GrpcClientModelEnvironmentService, - c interfaces.GrpcClient, - ) { - svc.cfgSvc = cfgSvc - svc.modelSvc = modelSvc - svc.clientModelSvc = clientModelSvc - svc.clientModelNodeSvc = clientModelNodeSvc - svc.clientModelSpiderSvc = clientModelSpiderSvc - svc.clientModelTaskSvc = clientModelTaskSvc - svc.clientModelTaskStatSvc = clientModelTaskStatSvc - svc.clientModelEnvironmentSvc = clientModelEnvironmentSvc - svc.c = c - }); err != nil { - return nil, trace.TraceError(err) - } - - log.Debugf("[NewTaskHandlerService] svc[cfgPath: %s]", svc.cfgSvc.GetConfigPath()) - - return svc, nil -} - -var _service interfaces.TaskHandlerService - -func GetTaskHandlerService() (svr interfaces.TaskHandlerService, err error) { - if _service != nil { - return _service, nil - } - _service, err = NewTaskHandlerService() - if err != nil { - return nil, err - } - return _service, nil -} diff --git a/core/task/handler/service_v2.go b/core/task/handler/service_v2.go index be7c2725..28c2c73c 100644 --- a/core/task/handler/service_v2.go +++ b/core/task/handler/service_v2.go @@ -245,11 +245,11 @@ func (svc *ServiceV2) GetSpiderById(id primitive.ObjectID) (s *models2.SpiderV2, return s, nil } -func (svc *ServiceV2) getRunners() (runners []*Runner) { +func (svc *ServiceV2) getRunners() (runners []*RunnerV2) { svc.mu.Lock() defer svc.mu.Unlock() svc.runners.Range(func(key, value interface{}) bool { - r := value.(Runner) + r := value.(RunnerV2) runners = append(runners, &r) return true }) diff --git a/grpc/proto/services/task_service.proto b/grpc/proto/services/task_service.proto index f004e565..6b67caf0 100644 --- a/grpc/proto/services/task_service.proto +++ b/grpc/proto/services/task_service.proto @@ -7,8 +7,13 @@ import "entity/stream_message.proto"; package grpc; option go_package = ".;grpc"; +message TaskServiceSendNotificationRequest { + string node_key = 1; + string task_id = 2; +} + service TaskService { rpc Subscribe(stream StreamMessage) returns (Response){}; rpc Fetch(Request) returns (Response){}; - rpc SendNotification(Request) returns (Response){}; + rpc SendNotification(TaskServiceSendNotificationRequest) returns (Response){}; } diff --git a/grpc/task_service.pb.go b/grpc/task_service.pb.go index 5d03cbb3..5aa365a9 100644 --- a/grpc/task_service.pb.go +++ b/grpc/task_service.pb.go @@ -10,6 +10,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" + sync "sync" ) const ( @@ -19,6 +20,61 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type TaskServiceSendNotificationRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + NodeKey string `protobuf:"bytes,1,opt,name=node_key,json=nodeKey,proto3" json:"node_key,omitempty"` + TaskId string `protobuf:"bytes,2,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` +} + +func (x *TaskServiceSendNotificationRequest) Reset() { + *x = TaskServiceSendNotificationRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_services_task_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TaskServiceSendNotificationRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TaskServiceSendNotificationRequest) ProtoMessage() {} + +func (x *TaskServiceSendNotificationRequest) ProtoReflect() protoreflect.Message { + mi := &file_services_task_service_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TaskServiceSendNotificationRequest.ProtoReflect.Descriptor instead. +func (*TaskServiceSendNotificationRequest) Descriptor() ([]byte, []int) { + return file_services_task_service_proto_rawDescGZIP(), []int{0} +} + +func (x *TaskServiceSendNotificationRequest) GetNodeKey() string { + if x != nil { + return x.NodeKey + } + return "" +} + +func (x *TaskServiceSendNotificationRequest) GetTaskId() string { + if x != nil { + return x.TaskId + } + return "" +} + var File_services_task_service_proto protoreflect.FileDescriptor var file_services_task_service_proto_rawDesc = []byte{ @@ -28,33 +84,54 @@ var file_services_task_service_proto_rawDesc = []byte{ 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x15, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xa2, 0x01, - 0x0a, 0x0b, 0x54, 0x61, 0x73, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x34, 0x0a, - 0x09, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x12, 0x13, 0x2e, 0x67, 0x72, 0x70, - 0x63, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x0e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x28, 0x01, 0x12, 0x28, 0x0a, 0x05, 0x46, 0x65, 0x74, 0x63, 0x68, 0x12, 0x0d, 0x2e, 0x67, - 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x67, 0x72, - 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x10, 0x53, 0x65, 0x6e, 0x64, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x0d, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x0e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2e, 0x3b, 0x67, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x58, 0x0a, + 0x22, 0x54, 0x61, 0x73, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x53, 0x65, 0x6e, 0x64, + 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x17, + 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x32, 0xbd, 0x01, 0x0a, 0x0b, 0x54, 0x61, 0x73, 0x6b, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x34, 0x0a, 0x09, 0x53, 0x75, 0x62, 0x73, 0x63, + 0x72, 0x69, 0x62, 0x65, 0x12, 0x13, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x74, 0x72, 0x65, + 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e, 0x67, 0x72, 0x70, 0x63, + 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x28, 0x0a, + 0x05, 0x46, 0x65, 0x74, 0x63, 0x68, 0x12, 0x0d, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4e, 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x4e, + 0x6f, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x28, 0x2e, 0x67, 0x72, + 0x70, 0x63, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x53, 0x65, + 0x6e, 0x64, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2e, 0x3b, 0x67, 0x72, 0x70, + 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } +var ( + file_services_task_service_proto_rawDescOnce sync.Once + file_services_task_service_proto_rawDescData = file_services_task_service_proto_rawDesc +) + +func file_services_task_service_proto_rawDescGZIP() []byte { + file_services_task_service_proto_rawDescOnce.Do(func() { + file_services_task_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_services_task_service_proto_rawDescData) + }) + return file_services_task_service_proto_rawDescData +} + +var file_services_task_service_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_services_task_service_proto_goTypes = []any{ - (*StreamMessage)(nil), // 0: grpc.StreamMessage - (*Request)(nil), // 1: grpc.Request - (*Response)(nil), // 2: grpc.Response + (*TaskServiceSendNotificationRequest)(nil), // 0: grpc.TaskServiceSendNotificationRequest + (*StreamMessage)(nil), // 1: grpc.StreamMessage + (*Request)(nil), // 2: grpc.Request + (*Response)(nil), // 3: grpc.Response } var file_services_task_service_proto_depIdxs = []int32{ - 0, // 0: grpc.TaskService.Subscribe:input_type -> grpc.StreamMessage - 1, // 1: grpc.TaskService.Fetch:input_type -> grpc.Request - 1, // 2: grpc.TaskService.SendNotification:input_type -> grpc.Request - 2, // 3: grpc.TaskService.Subscribe:output_type -> grpc.Response - 2, // 4: grpc.TaskService.Fetch:output_type -> grpc.Response - 2, // 5: grpc.TaskService.SendNotification:output_type -> grpc.Response + 1, // 0: grpc.TaskService.Subscribe:input_type -> grpc.StreamMessage + 2, // 1: grpc.TaskService.Fetch:input_type -> grpc.Request + 0, // 2: grpc.TaskService.SendNotification:input_type -> grpc.TaskServiceSendNotificationRequest + 3, // 3: grpc.TaskService.Subscribe:output_type -> grpc.Response + 3, // 4: grpc.TaskService.Fetch:output_type -> grpc.Response + 3, // 5: grpc.TaskService.SendNotification:output_type -> grpc.Response 3, // [3:6] is the sub-list for method output_type 0, // [0:3] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name @@ -70,18 +147,33 @@ func file_services_task_service_proto_init() { file_entity_request_proto_init() file_entity_response_proto_init() file_entity_stream_message_proto_init() + if !protoimpl.UnsafeEnabled { + file_services_task_service_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*TaskServiceSendNotificationRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_services_task_service_proto_rawDesc, NumEnums: 0, - NumMessages: 0, + NumMessages: 1, NumExtensions: 0, NumServices: 1, }, GoTypes: file_services_task_service_proto_goTypes, DependencyIndexes: file_services_task_service_proto_depIdxs, + MessageInfos: file_services_task_service_proto_msgTypes, }.Build() File_services_task_service_proto = out.File file_services_task_service_proto_rawDesc = nil diff --git a/grpc/task_service_grpc.pb.go b/grpc/task_service_grpc.pb.go index 6cb274dd..8ae3a469 100644 --- a/grpc/task_service_grpc.pb.go +++ b/grpc/task_service_grpc.pb.go @@ -30,7 +30,7 @@ const ( type TaskServiceClient interface { Subscribe(ctx context.Context, opts ...grpc.CallOption) (TaskService_SubscribeClient, error) Fetch(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) - SendNotification(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) + SendNotification(ctx context.Context, in *TaskServiceSendNotificationRequest, opts ...grpc.CallOption) (*Response, error) } type taskServiceClient struct { @@ -86,7 +86,7 @@ func (c *taskServiceClient) Fetch(ctx context.Context, in *Request, opts ...grpc return out, nil } -func (c *taskServiceClient) SendNotification(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) { +func (c *taskServiceClient) SendNotification(ctx context.Context, in *TaskServiceSendNotificationRequest, opts ...grpc.CallOption) (*Response, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Response) err := c.cc.Invoke(ctx, TaskService_SendNotification_FullMethodName, in, out, cOpts...) @@ -102,7 +102,7 @@ func (c *taskServiceClient) SendNotification(ctx context.Context, in *Request, o type TaskServiceServer interface { Subscribe(TaskService_SubscribeServer) error Fetch(context.Context, *Request) (*Response, error) - SendNotification(context.Context, *Request) (*Response, error) + SendNotification(context.Context, *TaskServiceSendNotificationRequest) (*Response, error) mustEmbedUnimplementedTaskServiceServer() } @@ -116,7 +116,7 @@ func (UnimplementedTaskServiceServer) Subscribe(TaskService_SubscribeServer) err func (UnimplementedTaskServiceServer) Fetch(context.Context, *Request) (*Response, error) { return nil, status.Errorf(codes.Unimplemented, "method Fetch not implemented") } -func (UnimplementedTaskServiceServer) SendNotification(context.Context, *Request) (*Response, error) { +func (UnimplementedTaskServiceServer) SendNotification(context.Context, *TaskServiceSendNotificationRequest) (*Response, error) { return nil, status.Errorf(codes.Unimplemented, "method SendNotification not implemented") } func (UnimplementedTaskServiceServer) mustEmbedUnimplementedTaskServiceServer() {} @@ -177,7 +177,7 @@ func _TaskService_Fetch_Handler(srv interface{}, ctx context.Context, dec func(i } func _TaskService_SendNotification_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Request) + in := new(TaskServiceSendNotificationRequest) if err := dec(in); err != nil { return nil, err } @@ -189,7 +189,7 @@ func _TaskService_SendNotification_Handler(srv interface{}, ctx context.Context, FullMethod: TaskService_SendNotification_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(TaskServiceServer).SendNotification(ctx, req.(*Request)) + return srv.(TaskServiceServer).SendNotification(ctx, req.(*TaskServiceSendNotificationRequest)) } return interceptor(ctx, in, info, handler) }