From 975c601f38b727d68143377da8472a2ac678a914 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Mon, 9 Jun 2025 10:21:00 +0800 Subject: [PATCH] refactor: simplify list retrieval methods and enhance error handling --- core/controllers/base.go | 35 ------ core/controllers/project.go | 23 ++-- core/controllers/router.go | 14 +++ core/controllers/schedule.go | 101 ++++++++++++++++ core/controllers/spider.go | 205 ++++++++++++++++++++------------- core/controllers/task_test.go | 1 - core/controllers/utils.go | 42 ++----- core/models/models/schedule.go | 3 + core/models/models/spider.go | 7 +- core/models/models/task.go | 11 +- 10 files changed, 271 insertions(+), 171 deletions(-) diff --git a/core/controllers/base.go b/core/controllers/base.go index fcfc931f..b96cd3da 100644 --- a/core/controllers/base.go +++ b/core/controllers/base.go @@ -67,15 +67,9 @@ type GetListParams struct { Sort string `query:"sort" default:"-_id" description:"Sort options"` Page int `query:"page" default:"1" description:"Page number" minimum:"1"` Size int `query:"size" default:"10" description:"Page size" minimum:"1"` - All bool `query:"all" default:"false" description:"Whether to get all items"` } func (ctr *BaseController[T]) GetList(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) { - // get all if query field "all" is set true - if params.All { - return ctr.GetAll(params) - } - return ctr.GetWithPagination(params) } @@ -283,35 +277,6 @@ func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParam return GetDataResponse(emptyModel) } -// GetAll retrieves all items based on filter and sort -func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListResponse[T], err error) { - // Get filter query - query := ConvertToBsonMFromListParams(params) - - // Get sort options - sort, err := GetSortOptionFromString(params.Sort) - if err != nil { - return GetErrorListResponse[T](errors.BadRequestf("invalid sort format: %v", err)) - } - - // Get models - models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{ - Sort: sort, - }) - if err != nil { - return nil, err - } - - // Total count - total, err := ctr.modelSvc.Count(query) - if err != nil { - return nil, err - } - - // Response - return GetListResponse(models, total) -} - // GetWithPagination retrieves items with pagination func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response *ListResponse[T], err error) { // Get filter query diff --git a/core/controllers/project.go b/core/controllers/project.go index 8f81b65f..59a7e263 100644 --- a/core/controllers/project.go +++ b/core/controllers/project.go @@ -11,11 +11,7 @@ import ( mongo2 "go.mongodb.org/mongo-driver/mongo" ) -func GetProjectList(c *gin.Context, params *GetListParams) (response *ListResponse[models.Project], err error) { - if params.All { - return NewController[models.Project]().GetAll(params) - } - +func GetProjectList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.Project], err error) { query := ConvertToBsonMFromListParams(params) sort, err := GetSortOptionFromString(params.Sort) @@ -31,20 +27,18 @@ func GetProjectList(c *gin.Context, params *GetListParams) (response *ListRespon }) if err != nil { if err.Error() != mongo2.ErrNoDocuments.Error() { - HandleErrorInternalServerError(c, err) + return GetErrorListResponse[models.Project](err) } return } if len(projects) == 0 { - HandleSuccessWithListData(c, []models.Project{}, 0) - return + return GetEmptyListResponse[models.Project]() } // total count total, err := service.NewModelService[models.Project]().Count(query) if err != nil { - HandleErrorInternalServerError(c, err) - return + return GetErrorListResponse[models.Project](err) } // project ids @@ -64,16 +58,13 @@ func GetProjectList(c *gin.Context, params *GetListParams) (response *ListRespon }, }, nil) if err != nil { - HandleErrorInternalServerError(c, err) - return + return GetErrorListResponse[models.Project](err) } for _, s := range spiders { _, ok := cache[s.ProjectId] - if !ok { - HandleErrorInternalServerError(c, errors.New("project id not found")) - return + if ok { + cache[s.ProjectId]++ } - cache[s.ProjectId]++ } // assign diff --git a/core/controllers/router.go b/core/controllers/router.go index 1bf765ec..09c25e4a 100644 --- a/core/controllers/router.go +++ b/core/controllers/router.go @@ -376,6 +376,20 @@ func InitRoutes(app *gin.Engine) (err error) { }, }...)) RegisterController(groups.AuthGroup.Group("", "Schedules", "APIs for schedules management"), "/schedules", NewController[models.Schedule]([]Action{ + { + Method: http.MethodGet, + Path: "/:id", + Name: "Get Schedule by ID", + Description: "Get a single schedule by ID", + HandlerFunc: GetScheduleById, + }, + { + Method: http.MethodGet, + Path: "", + Name: "Get Schedule List", + Description: "Get a list of schedules", + HandlerFunc: GetScheduleList, + }, { Method: http.MethodPost, Path: "", diff --git a/core/controllers/schedule.go b/core/controllers/schedule.go index bb29f79b..e7c1e71d 100644 --- a/core/controllers/schedule.go +++ b/core/controllers/schedule.go @@ -2,6 +2,9 @@ package controllers import ( errors2 "errors" + mongo2 "github.com/crawlab-team/crawlab/core/mongo" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/core/models/models" @@ -13,6 +16,104 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +// GetScheduleById handles getting a spider by ID +func GetScheduleById(_ *gin.Context, params *GetByIdParams) (response *Response[models.Schedule], err error) { + id, err := primitive.ObjectIDFromHex(params.Id) + if err != nil { + return GetErrorResponse[models.Schedule](errors.BadRequestf("invalid id format")) + } + s, err := service.NewModelService[models.Schedule]().GetById(id) + if errors.Is(err, mongo.ErrNoDocuments) { + return GetErrorResponse[models.Schedule](errors.NotFoundf("spider not found")) + } + if err != nil { + return GetErrorResponse[models.Schedule](err) + } + + // spider + if !s.SpiderId.IsZero() { + s.Spider, err = service.NewModelService[models.Spider]().GetById(s.SpiderId) + if err != nil { + if !errors.Is(err, mongo.ErrNoDocuments) { + return GetErrorResponse[models.Schedule](err) + } + } + } + + return GetDataResponse(*s) +} + +func GetScheduleList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.Schedule], err error) { + query := ConvertToBsonMFromListParams(params) + + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[models.Schedule](errors.BadRequestf("invalid request parameters: %v", err)) + } + + schedules, err := service.NewModelService[models.Schedule]().GetMany(query, &mongo2.FindOptions{ + Sort: sort, + Skip: params.Size * (params.Page - 1), + Limit: params.Size, + }) + if err != nil { + if !errors.Is(err, mongo.ErrNoDocuments) { + return GetErrorListResponse[models.Schedule](err) + } + return GetListResponse[models.Schedule]([]models.Schedule{}, 0) + } + if len(schedules) == 0 { + return GetListResponse[models.Schedule]([]models.Schedule{}, 0) + } + + // total count + total, err := service.NewModelService[models.Schedule]().Count(query) + if err != nil { + return GetErrorListResponse[models.Schedule](err) + } + + // ids + var ids []primitive.ObjectID + var spiderIds []primitive.ObjectID + for _, s := range schedules { + ids = append(ids, s.Id) + if !s.SpiderId.IsZero() { + spiderIds = append(spiderIds, s.SpiderId) + } + } + + // spider dict cache + var spiders []models.Spider + if len(spiderIds) > 0 { + spiders, err = service.NewModelService[models.Spider]().GetMany(bson.M{"_id": bson.M{"$in": spiderIds}}, nil) + if err != nil { + return GetErrorListResponse[models.Schedule](err) + } + } + dictSpider := map[primitive.ObjectID]models.Spider{} + for _, p := range spiders { + dictSpider[p.Id] = p + } + + // iterate list again + var data []models.Schedule + for _, s := range schedules { + // spider + if !s.SpiderId.IsZero() { + p, ok := dictSpider[s.SpiderId] + if ok { + s.Spider = &p + } + } + + // add to list + data = append(data, s) + } + + // response + return GetListResponse(data, total) +} + type PostScheduleParams struct { Data models.Schedule `json:"data" description:"The data to create" validate:"required"` } diff --git a/core/controllers/spider.go b/core/controllers/spider.go index a08006fa..12cb6502 100644 --- a/core/controllers/spider.go +++ b/core/controllers/spider.go @@ -45,6 +45,16 @@ func GetSpiderById(_ *gin.Context, params *GetByIdParams) (response *Response[mo } } + // project + if !s.ProjectId.IsZero() { + s.Project, err = service.NewModelService[models.Project]().GetById(s.ProjectId) + if err != nil { + if !errors.Is(err, mongo.ErrNoDocuments) { + return GetErrorResponse[models.Spider](err) + } + } + } + // data collection (compatible to old version) if s.ColName == "" && !s.ColId.IsZero() { col, err := service.NewModelService[models.DataCollection]().GetById(s.ColId) @@ -72,23 +82,17 @@ func GetSpiderById(_ *gin.Context, params *GetByIdParams) (response *Response[mo // GetSpiderList handles getting a list of spiders with optional stats func GetSpiderList(c *gin.Context, params *GetListParams) (response *ListResponse[models.Spider], err error) { - // get all list - all := params.All - if all { - return NewController[models.Spider]().GetAll(params) - } - // get list withStats := c.Query("stats") if withStats == "" { - return NewController[models.Spider]().GetList(c, params) + return getSpiderList(params) } // get list with stats return getSpiderListWithStats(params) } -func getSpiderListWithStats(params *GetListParams) (response *ListResponse[models.Spider], err error) { +func getSpiderList(params *GetListParams) (response *ListResponse[models.Spider], err error) { query := ConvertToBsonMFromListParams(params) sort, err := GetSortOptionFromString(params.Sort) @@ -111,83 +115,40 @@ func getSpiderListWithStats(params *GetListParams) (response *ListResponse[model return GetListResponse[models.Spider]([]models.Spider{}, 0) } - // ids - var ids []primitive.ObjectID - var gitIds []primitive.ObjectID - for _, s := range spiders { - ids = append(ids, s.Id) - if !s.GitId.IsZero() { - gitIds = append(gitIds, s.GitId) - } - } - // total count total, err := service.NewModelService[models.Spider]().Count(query) if err != nil { return GetErrorListResponse[models.Spider](err) } - // stat list - spiderStats, err := service.NewModelService[models.SpiderStat]().GetMany(bson.M{"_id": bson.M{"$in": ids}}, nil) - if err != nil { - return GetErrorListResponse[models.Spider](err) - } - - // cache stat list to dict - dict := map[primitive.ObjectID]models.SpiderStat{} - var taskIds []primitive.ObjectID - for _, st := range spiderStats { - if st.Tasks > 0 { - taskCount := int64(st.Tasks) - st.AverageWaitDuration = int64(math.Round(float64(st.WaitDuration) / float64(taskCount))) - st.AverageRuntimeDuration = int64(math.Round(float64(st.RuntimeDuration) / float64(taskCount))) - st.AverageTotalDuration = int64(math.Round(float64(st.TotalDuration) / float64(taskCount))) + // ids + var ids []primitive.ObjectID + var gitIds []primitive.ObjectID + var projectIds []primitive.ObjectID + for _, s := range spiders { + ids = append(ids, s.Id) + if !s.GitId.IsZero() { + gitIds = append(gitIds, s.GitId) } - dict[st.Id] = st - - if !st.LastTaskId.IsZero() { - taskIds = append(taskIds, st.LastTaskId) + if !s.ProjectId.IsZero() { + projectIds = append(projectIds, s.ProjectId) } } - // task list and stats - var tasks []models.Task - dictTask := map[primitive.ObjectID]models.Task{} - dictTaskStat := map[primitive.ObjectID]models.TaskStat{} - if len(taskIds) > 0 { - // task list - queryTask := bson.M{ - "_id": bson.M{ - "$in": taskIds, - }, - } - tasks, err = service.NewModelService[models.Task]().GetMany(queryTask, nil) + // project dict cache + var projects []models.Project + if len(projectIds) > 0 { + projects, err = service.NewModelService[models.Project]().GetMany(bson.M{"_id": bson.M{"$in": projectIds}}, nil) if err != nil { return GetErrorListResponse[models.Spider](err) } - - // task stats list - taskStats, err := service.NewModelService[models.TaskStat]().GetMany(queryTask, nil) - if err != nil { - return GetErrorListResponse[models.Spider](err) - } - - // cache task stats to dict - for _, st := range taskStats { - dictTaskStat[st.Id] = st - } - - // cache task list to dict - for _, t := range tasks { - st, ok := dictTaskStat[t.Id] - if ok { - t.Stat = &st - } - dictTask[t.SpiderId] = t - } + } + dictProject := map[primitive.ObjectID]models.Project{} + for _, p := range projects { + dictProject[p.Id] = p } - // git list + // git dict cache var gits []models.Git if len(gitIds) > 0 && utils.IsPro() { gits, err = service.NewModelService[models.Git]().GetMany(bson.M{"_id": bson.M{"$in": gitIds}}, nil) @@ -195,8 +156,6 @@ func getSpiderListWithStats(params *GetListParams) (response *ListResponse[model return GetErrorListResponse[models.Spider](err) } } - - // cache git list to dict dictGit := map[primitive.ObjectID]models.Git{} for _, g := range gits { dictGit[g.Id] = g @@ -205,15 +164,11 @@ func getSpiderListWithStats(params *GetListParams) (response *ListResponse[model // iterate list again var data []models.Spider for _, s := range spiders { - // spider stat - st, ok := dict[s.Id] - if ok { - s.Stat = &st - - // last task - t, ok := dictTask[s.Id] + // project + if !s.ProjectId.IsZero() { + p, ok := dictProject[s.ProjectId] if ok { - s.Stat.LastTask = &t + s.Project = &p } } @@ -233,6 +188,96 @@ func getSpiderListWithStats(params *GetListParams) (response *ListResponse[model return GetListResponse(data, total) } +func getSpiderListWithStats(params *GetListParams) (response *ListResponse[models.Spider], err error) { + response, err = getSpiderList(params) + if err != nil { + return GetErrorListResponse[models.Spider](err) + } + + // spider ids + var ids []primitive.ObjectID + for _, s := range response.Data { + ids = append(ids, s.Id) + } + + // spider stat dict + spiderStats, err := service.NewModelService[models.SpiderStat]().GetMany(bson.M{"_id": bson.M{"$in": ids}}, nil) + if err != nil { + return GetErrorListResponse[models.Spider](err) + } + dictSpiderStat := map[primitive.ObjectID]models.SpiderStat{} + + // task dict and task stat dict + var lastTasks []models.Task + var lastTaskIds []primitive.ObjectID + for _, st := range spiderStats { + if st.Tasks > 0 { + taskCount := int64(st.Tasks) + st.AverageWaitDuration = int64(math.Round(float64(st.WaitDuration) / float64(taskCount))) + st.AverageRuntimeDuration = int64(math.Round(float64(st.RuntimeDuration) / float64(taskCount))) + st.AverageTotalDuration = int64(math.Round(float64(st.TotalDuration) / float64(taskCount))) + } + dictSpiderStat[st.Id] = st + + if !st.LastTaskId.IsZero() { + lastTaskIds = append(lastTaskIds, st.LastTaskId) + } + } + dictLastTask := map[primitive.ObjectID]models.Task{} + dictLastTaskStat := map[primitive.ObjectID]models.TaskStat{} + if len(lastTaskIds) > 0 { + // task list + queryTask := bson.M{ + "_id": bson.M{ + "$in": lastTaskIds, + }, + } + lastTasks, err = service.NewModelService[models.Task]().GetMany(queryTask, nil) + if err != nil { + return GetErrorListResponse[models.Spider](err) + } + + // task stats list + taskStats, err := service.NewModelService[models.TaskStat]().GetMany(queryTask, nil) + if err != nil { + return GetErrorListResponse[models.Spider](err) + } + + for _, st := range taskStats { + dictLastTaskStat[st.Id] = st + } + + for _, t := range lastTasks { + st, ok := dictLastTaskStat[t.Id] + if ok { + t.Stat = &st + } + dictLastTask[t.SpiderId] = t + } + } + + // iterate list again + for i, s := range response.Data { + // spider stat + st, ok := dictSpiderStat[s.Id] + if ok { + s.Stat = &st + } + + // last task and stat + if !s.Stat.LastTaskId.IsZero() { + t, ok := dictLastTask[s.Stat.LastTaskId] + if ok { + s.Stat.LastTask = &t + } + } + + response.Data[i] = s + } + + return response, nil +} + // PostSpider handles creating a new spider func PostSpider(c *gin.Context, params *PostParams[models.Spider]) (response *Response[models.Spider], err error) { s := params.Data diff --git a/core/controllers/task_test.go b/core/controllers/task_test.go index 363f4d14..78b9133c 100644 --- a/core/controllers/task_test.go +++ b/core/controllers/task_test.go @@ -41,7 +41,6 @@ func createTestTask(t *testing.T) (task *models.Task, spiderId primitive.ObjectI Mode: constants.RunTypeAllNodes, Param: "test param", Cmd: "python main.py", - UserId: TestUserId, } // Set timestamps diff --git a/core/controllers/utils.go b/core/controllers/utils.go index 5d87d5c7..ddfa0d39 100644 --- a/core/controllers/utils.go +++ b/core/controllers/utils.go @@ -376,6 +376,16 @@ func GetListResponse[T any](models []T, total int) (res *ListResponse[T], err er }, nil } +func GetEmptyListResponse[T any]() (res *ListResponse[T], err error) { + return &ListResponse[T]{ + Status: constants.HttpResponseStatusOk, + Message: constants.HttpResponseMessageSuccess, + Data: []T{}, + Total: 0, + }, nil + +} + func GetVoidResponse() (res *VoidResponse, err error) { return &VoidResponse{ Status: constants.HttpResponseStatusOk, @@ -422,33 +432,10 @@ func HandleError(statusCode int, c *gin.Context, err error) { handleError(statusCode, c, err) } -func HandleErrorBadRequest(c *gin.Context, err error) { - HandleError(http.StatusBadRequest, c, err) -} - -func HandleErrorForbidden(c *gin.Context, err error) { - HandleError(http.StatusForbidden, c, err) -} - -func HandleErrorUnauthorized(c *gin.Context, err error) { - HandleError(http.StatusUnauthorized, c, err) -} - -func HandleErrorNotFound(c *gin.Context, err error) { - HandleError(http.StatusNotFound, c, err) -} - func HandleErrorInternalServerError(c *gin.Context, err error) { HandleError(http.StatusInternalServerError, c, err) } -func HandleSuccess(c *gin.Context) { - c.AbortWithStatusJSON(http.StatusOK, entity.Response{ - Status: constants.HttpResponseStatusOk, - Message: constants.HttpResponseMessageSuccess, - }) -} - func HandleSuccessWithData(c *gin.Context, data interface{}) { c.AbortWithStatusJSON(http.StatusOK, entity.Response{ Status: constants.HttpResponseStatusOk, @@ -456,12 +443,3 @@ func HandleSuccessWithData(c *gin.Context, data interface{}) { Data: data, }) } - -func HandleSuccessWithListData(c *gin.Context, data interface{}, total int) { - c.AbortWithStatusJSON(http.StatusOK, entity.ListResponse{ - Status: constants.HttpResponseStatusOk, - Message: constants.HttpResponseMessageSuccess, - Data: data, - Total: total, - }) -} diff --git a/core/models/models/schedule.go b/core/models/models/schedule.go index ad650220..db0fdbc9 100644 --- a/core/models/models/schedule.go +++ b/core/models/models/schedule.go @@ -19,4 +19,7 @@ type Schedule struct { NodeIds []primitive.ObjectID `json:"node_ids" bson:"node_ids" description:"Node IDs"` Priority int `json:"priority" bson:"priority" description:"Priority"` Enabled bool `json:"enabled" bson:"enabled" description:"Enabled"` + + // associated data + Spider *Spider `json:"spider" bson:"-" description:"Spider"` } diff --git a/core/models/models/spider.go b/core/models/models/spider.go index 50186592..3ebbf7de 100644 --- a/core/models/models/spider.go +++ b/core/models/models/spider.go @@ -12,14 +12,12 @@ type Spider struct { ColName string `json:"col_name,omitempty" bson:"col_name" description:"Data collection name"` DbName string `json:"db_name,omitempty" bson:"db_name" description:"Database name"` DataSourceId primitive.ObjectID `json:"data_source_id" bson:"data_source_id" description:"Data source id"` - DataSource *Database `json:"data_source,omitempty" bson:"-"` Description string `json:"description" bson:"description" description:"Description"` ProjectId primitive.ObjectID `json:"project_id" bson:"project_id" description:"Project ID"` Mode string `json:"mode" bson:"mode" description:"Default task mode" enum:"random,all,selected-nodes"` NodeIds []primitive.ObjectID `json:"node_ids" bson:"node_ids" description:"Default node ids, used in selected-nodes mode"` GitId primitive.ObjectID `json:"git_id" bson:"git_id" description:"Related Git ID"` GitRootPath string `json:"git_root_path" bson:"git_root_path" description:"Git root path"` - Git *Git `json:"git,omitempty" bson:"-"` Template string `json:"template,omitempty" bson:"template,omitempty" description:"Spider template"` TemplateParams *SpiderTemplateParams `json:"template_params,omitempty" bson:"template_params,omitempty" description:"Spider template params"` @@ -30,6 +28,11 @@ type Spider struct { Cmd string `json:"cmd" bson:"cmd" description:"Execute command"` Param string `json:"param" bson:"param" description:"Default task param"` Priority int `json:"priority" bson:"priority" description:"Priority" default:"5" minimum:"1" maximum:"10"` + + // associated data + Project *Project `json:"project,omitempty" bson:"-"` + Git *Git `json:"git,omitempty" bson:"-"` + DataSource *Database `json:"data_source,omitempty" bson:"-"` } type SpiderTemplateParams struct { diff --git a/core/models/models/task.go b/core/models/models/task.go index 1a9704e0..17e2e422 100644 --- a/core/models/models/task.go +++ b/core/models/models/task.go @@ -18,9 +18,10 @@ type Task struct { Mode string `json:"mode" bson:"mode" description:"Mode"` Priority int `json:"priority" bson:"priority" description:"Priority"` NodeIds []primitive.ObjectID `json:"node_ids,omitempty" bson:"-"` - Stat *TaskStat `json:"stat,omitempty" bson:"-"` - Spider *Spider `json:"spider,omitempty" bson:"-"` - Schedule *Schedule `json:"schedule,omitempty" bson:"-"` - Node *Node `json:"node,omitempty" bson:"-"` - UserId primitive.ObjectID `json:"-" bson:"-"` + + // associated data + Stat *TaskStat `json:"stat,omitempty" bson:"-"` + Spider *Spider `json:"spider,omitempty" bson:"-"` + Schedule *Schedule `json:"schedule,omitempty" bson:"-"` + Node *Node `json:"node,omitempty" bson:"-"` }