From 98843fdbaccfdcad2e771f887dd26cba7e763bd5 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Mon, 17 Mar 2025 17:14:08 +0800 Subject: [PATCH] refactor: standardize sorting logic across controllers - Introduced a consistent approach for sorting in multiple controller methods by utilizing GetSortOptionFromString for improved clarity and maintainability. - Updated the GetListParams structure to ensure sorting is handled uniformly across different entities. - Enhanced error handling for sorting parameters to provide clearer feedback on invalid formats. - Improved comments for better understanding of sorting logic and its integration within the controller methods. --- core/controllers/base.go | 58 ++++++++++++++++++++++++++++++++----- core/controllers/project.go | 7 ++++- core/controllers/spider.go | 24 ++++++++------- core/controllers/task.go | 7 ++++- core/controllers/token.go | 8 ++++- core/controllers/user.go | 8 ++++- core/controllers/utils.go | 21 ++++++++++++++ 7 files changed, 111 insertions(+), 22 deletions(-) diff --git a/core/controllers/base.go b/core/controllers/base.go index 52cc3bcf..ade6602c 100644 --- a/core/controllers/base.go +++ b/core/controllers/base.go @@ -1,6 +1,8 @@ package controllers import ( + "github.com/loopfz/gadgeto/tonic" + "net/http" "time" "github.com/crawlab-team/crawlab/core/interfaces" @@ -13,6 +15,33 @@ import ( mongo2 "go.mongodb.org/mongo-driver/mongo" ) +func init() { + tonic.SetErrorHook(func(context *gin.Context, err error) (int, interface{}) { + response := gin.H{ + "error": errors.Unwrap(err).Error(), + } + status := http.StatusInternalServerError + constErr, ok := errors.AsType[errors.ConstError](err) + if ok { + switch { + case errors.Is(constErr, errors.NotFound): + status = http.StatusNotFound + case errors.Is(constErr, errors.BadRequest): + status = http.StatusBadRequest + case errors.Is(constErr, errors.Unauthorized): + status = http.StatusUnauthorized + case errors.Is(constErr, errors.Forbidden): + status = http.StatusForbidden + default: + status = http.StatusInternalServerError + } + } else { + status = http.StatusInternalServerError + } + return status, response + }) +} + type Action struct { Method string Path string @@ -27,7 +56,7 @@ type BaseController[T any] struct { // GetListParams represents parameters for GetList with pagination type GetListParams struct { Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"` - Sort bson.D `query:"sort" description:"Sort options"` + Sort string `query:"sort" description:"Sort options"` Page int `query:"page" default:"1" description:"Page number"` Size int `query:"size" default:"10" description:"Page size"` All bool `query:"all" default:"false" description:"Whether to get all items"` @@ -201,15 +230,19 @@ func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParam // GetAll retrieves all items based on filter and sort func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListResponse[T], err error) { + // Get filter query query, err := GetFilterQueryFromListParams(params) if err != nil { return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err)) } - sort := params.Sort - if sort == nil { - sort = bson.D{{"_id", -1}} + + // Get sort options + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[T](errors.BadRequestf("invalid sort format: %v", err)) } + // Get models models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{ Sort: sort, }) @@ -217,22 +250,33 @@ func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListRespo return nil, err } + // Total count total, err := ctr.modelSvc.Count(query) if err != nil { return nil, err } + // Response return GetListResponse(models, total) } // GetWithPagination retrieves items with pagination func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response *ListResponse[T], err error) { + // Get filter query query, err := GetFilterQueryFromListParams(params) if err != nil { return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err)) } + + // Get sort options + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[T](errors.BadRequestf("invalid sort format: %v", err)) + } + + // Get models models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{ - Sort: params.Sort, + Sort: sort, Skip: params.Size * (params.Page - 1), Limit: params.Size, }) @@ -244,13 +288,13 @@ func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response } } - // total count + // Total count total, err := ctr.modelSvc.Count(query) if err != nil { return nil, err } - // response + // Response return GetListResponse(models, total) } diff --git a/core/controllers/project.go b/core/controllers/project.go index 8ec44378..ef5ebfc4 100644 --- a/core/controllers/project.go +++ b/core/controllers/project.go @@ -21,9 +21,14 @@ func GetProjectList(c *gin.Context, params *GetListParams) (response *ListRespon return GetErrorListResponse[models.Project](errors.BadRequestf("invalid request parameters: %v", err)) } + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[models.Project](errors.BadRequestf("invalid request parameters: %v", err)) + } + // get list projects, err := service.NewModelService[models.Project]().GetMany(query, &mongo.FindOptions{ - Sort: params.Sort, + Sort: sort, Skip: params.Size * (params.Page - 1), Limit: params.Size, }) diff --git a/core/controllers/spider.go b/core/controllers/spider.go index 722efeb3..a47e5ca0 100644 --- a/core/controllers/spider.go +++ b/core/controllers/spider.go @@ -93,9 +93,14 @@ func getSpiderListWithStats(params *GetListParams) (response *ListResponse[model if err != nil { return GetErrorListResponse[models.Spider](errors.BadRequestf("invalid request parameters: %v", err)) } - // get list + + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[models.Spider](errors.BadRequestf("invalid request parameters: %v", err)) + } + spiders, err := service.NewModelService[models.Spider]().GetMany(query, &mongo2.FindOptions{ - Sort: params.Sort, + Sort: sort, Skip: params.Size * (params.Page - 1), Limit: params.Size, }) @@ -662,7 +667,7 @@ type PostSpiderExportParams struct { Id string `path:"id"` } -func PostSpiderExport(c *gin.Context, params *PostSpiderExportParams) (err error) { +func PostSpiderExport(c *gin.Context, _ *PostSpiderExportParams) (err error) { rootPath, err := getSpiderRootPathByContext(c) if err != nil { return err @@ -712,9 +717,6 @@ func PostSpiderRun(c *gin.Context, params *PostSpiderRunParams) (response *Respo ScheduleId: scheduleId, Priority: params.Priority, } - if err := c.ShouldBindJSON(&opts); err != nil { - return GetErrorResponse[[]primitive.ObjectID](err) - } // user if u := GetUserFromContext(c); u != nil { @@ -731,7 +733,9 @@ func PostSpiderRun(c *gin.Context, params *PostSpiderRunParams) (response *Respo } type GetSpiderResultsParams struct { - Id string `path:"id"` + Id string `path:"id"` + Page int `query:"page"` + Size int `query:"size"` } func GetSpiderResults(c *gin.Context, params *GetSpiderResultsParams) (response *ListResponse[bson.M], err error) { @@ -745,8 +749,6 @@ func GetSpiderResults(c *gin.Context, params *GetSpiderResultsParams) (response return GetErrorListResponse[bson.M](err) } - // params - pagination := MustGetPagination(c) query := getResultListQuery(c) col := mongo2.GetMongoCol(s.ColName) @@ -754,8 +756,8 @@ func GetSpiderResults(c *gin.Context, params *GetSpiderResultsParams) (response var results []bson.M err = col.Find(mongo2.GetMongoQuery(query), mongo2.GetMongoOpts(&mongo2.ListOptions{ Sort: []mongo2.ListSort{{"_id", mongo2.SortDirectionDesc}}, - Skip: pagination.Size * (pagination.Page - 1), - Limit: pagination.Size, + Skip: params.Size * (params.Page - 1), + Limit: params.Size, })).All(&results) if err != nil { return GetErrorListResponse[bson.M](err) diff --git a/core/controllers/task.go b/core/controllers/task.go index 25fb80e8..d772b457 100644 --- a/core/controllers/task.go +++ b/core/controllers/task.go @@ -84,9 +84,14 @@ func GetTaskList(c *gin.Context, params *GetTaskListParams) (response *ListRespo return GetErrorListResponse[models.Task](err) } + sort, err := GetSortOptionFromString(params.GetListParams.Sort) + if err != nil { + return GetErrorListResponse[models.Task](err) + } + // get tasks tasks, err := service.NewModelService[models.Task]().GetMany(query, &mongo3.FindOptions{ - Sort: params.Sort, + Sort: sort, Skip: params.Size * (params.Page - 1), Limit: params.Size, }) diff --git a/core/controllers/token.go b/core/controllers/token.go index 1fab41c2..4c052e02 100644 --- a/core/controllers/token.go +++ b/core/controllers/token.go @@ -51,9 +51,15 @@ func GetTokenList(c *gin.Context, params *GetListParams) (response *ListResponse // Add filter for tokens created by the current user query["created_by"] = u.Id + // Get sort options + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[models.Token](errors.BadRequestf("invalid request parameters: %v", err)) + } + // Get tokens with pagination tokens, err := service.NewModelService[models.Token]().GetMany(query, &mongo.FindOptions{ - Sort: params.Sort, + Sort: sort, Skip: params.Size * (params.Page - 1), Limit: params.Size, }) diff --git a/core/controllers/user.go b/core/controllers/user.go index 435b534f..9dd55664 100644 --- a/core/controllers/user.go +++ b/core/controllers/user.go @@ -28,8 +28,14 @@ func GetUserList(_ *gin.Context, params *GetListParams) (response *ListResponse[ if err != nil { return GetErrorListResponse[models.User](err) } + + sort, err := GetSortOptionFromString(params.Sort) + if err != nil { + return GetErrorListResponse[models.User](err) + } + users, err := service.NewModelService[models.User]().GetMany(query, &mongo.FindOptions{ - Sort: params.Sort, + Sort: sort, Skip: params.Size * (params.Page - 1), Limit: params.Size, }) diff --git a/core/controllers/utils.go b/core/controllers/utils.go index 2cae437d..0313bc44 100644 --- a/core/controllers/utils.go +++ b/core/controllers/utils.go @@ -205,6 +205,27 @@ func MustGetPagination(c *gin.Context) (p *entity.Pagination) { return p } +func GetSortsFromString(sortStr string) (sorts []entity.Sort, err error) { + if sortStr == "" { + return nil, nil + } + if err := json.Unmarshal([]byte(sortStr), &sorts); err != nil { + return nil, err + } + return sorts, nil +} + +func GetSortOptionFromString(sortStr string) (sort bson.D, err error) { + sorts, err := GetSortsFromString(sortStr) + if err != nil { + return nil, err + } + if sorts == nil || len(sorts) == 0 { + return bson.D{{"_id", -1}}, nil + } + return SortsToOption(sorts) +} + // GetSorts Get entity.Sort from gin.Context func GetSorts(c *gin.Context) (sorts []entity.Sort, err error) { // bind