Files
crawlab/core/controllers/task.go

430 lines
12 KiB
Go

package controllers
import (
"os"
"path/filepath"
"strings"
"sync"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
mongo2 "github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/spider/admin"
"github.com/crawlab-team/crawlab/core/task/log"
"github.com/crawlab-team/crawlab/core/task/scheduler"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/gin-gonic/gin"
"github.com/juju/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
type GetTaskByIdParams struct {
Id string `path:"id" description:"Task ID" format:"objectid" pattern:"^[0-9a-fA-F]{24}$"`
}
func GetTaskById(_ *gin.Context, params *GetTaskByIdParams) (response *Response[models.TaskDTO], err error) {
// id
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorResponse[models.TaskDTO](err)
}
// aggregation pipelines
pipelines := service.GetByIdPipeline(id)
pipelines = addTaskPipelines(pipelines)
// perform query
var tasks []models.TaskDTO
err = service.GetCollection[models.TaskDTO]().Aggregate(pipelines, nil).All(&tasks)
if err != nil {
return GetErrorResponse[models.TaskDTO](err)
}
// check results
if len(tasks) == 0 {
return nil, errors.NotFoundf("task %s not found", params.Id)
}
return GetDataResponse(tasks[0])
}
func GetTaskList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.TaskDTO], err error) {
// query parameters
query := ConvertToBsonMFromListParams(params)
sort, err := GetSortOptionFromString(params.Sort)
if err != nil {
return GetErrorListResponse[models.TaskDTO](err)
}
skip, limit := GetSkipLimitFromListParams(params)
// total
total, err := service.NewModelService[models.TaskDTO]().Count(query)
if err != nil {
return GetErrorListResponse[models.TaskDTO](err)
}
// check total
if total == 0 {
return GetEmptyListResponse[models.TaskDTO]()
}
// aggregation pipelines
pipelines := service.GetPaginationPipeline(query, sort, skip, limit)
pipelines = addTaskPipelines(pipelines)
// perform query
var tasks []models.TaskDTO
err = service.GetCollection[models.TaskDTO]().Aggregate(pipelines, nil).All(&tasks)
if err != nil {
return GetErrorListResponse[models.TaskDTO](err)
}
return GetListResponse(tasks, total)
}
type DeleteTaskByIdParams struct {
Id string `path:"id" description:"Task ID" format:"objectid" pattern:"^[0-9a-fA-F]{24}$"`
}
func DeleteTaskById(_ *gin.Context, params *DeleteTaskByIdParams) (response *VoidResponse, err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorVoidResponse(err)
}
// delete in db
if err := mongo2.RunTransaction(func(context mongo.SessionContext) (err error) {
// delete task
_, err = service.NewModelService[models.Task]().GetById(id)
if err != nil {
return err
}
err = service.NewModelService[models.Task]().DeleteById(id)
if err != nil {
return err
}
// delete task stat
_, err = service.NewModelService[models.TaskStat]().GetById(id)
if err != nil {
logger.Warnf("delete task stat error: %s", err.Error())
return nil
}
err = service.NewModelService[models.TaskStat]().DeleteById(id)
if err != nil {
logger.Warnf("delete task stat error: %s", err.Error())
return nil
}
return nil
}); err != nil {
return GetErrorVoidResponse(err)
}
// delete task logs
logPath := filepath.Join(utils.GetTaskLogPath(), id.Hex())
if err := os.RemoveAll(logPath); err != nil {
logger.Warnf("failed to remove task log directory: %s", logPath)
}
return GetVoidResponse()
}
type DeleteTaskListParams struct {
Ids []string `json:"ids"`
}
func DeleteTaskList(_ *gin.Context, params *DeleteTaskListParams) (response *VoidResponse, err error) {
var ids []primitive.ObjectID
for _, id := range params.Ids {
id, err := primitive.ObjectIDFromHex(id)
if err != nil {
return GetErrorVoidResponse(err)
}
ids = append(ids, id)
}
if err := mongo2.RunTransaction(func(context mongo.SessionContext) error {
// delete tasks
if err := service.NewModelService[models.Task]().DeleteMany(bson.M{
"_id": bson.M{
"$in": ids,
},
}); err != nil {
return err
}
// delete task stats
if err := service.NewModelService[models.Task]().DeleteMany(bson.M{
"_id": bson.M{
"$in": ids,
},
}); err != nil {
logger.Warnf("delete task stat error: %s", err.Error())
return nil
}
return nil
}); err != nil {
return GetErrorVoidResponse(err)
}
// delete tasks logs
wg := sync.WaitGroup{}
wg.Add(len(ids))
for _, id := range ids {
go func(taskId primitive.ObjectID) {
// delete task logs
logPath := filepath.Join(utils.GetTaskLogPath(), taskId.Hex())
if err := os.RemoveAll(logPath); err != nil {
logger.Warnf("failed to remove task log directory: %s", logPath)
}
wg.Done()
}(id)
}
wg.Wait()
return GetVoidResponse()
}
type PostTaskRunParams struct {
SpiderId string `json:"spider_id" validate:"required"`
Mode string `json:"mode"`
NodeIds []string `json:"node_ids"`
Cmd string `json:"cmd"`
Param string `json:"param"`
Priority int `json:"priority"`
}
func PostTaskRun(c *gin.Context, params *PostTaskRunParams) (response *Response[[]primitive.ObjectID], err error) {
spiderId, err := primitive.ObjectIDFromHex(params.SpiderId)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
var nodeIds []primitive.ObjectID
if params.NodeIds != nil {
for _, nodeId := range params.NodeIds {
nodeId, err := primitive.ObjectIDFromHex(nodeId)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
nodeIds = append(nodeIds, nodeId)
}
}
// spider
s, err := service.NewModelService[models.Spider]().GetById(spiderId)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
// options
opts := &interfaces.SpiderRunOptions{
Mode: params.Mode,
NodeIds: nodeIds,
Cmd: params.Cmd,
Param: params.Param,
Priority: params.Priority,
}
// user
if u := GetUserFromContext(c); u != nil {
opts.UserId = u.Id
}
// run
adminSvc := admin.GetSpiderAdminService()
taskIds, err := adminSvc.Schedule(s.Id, opts)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
return GetDataResponse(taskIds)
}
type PostTaskRestartParams struct {
Id string `path:"id" description:"Task ID" format:"objectid" pattern:"^[0-9a-fA-F]{24}$"`
}
func PostTaskRestart(c *gin.Context, params *PostTaskRestartParams) (response *Response[[]primitive.ObjectID], err error) {
// id
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
// task
t, err := service.NewModelService[models.Task]().GetById(id)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
// options
opts := &interfaces.SpiderRunOptions{
Mode: t.Mode,
NodeIds: t.NodeIds,
Cmd: t.Cmd,
Param: t.Param,
Priority: t.Priority,
}
// normalize options for tasks with assigned node
if !t.NodeId.IsZero() {
opts.NodeIds = []primitive.ObjectID{t.NodeId}
opts.Mode = constants.RunTypeSelectedNodes
}
// user
if u := GetUserFromContext(c); u != nil {
opts.UserId = u.Id
}
// run
adminSvc := admin.GetSpiderAdminService()
taskIds, err := adminSvc.Schedule(t.SpiderId, opts)
if err != nil {
return GetErrorResponse[[]primitive.ObjectID](err)
}
return GetDataResponse(taskIds)
}
type PostTaskCancelParams struct {
Id string `path:"id" description:"Task ID" format:"objectid" pattern:"^[0-9a-fA-F]{24}$"`
Force bool `json:"force,omitempty" description:"Force cancel" default:"false"`
}
func PostTaskCancel(c *gin.Context, params *PostTaskCancelParams) (response *VoidResponse, err error) {
// id
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorVoidResponse(err)
}
// task
t, err := service.NewModelService[models.Task]().GetById(id)
if err != nil {
return GetErrorVoidResponse(err)
}
// validate - only check if task is cancellable when force is false
if !params.Force && !utils.IsCancellable(t.Status) {
return GetErrorVoidResponse(errors.New("task is not cancellable"))
}
u := GetUserFromContext(c)
// cancel
schedulerSvc := scheduler.GetTaskSchedulerService()
err = schedulerSvc.Cancel(id, u.Id, params.Force)
if err != nil {
return GetErrorVoidResponse(err)
}
return GetVoidResponse()
}
type GetTaskLogsParams struct {
Id string `path:"id" description:"Task ID" format:"objectid" pattern:"^[0-9a-fA-F]{24}$"`
Page int `query:"page" description:"Page (default: 1)" default:"1" minimum:"1"`
Size int `query:"size" description:"Size (default: 10000)" default:"10000" minimum:"1"`
Latest bool `query:"latest" description:"Whether to get latest logs (default: true)" default:"true"`
}
func GetTaskLogs(_ *gin.Context, params *GetTaskLogsParams) (response *ListResponse[string], err error) {
// id
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorListResponse[string](err)
}
// Get total count first for pagination
logDriver := log.GetFileLogDriver()
total, err := logDriver.Count(id.Hex(), "")
if err != nil {
return GetErrorListResponse[string](err)
}
// Skip calculation depends on whether we're getting latest logs or not
skip := (params.Page - 1) * params.Size
if params.Latest {
// For latest logs (tail mode), skip is from the end
// No adjustment needed as our implementation handles it correctly
} else {
// For oldest logs (normal mode), skip is from the beginning
// No adjustment needed as it's already the default behavior
}
// Get logs
logs, err := logDriver.Find(id.Hex(), "", skip, params.Size, params.Latest)
if err != nil {
if strings.HasSuffix(err.Error(), "Status:404 Not Found") {
return GetListResponse[string](nil, 0)
}
return GetErrorListResponse[string](err)
}
return GetListResponse(logs, total)
}
type GetTaskResultsParams struct {
Id string `path:"id" description:"Task ID" format:"objectid" pattern:"^[0-9a-fA-F]{24}$"`
Page int `query:"page" description:"Page" default:"1" minimum:"1"`
Size int `query:"size" description:"Size" default:"10" minimum:"1"`
}
func GetTaskResults(c *gin.Context, params *GetSpiderResultsParams) (response *ListResponse[bson.M], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorListResponse[bson.M](errors.BadRequestf("invalid id format"))
}
t, err := service.NewModelService[models.Task]().GetById(id)
if err != nil {
return GetErrorListResponse[bson.M](err)
}
s, err := service.NewModelService[models.Spider]().GetById(t.SpiderId)
if err != nil {
return GetErrorListResponse[bson.M](err)
}
query := ConvertToBsonMFromContext(c)
if query == nil {
query = bson.M{}
}
query["_tid"] = t.Id
col := mongo2.GetMongoCol(s.ColName)
var results []bson.M
err = col.Find(query, mongo2.GetMongoOpts(&mongo2.ListOptions{
Sort: []mongo2.ListSort{{"_id", mongo2.SortDirectionDesc}},
Skip: params.Size * (params.Page - 1),
Limit: params.Size,
})).All(&results)
if err != nil {
return GetErrorListResponse[bson.M](err)
}
total, err := mongo2.GetMongoCol(s.ColName).Count(query)
if err != nil {
return GetErrorListResponse[bson.M](err)
}
return GetListResponse(results, total)
}
func addTaskPipelines(pipelines []bson.D) []bson.D {
pipelines = append(pipelines, service.GetJoinPipeline[models.TaskStat]("_id", "_id", "_stat")...)
pipelines = append(pipelines, service.GetDefaultJoinPipeline[models.Node]()...)
pipelines = append(pipelines, service.GetDefaultJoinPipeline[models.Spider]()...)
pipelines = append(pipelines, service.GetDefaultJoinPipeline[models.Schedule]()...)
return pipelines
}