refactor: standardize filter handling across controllers

- Replaced "conditions" parameter with "filter" in various controller methods to unify the filtering approach.
- Updated related functions to utilize the new filter parameter, enhancing consistency in query handling.
- Improved overall code readability and maintainability by aligning naming conventions and refactoring filter-related logic.
This commit is contained in:
Marvin Zhang
2025-03-26 22:08:01 +08:00
parent 863ba153f6
commit f86173c973
17 changed files with 274 additions and 269 deletions

View File

@@ -1,9 +1,7 @@
package constants
const (
FilterQueryFieldConditions = "conditions"
FilterQueryFieldAll = "all"
FilterQueryFieldFilter = "filter"
FilterQueryFieldFilter = "filter"
)
const (

View File

@@ -25,9 +25,8 @@ func init() {
response := gin.H{
"error": err.Error(),
}
status := http.StatusInternalServerError
constErr, ok := errors.AsType[errors.ConstError](err)
if ok {
var status int
if constErr, ok := errors.AsType[errors.ConstError](err); ok {
switch {
case errors.Is(constErr, errors.NotFound):
status = http.StatusNotFound
@@ -40,6 +39,8 @@ func init() {
default:
status = http.StatusInternalServerError
}
} else if _, ok := errors.AsType[tonic.BindError](err); ok {
status = http.StatusBadRequest
} else {
status = http.StatusInternalServerError
}
@@ -62,11 +63,11 @@ type BaseController[T any] struct {
// GetListParams represents parameters for GetList with pagination
type GetListParams struct {
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
Sort string `query:"sort" description:"Sort options"`
Page int `query:"page" default:"1" description:"Page number" minimum:"1"`
Size int `query:"size" default:"10" description:"Page size" minimum:"1"`
All bool `query:"all" default:"false" description:"Whether to get all items"`
Filter string `query:"filter" description:"Filter query"`
Sort string `query:"sort" description:"Sort options"`
Page int `query:"page" default:"1" description:"Page number" minimum:"1"`
Size int `query:"size" default:"10" description:"Page size" minimum:"1"`
All bool `query:"all" default:"false" description:"Whether to get all items"`
}
func (ctr *BaseController[T]) GetList(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
@@ -221,7 +222,7 @@ func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParam
for _, id := range params.Ids {
objectId, err := primitive.ObjectIDFromHex(id)
if err != nil {
return GetErrorResponse[T](err)
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
}
ids = append(ids, objectId)
}
@@ -241,10 +242,7 @@ func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParam
// GetAll retrieves all items based on filter and sort
func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListResponse[T], err error) {
// Get filter query
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err))
}
query := ConvertToBsonMFromListParams(params)
// Get sort options
sort, err := GetSortOptionFromString(params.Sort)
@@ -273,10 +271,7 @@ func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListRespo
// GetWithPagination retrieves items with pagination
func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response *ListResponse[T], err error) {
// Get filter query
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err))
}
query := ConvertToBsonMFromListParams(params)
// Get sort options
sort, err := GetSortOptionFromString(params.Sort)

View File

@@ -269,7 +269,7 @@ func TestBaseController_GetList(t *testing.T) {
condBytes, err := json.Marshal(cond)
require.Nil(t, err)
params := url.Values{}
params.Add("conditions", string(condBytes))
params.Add("filter", string(condBytes))
params.Add("page", "1")
params.Add("size", "10")
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}

View File

