diff --git a/core/constants/filter.go b/core/constants/filter.go index ad22256a..6c3ca189 100644 --- a/core/constants/filter.go +++ b/core/constants/filter.go @@ -1,9 +1,7 @@ package constants const ( - FilterQueryFieldConditions = "conditions" - FilterQueryFieldAll = "all" - FilterQueryFieldFilter = "filter" + FilterQueryFieldFilter = "filter" ) const ( diff --git a/core/controllers/base.go b/core/controllers/base.go index faab778c..d14587f0 100644 --- a/core/controllers/base.go +++ b/core/controllers/base.go @@ -25,9 +25,8 @@ func init() { response := gin.H{ "error": err.Error(), } - status := http.StatusInternalServerError - constErr, ok := errors.AsType[errors.ConstError](err) - if ok { + var status int + if constErr, ok := errors.AsType[errors.ConstError](err); ok { switch { case errors.Is(constErr, errors.NotFound): status = http.StatusNotFound @@ -40,6 +39,8 @@ func init() { default: status = http.StatusInternalServerError } + } else if _, ok := errors.AsType[tonic.BindError](err); ok { + status = http.StatusBadRequest } else { status = http.StatusInternalServerError } @@ -62,11 +63,11 @@ type BaseController[T any] struct { // GetListParams represents parameters for GetList with pagination type GetListParams struct { - Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"` - Sort string `query:"sort" description:"Sort options"` - Page int `query:"page" default:"1" description:"Page number" minimum:"1"` - Size int `query:"size" default:"10" description:"Page size" minimum:"1"` - All bool `query:"all" default:"false" description:"Whether to get all items"` + Filter string `query:"filter" description:"Filter query"` + Sort string `query:"sort" description:"Sort options"` + Page int `query:"page" default:"1" description:"Page number" minimum:"1"` + Size int `query:"size" default:"10" description:"Page size" minimum:"1"` + All bool `query:"all" default:"false" description:"Whether to get all items"` } func (ctr *BaseController[T]) GetList(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) { @@ -221,7 +222,7 @@ func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParam for _, id := range params.Ids { objectId, err := primitive.ObjectIDFromHex(id) if err != nil { - return GetErrorResponse[T](err) + return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err)) } ids = append(ids, objectId) } @@ -241,10 +242,7 @@ func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParam // GetAll retrieves all items based on filter and sort func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListResponse[T], err error) { // Get filter query - query, err := GetFilterQueryFromListParams(params) - if err != nil { - return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err)) - } + query := ConvertToBsonMFromListParams(params) // Get sort options sort, err := GetSortOptionFromString(params.Sort) @@ -273,10 +271,7 @@ func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListRespo // GetWithPagination retrieves items with pagination func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response *ListResponse[T], err error) { // Get filter query - query, err := GetFilterQueryFromListParams(params) - if err != nil { - return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err)) - } + query := ConvertToBsonMFromListParams(params) // Get sort options sort, err := GetSortOptionFromString(params.Sort) diff --git a/core/controllers/base_test.go b/core/controllers/base_test.go index 0217b636..671ed9db 100644 --- a/core/controllers/base_test.go +++ b/core/controllers/base_test.go @@ -269,7 +269,7 @@ func TestBaseController_GetList(t *testing.T) { condBytes, err := json.Marshal(cond) require.Nil(t, err) params := url.Values{} - params.Add("conditions", string(condBytes)) + params.Add("filter", string(condBytes)) params.Add("page", "1") params.Add("size", "10") requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()} diff --git a/core/controllers/export.go b/core/controllers/export.go index fe91e531..4d25923b 100644 --- a/core/controllers/export.go +++ b/core/controllers/export.go @@ -11,16 +11,13 @@ import ( ) type PostExportParams struct { - Type string `path:"type" validate:"required"` - Target string `query:"target" validate:"required"` - Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"` + Type string `path:"type" validate:"required"` + Target string `query:"target" validate:"required"` + Filter string `query:"filter" description:"Filter query"` } func PostExport(_ *gin.Context, params *PostExportParams) (response *Response[string], err error) { - query, err := GetFilterQueryFromConditionString(params.Conditions) - if err != nil { - return GetErrorResponse[string](err) - } + query := ConvertToBsonMFromFilter(params.Filter) var exportId string switch params.Type { case constants.ExportTypeCsv: diff --git a/core/controllers/filter.go b/core/controllers/filter.go index 0afcbe5f..581af7d3 100644 --- a/core/controllers/filter.go +++ b/core/controllers/filter.go @@ -10,36 +10,36 @@ import ( ) type GetFilterColFieldOptionsParams struct { - Col string `path:"col" validate:"required"` - Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"` + Col string `path:"col" validate:"required"` + Filter string `query:"filter" description:"Filter query"` } func GetFilterColFieldOptions(c *gin.Context, params *GetFilterColFieldOptionsParams) (response *Response[[]entity.FilterSelectOption], err error) { return GetFilterColFieldOptionsWithValueLabel(c, &GetFilterColFieldOptionsWithValueLabelParams{ - Col: params.Col, - Conditions: params.Conditions, + Col: params.Col, + Filter: params.Filter, }) } type GetFilterColFieldOptionsWithValueParams struct { - Col string `path:"col" validate:"required"` - Value string `path:"value"` - Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"` + Col string `path:"col" validate:"required"` + Value string `path:"value"` + Filter string `query:"filter" description:"Filter query"` } func GetFilterColFieldOptionsWithValue(c *gin.Context, params *GetFilterColFieldOptionsWithValueParams) (response *Response[[]entity.FilterSelectOption], err error) { return GetFilterColFieldOptionsWithValueLabel(c, &GetFilterColFieldOptionsWithValueLabelParams{ - Col: params.Col, - Value: params.Value, - Conditions: params.Conditions, + Col: params.Col, + Value: params.Value, + Filter: params.Filter, }) } type GetFilterColFieldOptionsWithValueLabelParams struct { - Col string `path:"col" validate:"required"` - Value string `path:"value"` - Label string `path:"label"` - Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"` + Col string `path:"col" validate:"required"` + Value string `path:"value"` + Label string `path:"label"` + Filter string `query:"filter" description:"Filter query"` } func GetFilterColFieldOptionsWithValueLabel(_ *gin.Context, params *GetFilterColFieldOptionsWithValueLabelParams) (response *Response[[]entity.FilterSelectOption], err error) { @@ -53,11 +53,8 @@ func GetFilterColFieldOptionsWithValueLabel(_ *gin.Context, params *GetFilterCol } pipelines := mongo2.Pipeline{} - if params.Conditions != "" { - query, err := GetFilterFromConditionString(params.Conditions) - if err != nil { - return GetErrorResponse[[]entity.FilterSelectOption](errors.Trace(err)) - } + if params.Filter != "" { + query := ConvertToFilter(params.Filter) pipelines = append(pipelines, bson.D{{"$match", query}}) } pipelines = append( diff --git a/core/controllers/project.go b/core/controllers/project.go index ef5ebfc4..8f81b65f 100644 --- a/core/controllers/project.go +++ b/core/controllers/project.go @@ -16,10 +16,7 @@ func GetProjectList(c *gin.Context, params *GetListParams) (response *ListRespon return NewController[models.Project]().GetAll(params) } - query, err := GetFilterQueryFromListParams(params) - if err != nil { - return GetErrorListResponse[models.Project](errors.BadRequestf("invalid request parameters: %v", err)) - } + query := ConvertToBsonMFromListParams(params) sort, err := GetSortOptionFromString(params.Sort) if err != nil { diff --git a/core/controllers/router.go b/core/controllers/router.go index c0f9aa0b..5989c3c5 100644 --- a/core/controllers/router.go +++ b/core/controllers/router.go @@ -243,8 +243,6 @@ func InitRoutes(app *gin.Engine) (err error) { // Register resource controllers with their respective endpoints // Each RegisterController call sets up standard CRUD operations // Additional custom actions can be specified in the controller initialization - RegisterController(groups.AuthGroup.Group("", "Data Collections", "APIs for data collections management"), "/data/collections", NewController[models.DataCollection]()) - RegisterController(groups.AuthGroup.Group("", "Environments", "APIs for environment variables management"), "/environments", NewController[models.Environment]()) RegisterController(groups.AuthGroup.Group("", "Nodes", "APIs for nodes management"), "/nodes", NewController[models.Node]()) RegisterController(groups.AuthGroup.Group("", "Projects", "APIs for projects management"), "/projects", NewController[models.Project]([]Action{ { @@ -478,22 +476,6 @@ func InitRoutes(app *gin.Engine) (err error) { HandlerFunc: GetTaskLogs, }, }...)) - RegisterController(groups.AuthGroup.Group("", "Tokens", "APIs for PAT management"), "/tokens", NewController[models.Token]([]Action{ - { - Method: http.MethodPost, - Path: "", - Name: "Create Token", - Description: "Create a new token", - HandlerFunc: PostToken, - }, - { - Method: http.MethodGet, - Path: "", - Name: "Get Token List", - Description: "Get a list of tokens", - HandlerFunc: GetTokenList, - }, - }...)) RegisterController(groups.AuthGroup.Group("", "Users", "APIs for users management"), "/users", NewController[models.User]([]Action{ { Method: http.MethodGet, @@ -566,6 +548,24 @@ func InitRoutes(app *gin.Engine) (err error) { HandlerFunc: PostUserMeChangePassword, }, }...)) + RegisterController(groups.AuthGroup.Group("", "Tokens", "APIs for PAT management"), "/tokens", NewController[models.Token]([]Action{ + { + Method: http.MethodPost, + Path: "", + Name: "Create Token", + Description: "Create a new token", + HandlerFunc: PostToken, + }, + { + Method: http.MethodGet, + Path: "", + Name: "Get Token List", + Description: "Get a list of tokens", + HandlerFunc: GetTokenList, + }, + }...)) + RegisterController(groups.AuthGroup.Group("", "Environments", "APIs for environment variables management"), "/environments", NewController[models.Environment]()) + RegisterController(groups.AuthGroup.Group("", "Data Collections", "APIs for data collections management"), "/data/collections", NewController[models.DataCollection]()) // Register standalone action routes that don't fit the standard CRUD pattern RegisterActions(groups.AuthGroup.Group("", "Export", "APIs for exporting data"), "/export", []Action{ diff --git a/core/controllers/router_test.go b/core/controllers/router_test.go index d15cc2c1..93db1e71 100644 --- a/core/controllers/router_test.go +++ b/core/controllers/router_test.go @@ -35,7 +35,7 @@ func TestRegisterController_Routes(t *testing.T) { controllers.RegisterController(groups.AuthGroup, basePath, ctr) // Check if all routes are registered - routes := router.Routes() + routes := controllers.GetGlobalFizzWrapper().GetFizz().Engine().Routes() var methodPaths []string for _, route := range routes { @@ -64,7 +64,7 @@ func TestInitRoutes_ProjectsRoute(t *testing.T) { _ = controllers.InitRoutes(router) // Check if the projects route is registered - routes := router.Routes() + routes := controllers.GetGlobalFizzWrapper().GetFizz().Engine().Routes() var methodPaths []string for _, route := range routes { diff --git a/core/controllers/spider.go b/core/controllers/spider.go index b41e7eb5..09ed677e 100644 --- a/core/controllers/spider.go +++ b/core/controllers/spider.go @@ -89,10 +89,7 @@ func GetSpiderList(c *gin.Context, params *GetListParams) (response *ListRespons } func getSpiderListWithStats(params *GetListParams) (response *ListResponse[models.Spider], err error) { - query, err := GetFilterQueryFromListParams(params) - if err != nil { - return GetErrorListResponse[models.Spider](errors.BadRequestf("invalid request parameters: %v", err)) - } + query := ConvertToBsonMFromListParams(params) sort, err := GetSortOptionFromString(params.Sort) if err != nil { @@ -749,7 +746,7 @@ func GetSpiderResults(c *gin.Context, params *GetSpiderResultsParams) (response return GetErrorListResponse[bson.M](err) } - query := getResultListQuery(c) + query := GetResultListQuery(c) col := mongo2.GetMongoCol(s.ColName) diff --git a/core/controllers/task.go b/core/controllers/task.go index ea9cfe80..e3618b92 100644 --- a/core/controllers/task.go +++ b/core/controllers/task.go @@ -79,10 +79,7 @@ func GetTaskList(c *gin.Context, params *GetTaskListParams) (response *ListRespo } // get query - query, err := GetFilterQueryFromListParams(params.GetListParams) - if err != nil { - return GetErrorListResponse[models.Task](err) - } + query := ConvertToBsonMFromListParams(params.GetListParams) sort, err := GetSortOptionFromString(params.GetListParams.Sort) if err != nil { diff --git a/core/controllers/token.go b/core/controllers/token.go index 44a05913..e231ca87 100644 --- a/core/controllers/token.go +++ b/core/controllers/token.go @@ -43,10 +43,7 @@ func GetTokenList(c *gin.Context, params *GetListParams) (response *ListResponse u := GetUserFromContext(c) // Get filter query - query, err := GetFilterQueryFromListParams(params) - if err != nil { - return GetErrorListResponse[models.Token](errors.BadRequestf("invalid request parameters: %v", err)) - } + query := ConvertToBsonMFromListParams(params) // Add filter for tokens created by the current user query["created_by"] = u.Id diff --git a/core/controllers/user.go b/core/controllers/user.go index 4c40ad2d..b7bcff38 100644 --- a/core/controllers/user.go +++ b/core/controllers/user.go @@ -24,10 +24,7 @@ func GetUserById(_ *gin.Context, params *GetByIdParams) (response *Response[mode } func GetUserList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.User], err error) { - query, err := GetFilterQueryFromListParams(params) - if err != nil { - return GetErrorListResponse[models.User](err) - } + query := ConvertToBsonMFromListParams(params) sort, err := GetSortOptionFromString(params.Sort) if err != nil { @@ -171,7 +168,7 @@ func DeleteUserById(_ *gin.Context, params *DeleteByIdParams) (response *Respons return GetErrorResponse[models.User](err) } if user.RootAdmin { - return GetErrorResponse[models.User](errors.New("root admin cannot be deleted")) + return GetErrorResponse[models.User](errors.Forbiddenf("root admin cannot be deleted")) } if err := service.NewModelService[models.User]().DeleteById(id); err != nil { @@ -200,7 +197,7 @@ func DeleteUserList(_ *gin.Context, params *DeleteListParams) (response *Respons "root_admin": true, }, nil) if err == nil { - return GetErrorResponse[models.User](errors.New("root admin cannot be deleted")) + return GetErrorResponse[models.User](errors.Forbiddenf("root admin cannot be deleted")) } if !errors.Is(err, mongo2.ErrNoDocuments) { return GetErrorResponse[models.User](err) diff --git a/core/controllers/user_test.go b/core/controllers/user_test.go index 45b4e2fb..55377667 100644 --- a/core/controllers/user_test.go +++ b/core/controllers/user_test.go @@ -150,7 +150,7 @@ func TestPostUser_Success(t *testing.T) { w = httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Equalf(t, http.StatusBadRequest, w.Code, w.Body.String()) } func TestPutUserById_Success(t *testing.T) { @@ -428,7 +428,7 @@ func TestDeleteUserById_Success(t *testing.T) { w = httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String()) + assert.Equalf(t, http.StatusForbidden, w.Code, "response body: %s", w.Body.String()) // Test deleting with invalid ID req, err = http.NewRequest(http.MethodDelete, "/users/invalid-id", nil) @@ -497,7 +497,7 @@ func TestDeleteUserList_Success(t *testing.T) { w = httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String()) + assert.Equalf(t, http.StatusForbidden, w.Code, "response body: %s", w.Body.String()) // Test with mix of valid and invalid ids reqBody = strings.NewReader(fmt.Sprintf(`{"ids":["%s","invalid-id"]}`, normalUserIds[0].Hex())) diff --git a/core/controllers/utils.go b/core/controllers/utils.go index 0313bc44..8ae3951d 100644 --- a/core/controllers/utils.go +++ b/core/controllers/utils.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "reflect" + "strings" "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/entity" @@ -30,142 +31,152 @@ func GetUserFromContext(c *gin.Context) (u *models.User) { return u } -func GetFilterQueryFromListParams(params *GetListParams) (q bson.M, err error) { - if params.Conditions == "" { - return nil, nil +func ConvertToBsonMFromListParams(params *GetListParams) (q bson.M) { + if params.Filter == "" { + return nil } - conditions, err := GetFilterFromConditionString(params.Conditions) - if err != nil { - return nil, err - } - return utils.FilterToQuery(conditions) + filter := ConvertToFilter(params.Filter) + return utils.FilterToQuery(filter) } -func GetFilterQueryFromConditionString(condStr string) (q bson.M, err error) { - if condStr == "" { - return nil, nil +func ConvertToBsonMFromFilter(filterStr string) (q bson.M) { + if filterStr == "" { + return nil } - conditions, err := GetFilterFromConditionString(condStr) - if err != nil { - return nil, err - } - return utils.FilterToQuery(conditions) + filter := ConvertToFilter(filterStr) + return utils.FilterToQuery(filter) } -func GetFilterFromConditionString(condStr string) (f *entity.Filter, err error) { - if condStr == "" { - return nil, nil +func ConvertToFilter(filterStr string) (f *entity.Filter) { + if filterStr == "" { + return nil } + trimmedFilterStr := strings.TrimSpace(filterStr) + if strings.HasPrefix(trimmedFilterStr, "[") { + return convertToFilterArray(filterStr) + } else if strings.HasPrefix(trimmedFilterStr, "{") { + return convertToFilterMap(filterStr) + } else { + return nil + } +} + +func convertToFilterMap(filterStr string) (f *entity.Filter) { + var filter map[string]interface{} + if err := json.Unmarshal([]byte(filterStr), &filter); err != nil { + return nil + } + var conditions []*entity.Condition - if err := json.Unmarshal([]byte(condStr), &conditions); err != nil { - return nil, err - } - - for i, cond := range conditions { - v := reflect.ValueOf(cond.Value) - switch v.Kind() { - case reflect.String: - item := cond.Value.(string) - // attempt to convert object id - id, err := primitive.ObjectIDFromHex(item) - if err == nil { - conditions[i].Value = id - } else { - conditions[i].Value = item - } - case reflect.Float64: - // JSON numbers are decoded as float64 by default - switch cond.Value.(type) { - case float64: - num := cond.Value.(float64) - // Check if it's a whole number - if num == float64(int64(num)) { - conditions[i].Value = int64(num) - } else { - conditions[i].Value = num - } - case int: - num := cond.Value.(int) - conditions[i].Value = int64(num) - case int64: - num := cond.Value.(int64) - conditions[i].Value = num - } - case reflect.Bool: - conditions[i].Value = cond.Value.(bool) - case reflect.Slice, reflect.Array: - var items []interface{} - for i := 0; i < v.Len(); i++ { - vItem := v.Index(i) - item := vItem.Interface() - - switch typedItem := item.(type) { - case string: - // Try to convert to ObjectID first - if id, err := primitive.ObjectIDFromHex(typedItem); err == nil { - items = append(items, id) - } else { - items = append(items, typedItem) - } - case float64: - if typedItem == float64(int64(typedItem)) { - items = append(items, int64(typedItem)) - } else { - items = append(items, typedItem) - } - case bool: - items = append(items, typedItem) - default: - items = append(items, item) - } - } - conditions[i].Value = items - default: - conditions[i].Value = cond.Value - } + for k, v := range filter { + conditions = append(conditions, &entity.Condition{ + Key: k, + Op: constants.FilterOpEqual, + Value: convertFilterValue(v), + }) } return &entity.Filter{ IsOr: false, Conditions: conditions, - }, nil + } } -// GetFilter Get entity.Filter from gin.Context -func GetFilter(c *gin.Context) (f *entity.Filter, err error) { - condStr := c.Query(constants.FilterQueryFieldConditions) - return GetFilterFromConditionString(condStr) +func convertToFilterArray(filterStr string) (f *entity.Filter) { + var conditions []*entity.Condition + if err := json.Unmarshal([]byte(filterStr), &conditions); err != nil { + return nil + } + + for i, cond := range conditions { + conditions[i].Value = convertFilterValue(cond.Value) + } + + return &entity.Filter{ + IsOr: false, + Conditions: conditions, + } } -// GetFilterQuery Get bson.M from gin.Context -func GetFilterQuery(c *gin.Context) (q bson.M, err error) { - f, err := GetFilter(c) - if err != nil { - return nil, err +func convertFilterValue(value interface{}) interface{} { + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.String: + item := value.(string) + // attempt to convert object id + id, err := primitive.ObjectIDFromHex(item) + if err == nil { + return id + } else { + return item + } + case reflect.Float64: + // JSON numbers are decoded as float64 by default + switch value.(type) { + case float64: + num := value.(float64) + // Check if it's a whole number + if num == float64(int64(num)) { + return int64(num) + } else { + return num + } + case int: + num := value.(int) + return int64(num) + case int64: + num := value.(int64) + return num + } + case reflect.Bool: + return value.(bool) + case reflect.Slice, reflect.Array: + var items []interface{} + for i := 0; i < v.Len(); i++ { + vItem := v.Index(i) + item := vItem.Interface() + + switch typedItem := item.(type) { + case string: + // Try to convert to ObjectID first + if id, err := primitive.ObjectIDFromHex(typedItem); err == nil { + items = append(items, id) + } else { + items = append(items, typedItem) + } + case float64: + if typedItem == float64(int64(typedItem)) { + items = append(items, int64(typedItem)) + } else { + items = append(items, typedItem) + } + case bool: + items = append(items, typedItem) + default: + items = append(items, item) + } + } + return items + default: + return value } + return value +} - if f == nil { - return nil, nil - } - - // TODO: implement logic OR +// GetFilterFromContext Get entity.Filter from gin.Context +func GetFilterFromContext(c *gin.Context) (f *entity.Filter) { + filterStr := c.GetString(constants.FilterQueryFieldFilter) + return ConvertToFilter(filterStr) +} +// ConvertToBsonMFromContext Get bson.M from gin.Context +func ConvertToBsonMFromContext(c *gin.Context) (q bson.M) { + f := GetFilterFromContext(c) return utils.FilterToQuery(f) } -func MustGetFilterQuery(c *gin.Context) (q bson.M) { - q, err := GetFilterQuery(c) - if err != nil { - return nil - } - return q -} - -func getResultListQuery(c *gin.Context) (q mongo.ListQuery) { - f, err := GetFilter(c) - if err != nil { - return q - } +func GetResultListQuery(c *gin.Context) (q mongo.ListQuery) { + f := GetFilterFromContext(c) for _, cond := range f.Conditions { q = append(q, mongo.ListQueryCondition{ Key: cond.Key, @@ -176,35 +187,6 @@ func getResultListQuery(c *gin.Context) (q mongo.ListQuery) { return q } -func GetDefaultPagination() (p *entity.Pagination) { - return &entity.Pagination{ - Page: constants.PaginationDefaultPage, - Size: constants.PaginationDefaultSize, - } -} - -func GetPagination(c *gin.Context) (p *entity.Pagination, err error) { - var _p entity.Pagination - if err := c.ShouldBindQuery(&_p); err != nil { - return GetDefaultPagination(), err - } - if _p.Page == 0 { - _p.Page = constants.PaginationDefaultPage - } - if _p.Size == 0 { - _p.Size = constants.PaginationDefaultSize - } - return &_p, nil -} - -func MustGetPagination(c *gin.Context) (p *entity.Pagination) { - p, err := GetPagination(c) - if err != nil || p == nil { - return GetDefaultPagination() - } - return p -} - func GetSortsFromString(sortStr string) (sorts []entity.Sort, err error) { if sortStr == "" { return nil, nil diff --git a/core/controllers/utils_test.go b/core/controllers/utils_test.go index 29e872a3..ff336c41 100644 --- a/core/controllers/utils_test.go +++ b/core/controllers/utils_test.go @@ -1,9 +1,11 @@ package controllers_test import ( + "sort" "testing" "github.com/crawlab-team/crawlab/core/controllers" + "github.com/crawlab-team/crawlab/core/entity" "github.com/crawlab-team/crawlab/core/models/models" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" @@ -15,8 +17,7 @@ import ( func TestGetFilterFromConditionString(t *testing.T) { // Simple condition with string value condStr := `[{"key":"name","op":"eq","value":"test"}]` - filter, err := controllers.GetFilterFromConditionString(condStr) - require.NoError(t, err) + filter := controllers.ConvertToFilter(condStr) require.NotNil(t, filter) require.Len(t, filter.Conditions, 1) assert.Equal(t, "name", filter.Conditions[0].Key) @@ -25,8 +26,7 @@ func TestGetFilterFromConditionString(t *testing.T) { // Multiple conditions with different types condStr = `[{"key":"name","op":"eq","value":"test"},{"key":"priority","op":"gt","value":5}]` - filter, err = controllers.GetFilterFromConditionString(condStr) - require.NoError(t, err) + filter = controllers.ConvertToFilter(condStr) require.NotNil(t, filter) require.Len(t, filter.Conditions, 2) assert.Equal(t, "name", filter.Conditions[0].Key) @@ -38,60 +38,54 @@ func TestGetFilterFromConditionString(t *testing.T) { // Invalid JSON should return error condStr = `[{"key":"name","op":"eq","value":"test"` - _, err = controllers.GetFilterFromConditionString(condStr) - assert.Error(t, err) + filter = controllers.ConvertToFilter(condStr) + assert.Nil(t, filter) } func TestGetFilterQueryFromConditionString(t *testing.T) { // Simple equality condition condStr := `[{"key":"name","op":"eq","value":"test"}]` - query, err := controllers.GetFilterQueryFromConditionString(condStr) - require.NoError(t, err) + query := controllers.ConvertToBsonMFromFilter(condStr) require.NotNil(t, query) expected := bson.M{"name": "test"} assert.Equal(t, expected, query) // Greater than condition condStr = `[{"key":"priority","op":"gt","value":5}]` - query, err = controllers.GetFilterQueryFromConditionString(condStr) - require.NoError(t, err) + query = controllers.ConvertToBsonMFromFilter(condStr) require.NotNil(t, query) expected = bson.M{"priority": bson.M{"$gt": int64(5)}} assert.Equal(t, expected, query) // Multiple conditions condStr = `[{"key":"name","op":"eq","value":"test"},{"key":"priority","op":"gt","value":5}]` - query, err = controllers.GetFilterQueryFromConditionString(condStr) - require.NoError(t, err) + query = controllers.ConvertToBsonMFromFilter(condStr) require.NotNil(t, query) expected = bson.M{"name": "test", "priority": bson.M{"$gt": int64(5)}} assert.Equal(t, expected, query) // Contains operator condStr = `[{"key":"name","op":"c","value":"test"}]` - query, err = controllers.GetFilterQueryFromConditionString(condStr) - require.NoError(t, err) + query = controllers.ConvertToBsonMFromFilter(condStr) require.NotNil(t, query) expectedRegex := bson.M{"name": bson.M{"$regex": "test", "$options": "i"}} assert.Equal(t, expectedRegex, query) // Invalid condition should return error condStr = `[{"key":"name","op":"invalid_op","value":"test"}]` - _, err = controllers.GetFilterQueryFromConditionString(condStr) - assert.Error(t, err) + query = controllers.ConvertToBsonMFromFilter(condStr) + assert.Nil(t, query) } func TestGetFilterQueryFromListParams(t *testing.T) { // No conditions params := &controllers.GetListParams{} - query, err := controllers.GetFilterQueryFromListParams(params) - require.NoError(t, err) + query := controllers.ConvertToBsonMFromListParams(params) assert.Nil(t, query) // With conditions - params.Conditions = `[{"key":"name","op":"eq","value":"test"}]` - query, err = controllers.GetFilterQueryFromListParams(params) - require.NoError(t, err) + params.Filter = `[{"key":"name","op":"eq","value":"test"}]` + query = controllers.ConvertToBsonMFromListParams(params) require.NotNil(t, query) expected := bson.M{"name": "test"} assert.Equal(t, expected, query) @@ -176,3 +170,52 @@ func TestGetErrorListResponse(t *testing.T) { assert.Nil(t, resp.Data) assert.Equal(t, 0, resp.Total) } + +func TestConvertToFilterMap(t *testing.T) { + // Simple map with string value + condStr := `{"name": "test"}` + filter := controllers.ConvertToFilter(condStr) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + assert.Equal(t, "name", filter.Conditions[0].Key) + assert.Equal(t, "eq", filter.Conditions[0].Op) + assert.Equal(t, "test", filter.Conditions[0].Value) + + // Map with multiple fields of different types + condStr = `{"name": "test", "priority": 5, "active": true}` + filter = controllers.ConvertToFilter(condStr) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 3) + // Sort conditions to ensure consistent test results + sortConditions(filter.Conditions) + assert.Equal(t, "active", filter.Conditions[0].Key) + assert.Equal(t, "eq", filter.Conditions[0].Op) + assert.Equal(t, true, filter.Conditions[0].Value) + assert.Equal(t, "name", filter.Conditions[1].Key) + assert.Equal(t, "eq", filter.Conditions[1].Op) + assert.Equal(t, "test", filter.Conditions[1].Value) + assert.Equal(t, "priority", filter.Conditions[2].Key) + assert.Equal(t, "eq", filter.Conditions[2].Op) + assert.Equal(t, int64(5), filter.Conditions[2].Value) + + // Map with ObjectID string + id := primitive.NewObjectID() + condStr = `{"_id": "` + id.Hex() + `"}` + filter = controllers.ConvertToFilter(condStr) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + assert.Equal(t, "_id", filter.Conditions[0].Key) + assert.Equal(t, "eq", filter.Conditions[0].Op) + assert.Equal(t, id, filter.Conditions[0].Value) + + // Invalid JSON should return nil + condStr = `{"name": "test"` + filter = controllers.ConvertToFilter(condStr) + assert.Nil(t, filter) +} + +func sortConditions(conditions []*entity.Condition) { + sort.Slice(conditions, func(i, j int) bool { + return conditions[i].Key < conditions[j].Key + }) +} diff --git a/core/utils/config.go b/core/utils/config.go index 714fcabb..a2f43b45 100644 --- a/core/utils/config.go +++ b/core/utils/config.go @@ -35,9 +35,10 @@ const ( DefaultPyenvPath = "/root/.pyenv" DefaultNodeModulesPath = "/usr/lib/node_modules" DefaultGoPath = "/root/go" - DefaultMCPServerBaseUrl = "http://localhost:9000" - DefaultMCPClientBaseUrl = "http://localhost:9000/sse" - DefaultOpenAPIUrl = "http://localhost:8000/openapi.json" + DefaultMCPServerHost = "0.0.0.0" + DefaultMCPServerPort = 9777 + DefaultMCPClientBaseUrl = "http://localhost:9777/sse" + DefaultOpenAPIUrlPath = "/openapi.json" ) func IsDev() bool { @@ -289,11 +290,16 @@ func GetGoPath() string { return DefaultGoPath } -func GetMCPServerBaseUrl() string { - if res := viper.GetString("mcp.server.base_url"); res != "" { - return res +func GetMCPServerAddress() string { + host := viper.GetString("mcp.server.host") + if host == "" { + host = DefaultMCPServerHost } - return DefaultMCPServerBaseUrl + port := viper.GetInt("mcp.server.port") + if port == 0 { + port = DefaultMCPServerPort + } + return fmt.Sprintf("%s:%d", host, port) } func GetMCPClientBaseUrl() string { @@ -307,5 +313,5 @@ func GetOpenAPIUrl() string { if res := viper.GetString("openapi.url"); res != "" { return res } - return DefaultOpenAPIUrl + return GetApiEndpoint() + DefaultOpenAPIUrlPath } diff --git a/core/utils/filter.go b/core/utils/filter.go index 16fe804a..1884de3d 100644 --- a/core/utils/filter.go +++ b/core/utils/filter.go @@ -1,16 +1,15 @@ package utils import ( - errors2 "errors" "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/interfaces" "go.mongodb.org/mongo-driver/bson" ) // FilterToQuery Translate entity.Filter to bson.M -func FilterToQuery(f interfaces.Filter) (q bson.M, err error) { +func FilterToQuery(f interfaces.Filter) (q bson.M) { if f == nil || f.IsNil() { - return nil, nil + return nil } q = bson.M{} @@ -42,8 +41,11 @@ func FilterToQuery(f interfaces.Filter) (q bson.M, err error) { case constants.FilterOpLessThanEqual: q[key] = bson.M{"$lte": value} default: - return nil, errors2.New("invalid operation") + // ignore invalid operation } } - return q, nil + if len(q) == 0 { + return nil + } + return q }