From e8edf9dac4dcd82d55cb7f88e870c78f2cb95c8d Mon Sep 17 00:00:00 2001 From: marvzhang Date: Thu, 5 Dec 2019 11:51:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=BF=90=E8=A1=8C=E7=B1=BB?= =?UTF-8?q?=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/constants/task.go | 6 +++ backend/routes/task.go | 73 ++++++++++++++++++++++++++---------- backend/services/schedule.go | 4 +- backend/services/task.go | 28 ++++++++++++++ 4 files changed, 90 insertions(+), 21 deletions(-) diff --git a/backend/constants/task.go b/backend/constants/task.go index b6fb615c..63144e8b 100644 --- a/backend/constants/task.go +++ b/backend/constants/task.go @@ -19,3 +19,9 @@ const ( TaskFinish string = "finish" TaskCancel string = "cancel" ) + +const ( + RunTypeAllNodes string = "all-nodes" + RunTypeRandom string = "random" + RunTypeSelectedNodes string = "selected-nodes" +) diff --git a/backend/routes/task.go b/backend/routes/task.go index 9c0aa43f..45a16675 100644 --- a/backend/routes/task.go +++ b/backend/routes/task.go @@ -9,7 +9,6 @@ import ( "encoding/csv" "github.com/gin-gonic/gin" "github.com/globalsign/mgo/bson" - uuid "github.com/satori/go.uuid" "net/http" ) @@ -86,32 +85,68 @@ func GetTask(c *gin.Context) { } func PutTask(c *gin.Context) { - // 生成任务ID - id := uuid.NewV4() + type TaskRequestBody struct { + SpiderId bson.ObjectId `json:"spider_id"` + RunType string `json:"run_type"` + NodeIds []bson.ObjectId `json:"node_ids"` + Param string `json:"param"` + } // 绑定数据 - var t model.Task - if err := c.ShouldBindJSON(&t); err != nil { + var reqBody TaskRequestBody + if err := c.ShouldBindJSON(&reqBody); err != nil { HandleError(http.StatusBadRequest, c, err) return } - t.Id = id.String() - t.Status = constants.StatusPending - // 如果没有传入node_id,则置为null - if t.NodeId.Hex() == "" { - t.NodeId = bson.ObjectIdHex(constants.ObjectIdNull) - } + if reqBody.RunType == constants.RunTypeAllNodes { + // 所有节点 + nodes, err := model.GetNodeList(nil) + if err != nil { + HandleError(http.StatusInternalServerError, c, err) + return + } + for _, node := range nodes { + t := model.Task{ + SpiderId: reqBody.SpiderId, + NodeId: node.Id, + Param: reqBody.Param, + } - // 将任务存入数据库 - if err := model.AddTask(t); err != nil { - HandleError(http.StatusInternalServerError, c, err) - return - } + if err := services.AddTask(t); err != nil { + HandleError(http.StatusInternalServerError, c, err) + return + } + } - // 加入任务队列 - if err := services.AssignTask(t); err != nil { - HandleError(http.StatusInternalServerError, c, err) + } else if reqBody.RunType == constants.RunTypeRandom { + // 随机 + t := model.Task{ + SpiderId: reqBody.SpiderId, + Param: reqBody.Param, + } + if err := services.AddTask(t); err != nil { + HandleError(http.StatusInternalServerError, c, err) + return + } + + } else if reqBody.RunType == constants.RunTypeSelectedNodes { + // 指定节点 + for _, nodeId := range reqBody.NodeIds { + t := model.Task{ + SpiderId: reqBody.SpiderId, + NodeId: nodeId, + Param: reqBody.Param, + } + + if err := services.AddTask(t); err != nil { + HandleError(http.StatusInternalServerError, c, err) + return + } + } + + } else { + HandleErrorF(http.StatusBadRequest, c, "invalid run_type") return } diff --git a/backend/services/schedule.go b/backend/services/schedule.go index d4c1635b..52a6492e 100644 --- a/backend/services/schedule.go +++ b/backend/services/schedule.go @@ -15,7 +15,7 @@ type Scheduler struct { cron *cron.Cron } -func AddTask(s model.Schedule) func() { +func AddScheduleTask(s model.Schedule) func() { return func() { node, err := model.GetNodeByKey(s.NodeKey) if err != nil || node.Id.Hex() == "" { @@ -97,7 +97,7 @@ func (s *Scheduler) AddJob(job model.Schedule) error { spec := job.Cron // 添加任务 - eid, err := s.cron.AddFunc(spec, AddTask(job)) + eid, err := s.cron.AddFunc(spec, AddScheduleTask(job)) if err != nil { log.Errorf("add func task error: %s", err.Error()) debug.PrintStack() diff --git a/backend/services/task.go b/backend/services/task.go index 0339118a..7267ec5b 100644 --- a/backend/services/task.go +++ b/backend/services/task.go @@ -10,6 +10,8 @@ import ( "encoding/json" "errors" "github.com/apex/log" + "github.com/globalsign/mgo/bson" + uuid "github.com/satori/go.uuid" "github.com/spf13/viper" "os" "os/exec" @@ -581,6 +583,32 @@ func CancelTask(id string) (err error) { return nil } +func AddTask(t model.Task) error { + // 生成任务ID + id := uuid.NewV4() + t.Id = id.String() + + // 设置任务状态 + t.Status = constants.StatusPending + + // 如果没有传入node_id,则置为null + if t.NodeId.Hex() == "" { + t.NodeId = bson.ObjectIdHex(constants.ObjectIdNull) + } + + // 将任务存入数据库 + if err := model.AddTask(t); err != nil { + return err + } + + // 加入任务队列 + if err := AssignTask(t); err != nil { + return err + } + + return nil +} + func HandleTaskError(t model.Task, err error) { log.Error("handle task error:" + err.Error()) t.Status = constants.StatusError