@@ -11,16 +11,13 @@ import (
)
type PostExportParams struct {
Type string `path:"type" validate:"required"`
Target string `query:"target" validate:"required"`
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
Type string `path:"type" validate:"required"`
Target string `query:"target" validate:"required"`
Filter string `query:"filter" description:"Filter query"`
}
func PostExport(_ *gin.Context, params *PostExportParams) (response *Response[string], err error) {
query, err := GetFilterQueryFromConditionString(params.Conditions)
if err != nil {
return GetErrorResponse[string](err)
}
query := ConvertToBsonMFromFilter(params.Filter)
var exportId string
switch params.Type {
case constants.ExportTypeCsv:

View File

@@ -10,36 +10,36 @@ import (
)
type GetFilterColFieldOptionsParams struct {
Col string `path:"col" validate:"required"`
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
Col string `path:"col" validate:"required"`
Filter string `query:"filter" description:"Filter query"`
}
func GetFilterColFieldOptions(c *gin.Context, params *GetFilterColFieldOptionsParams) (response *Response[[]entity.FilterSelectOption], err error) {
return GetFilterColFieldOptionsWithValueLabel(c, &GetFilterColFieldOptionsWithValueLabelParams{
Col: params.Col,
Conditions: params.Conditions,
Col: params.Col,
Filter: params.Filter,
})
}
type GetFilterColFieldOptionsWithValueParams struct {
Col string `path:"col" validate:"required"`
Value string `path:"value"`
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
Col string `path:"col" validate:"required"`
Value string `path:"value"`
Filter string `query:"filter" description:"Filter query"`
}
func GetFilterColFieldOptionsWithValue(c *gin.Context, params *GetFilterColFieldOptionsWithValueParams) (response *Response[[]entity.FilterSelectOption], err error) {
return GetFilterColFieldOptionsWithValueLabel(c, &GetFilterColFieldOptionsWithValueLabelParams{
Col: params.Col,
Value: params.Value,
Conditions: params.Conditions,
Col: params.Col,
Value: params.Value,
Filter: params.Filter,
})
}
type GetFilterColFieldOptionsWithValueLabelParams struct {
Col string `path:"col" validate:"required"`
Value string `path:"value"`
Label string `path:"label"`
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
Col string `path:"col" validate:"required"`
Value string `path:"value"`
Label string `path:"label"`
Filter string `query:"filter" description:"Filter query"`
}
func GetFilterColFieldOptionsWithValueLabel(_ *gin.Context, params *GetFilterColFieldOptionsWithValueLabelParams) (response *Response[[]entity.FilterSelectOption], err error) {
@@ -53,11 +53,8 @@ func GetFilterColFieldOptionsWithValueLabel(_ *gin.Context, params *GetFilterCol
}
pipelines := mongo2.Pipeline{}
if params.Conditions != "" {
query, err := GetFilterFromConditionString(params.Conditions)
if err != nil {
return GetErrorResponse[[]entity.FilterSelectOption](errors.Trace(err))
}
if params.Filter != "" {
query := ConvertToFilter(params.Filter)
pipelines = append(pipelines, bson.D{{"$match", query}})
}
pipelines = append(

View File

@@ -16,10 +16,7 @@ func GetProjectList(c *gin.Context, params *GetListParams) (response *ListRespon
return NewController[models.Project]().GetAll(params)
}
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.Project](errors.BadRequestf("invalid request parameters: %v", err))
}
query := ConvertToBsonMFromListParams(params)
sort, err := GetSortOptionFromString(params.Sort)
if err != nil {

View File

@@ -243,8 +243,6 @@ func InitRoutes(app *gin.Engine) (err error) {
// Register resource controllers with their respective endpoints
// Each RegisterController call sets up standard CRUD operations
// Additional custom actions can be specified in the controller initialization
RegisterController(groups.AuthGroup.Group("", "Data Collections", "APIs for data collections management"), "/data/collections", NewController[models.DataCollection]())
RegisterController(groups.AuthGroup.Group("", "Environments", "APIs for environment variables management"), "/environments", NewController[models.Environment]())
RegisterController(groups.AuthGroup.Group("", "Nodes", "APIs for nodes management"), "/nodes", NewController[models.Node]())
RegisterController(groups.AuthGroup.Group("", "Projects", "APIs for projects management"), "/projects", NewController[models.Project]([]Action{
{
@@ -478,22 +476,6 @@ func InitRoutes(app *gin.Engine) (err error) {
HandlerFunc: GetTaskLogs,
},
}...))
RegisterController(groups.AuthGroup.Group("", "Tokens", "APIs for PAT management"), "/tokens", NewController[models.Token]([]Action{
{
Method: http.MethodPost,
Path: "",
Name: "Create Token",
Description: "Create a new token",
HandlerFunc: PostToken,
},
{
Method: http.MethodGet,
Path: "",
Name: "Get Token List",
Description: "Get a list of tokens",
HandlerFunc: GetTokenList,
},
}...))
RegisterController(groups.AuthGroup.Group("", "Users", "APIs for users management"), "/users", NewController[models.User]([]Action{
{
Method: http.MethodGet,
@@ -566,6 +548,24 @@ func InitRoutes(app *gin.Engine) (err error) {
HandlerFunc: PostUserMeChangePassword,
},
}...))
RegisterController(groups.AuthGroup.Group("", "Tokens", "APIs for PAT management"), "/tokens", NewController[models.Token]([]Action{
{
Method: http.MethodPost,
Path: "",
Name: "Create Token",
Description: "Create a new token",
HandlerFunc: PostToken,
},
{
Method: http.MethodGet,
Path: "",
Name: "Get Token List",
Description: "Get a list of tokens",
HandlerFunc: GetTokenList,
},
}...))
RegisterController(groups.AuthGroup.Group("", "Environments", "APIs for environment variables management"), "/environments", NewController[models.Environment]())
RegisterController(groups.AuthGroup.Group("", "Data Collections", "APIs for data collections management"), "/data/collections", NewController[models.DataCollection]())
// Register standalone action routes that don't fit the standard CRUD pattern
RegisterActions(groups.AuthGroup.Group("", "Export", "APIs for exporting data"), "/export", []Action{

View File

@@ -35,7 +35,7 @@ func TestRegisterController_Routes(t *testing.T) {
controllers.RegisterController(groups.AuthGroup, basePath, ctr)
// Check if all routes are registered
routes := router.Routes()
routes := controllers.GetGlobalFizzWrapper().GetFizz().Engine().Routes()
var methodPaths []string
for _, route := range routes {
@@ -64,7 +64,7 @@ func TestInitRoutes_ProjectsRoute(t *testing.T) {
_ = controllers.InitRoutes(router)
// Check if the projects route is registered
routes := router.Routes()
routes := controllers.GetGlobalFizzWrapper().GetFizz().Engine().Routes()
var methodPaths []string
for _, route := range routes {

View File

@@ -89,10 +89,7 @@ func GetSpiderList(c *gin.Context, params *GetListParams) (response *ListRespons
}
func getSpiderListWithStats(params *GetListParams) (response *ListResponse[models.Spider], err error) {
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.Spider](errors.BadRequestf("invalid request parameters: %v", err))
}
query := ConvertToBsonMFromListParams(params)
sort, err := GetSortOptionFromString(params.Sort)
if err != nil {
@@ -749,7 +746,7 @@ func GetSpiderResults(c *gin.Context, params *GetSpiderResultsParams) (response
return GetErrorListResponse[bson.M](err)
}
query := getResultListQuery(c)
query := GetResultListQuery(c)
col := mongo2.GetMongoCol(s.ColName)

View File

@@ -79,10 +79,7 @@ func GetTaskList(c *gin.Context, params *GetTaskListParams) (response *ListRespo
}
// get query
query, err := GetFilterQueryFromListParams(params.GetListParams)
if err != nil {
return GetErrorListResponse[models.Task](err)
}
query := ConvertToBsonMFromListParams(params.GetListParams)
sort, err := GetSortOptionFromString(params.GetListParams.Sort)
if err != nil {

View File

@@ -43,10 +43,7 @@ func GetTokenList(c *gin.Context, params *GetListParams) (response *ListResponse
u := GetUserFromContext(c)
// Get filter query
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.Token](errors.BadRequestf("invalid request parameters: %v", err))
}
query := ConvertToBsonMFromListParams(params)
// Add filter for tokens created by the current user
query["created_by"] = u.Id

View File

@@ -24,10 +24,7 @@ func GetUserById(_ *gin.Context, params *GetByIdParams) (response *Response[mode
}
func GetUserList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.User], err error) {
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.User](err)
}
query := ConvertToBsonMFromListParams(params)
sort, err := GetSortOptionFromString(params.Sort)
if err != nil {
@@ -171,7 +168,7 @@ func DeleteUserById(_ *gin.Context, params *DeleteByIdParams) (response *Respons
return GetErrorResponse[models.User](err)
}
if user.RootAdmin {
return GetErrorResponse[models.User](errors.New("root admin cannot be deleted"))
return GetErrorResponse[models.User](errors.Forbiddenf("root admin cannot be deleted"))
}
if err := service.NewModelService[models.User]().DeleteById(id); err != nil {
@@ -200,7 +197,7 @@ func DeleteUserList(_ *gin.Context, params *DeleteListParams) (response *Respons
"root_admin": true,
}, nil)
if err == nil {
return GetErrorResponse[models.User](errors.New("root admin cannot be deleted"))
return GetErrorResponse[models.User](errors.Forbiddenf("root admin cannot be deleted"))
}
if !errors.Is(err, mongo2.ErrNoDocuments) {
return GetErrorResponse[models.User](err)

View File

@@ -150,7 +150,7 @@ func TestPostUser_Success(t *testing.T) {
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Equalf(t, http.StatusBadRequest, w.Code, w.Body.String())
}
func TestPutUserById_Success(t *testing.T) {
@@ -428,7 +428,7 @@ func TestDeleteUserById_Success(t *testing.T) {
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String())
assert.Equalf(t, http.StatusForbidden, w.Code, "response body: %s", w.Body.String())
// Test deleting with invalid ID
req, err = http.NewRequest(http.MethodDelete, "/users/invalid-id", nil)
@@ -497,7 +497,7 @@ func TestDeleteUserList_Success(t *testing.T) {
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String())
assert.Equalf(t, http.StatusForbidden, w.Code, "response body: %s", w.Body.String())
// Test with mix of valid and invalid ids
reqBody = strings.NewReader(fmt.Sprintf(`{"ids":["%s","invalid-id"]}`, normalUserIds[0].Hex()))

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"reflect"
"strings"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
@@ -30,142 +31,152 @@ func GetUserFromContext(c *gin.Context) (u *models.User) {
return u
}
func GetFilterQueryFromListParams(params *GetListParams) (q bson.M, err error) {
if params.Conditions == "" {
return nil, nil
func ConvertToBsonMFromListParams(params *GetListParams) (q bson.M) {
if params.Filter == "" {
return nil
}
conditions, err := GetFilterFromConditionString(params.Conditions)
if err != nil {
return nil, err
}
return utils.FilterToQuery(conditions)
filter := ConvertToFilter(params.Filter)
return utils.FilterToQuery(filter)
}
func GetFilterQueryFromConditionString(condStr string) (q bson.M, err error) {
if condStr == "" {
return nil, nil
func ConvertToBsonMFromFilter(filterStr string) (q bson.M) {
if filterStr == "" {
return nil
}
conditions, err := GetFilterFromConditionString(condStr)
if err != nil {
return nil, err
}
return utils.FilterToQuery(conditions)
filter := ConvertToFilter(filterStr)
return utils.FilterToQuery(filter)
}
func GetFilterFromConditionString(condStr string) (f *entity.Filter, err error) {
if condStr == "" {
return nil, nil
func ConvertToFilter(filterStr string) (f *entity.Filter) {
if filterStr == "" {
return nil
}
trimmedFilterStr := strings.TrimSpace(filterStr)
if strings.HasPrefix(trimmedFilterStr, "[") {
return convertToFilterArray(filterStr)
} else if strings.HasPrefix(trimmedFilterStr, "{") {
return convertToFilterMap(filterStr)
} else {
return nil
}
}
func convertToFilterMap(filterStr string) (f *entity.Filter) {
var filter map[string]interface{}
if err := json.Unmarshal([]byte(filterStr), &filter); err != nil {
return nil
}
var conditions []*entity.Condition
if err := json.Unmarshal([]byte(condStr), &conditions); err != nil {
return nil, err
}
for i, cond := range conditions {
v := reflect.ValueOf(cond.Value)
switch v.Kind() {
case reflect.String:
item := cond.Value.(string)
// attempt to convert object id
id, err := primitive.ObjectIDFromHex(item)
if err == nil {
conditions[i].Value = id
} else {
conditions[i].Value = item
}
case reflect.Float64:
// JSON numbers are decoded as float64 by default
switch cond.Value.(type) {
case float64:
num := cond.Value.(float64)
// Check if it's a whole number
if num == float64(int64(num)) {
conditions[i].Value = int64(num)
} else {
conditions[i].Value = num
}
case int:
num := cond.Value.(int)
conditions[i].Value = int64(num)
case int64:
num := cond.Value.(int64)
conditions[i].Value = num
}
case reflect.Bool:
conditions[i].Value = cond.Value.(bool)
case reflect.Slice, reflect.Array:
var items []interface{}
for i := 0; i < v.Len(); i++ {
vItem := v.Index(i)
item := vItem.Interface()
switch typedItem := item.(type) {
case string:
// Try to convert to ObjectID first
if id, err := primitive.ObjectIDFromHex(typedItem); err == nil {
items = append(items, id)
} else {
items = append(items, typedItem)
}
case float64:
if typedItem == float64(int64(typedItem)) {
items = append(items, int64(typedItem))
} else {
items = append(items, typedItem)
}
case bool:
items = append(items, typedItem)
default:
items = append(items, item)
}
}
conditions[i].Value = items
default:
conditions[i].Value = cond.Value
}
for k, v := range filter {
conditions = append(conditions, &entity.Condition{
Key: k,
Op: constants.FilterOpEqual,
Value: convertFilterValue(v),
})
}
return &entity.Filter{
IsOr: false,
Conditions: conditions,
}, nil
}
}
// GetFilter Get entity.Filter from gin.Context
func GetFilter(c *gin.Context) (f *entity.Filter, err error) {
condStr := c.Query(constants.FilterQueryFieldConditions)
return GetFilterFromConditionString(condStr)
func convertToFilterArray(filterStr string) (f *entity.Filter) {
var conditions []*entity.Condition
if err := json.Unmarshal([]byte(filterStr), &conditions); err != nil {
return nil
}
for i, cond := range conditions {
conditions[i].Value = convertFilterValue(cond.Value)
}
return &entity.Filter{
IsOr: false,
Conditions: conditions,
}
}
// GetFilterQuery Get bson.M from gin.Context
func GetFilterQuery(c *gin.Context) (q bson.M, err error) {
f, err := GetFilter(c)
if err != nil {
return nil, err
func convertFilterValue(value interface{}) interface{} {
v := reflect.ValueOf(value)
switch v.Kind() {
case reflect.String:
item := value.(string)
// attempt to convert object id
id, err := primitive.ObjectIDFromHex(item)
if err == nil {
return id
} else {
return item
}
case reflect.Float64:
// JSON numbers are decoded as float64 by default
switch value.(type) {
case float64:
num := value.(float64)
// Check if it's a whole number
if num == float64(int64(num)) {
return int64(num)
} else {
return num
}
case int:
num := value.(int)
return int64(num)
case int64:
num := value.(int64)
return num
}
case reflect.Bool:
return value.(bool)
case reflect.Slice, reflect.Array:
var items []interface{}
for i := 0; i < v.Len(); i++ {
vItem := v.Index(i)
item := vItem.Interface()
switch typedItem := item.(type) {
case string:
// Try to convert to ObjectID first
if id, err := primitive.ObjectIDFromHex(typedItem); err == nil {
items = append(items, id)
} else {
items = append(items, typedItem)
}
case float64:
if typedItem == float64(int64(typedItem)) {
items = append(items, int64(typedItem))
} else {
items = append(items, typedItem)
}
case bool:
items = append(items, typedItem)
default:
items = append(items, item)
}
}
return items
default:
return value
}
return value
}
if f == nil {
return nil, nil
}
// TODO: implement logic OR
// GetFilterFromContext Get entity.Filter from gin.Context
func GetFilterFromContext(c *gin.Context) (f *entity.Filter) {
filterStr := c.GetString(constants.FilterQueryFieldFilter)
return ConvertToFilter(filterStr)
}
// ConvertToBsonMFromContext Get bson.M from gin.Context
func ConvertToBsonMFromContext(c *gin.Context) (q bson.M) {
f := GetFilterFromContext(c)
return utils.FilterToQuery(f)
}
func MustGetFilterQuery(c *gin.Context) (q bson.M) {
q, err := GetFilterQuery(c)
if err != nil {
return nil
}
return q
}
func getResultListQuery(c *gin.Context) (q mongo.ListQuery) {
f, err := GetFilter(c)
if err != nil {
return q
}
func GetResultListQuery(c *gin.Context) (q mongo.ListQuery) {
f := GetFilterFromContext(c)
for _, cond := range f.Conditions {
q = append(q, mongo.ListQueryCondition{
Key: cond.Key,
@@ -176,35 +187,6 @@ func getResultListQuery(c *gin.Context) (q mongo.ListQuery) {
return q
}
func GetDefaultPagination() (p *entity.Pagination) {
return &entity.Pagination{
Page: constants.PaginationDefaultPage,
Size: constants.PaginationDefaultSize,
}
}
func GetPagination(c *gin.Context) (p *entity.Pagination, err error) {
var _p entity.Pagination
if err := c.ShouldBindQuery(&_p); err != nil {
return GetDefaultPagination(), err
}
if _p.Page == 0 {
_p.Page = constants.PaginationDefaultPage
}
if _p.Size == 0 {
_p.Size = constants.PaginationDefaultSize
}
return &_p, nil
}
func MustGetPagination(c *gin.Context) (p *entity.Pagination) {
p, err := GetPagination(c)
if err != nil || p == nil {
return GetDefaultPagination()
}
return p
}
func GetSortsFromString(sortStr string) (sorts []entity.Sort, err error) {
if sortStr == "" {
return nil, nil

View File

@@ -1,9 +1,11 @@
package controllers_test
import (
"sort"
"testing"
"github.com/crawlab-team/crawlab/core/controllers"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
@@ -15,8 +17,7 @@ import (
func TestGetFilterFromConditionString(t *testing.T) {
// Simple condition with string value
condStr := `[{"key":"name","op":"eq","value":"test"}]`
filter, err := controllers.GetFilterFromConditionString(condStr)
require.NoError(t, err)
filter := controllers.ConvertToFilter(condStr)
require.NotNil(t, filter)
require.Len(t, filter.Conditions, 1)
assert.Equal(t, "name", filter.Conditions[0].Key)
@@ -25,8 +26,7 @@ func TestGetFilterFromConditionString(t *testing.T) {
// Multiple conditions with different types
condStr = `[{"key":"name","op":"eq","value":"test"},{"key":"priority","op":"gt","value":5}]`
filter, err = controllers.GetFilterFromConditionString(condStr)
require.NoError(t, err)
filter = controllers.ConvertToFilter(condStr)
require.NotNil(t, filter)
require.Len(t, filter.Conditions, 2)
assert.Equal(t, "name", filter.Conditions[0].Key)
@@ -38,60 +38,54 @@ func TestGetFilterFromConditionString(t *testing.T) {
// Invalid JSON should return error
condStr = `[{"key":"name","op":"eq","value":"test"`
_, err = controllers.GetFilterFromConditionString(condStr)
assert.Error(t, err)
filter = controllers.ConvertToFilter(condStr)
assert.Nil(t, filter)
}
func TestGetFilterQueryFromConditionString(t *testing.T) {
// Simple equality condition
condStr := `[{"key":"name","op":"eq","value":"test"}]`
query, err := controllers.GetFilterQueryFromConditionString(condStr)
require.NoError(t, err)
query := controllers.ConvertToBsonMFromFilter(condStr)
require.NotNil(t, query)
expected := bson.M{"name": "test"}
assert.Equal(t, expected, query)
// Greater than condition
condStr = `[{"key":"priority","op":"gt","value":5}]`
query, err = controllers.GetFilterQueryFromConditionString(condStr)
require.NoError(t, err)
query = controllers.ConvertToBsonMFromFilter(condStr)
require.NotNil(t, query)
expected = bson.M{"priority": bson.M{"$gt": int64(5)}}
assert.Equal(t, expected, query)
// Multiple conditions
condStr = `[{"key":"name","op":"eq","value":"test"},{"key":"priority","op":"gt","value":5}]`
query, err = controllers.GetFilterQueryFromConditionString(condStr)
require.NoError(t, err)
query = controllers.ConvertToBsonMFromFilter(condStr)
require.NotNil(t, query)
expected = bson.M{"name": "test", "priority": bson.M{"$gt": int64(5)}}
assert.Equal(t, expected, query)
// Contains operator
condStr = `[{"key":"name","op":"c","value":"test"}]`
query, err = controllers.GetFilterQueryFromConditionString(condStr)
require.NoError(t, err)
query = controllers.ConvertToBsonMFromFilter(condStr)
require.NotNil(t, query)
expectedRegex := bson.M{"name": bson.M{"$regex": "test", "$options": "i"}}
assert.Equal(t, expectedRegex, query)
// Invalid condition should return error
condStr = `[{"key":"name","op":"invalid_op","value":"test"}]`
_, err = controllers.GetFilterQueryFromConditionString(condStr)
assert.Error(t, err)
query = controllers.ConvertToBsonMFromFilter(condStr)
assert.Nil(t, query)
}
func TestGetFilterQueryFromListParams(t *testing.T) {
// No conditions
params := &controllers.GetListParams{}
query, err := controllers.GetFilterQueryFromListParams(params)
require.NoError(t, err)
query := controllers.ConvertToBsonMFromListParams(params)
assert.Nil(t, query)
// With conditions
params.Conditions = `[{"key":"name","op":"eq","value":"test"}]`
query, err = controllers.GetFilterQueryFromListParams(params)
require.NoError(t, err)
params.Filter = `[{"key":"name","op":"eq","value":"test"}]`
query = controllers.ConvertToBsonMFromListParams(params)
require.NotNil(t, query)
expected := bson.M{"name": "test"}
assert.Equal(t, expected, query)
@@ -176,3 +170,52 @@ func TestGetErrorListResponse(t *testing.T) {
assert.Nil(t, resp.Data)
assert.Equal(t, 0, resp.Total)
}
func TestConvertToFilterMap(t *testing.T) {
// Simple map with string value
condStr := `{"name": "test"}`
filter := controllers.ConvertToFilter(condStr)
require.NotNil(t, filter)
require.Len(t, filter.Conditions, 1)
assert.Equal(t, "name", filter.Conditions[0].Key)
assert.Equal(t, "eq", filter.Conditions[0].Op)
assert.Equal(t, "test", filter.Conditions[0].Value)
// Map with multiple fields of different types
condStr = `{"name": "test", "priority": 5, "active": true}`
filter = controllers.ConvertToFilter(condStr)
require.NotNil(t, filter)
require.Len(t, filter.Conditions, 3)
// Sort conditions to ensure consistent test results
sortConditions(filter.Conditions)
assert.Equal(t, "active", filter.Conditions[0].Key)
assert.Equal(t, "eq", filter.Conditions[0].Op)
assert.Equal(t, true, filter.Conditions[0].Value)
assert.Equal(t, "name", filter.Conditions[1].Key)
assert.Equal(t, "eq", filter.Conditions[1].Op)
assert.Equal(t, "test", filter.Conditions[1].Value)
assert.Equal(t, "priority", filter.Conditions[2].Key)
assert.Equal(t, "eq", filter.Conditions[2].Op)
assert.Equal(t, int64(5), filter.Conditions[2].Value)
// Map with ObjectID string
id := primitive.NewObjectID()
condStr = `{"_id": "` + id.Hex() + `"}`
filter = controllers.ConvertToFilter(condStr)
require.NotNil(t, filter)
require.Len(t, filter.Conditions, 1)
assert.Equal(t, "_id", filter.Conditions[0].Key)
assert.Equal(t, "eq", filter.Conditions[0].Op)
assert.Equal(t, id, filter.Conditions[0].Value)
// Invalid JSON should return nil
condStr = `{"name": "test"`
filter = controllers.ConvertToFilter(condStr)
assert.Nil(t, filter)
}
func sortConditions(conditions []*entity.Condition) {
sort.Slice(conditions, func(i, j int) bool {
return conditions[i].Key < conditions[j].Key
})
}

View File

@@ -35,9 +35,10 @@ const (
DefaultPyenvPath = "/root/.pyenv"
DefaultNodeModulesPath = "/usr/lib/node_modules"
DefaultGoPath = "/root/go"
DefaultMCPServerBaseUrl = "http://localhost:9000"
DefaultMCPClientBaseUrl = "http://localhost:9000/sse"
DefaultOpenAPIUrl = "http://localhost:8000/openapi.json"
DefaultMCPServerHost = "0.0.0.0"
DefaultMCPServerPort = 9777
DefaultMCPClientBaseUrl = "http://localhost:9777/sse"
DefaultOpenAPIUrlPath = "/openapi.json"
)
func IsDev() bool {
@@ -289,11 +290,16 @@ func GetGoPath() string {
return DefaultGoPath
}
func GetMCPServerBaseUrl() string {
if res := viper.GetString("mcp.server.base_url"); res != "" {
return res
func GetMCPServerAddress() string {
host := viper.GetString("mcp.server.host")
if host == "" {
host = DefaultMCPServerHost
}
return DefaultMCPServerBaseUrl
port := viper.GetInt("mcp.server.port")
if port == 0 {
port = DefaultMCPServerPort
}
return fmt.Sprintf("%s:%d", host, port)
}
func GetMCPClientBaseUrl() string {
@@ -307,5 +313,5 @@ func GetOpenAPIUrl() string {
if res := viper.GetString("openapi.url"); res != "" {
return res
}
return DefaultOpenAPIUrl
return GetApiEndpoint() + DefaultOpenAPIUrlPath
}

View File

@@ -1,16 +1,15 @@
package utils
import (
errors2 "errors"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/interfaces"
"go.mongodb.org/mongo-driver/bson"
)
// FilterToQuery Translate entity.Filter to bson.M
func FilterToQuery(f interfaces.Filter) (q bson.M, err error) {
func FilterToQuery(f interfaces.Filter) (q bson.M) {
if f == nil || f.IsNil() {
return nil, nil
return nil
}
q = bson.M{}
@@ -42,8 +41,11 @@ func FilterToQuery(f interfaces.Filter) (q bson.M, err error) {
case constants.FilterOpLessThanEqual:
q[key] = bson.M{"$lte": value}
default:
return nil, errors2.New("invalid operation")
// ignore invalid operation
}
}
return q, nil
if len(q) == 0 {
return nil
}
return q
}