chore: update Go version and dependencies

- Updated Go version in go.work and backend/go.mod to 1.23.7
- Updated various dependencies in go.sum and backend/go.sum
- Refactored models to remove generic type parameters from BaseModel
- Introduced new utility functions for consistent API responses
- Removed unused utility files from controllers
This commit is contained in:
Marvin Zhang
2025-03-12 23:20:06 +08:00
parent d6badb533d
commit ddff881954
58 changed files with 1555 additions and 1306 deletions

View File

@@ -1,7 +1,8 @@
package controllers
import (
"github.com/crawlab-team/crawlab/core/entity"
"time"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/mongo"
@@ -23,28 +24,32 @@ type BaseController[T any] struct {
actions []Action
}
type GetByIdParams struct {
Id string `path:"id" description:"The ID of the item to get"`
}
// GetAllParams represents parameters for GetAll method
type GetAllParams struct {
Query bson.M `json:"query"`
Sort bson.D `json:"sort"`
}
// GetListParams represents parameters for GetList with pagination
type GetListParams struct {
Query bson.M `json:"query"`
Sort bson.D `json:"sort"`
Pagination *entity.Pagination `json:"pagination"`
All bool `query:"all" description:"Whether to get all items"`
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
Sort bson.D `query:"sort" description:"Sort options"`
Page int `query:"page" default:"1" description:"Page number"`
Size int `query:"size" default:"10" description:"Page size"`
All bool `query:"all" default:"false" description:"Whether to get all items"`
}
func (ctr *BaseController[T]) GetList(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
// get all if query field "all" is set true
if params.All {
return ctr.GetAll(params)
}
return ctr.GetWithPagination(params)
}
type GetByIdParams struct {
Id string `path:"id" description:"The ID of the item to get"`
}
func (ctr *BaseController[T]) GetById(_ *gin.Context, params *GetByIdParams) (response *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return nil, errors.BadRequestf("invalid id: %s", params.Id)
return GetErrorResponse[T](errors.BadRequestf("invalid id format"))
}
model, err := ctr.modelSvc.GetById(id)
@@ -52,143 +57,132 @@ func (ctr *BaseController[T]) GetById(_ *gin.Context, params *GetByIdParams) (re
return nil, err
}
return GetSuccessDataResponse(*model)
return GetDataResponse(*model)
}
func (ctr *BaseController[T]) GetList(c *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
// get all if query field "all" is set true
all := params.All || MustGetFilterAll(c)
// Prepare parameters
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
if all {
allParams := &GetAllParams{
Query: query,
Sort: sort,
}
return ctr.GetAll(c, allParams)
}
// get list with pagination
pagination := MustGetPagination(c)
listParams := &GetListParams{
Query: query,
Sort: sort,
Pagination: pagination,
}
return ctr.GetWithPagination(c, listParams)
type PostParams[T any] struct {
Data T `json:"data"`
}
func (ctr *BaseController[T]) Post(c *gin.Context) (response *Response[T], err error) {
var model T
if err := c.ShouldBindJSON(&model); err != nil {
return GetErrorDataResponse[T](err)
}
func (ctr *BaseController[T]) Post(c *gin.Context, params *PostParams[T]) (response *Response[T], err error) {
u := GetUserFromContext(c)
m := any(&model).(interfaces.Model)
m := any(&params.Data).(interfaces.Model)
m.SetId(primitive.NewObjectID())
m.SetCreated(u.Id)
m.SetUpdated(u.Id)
col := ctr.modelSvc.GetCol()
res, err := col.GetCollection().InsertOne(col.GetContext(), m)
if err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](err)
}
result, err := ctr.modelSvc.GetById(res.InsertedID.(primitive.ObjectID))
if err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](err)
}
return GetSuccessDataResponse(*result)
return GetDataResponse(*result)
}
func (ctr *BaseController[T]) PutById(c *gin.Context) (response *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
if err != nil {
return GetErrorDataResponse[T](err)
}
type PutByIdParams[T any] struct {
Id string `path:"id" description:"The ID of the item to update"`
Data T `json:"data"`
}
var model T
if err := c.ShouldBindJSON(&model); err != nil {
return GetErrorDataResponse[T](err)
func (ctr *BaseController[T]) PutById(c *gin.Context, params *PutByIdParams[T]) (response *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
}
u := GetUserFromContext(c)
m := any(&model).(interfaces.Model)
m := any(&params.Data).(interfaces.Model)
m.SetUpdated(u.Id)
if m.GetId().IsZero() {
m.SetId(id)
}
if err := ctr.modelSvc.ReplaceById(id, model); err != nil {
return GetErrorDataResponse[T](err)
if err := ctr.modelSvc.ReplaceById(id, params.Data); err != nil {
return GetErrorResponse[T](err)
}
result, err := ctr.modelSvc.GetById(id)
if err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](err)
}
return GetSuccessDataResponse(*result)
return GetDataResponse(*result)
}
func (ctr *BaseController[T]) PatchList(c *gin.Context) (res *Response[T], err error) {
type Payload struct {
Ids []primitive.ObjectID `json:"ids"`
Update bson.M `json:"update"`
type PatchParams struct {
Ids []string `json:"ids" description:"The IDs of the items to update" validate:"required"`
Update bson.M `json:"update" description:"The update object" validate:"required"`
}
func (ctr *BaseController[T]) PatchList(c *gin.Context, params *PatchParams) (res *Response[T], err error) {
var ids []primitive.ObjectID
for _, id := range params.Ids {
objectId, err := primitive.ObjectIDFromHex(id)
if err != nil {
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
}
ids = append(ids, objectId)
}
var payload Payload
if err := c.ShouldBindJSON(&payload); err != nil {
return GetErrorDataResponse[T](err)
}
// Get user from context for updated_by
u := GetUserFromContext(c)
// query
query := bson.M{
"_id": bson.M{
"$in": payload.Ids,
"$in": ids,
},
}
// Add updated_by and updated_ts to the update object
updateObj := params.Update
updateObj["updated_by"] = u.Id
updateObj["updated_ts"] = time.Now()
// update
if err := ctr.modelSvc.UpdateMany(query, bson.M{"$set": payload.Update}); err != nil {
return GetErrorDataResponse[T](err)
if err := ctr.modelSvc.UpdateMany(query, bson.M{"$set": updateObj}); err != nil {
return GetErrorResponse[T](err)
}
// Return an empty response with success status
var emptyModel T
return GetSuccessDataResponse(emptyModel)
return GetDataResponse(emptyModel)
}
func (ctr *BaseController[T]) DeleteById(c *gin.Context) (res *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
type DeleteByIdParams struct {
Id string `path:"id" description:"The ID of the item to get"`
}
func (ctr *BaseController[T]) DeleteById(c *gin.Context, params *DeleteByIdParams) (res *Response[T], err error) {
params.Id = c.Param("id")
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
}
if err := ctr.modelSvc.DeleteById(id); err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](err)
}
var emptyModel T
return GetSuccessDataResponse(emptyModel)
return GetDataResponse(emptyModel)
}
func (ctr *BaseController[T]) DeleteList(c *gin.Context) (res *Response[T], err error) {
type Payload struct {
Ids []string `json:"ids"`
}
var payload Payload
if err := c.ShouldBindJSON(&payload); err != nil {
return GetErrorDataResponse[T](err)
}
type DeleteListParams struct {
Ids []string `json:"ids" description:"The IDs of the items to delete"`
}
func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParams) (res *Response[T], err error) {
var ids []primitive.ObjectID
for _, id := range payload.Ids {
for _, id := range params.Ids {
objectId, err := primitive.ObjectIDFromHex(id)
if err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](err)
}
ids = append(ids, objectId)
}
@@ -198,16 +192,19 @@ func (ctr *BaseController[T]) DeleteList(c *gin.Context) (res *Response[T], err
"$in": ids,
},
}); err != nil {
return GetErrorDataResponse[T](err)
return GetErrorResponse[T](err)
}
var emptyModel T
return GetSuccessDataResponse(emptyModel)
return GetDataResponse(emptyModel)
}
// GetAll retrieves all items based on filter and sort
func (ctr *BaseController[T]) GetAll(_ *gin.Context, params *GetAllParams) (response *ListResponse[T], err error) {
query := params.Query
func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListResponse[T], err error) {
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err))
}
sort := params.Sort
if sort == nil {
sort = bson.D{{"_id", -1}}
@@ -225,29 +222,23 @@ func (ctr *BaseController[T]) GetAll(_ *gin.Context, params *GetAllParams) (resp
return nil, err
}
return GetSuccessListResponse(models, total)
return GetListResponse(models, total)
}
// GetWithPagination retrieves items with pagination
func (ctr *BaseController[T]) GetWithPagination(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
// params
pagination := params.Pagination
query := params.Query
sort := params.Sort
if pagination == nil {
pagination = GetDefaultPagination()
func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response *ListResponse[T], err error) {
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err))
}
// get list
models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{
Sort: sort,
Skip: pagination.Size * (pagination.Page - 1),
Limit: pagination.Size,
Sort: params.Sort,
Skip: params.Size * (params.Page - 1),
Limit: params.Size,
})
if err != nil {
if errors.Is(err, mongo2.ErrNoDocuments) {
return GetSuccessListResponse[T](nil, 0)
return GetListResponse[T](nil, 0)
} else {
return nil, err
}
@@ -260,36 +251,7 @@ func (ctr *BaseController[T]) GetWithPagination(_ *gin.Context, params *GetListP
}
// response
return GetSuccessListResponse(models, total)
}
// getAll is kept for backward compatibility
func (ctr *BaseController[T]) getAll(c *gin.Context) (response *ListResponse[T], err error) {
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
params := &GetAllParams{
Query: query,
Sort: sort,
}
return ctr.GetAll(c, params)
}
// getList is kept for backward compatibility
func (ctr *BaseController[T]) getList(c *gin.Context) (response *ListResponse[T], err error) {
// params
pagination := MustGetPagination(c)
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
params := &GetListParams{
Query: query,
Sort: sort,
Pagination: pagination,
}
return ctr.GetWithPagination(c, params)
return GetListResponse(models, total)
}
func NewController[T any](actions ...Action) *BaseController[T] {

View File

@@ -4,48 +4,54 @@ import (
"errors"
"fmt"
"github.com/crawlab-team/crawlab/core/fs"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/gin-gonic/gin"
"io"
"os"
"sync"
)
func GetBaseFileListDir(rootPath string, c *gin.Context) {
path := c.Query("path")
type GetBaseFileListDirParams struct {
Path string `path:"path"`
}
func GetBaseFileListDir(rootPath string, params *GetBaseFileListDirParams) (response *Response[[]interfaces.FsFileInfo], err error) {
path := params.Path
fsSvc, err := fs.GetBaseFileFsSvc(rootPath)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[[]interfaces.FsFileInfo](err)
}
files, err := fsSvc.List(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[[]interfaces.FsFileInfo](err)
}
}
HandleSuccessWithData(c, files)
//HandleSuccessWithData(c, files)
return GetDataResponse[[]interfaces.FsFileInfo](files)
}
func GetBaseFileFile(rootPath string, c *gin.Context) {
path := c.Query("path")
type GetBaseFileFileParams struct {
Path string `path:"path"`
}
func GetBaseFileFile(rootPath string, params *GetBaseFileFileParams) (response *Response[string], err error) {
path := params.Path
fsSvc, err := fs.GetBaseFileFsSvc(rootPath)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[string](err)
}
data, err := fsSvc.GetFile(path)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[string](err)
}
HandleSuccessWithData(c, string(data))
return GetDataResponse[string](string(data))
}
func GetBaseFileFileInfo(rootPath string, c *gin.Context) {

View File

@@ -4,19 +4,30 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/controllers"
"github.com/crawlab-team/crawlab/core/middlewares"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/user"
"github.com/loopfz/gadgeto/tonic"
"github.com/spf13/viper"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/wI2L/fizz"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func init() {
@@ -26,7 +37,12 @@ func init() {
// TestModel is a simple struct to be used as a model in tests
type TestModel models.TestModel
//type TestModel struct {
// Name string `json:"name" bson:"name"`
//}
var TestToken string
var TestUserId primitive.ObjectID
// SetupTestDB sets up the test database
func SetupTestDB() {
@@ -50,11 +66,12 @@ func SetupTestDB() {
panic(err)
}
TestToken = token
TestUserId = u.Id
}
// SetupRouter sets up the gin router for testing
func SetupRouter() *gin.Engine {
router := gin.Default()
func SetupRouter() *fizz.Fizz {
router := fizz.New()
return router
}
@@ -77,7 +94,7 @@ func TestBaseController_GetById(t *testing.T) {
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.GET("/testmodels/:id", ctr.GetById)
router.GET("/testmodels/:id", nil, tonic.Handler(ctr.GetById, 200))
// Create a test request
req, _ := http.NewRequest("GET", "/testmodels/"+id.Hex(), nil)
@@ -106,30 +123,34 @@ func TestBaseController_Post(t *testing.T) {
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.POST("/testmodels", ctr.Post)
router.POST("/testmodels", nil, tonic.Handler(ctr.Post, 200))
// Create a test request
testModel := TestModel{Name: "test"}
jsonValue, _ := json.Marshal(testModel)
requestBody := controllers.PostParams[TestModel]{
Data: testModel,
}
jsonValue, _ := json.Marshal(requestBody)
req, _ := http.NewRequest("POST", "/testmodels", bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var response controllers.Response[TestModel]
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "test", response.Data.Name)
require.NoError(t, err)
require.Equal(t, "test", response.Data.Name)
// Check if the document was inserted into the database
result, err := service.NewModelService[TestModel]().GetById(response.Data.Id)
assert.NoError(t, err)
assert.Equal(t, "test", result.Name)
require.NoError(t, err)
require.Equal(t, "test", result.Name)
}
func TestBaseController_DeleteById(t *testing.T) {
@@ -146,7 +167,7 @@ func TestBaseController_DeleteById(t *testing.T) {
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.DELETE("/testmodels/:id", ctr.DeleteById)
router.DELETE("/testmodels/:id", nil, tonic.Handler(ctr.DeleteById, 200))
// Create a test request
req, _ := http.NewRequest("DELETE", "/testmodels/"+id.Hex(), nil)
@@ -163,3 +184,321 @@ func TestBaseController_DeleteById(t *testing.T) {
_, err = service.NewModelService[TestModel]().GetById(id)
assert.Error(t, err)
}
func TestBaseController_GetList(t *testing.T) {
SetupTestDB()
defer CleanupTestDB()
// Insert test documents
modelSvc := service.NewModelService[TestModel]()
for i := 0; i < 15; i++ {
_, err := modelSvc.InsertOne(TestModel{Name: fmt.Sprintf("test%d", i)})
assert.NoError(t, err)
}
// Initialize the controller
ctr := controllers.NewController[TestModel]()
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.GET("/testmodels/list", nil, tonic.Handler(ctr.GetList, 200))
// Test case 1: Get with pagination
t.Run("test_get_with_pagination", func(t *testing.T) {
var testData = []struct {
Page int
ExpectedDataCount int
ExpectedTotalCount int
}{
{1, 10, 15},
{2, 5, 15},
}
for _, data := range testData {
params := url.Values{}
params.Add("page", strconv.Itoa(data.Page))
params.Add("size", "10")
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}
req, _ := http.NewRequest("GET", requestUrl.String(), nil)
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
var response controllers.ListResponse[TestModel]
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, data.ExpectedDataCount, len(response.Data))
assert.Equal(t, data.ExpectedTotalCount, response.Total)
}
})
// Test case 2: Get all
t.Run("test_get_all", func(t *testing.T) {
params := url.Values{}
params.Add("all", "true")
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}
req, _ := http.NewRequest("GET", requestUrl.String(), nil)
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
var response controllers.ListResponse[TestModel]
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, 15, len(response.Data))
assert.Equal(t, 15, response.Total)
})
// Test case 3: Get with query filter
t.Run("test_get_with_query_filter", func(t *testing.T) {
cond := []entity.Condition{
{Key: "name", Op: "eq", Value: "test1"},
}
condBytes, err := json.Marshal(cond)
require.Nil(t, err)
params := url.Values{}
params.Add("conditions", string(condBytes))
params.Add("page", "1")
params.Add("size", "10")
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}
req, _ := http.NewRequest("GET", requestUrl.String(), nil)
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
var response controllers.ListResponse[TestModel]
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, 1, len(response.Data))
assert.Equal(t, 1, response.Total)
})
}
func TestBaseController_PutById(t *testing.T) {
SetupTestDB()
defer CleanupTestDB()
// Insert a test document
id, err := service.NewModelService[TestModel]().InsertOne(TestModel{Name: "test"})
assert.NoError(t, err)
// Initialize the controller
ctr := controllers.NewController[TestModel]()
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.PUT("/testmodels/:id", nil, tonic.Handler(ctr.PutById, 200))
// Create a test request
updatedModel := TestModel{Name: "updated"}
requestParams := controllers.PutByIdParams[TestModel]{
Data: updatedModel,
}
jsonValue, _ := json.Marshal(requestParams)
req, _ := http.NewRequest("PUT", "/testmodels/"+id.Hex(), bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
var response controllers.Response[TestModel]
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "updated", response.Data.Name)
// Check if the document was updated in the database
result, err := service.NewModelService[TestModel]().GetById(id)
assert.NoError(t, err)
assert.Equal(t, "updated", result.Name)
// Test with invalid ID
req, _ = http.NewRequest("PUT", "/testmodels/invalid-id", bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestBaseController_PatchList(t *testing.T) {
SetupTestDB()
defer CleanupTestDB()
// Insert test documents
modelSvc := service.NewModelService[TestModel]()
var ids []primitive.ObjectID
for i := 0; i < 3; i++ {
id, err := modelSvc.InsertOne(TestModel{Name: fmt.Sprintf("test%d", i)})
assert.NoError(t, err)
ids = append(ids, id)
}
// Initialize the controller
ctr := controllers.NewController[TestModel]()
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.PATCH("/testmodels", nil, tonic.Handler(ctr.PatchList, 200))
// Create a test request
t.Run("test_patch_list", func(t *testing.T) {
var idStrings []string
for _, id := range ids {
idStrings = append(idStrings, id.Hex())
}
requestBody := controllers.PatchParams{
Ids: idStrings,
Update: bson.M{"name": "patched"},
}
jsonValue, _ := json.Marshal(requestBody)
req, _ := http.NewRequest("PATCH", "/testmodels", bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Get the user ID
userId := TestUserId
// Record time before the update
beforeUpdate := time.Now()
time.Sleep(100 * time.Millisecond)
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
time.Sleep(100 * time.Millisecond)
// Record time after the update
afterUpdate := time.Now()
// Check if the documents were updated in the database
for _, id := range ids {
result, err := modelSvc.GetById(id)
assert.NoError(t, err)
assert.Equal(t, "patched", result.Name)
// Verify updated_by is set to the current user's ID
assert.Equal(t, userId, result.UpdatedBy)
// Verify updated_ts is set to a recent timestamp
assert.GreaterOrEqual(t, result.UpdatedAt.UnixMilli(), beforeUpdate.UnixMilli())
assert.LessOrEqual(t, result.UpdatedAt.UnixMilli(), afterUpdate.UnixMilli())
}
})
// Test with invalid ID
t.Run("test_patch_list_with_invalid_id", func(t *testing.T) {
requestBody := controllers.PatchParams{
Ids: []string{"invalid-id"},
Update: bson.M{"name": "patched"},
}
jsonValue, _ := json.Marshal(requestBody)
req, _ := http.NewRequest("PATCH", "/testmodels", bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
func TestBaseController_DeleteList(t *testing.T) {
SetupTestDB()
defer CleanupTestDB()
// Insert test documents
modelSvc := service.NewModelService[TestModel]()
var ids []primitive.ObjectID
for i := 0; i < 3; i++ {
id, err := modelSvc.InsertOne(TestModel{Name: fmt.Sprintf("test%d", i)})
assert.NoError(t, err)
ids = append(ids, id)
}
// Initialize the controller
ctr := controllers.NewController[TestModel]()
// Set up the router
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.DELETE("/testmodels", nil, tonic.Handler(ctr.DeleteList, 200))
// Create a test request
var idStrings []string
for _, id := range ids {
idStrings = append(idStrings, id.Hex())
}
requestBody := controllers.DeleteListParams{
Ids: idStrings,
}
jsonValue, _ := json.Marshal(requestBody)
req, _ := http.NewRequest("DELETE", "/testmodels", bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
// Check if the documents were deleted from the database
for _, id := range ids {
_, err := modelSvc.GetById(id)
assert.Error(t, err)
}
// Test with invalid ID
requestBody = controllers.DeleteListParams{
Ids: []string{"invalid-id"},
}
jsonValue, _ = json.Marshal(requestBody)
req, _ = http.NewRequest("DELETE", "/testmodels", bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusBadRequest, w.Code)
}

View File

@@ -1,34 +1,31 @@
package controllers
import (
"errors"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/mongo"
"github.com/gin-gonic/gin"
"github.com/juju/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
mongo2 "go.mongodb.org/mongo-driver/mongo"
)
func GetProjectList(c *gin.Context) {
// get all list
all := MustGetFilterAll(c)
if all {
NewController[models.Project]().getAll(c)
return
func GetProjectList(c *gin.Context, params *GetListParams) (response *ListResponse[models.Project], err error) {
if params.All {
return NewController[models.Project]().GetAll(params)
}
// params
pagination := MustGetPagination(c)
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.Project](errors.BadRequestf("invalid request parameters: %v", err))
}
// get list
projects, err := service.NewModelService[models.Project]().GetMany(query, &mongo.FindOptions{
Sort: sort,
Skip: pagination.Size * (pagination.Page - 1),
Limit: pagination.Size,
Sort: params.Sort,
Skip: params.Size * (params.Page - 1),
Limit: params.Size,
})
if err != nil {
if err.Error() != mongo2.ErrNoDocuments.Error() {
@@ -82,5 +79,5 @@ func GetProjectList(c *gin.Context) {
projects[i].Spiders = cache[p.Id]
}
HandleSuccessWithListData(c, projects, total)
return GetListResponse[models.Project](projects, total)
}

View File

@@ -29,13 +29,13 @@ func GetGlobalFizzWrapper() *openapi.FizzWrapper {
// NewRouterGroups initializes the router groups with their respective middleware
func NewRouterGroups(app *gin.Engine) (groups *RouterGroups) {
// Create OpenAPI wrapper
wrapper := openapi.NewFizzWrapper(app)
globalWrapper = openapi.NewFizzWrapper(app)
return &RouterGroups{
AuthGroup: app.Group("/", middlewares.AuthorizationMiddleware()),
SyncAuthGroup: app.Group("/", middlewares.SyncAuthorizationMiddleware()),
AnonymousGroup: app.Group("/"),
Wrapper: wrapper,
Wrapper: globalWrapper,
}
}
@@ -232,33 +232,6 @@ func InitRoutes(app *gin.Engine) (err error) {
HandlerFunc: GetProjectList,
},
}...))
RegisterController(groups.AuthGroup, "/schedules", NewController[models.Schedule]([]Action{
{
Method: http.MethodPost,
Path: "",
HandlerFunc: PostSchedule,
},
{
Method: http.MethodPut,
Path: "/:id",
HandlerFunc: PutScheduleById,
},
{
Method: http.MethodPost,
Path: "/:id/enable",
HandlerFunc: PostScheduleEnable,
},
{
Method: http.MethodPost,
Path: "/:id/disable",
HandlerFunc: PostScheduleDisable,
},
{
Method: http.MethodPost,
Path: "/:id/run",
HandlerFunc: PostScheduleRun,
},
}...))
RegisterController(groups.AuthGroup, "/spiders", NewController[models.Spider]([]Action{
{
Method: http.MethodGet,
@@ -351,6 +324,35 @@ func InitRoutes(app *gin.Engine) (err error) {
HandlerFunc: GetSpiderResults,
},
}...))
groups.AnonymousGroup.GET("/openapi.json", GetOpenAPI)
return
RegisterController(groups.AuthGroup, "/schedules", NewController[models.Schedule]([]Action{
{
Method: http.MethodPost,
Path: "",
HandlerFunc: PostSchedule,
},
{
Method: http.MethodPut,
Path: "/:id",
HandlerFunc: PutScheduleById,
},
{
Method: http.MethodPost,
Path: "/:id/enable",
HandlerFunc: PostScheduleEnable,
},
{
Method: http.MethodPost,
Path: "/:id/disable",
HandlerFunc: PostScheduleDisable,
},
{
Method: http.MethodPost,
Path: "/:id/run",
HandlerFunc: PostScheduleRun,
},
}...))
RegisterController(groups.AuthGroup, "/tasks", NewController[models.Task]([]Action{
{
Method: http.MethodGet,

View File

@@ -60,7 +60,7 @@ func TestRegisterController_Routes(t *testing.T) {
func TestInitRoutes_ProjectsRoute(t *testing.T) {
router := gin.Default()
controllers.InitRoutes(router)
_ = controllers.InitRoutes(router)
// Check if the projects route is registered
routes := router.Routes()

View File

@@ -1,59 +1,55 @@
package controllers
import (
"errors"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
mongo2 "github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/spider"
"math"
"os"
"path/filepath"
"sync"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/fs"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
mongo2 "github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/spider"
"github.com/crawlab-team/crawlab/core/spider/admin"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/gin-gonic/gin"
"github.com/juju/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func GetSpiderById(c *gin.Context) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
// GetSpiderById handles getting a spider by ID
func GetSpiderById(_ *gin.Context, params *GetByIdParams) (response *Response[models.Spider], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
}
s, err := service.NewModelService[models.Spider]().GetById(id)
if errors.Is(err, mongo.ErrNoDocuments) {
HandleErrorNotFound(c, err)
return
return GetErrorResponse[models.Spider](errors.NotFoundf("spider not found"))
}
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
// stat
s.Stat, err = service.NewModelService[models.SpiderStat]().GetById(s.Id)
if err != nil {
if !errors.Is(err, mongo.ErrNoDocuments) {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
}
// data collection (compatible to old version) # TODO: remove in the future
// data collection (compatible to old version)
if s.ColName == "" && !s.ColId.IsZero() {
col, err := service.NewModelService[models.DataCollection]().GetById(s.ColId)
if err != nil {
if !errors.Is(err, mongo.ErrNoDocuments) {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
} else {
s.ColName = col.Name
@@ -65,55 +61,51 @@ func GetSpiderById(c *gin.Context) {
s.Git, err = service.NewModelService[models.Git]().GetById(s.GitId)
if err != nil {
if !errors.Is(err, mongo.ErrNoDocuments) {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
}
}
HandleSuccessWithData(c, s)
return GetDataResponse(*s)
}
func GetSpiderList(c *gin.Context, params *GetListParams) {
// GetSpiderList handles getting a list of spiders with optional stats
func GetSpiderList(c *gin.Context, params *GetListParams) (response *ListResponse[models.Spider], err error) {
// get all list
all := MustGetFilterAll(c)
all := params.All
if all {
NewController[models.Spider]().getAll(c)
return
return NewController[models.Spider]().GetAll(params)
}
// get list
withStats := c.Query("stats")
if withStats == "" {
NewController[models.Spider]().GetList(c, params)
return
return NewController[models.Spider]().GetList(c, params)
}
// get list with stats
getSpiderListWithStats(c)
return getSpiderListWithStats(params)
}
func getSpiderListWithStats(c *gin.Context) {
// params
pagination := MustGetPagination(c)
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
func getSpiderListWithStats(params *GetListParams) (response *ListResponse[models.Spider], err error) {
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.Spider](errors.BadRequestf("invalid request parameters: %v", err))
}
// get list
spiders, err := service.NewModelService[models.Spider]().GetMany(query, &mongo2.FindOptions{
Sort: sort,
Skip: pagination.Size * (pagination.Page - 1),
Limit: pagination.Size,
Sort: params.Sort,
Skip: params.Size * (params.Page - 1),
Limit: params.Size,
})
if err != nil {
if err.Error() != mongo.ErrNoDocuments.Error() {
HandleErrorInternalServerError(c, err)
if !errors.Is(err, mongo.ErrNoDocuments) {
return GetErrorListResponse[models.Spider](err)
}
return
return GetListResponse[models.Spider]([]models.Spider{}, 0)
}
if len(spiders) == 0 {
HandleSuccessWithListData(c, []models.Spider{}, 0)
return
return GetListResponse[models.Spider]([]models.Spider{}, 0)
}
// ids
@@ -129,15 +121,13 @@ func getSpiderListWithStats(c *gin.Context) {
// total count
total, err := service.NewModelService[models.Spider]().Count(query)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.Spider](err)
}
// stat list
spiderStats, err := service.NewModelService[models.SpiderStat]().GetMany(bson.M{"_id": bson.M{"$in": ids}}, nil)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.Spider](err)
}
// cache stat list to dict
@@ -170,15 +160,13 @@ func getSpiderListWithStats(c *gin.Context) {
}
tasks, err = service.NewModelService[models.Task]().GetMany(queryTask, nil)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.Spider](err)
}
// task stats list
taskStats, err := service.NewModelService[models.TaskStat]().GetMany(queryTask, nil)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.Spider](err)
}
// cache task stats to dict
@@ -201,8 +189,7 @@ func getSpiderListWithStats(c *gin.Context) {
if len(gitIds) > 0 && utils.IsPro() {
gits, err = service.NewModelService[models.Git]().GetMany(bson.M{"_id": bson.M{"$in": gitIds}}, nil)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.Spider](err)
}
}
@@ -240,16 +227,12 @@ func getSpiderListWithStats(c *gin.Context) {
}
// response
HandleSuccessWithListData(c, data, total)
return GetListResponse(data, total)
}
func PostSpider(c *gin.Context) {
// bind
var s models.Spider
if err := c.ShouldBindJSON(&s); err != nil {
HandleErrorBadRequest(c, err)
return
}
// PostSpider handles creating a new spider
func PostSpider(c *gin.Context, params *PostParams[models.Spider]) (response *Response[models.Spider], err error) {
s := params.Data
if s.Mode == "" {
s.Mode = constants.RunTypeRandom
@@ -266,8 +249,7 @@ func PostSpider(c *gin.Context) {
s.SetUpdated(u.Id)
id, err := service.NewModelService[models.Spider]().InsertOne(s)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
s.SetId(id)
@@ -278,20 +260,17 @@ func PostSpider(c *gin.Context) {
st.SetUpdated(u.Id)
_, err = service.NewModelService[models.SpiderStat]().InsertOne(st)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
// create folder
fsSvc, err := getSpiderFsSvcById(id)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
err = fsSvc.CreateDir(".")
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
// create template if available
@@ -299,63 +278,52 @@ func PostSpider(c *gin.Context) {
if templateSvc := spider.GetSpiderTemplateRegistryService(); templateSvc != nil {
err = templateSvc.CreateTemplate(s.Id)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
}
}
HandleSuccessWithData(c, s)
return GetDataResponse(s)
}
func PutSpiderById(c *gin.Context) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
// PutSpiderById handles updating a spider by ID
func PutSpiderById(c *gin.Context, params *PutByIdParams[models.Spider]) (response *Response[models.Spider], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
HandleErrorBadRequest(c, err)
return
}
// bind
var s models.Spider
if err := c.ShouldBindJSON(&s); err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
}
u := GetUserFromContext(c)
modelSvc := service.NewModelService[models.Spider]()
// save
s.SetUpdated(u.Id)
err = modelSvc.ReplaceById(id, s)
if err != nil {
HandleErrorInternalServerError(c, err)
return
params.Data.SetUpdated(u.Id)
if params.Data.Id.IsZero() {
params.Data.SetId(id)
}
_s, err := modelSvc.GetById(id)
err = modelSvc.ReplaceById(id, params.Data)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
s = *_s
HandleSuccessWithData(c, s)
s, err := modelSvc.GetById(id)
if err != nil {
return GetErrorResponse[models.Spider](err)
}
return GetDataResponse(*s)
}
func DeleteSpiderById(c *gin.Context) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
// DeleteSpiderById handles deleting a spider by ID
func DeleteSpiderById(_ *gin.Context, params *DeleteByIdParams) (response *Response[models.Spider], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
}
// spider
s, err := service.NewModelService[models.Spider]().GetById(id)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](errors.NotFoundf("spider not found"))
}
if err := mongo2.RunTransaction(func(context mongo.SessionContext) (err error) {
@@ -416,14 +384,13 @@ func DeleteSpiderById(c *gin.Context) {
return nil
}); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
if !s.GitId.IsZero() {
go func() {
// delete spider directory
fsSvc, err := getSpiderFsSvcById(id)
fsSvc, err := getSpiderFsSvcById(s.Id)
if err != nil {
logger.Errorf("failed to get spider fs service: %v", err)
return
@@ -436,34 +403,39 @@ func DeleteSpiderById(c *gin.Context) {
}()
}
HandleSuccess(c)
return GetDataResponse(models.Spider{})
}
func DeleteSpiderList(c *gin.Context) {
var payload struct {
Ids []primitive.ObjectID `json:"ids"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
HandleErrorBadRequest(c, err)
return
type DeleteSpiderListParams struct {
Ids []string `json:"ids" validate:"required"`
}
// DeleteSpiderList handles deleting multiple spiders
func DeleteSpiderList(_ *gin.Context, params *DeleteSpiderListParams) (response *Response[models.Spider], err error) {
var ids []primitive.ObjectID
for _, id := range params.Ids {
_id, err := primitive.ObjectIDFromHex(id)
if err != nil {
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
}
ids = append(ids, _id)
}
// Fetch spiders before deletion
spiders, err := service.NewModelService[models.Spider]().GetMany(bson.M{
"_id": bson.M{
"$in": payload.Ids,
"$in": ids,
},
}, nil)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return nil, err
}
if err := mongo2.RunTransaction(func(context mongo.SessionContext) (err error) {
// delete spiders
if err := service.NewModelService[models.Spider]().DeleteMany(bson.M{
"_id": bson.M{
"$in": payload.Ids,
"$in": ids,
},
}); err != nil {
return err
@@ -472,14 +444,14 @@ func DeleteSpiderList(c *gin.Context) {
// delete spider stats
if err := service.NewModelService[models.SpiderStat]().DeleteMany(bson.M{
"_id": bson.M{
"$in": payload.Ids,
"$in": ids,
},
}); err != nil {
return err
}
// related tasks
tasks, err := service.NewModelService[models.Task]().GetMany(bson.M{"spider_id": bson.M{"$in": payload.Ids}}, nil)
tasks, err := service.NewModelService[models.Task]().GetMany(bson.M{"spider_id": bson.M{"$in": ids}}, nil)
if err != nil {
return err
}
@@ -521,8 +493,7 @@ func DeleteSpiderList(c *gin.Context) {
return nil
}); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.Spider](err)
}
// Delete spider directories
@@ -554,25 +525,25 @@ func DeleteSpiderList(c *gin.Context) {
wg.Wait()
}()
HandleSuccess(c)
return GetDataResponse(models.Spider{})
}
func GetSpiderListDir(c *gin.Context) {
func GetSpiderListDir(c *gin.Context, params *GetBaseFileListDirParams) (response *Response[[]interfaces.FsFileInfo], err error) {
rootPath, err := getSpiderRootPathByContext(c)
if err != nil {
HandleErrorForbidden(c, err)
return
}
GetBaseFileListDir(rootPath, c)
return GetBaseFileListDir(rootPath, params)
}
func GetSpiderFile(c *gin.Context) {
func GetSpiderFile(c *gin.Context, params *GetBaseFileFileParams) (response *Response[string], err error) {
rootPath, err := getSpiderRootPathByContext(c)
if err != nil {
HandleErrorForbidden(c, err)
return
}
GetBaseFileFile(rootPath, c)
return GetBaseFileFile(rootPath, params)
}
func GetSpiderFileInfo(c *gin.Context) {
@@ -648,18 +619,16 @@ func PostSpiderExport(c *gin.Context) {
PostBaseFileExport(rootPath, c)
}
func PostSpiderRun(c *gin.Context) {
func PostSpiderRun(c *gin.Context) (response *Response[[]primitive.ObjectID], err error) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[[]primitive.ObjectID](errors.BadRequestf("invalid id format"))
}
// options
var opts interfaces.SpiderRunOptions
if err := c.ShouldBindJSON(&opts); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[[]primitive.ObjectID](err)
}
// user
@@ -670,11 +639,11 @@ func PostSpiderRun(c *gin.Context) {
// schedule tasks
taskIds, err := admin.GetSpiderAdminService().Schedule(id, &opts)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[[]primitive.ObjectID](err)
}
HandleSuccessWithData(c, taskIds)
return GetDataResponse(taskIds)
}
func GetSpiderResults(c *gin.Context) {
@@ -733,17 +702,13 @@ func getSpiderFsSvcById(id primitive.ObjectID) (svc interfaces.FsService, err er
}
func getSpiderRootPathByContext(c *gin.Context) (rootPath string, err error) {
// spider id
id, err := primitive.ObjectIDFromHex(c.Param("id"))
if err != nil {
return "", err
}
// spider
s, err := service.NewModelService[models.Spider]().GetById(id)
if err != nil {
return "", err
}
return utils.GetSpiderRootPath(s)
}

View File

@@ -3,11 +3,13 @@ package controllers_test
import (
"bytes"
"encoding/json"
"github.com/crawlab-team/crawlab/core/models/models"
"net/http"
"net/http/httptest"
"testing"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/loopfz/gadgeto/tonic"
"github.com/crawlab-team/crawlab/core/controllers"
"github.com/crawlab-team/crawlab/core/middlewares"
"github.com/crawlab-team/crawlab/core/models/service"
@@ -24,9 +26,9 @@ func TestCreateSpider(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.POST("/spiders", controllers.PostSpider)
router.POST("/spiders", nil, tonic.Handler(controllers.PostSpider, 200))
payload := models.Spider{
Name: "Test Spider",
@@ -54,9 +56,9 @@ func TestGetSpiderById(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.GET("/spiders/:id", controllers.GetSpiderById)
router.GET("/spiders/:id", nil, tonic.Handler(controllers.GetSpiderById, 200))
model := models.Spider{
Name: "Test Spider",
@@ -89,9 +91,9 @@ func TestUpdateSpiderById(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.PUT("/spiders/:id", controllers.PutSpiderById)
router.PUT("/spiders/:id", nil, tonic.Handler(controllers.PutSpiderById, 200))
model := models.Spider{
Name: "Test Spider",
@@ -110,7 +112,10 @@ func TestUpdateSpiderById(t *testing.T) {
ColName: "test_spider",
}
payload.SetId(id)
jsonValue, _ := json.Marshal(payload)
requestBody := controllers.PutByIdParams[models.Spider]{
Data: payload,
}
jsonValue, _ := json.Marshal(requestBody)
req, _ := http.NewRequest("PUT", "/spiders/"+spiderId, bytes.NewBuffer(jsonValue))
req.Header.Set("Authorization", TestToken)
resp := httptest.NewRecorder()
@@ -136,9 +141,9 @@ func TestDeleteSpiderById(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.DELETE("/spiders/:id", controllers.DeleteSpiderById)
router.DELETE("/spiders/:id", nil, tonic.Handler(controllers.DeleteSpiderById, 200))
model := models.Spider{
Name: "Test Spider",
@@ -186,9 +191,9 @@ func TestDeleteSpiderList(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.DELETE("/spiders", controllers.DeleteSpiderList)
router.DELETE("/spiders", nil, tonic.Handler(controllers.DeleteSpiderList, 200))
modelList := []models.Spider{
{

View File

@@ -1,11 +1,11 @@
package controllers
import (
"errors"
"fmt"
"github.com/crawlab-team/crawlab/core/mongo"
"regexp"
"github.com/crawlab-team/crawlab/core/mongo"
"github.com/juju/errors"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/utils"
@@ -15,34 +15,31 @@ import (
mongo2 "go.mongodb.org/mongo-driver/mongo"
)
func GetUserById(c *gin.Context) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
func GetUserById(c *gin.Context, params *GetByIdParams) (response *Response[models.User], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
HandleErrorBadRequest(c, err)
return
}
getUserById(id, c)
return getUserById(id)
}
func GetUserList(c *gin.Context) {
// params
pagination := MustGetPagination(c)
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
// get users
func GetUserList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.User], err error) {
query, err := GetFilterQueryFromListParams(params)
if err != nil {
return GetErrorListResponse[models.User](err)
}
users, err := service.NewModelService[models.User]().GetMany(query, &mongo.FindOptions{
Sort: sort,
Skip: pagination.Size * (pagination.Page - 1),
Limit: pagination.Size,
Sort: params.Sort,
Skip: params.Size * (params.Page - 1),
Limit: params.Size,
})
if err != nil {
if errors.Is(err, mongo2.ErrNoDocuments) {
HandleSuccessWithListData(c, nil, 0)
return GetListResponse[models.User](nil, 0)
} else {
HandleErrorInternalServerError(c, err)
return GetErrorListResponse[models.User](err)
}
return
}
// get roles
@@ -58,8 +55,7 @@ func GetUserList(c *gin.Context) {
"_id": bson.M{"$in": roleIds},
}, nil)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.User](err)
}
rolesMap := make(map[primitive.ObjectID]models.Role)
for _, role := range roles {
@@ -80,149 +76,124 @@ func GetUserList(c *gin.Context) {
// total count
total, err := service.NewModelService[models.User]().Count(query)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorListResponse[models.User](err)
}
// response
HandleSuccessWithListData(c, users, total)
return GetListResponse[models.User](users, total)
}
func PostUser(c *gin.Context) {
var payload struct {
Username string `json:"username"`
Password string `json:"password"`
Role string `json:"role"`
RoleId primitive.ObjectID `json:"role_id"`
Email string `json:"email"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
HandleErrorBadRequest(c, err)
return
}
type PostUserParams struct {
Username string `json:"username" validate:"required"`
Password string `json:"password" validate:"required"`
Role string `json:"role"`
RoleId primitive.ObjectID `json:"role_id"`
Email string `json:"email"`
}
func PostUser(c *gin.Context, params *PostUserParams) (response *Response[models.User], err error) {
// Validate email format
if payload.Email != "" {
if params.Email != "" {
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(payload.Email) {
HandleErrorBadRequest(c, fmt.Errorf("invalid email format"))
return
if !emailRegex.MatchString(params.Email) {
return GetErrorResponse[models.User](errors.BadRequestf("invalid email format"))
}
}
if !payload.RoleId.IsZero() {
_, err := service.NewModelService[models.Role]().GetById(payload.RoleId)
if !params.RoleId.IsZero() {
_, err := service.NewModelService[models.Role]().GetById(params.RoleId)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.User](errors.BadRequestf("role not found: %v", err))
}
}
u := GetUserFromContext(c)
model := models.User{
Username: payload.Username,
Password: utils.EncryptMd5(payload.Password),
Role: payload.Role,
RoleId: payload.RoleId,
Email: payload.Email,
Username: params.Username,
Password: utils.EncryptMd5(params.Password),
Role: params.Role,
RoleId: params.RoleId,
Email: params.Email,
}
model.SetCreated(u.Id)
model.SetUpdated(u.Id)
id, err := service.NewModelService[models.User]().InsertOne(model)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
result, err := service.NewModelService[models.User]().GetById(id)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
HandleSuccessWithData(c, result)
return GetDataResponse[models.User](*result)
}
func PutUserById(c *gin.Context) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
func PutUserById(c *gin.Context, params *PutByIdParams[models.User]) (response *Response[models.User], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
}
putUser(id, c)
return putUser(id, GetUserFromContext(c).Id, params.Data)
}
func PostUserChangePassword(c *gin.Context) {
// get id
type PostUserChangePasswordParams struct {
Id string `path:"id"`
Password string `json:"password" validate:"required"`
}
func PostUserChangePassword(c *gin.Context, params *PostUserChangePasswordParams) (response *Response[models.User], err error) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
}
postUserChangePassword(id, c)
return postUserChangePassword(id, GetUserFromContext(c).Id, params.Password)
}
func DeleteUserById(c *gin.Context) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
func DeleteUserById(_ *gin.Context, params *DeleteByIdParams) (response *Response[models.User], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
}
user, err := service.NewModelService[models.User]().GetById(id)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
if user.RootAdmin {
HandleErrorForbidden(c, errors.New("root admin cannot be deleted"))
return
return GetErrorResponse[models.User](errors.New("root admin cannot be deleted"))
}
if err := service.NewModelService[models.User]().DeleteById(id); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
HandleSuccess(c)
return GetDataResponse[models.User](models.User{})
}
func DeleteUserList(c *gin.Context) {
type Payload struct {
Ids []string `json:"ids"`
}
var payload Payload
if err := c.ShouldBindJSON(&payload); err != nil {
HandleErrorBadRequest(c, err)
return
}
func DeleteUserList(_ *gin.Context, params *DeleteListParams) (response *Response[models.User], err error) {
// Convert string IDs to ObjectIDs
var ids []primitive.ObjectID
for _, id := range payload.Ids {
for _, id := range params.Ids {
objectId, err := primitive.ObjectIDFromHex(id)
if err != nil {
HandleErrorBadRequest(c, err)
return
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
}
ids = append(ids, objectId)
}
// Check if root admin is in the list
_, err := service.NewModelService[models.User]().GetOne(bson.M{
_, err = service.NewModelService[models.User]().GetOne(bson.M{
"_id": bson.M{
"$in": ids,
},
"root_admin": true,
}, nil)
if err == nil {
HandleErrorForbidden(c, errors.New("root admin cannot be deleted"))
return
return GetErrorResponse[models.User](errors.New("root admin cannot be deleted"))
}
if !errors.Is(err, mongo2.ErrNoDocuments) {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
// Delete users
@@ -231,34 +202,43 @@ func DeleteUserList(c *gin.Context) {
"$in": ids,
},
}); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
HandleSuccess(c)
return GetDataResponse[models.User](models.User{})
}
func GetUserMe(c *gin.Context) {
func GetUserMe(c *gin.Context) (response *Response[models.User], err error) {
u := GetUserFromContext(c)
getUserByIdWithRoutes(u.Id, c)
return getUserByIdWithRoutes(u.Id)
}
func PutUserMe(c *gin.Context) {
type PutUserMeParams struct {
Data models.User `json:"data"`
}
func PutUserMe(c *gin.Context, params *PutUserMeParams) (response *Response[models.User], err error) {
u := GetUserFromContext(c)
putUser(u.Id, c)
return putUser(u.Id, u.Id, params.Data)
}
func PostUserMeChangePassword(c *gin.Context) {
type PostUserMeChangePasswordParams struct {
Password string `json:"password" validate:"required"`
}
func PostUserMeChangePassword(c *gin.Context, params *PostUserMeChangePasswordParams) (response *Response[models.User], err error) {
u := GetUserFromContext(c)
postUserChangePassword(u.Id, c)
return postUserChangePassword(u.Id, u.Id, params.Password)
}
func getUserById(userId primitive.ObjectID, c *gin.Context) {
func getUserById(userId primitive.ObjectID) (response *Response[models.User], err error) {
// get user
user, err := service.NewModelService[models.User]().GetById(userId)
if err != nil {
HandleErrorInternalServerError(c, err)
return
if errors.Is(err, mongo2.ErrNoDocuments) {
return GetErrorResponse[models.User](errors.BadRequestf("user not found: %v", err))
}
return GetErrorResponse[models.User](err)
}
// get role
@@ -266,61 +246,58 @@ func getUserById(userId primitive.ObjectID, c *gin.Context) {
if !user.RoleId.IsZero() {
role, err := service.NewModelService[models.Role]().GetById(user.RoleId)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](errors.BadRequestf("role not found: %v", err))
}
user.Role = role.Name
user.RootAdminRole = role.RootAdmin
}
}
HandleSuccessWithData(c, user)
return GetDataResponse[models.User](*user)
}
func getUserByIdWithRoutes(userId primitive.ObjectID, c *gin.Context) {
func getUserByIdWithRoutes(userId primitive.ObjectID) (response *Response[models.User], err error) {
if !utils.IsPro() {
getUserById(userId, c)
return
return getUserById(userId)
}
// get user
user, err := service.NewModelService[models.User]().GetById(userId)
if err != nil {
HandleErrorInternalServerError(c, err)
return
if errors.Is(err, mongo2.ErrNoDocuments) {
return GetErrorResponse[models.User](errors.BadRequestf("user not found: %v", err))
}
return GetErrorResponse[models.User](err)
}
// get role
if !user.RoleId.IsZero() {
role, err := service.NewModelService[models.Role]().GetById(user.RoleId)
if err != nil {
HandleErrorInternalServerError(c, err)
return
if errors.Is(err, mongo2.ErrNoDocuments) {
return GetErrorResponse[models.User](errors.BadRequestf("role not found: %v", err))
}
return GetErrorResponse[models.User](err)
}
user.Role = role.Name
user.RootAdminRole = role.RootAdmin
user.Routes = role.Routes
}
HandleSuccessWithData(c, user)
return GetDataResponse[models.User](*user)
}
func putUser(userId primitive.ObjectID, c *gin.Context) {
// get payload
var user models.User
if err := c.ShouldBindJSON(&user); err != nil {
HandleErrorBadRequest(c, err)
return
}
func putUser(userId, by primitive.ObjectID, user models.User) (response *Response[models.User], err error) {
// model service
modelSvc := service.NewModelService[models.User]()
// update user
userDb, err := modelSvc.GetById(userId)
if err != nil {
HandleErrorInternalServerError(c, err)
return
if errors.Is(err, mongo2.ErrNoDocuments) {
return GetErrorResponse[models.User](errors.BadRequestf("user not found: %v", err))
}
return GetErrorResponse[models.User](err)
}
// if root admin, disallow changing username and role
@@ -332,53 +309,34 @@ func putUser(userId primitive.ObjectID, c *gin.Context) {
// disallow changing password
user.Password = userDb.Password
// current user
u := GetUserFromContext(c)
// update user
user.SetUpdated(u.Id)
user.SetUpdated(by)
if user.Id.IsZero() {
user.Id = userId
}
if err := modelSvc.ReplaceById(userId, user); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
// handle success
HandleSuccess(c)
return GetDataResponse[models.User](user)
}
func postUserChangePassword(userId primitive.ObjectID, c *gin.Context) {
// get payload
var payload struct {
Password string `json:"password"`
func postUserChangePassword(userId, by primitive.ObjectID, password string) (response *Response[models.User], err error) {
if len(password) < 5 {
return GetErrorResponse[models.User](errors.BadRequestf("password must be at least 5 characters"))
}
if err := c.ShouldBindJSON(&payload); err != nil {
HandleErrorBadRequest(c, err)
return
}
if len(payload.Password) < 5 {
HandleErrorBadRequest(c, errors.New("password must be at least 5 characters"))
return
}
// current user
u := GetUserFromContext(c)
// update password
userDb, err := service.NewModelService[models.User]().GetById(userId)
if err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
userDb.SetUpdated(u.Id)
userDb.Password = utils.EncryptMd5(payload.Password)
userDb.SetUpdated(by)
userDb.Password = utils.EncryptMd5(password)
if err := service.NewModelService[models.User]().ReplaceById(userDb.Id, *userDb); err != nil {
HandleErrorInternalServerError(c, err)
return
return GetErrorResponse[models.User](err)
}
// handle success
HandleSuccess(c)
return GetDataResponse[models.User](models.User{})
}

View File

@@ -1,6 +1,8 @@
package controllers_test
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -12,7 +14,7 @@ import (
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/user"
"github.com/gin-gonic/gin"
"github.com/loopfz/gadgeto/tonic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
@@ -36,9 +38,9 @@ func TestGetUserById_Success(t *testing.T) {
require.Nil(t, err)
u.SetId(id)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.GET("/users/:id", controllers.GetUserById)
router.GET("/users/:id", nil, tonic.Handler(controllers.GetUserById, 200))
// Test valid ID
req, err := http.NewRequest(http.MethodGet, "/users/"+id.Hex(), nil)
@@ -79,9 +81,9 @@ func TestGetUserList_Success(t *testing.T) {
assert.Nil(t, err)
}
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.GET("/users", controllers.GetUserList)
router.GET("/users", nil, tonic.Handler(controllers.GetUserList, 200))
// Test default pagination
req, err := http.NewRequest(http.MethodGet, "/users", nil)
@@ -108,9 +110,9 @@ func TestPostUser_Success(t *testing.T) {
SetupTestDB()
defer CleanupTestDB()
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.POST("/users", controllers.PostUser)
router.POST("/users", nil, tonic.Handler(controllers.PostUser, 200))
// Test creating a new user with valid data
reqBody := strings.NewReader(`{
@@ -161,9 +163,9 @@ func TestPutUserById_Success(t *testing.T) {
require.Nil(t, err)
u.SetId(id)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.PUT("/users/:id", controllers.PutUserById)
router.PUT("/users/:id", nil, tonic.Handler(controllers.PutUserById, 200))
// Test case 1: Regular user update
reqBody := strings.NewReader(`{
@@ -214,9 +216,9 @@ func TestPostUserChangePassword_Success(t *testing.T) {
require.Nil(t, err)
u.SetId(id)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.POST("/users/:id/change-password", controllers.PostUserChangePassword)
router.POST("/users/:id/change-password", nil, tonic.Handler(controllers.PostUserChangePassword, 200))
// Add validation for minimum password length
// Test case 1: Valid password
@@ -252,9 +254,9 @@ func TestGetUserMe_Success(t *testing.T) {
require.Nil(t, err)
u.SetId(id)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.GET("/users/me", controllers.GetUserMe)
router.GET("/users/me", nil, tonic.Handler(controllers.GetUserMe, 200))
req, _ := http.NewRequest(http.MethodGet, "/users/me", nil)
req.Header.Set("Content-Type", "application/json")
@@ -288,23 +290,26 @@ func TestPutUserMe_Success(t *testing.T) {
require.Nil(t, err)
// Create router
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.PUT("/users/me", controllers.PutUserMe)
router.PUT("/users/me", nil, tonic.Handler(controllers.PutUserMe, 200))
// Test valid update
reqBody := strings.NewReader(`{
"username": "updateduser",
"email": "updated@example.com"
}`)
req, err := http.NewRequest(http.MethodPut, "/users/me", reqBody)
reqParams := controllers.PutUserMeParams{
Data: models.User{
Username: "updateduser",
Email: "updated@example.com",
},
}
jsonValue, _ := json.Marshal(reqParams)
req, err := http.NewRequest(http.MethodPut, "/users/me", bytes.NewBuffer(jsonValue))
assert.Nil(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equalf(t, http.StatusOK, w.Code, "response body: %s", w.Body.String())
// Verify the update
updatedUser, err := modelSvc.GetById(id)
@@ -338,9 +343,9 @@ func TestPostUserMeChangePassword_Success(t *testing.T) {
require.Nil(t, err)
// Create router
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.POST("/users/me/change-password", controllers.PostUserMeChangePassword)
router.POST("/users/me/change-password", nil, tonic.Handler(controllers.PostUserMeChangePassword, 200))
// Test valid password change
password := "newValidPassword123"
@@ -388,9 +393,9 @@ func TestDeleteUserById_Success(t *testing.T) {
require.Nil(t, err)
u.SetId(id)
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.DELETE("/users/:id", controllers.DeleteUserById)
router.DELETE("/users/:id", nil, tonic.Handler(controllers.DeleteUserById, 200))
// Test deleting normal user
req, err := http.NewRequest(http.MethodDelete, "/users/"+id.Hex(), nil)
@@ -423,7 +428,7 @@ func TestDeleteUserById_Success(t *testing.T) {
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String())
// Test deleting with invalid ID
req, err = http.NewRequest(http.MethodDelete, "/users/invalid-id", nil)
@@ -460,9 +465,9 @@ func TestDeleteUserList_Success(t *testing.T) {
}
}
router := gin.Default()
router := SetupRouter()
router.Use(middlewares.AuthorizationMiddleware())
router.DELETE("/users", controllers.DeleteUserList)
router.DELETE("/users", nil, tonic.Handler(controllers.DeleteUserList, 200))
// Test deleting normal users
reqBody := strings.NewReader(fmt.Sprintf(`{"ids":["%s","%s"]}`,
@@ -492,7 +497,7 @@ func TestDeleteUserList_Success(t *testing.T) {
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String())
// Test with mix of valid and invalid ids
reqBody = strings.NewReader(fmt.Sprintf(`{"ids":["%s","invalid-id"]}`, normalUserIds[0].Hex()))

331
core/controllers/utils.go Normal file
View File

@@ -0,0 +1,331 @@
package controllers
import (
"encoding/json"
"errors"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/trace"
"github.com/gin-gonic/gin"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"net/http"
"reflect"
)
var logger = utils.NewLogger("Controllers")
func GetUserFromContext(c *gin.Context) (u *models.User) {
value, ok := c.Get(constants.UserContextKey)
if !ok {
return nil
}
u, ok = value.(*models.User)
if !ok {
return nil
}
return u
}
func GetFilterQueryFromListParams(params *GetListParams) (q bson.M, err error) {
if params.Conditions == "" {
return nil, nil
}
conditions, err := GetFilterFromConditionString(params.Conditions)
if err != nil {
return nil, err
}
return utils.FilterToQuery(conditions)
}
func GetFilterQueryFromConditionString(condStr string) (q bson.M, err error) {
conditions, err := GetFilterFromConditionString(condStr)
if err != nil {
return nil, err
}
return utils.FilterToQuery(conditions)
}
func GetFilterFromConditionString(condStr string) (f *entity.Filter, err error) {
var conditions []*entity.Condition
if err := json.Unmarshal([]byte(condStr), &conditions); err != nil {
return nil, err
}
for i, cond := range conditions {
v := reflect.ValueOf(cond.Value)
switch v.Kind() {
case reflect.String:
item := cond.Value.(string)
// attempt to convert object id
id, err := primitive.ObjectIDFromHex(item)
if err == nil {
conditions[i].Value = id
} else {
conditions[i].Value = item
}
case reflect.Slice, reflect.Array:
var items []interface{}
for i := 0; i < v.Len(); i++ {
vItem := v.Index(i)
item := vItem.Interface()
// string
stringItem, ok := item.(string)
if ok {
id, err := primitive.ObjectIDFromHex(stringItem)
if err == nil {
items = append(items, id)
} else {
items = append(items, stringItem)
}
continue
}
// default
items = append(items, item)
}
conditions[i].Value = items
default:
return nil, errors.New("invalid type")
}
}
return &entity.Filter{
IsOr: false,
Conditions: conditions,
}, nil
}
// GetFilter Get entity.Filter from gin.Context
func GetFilter(c *gin.Context) (f *entity.Filter, err error) {
condStr := c.Query(constants.FilterQueryFieldConditions)
return GetFilterFromConditionString(condStr)
}
// GetFilterQuery Get bson.M from gin.Context
func GetFilterQuery(c *gin.Context) (q bson.M, err error) {
f, err := GetFilter(c)
if err != nil {
return nil, err
}
if f == nil {
return nil, nil
}
// TODO: implement logic OR
return utils.FilterToQuery(f)
}
func MustGetFilterQuery(c *gin.Context) (q bson.M) {
q, err := GetFilterQuery(c)
if err != nil {
return nil
}
return q
}
func getResultListQuery(c *gin.Context) (q mongo.ListQuery) {
f, err := GetFilter(c)
if err != nil {
return q
}
for _, cond := range f.Conditions {
q = append(q, mongo.ListQueryCondition{
Key: cond.Key,
Op: cond.Op,
Value: utils.NormalizeObjectId(cond.Value),
})
}
return q
}
func GetDefaultPagination() (p *entity.Pagination) {
return &entity.Pagination{
Page: constants.PaginationDefaultPage,
Size: constants.PaginationDefaultSize,
}
}
func GetPagination(c *gin.Context) (p *entity.Pagination, err error) {
var _p entity.Pagination
if err := c.ShouldBindQuery(&_p); err != nil {
return GetDefaultPagination(), err
}
if _p.Page == 0 {
_p.Page = constants.PaginationDefaultPage
}
if _p.Size == 0 {
_p.Size = constants.PaginationDefaultSize
}
return &_p, nil
}
func MustGetPagination(c *gin.Context) (p *entity.Pagination) {
p, err := GetPagination(c)
if err != nil || p == nil {
return GetDefaultPagination()
}
return p
}
// GetSorts Get entity.Sort from gin.Context
func GetSorts(c *gin.Context) (sorts []entity.Sort, err error) {
// bind
sortStr := c.Query(constants.SortQueryField)
if err := json.Unmarshal([]byte(sortStr), &sorts); err != nil {
return nil, err
}
return sorts, nil
}
// GetSortsOption Get entity.Sort from gin.Context
func GetSortsOption(c *gin.Context) (sort bson.D, err error) {
sorts, err := GetSorts(c)
if err != nil {
return nil, err
}
if sorts == nil || len(sorts) == 0 {
return bson.D{{"_id", -1}}, nil
}
return SortsToOption(sorts)
}
func MustGetSortOption(c *gin.Context) (sort bson.D) {
sort, err := GetSortsOption(c)
if err != nil {
return nil
}
return sort
}
// SortsToOption Translate entity.Sort to bson.D
func SortsToOption(sorts []entity.Sort) (sort bson.D, err error) {
sort = bson.D{}
for _, s := range sorts {
switch s.Direction {
case constants.ASCENDING:
sort = append(sort, bson.E{Key: s.Key, Value: 1})
case constants.DESCENDING:
sort = append(sort, bson.E{Key: s.Key, Value: -1})
}
}
if len(sort) == 0 {
sort = bson.D{{"_id", -1}}
}
return sort, nil
}
type Response[T any] struct {
Status string `json:"status"`
Message string `json:"message"`
Data T `json:"data"`
Error string `json:"error"`
}
type ListResponse[T any] struct {
Status string `json:"status"`
Message string `json:"message"`
Total int `json:"total"`
Data []T `json:"data"`
Error string `json:"error"`
}
func GetDataResponse[T any](model T) (res *Response[T], err error) {
return &Response[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: model,
}, nil
}
func GetListResponse[T any](models []T, total int) (res *ListResponse[T], err error) {
return &ListResponse[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: models,
Total: total,
}, nil
}
func GetErrorResponse[T any](err error) (res *Response[T], err2 error) {
return &Response[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageError,
Error: err.Error(),
}, err
}
func GetErrorListResponse[T any](err error) (res *ListResponse[T], err2 error) {
return &ListResponse[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageError,
Error: err.Error(),
}, err
}
func handleError(statusCode int, c *gin.Context, err error) {
if utils.IsDev() {
trace.PrintError(err)
}
c.AbortWithStatusJSON(statusCode, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageError,
Error: err.Error(),
})
}
func HandleError(statusCode int, c *gin.Context, err error) {
handleError(statusCode, c, err)
}
func HandleErrorBadRequest(c *gin.Context, err error) {
HandleError(http.StatusBadRequest, c, err)
}
func HandleErrorForbidden(c *gin.Context, err error) {
HandleError(http.StatusForbidden, c, err)
}
func HandleErrorUnauthorized(c *gin.Context, err error) {
HandleError(http.StatusUnauthorized, c, err)
}
func HandleErrorNotFound(c *gin.Context, err error) {
HandleError(http.StatusNotFound, c, err)
}
func HandleErrorInternalServerError(c *gin.Context, err error) {
HandleError(http.StatusInternalServerError, c, err)
}
func HandleSuccess(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
})
}
func HandleSuccessWithData(c *gin.Context, data interface{}) {
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: data,
})
}
func HandleSuccessWithListData(c *gin.Context, data interface{}, total int) {
c.AbortWithStatusJSON(http.StatusOK, entity.ListResponse{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: data,
Total: total,
})
}

View File

@@ -1,19 +0,0 @@
package controllers
import (
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/gin-gonic/gin"
)
func GetUserFromContext(c *gin.Context) (u *models.User) {
value, ok := c.Get(constants.UserContextKey)
if !ok {
return nil
}
u, ok = value.(*models.User)
if !ok {
return nil
}
return u
}

View File

@@ -1,141 +0,0 @@
package controllers
import (
"encoding/json"
errors2 "errors"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/gin-gonic/gin"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"reflect"
"strings"
)
// GetFilter Get entity.Filter from gin.Context
func GetFilter(c *gin.Context) (f *entity.Filter, err error) {
// bind
condStr := c.Query(constants.FilterQueryFieldConditions)
var conditions []*entity.Condition
if err := json.Unmarshal([]byte(condStr), &conditions); err != nil {
return nil, err
}
// attempt to convert object id
for i, cond := range conditions {
v := reflect.ValueOf(cond.Value)
switch v.Kind() {
case reflect.String:
item := cond.Value.(string)
id, err := primitive.ObjectIDFromHex(item)
if err == nil {
conditions[i].Value = id
} else {
conditions[i].Value = item
}
case reflect.Slice, reflect.Array:
var items []interface{}
for i := 0; i < v.Len(); i++ {
vItem := v.Index(i)
item := vItem.Interface()
// string
stringItem, ok := item.(string)
if ok {
id, err := primitive.ObjectIDFromHex(stringItem)
if err == nil {
items = append(items, id)
} else {
items = append(items, stringItem)
}
continue
}
// default
items = append(items, item)
}
conditions[i].Value = items
default:
return nil, errors2.New("invalid type")
}
}
return &entity.Filter{
IsOr: false,
Conditions: conditions,
}, nil
}
// GetFilterQuery Get bson.M from gin.Context
func GetFilterQuery(c *gin.Context) (q bson.M, err error) {
f, err := GetFilter(c)
if err != nil {
return nil, err
}
if f == nil {
return nil, nil
}
// TODO: implement logic OR
return utils.FilterToQuery(f)
}
func MustGetFilterQuery(c *gin.Context) (q bson.M) {
q, err := GetFilterQuery(c)
if err != nil {
return nil
}
return q
}
// GetFilterAll Get all from gin.Context
func GetFilterAll(c *gin.Context) (res bool, err error) {
resStr := c.Query(constants.FilterQueryFieldAll)
switch strings.ToUpper(resStr) {
case "1":
return true, nil
case "0":
return false, nil
case "Y":
return true, nil
case "N":
return false, nil
case "T":
return true, nil
case "F":
return false, nil
case "TRUE":
return true, nil
case "FALSE":
return false, nil
default:
return false, errors2.New("invalid value")
}
}
func MustGetFilterAll(c *gin.Context) (res bool) {
res, err := GetFilterAll(c)
if err != nil {
return false
}
return res
}
func getResultListQuery(c *gin.Context) (q mongo.ListQuery) {
f, err := GetFilter(c)
if err != nil {
return q
}
for _, cond := range f.Conditions {
q = append(q, mongo.ListQueryCondition{
Key: cond.Key,
Op: cond.Op,
Value: utils.NormalizeObjectId(cond.Value),
})
}
return q
}

View File

@@ -1,110 +0,0 @@
package controllers
import (
"net/http"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/trace"
"github.com/gin-gonic/gin"
)
type Response[T any] struct {
Status string `json:"status"`
Message string `json:"message"`
Data T `json:"data"`
Error string `json:"error"`
}
type ListResponse[T any] struct {
Status string `json:"status"`
Message string `json:"message"`
Total int `json:"total"`
Data []T `json:"data"`
Error string `json:"error"`
}
func GetSuccessDataResponse[T any](model T) (res *Response[T], err error) {
return &Response[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: model,
}, nil
}
func GetSuccessListResponse[T any](models []T, total int) (res *ListResponse[T], err error) {
return &ListResponse[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: models,
Total: total,
}, nil
}
func GetErrorDataResponse[T any](err error) (res *Response[T], err2 error) {
return &Response[T]{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageError,
Error: err.Error(),
}, err
}
func handleError(statusCode int, c *gin.Context, err error) {
if utils.IsDev() {
trace.PrintError(err)
}
c.AbortWithStatusJSON(statusCode, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageError,
Error: err.Error(),
})
}
func HandleError(statusCode int, c *gin.Context, err error) {
handleError(statusCode, c, err)
}
func HandleErrorBadRequest(c *gin.Context, err error) {
HandleError(http.StatusBadRequest, c, err)
}
func HandleErrorForbidden(c *gin.Context, err error) {
HandleError(http.StatusForbidden, c, err)
}
func HandleErrorUnauthorized(c *gin.Context, err error) {
HandleError(http.StatusUnauthorized, c, err)
}
func HandleErrorNotFound(c *gin.Context, err error) {
HandleError(http.StatusNotFound, c, err)
}
func HandleErrorInternalServerError(c *gin.Context, err error) {
HandleError(http.StatusInternalServerError, c, err)
}
func HandleSuccess(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
})
}
func HandleSuccessWithData(c *gin.Context, data interface{}) {
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: data,
})
}
func HandleSuccessWithListData(c *gin.Context, data interface{}, total int) {
c.AbortWithStatusJSON(http.StatusOK, entity.ListResponse{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageSuccess,
Data: data,
Total: total,
})
}

View File

@@ -1,5 +0,0 @@
package controllers
import "github.com/crawlab-team/crawlab/core/utils"
var logger = utils.NewLogger("Controllers")

View File

@@ -1,36 +0,0 @@
package controllers
import (
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/gin-gonic/gin"
)
func GetDefaultPagination() (p *entity.Pagination) {
return &entity.Pagination{
Page: constants.PaginationDefaultPage,
Size: constants.PaginationDefaultSize,
}
}
func GetPagination(c *gin.Context) (p *entity.Pagination, err error) {
var _p entity.Pagination
if err := c.ShouldBindQuery(&_p); err != nil {
return GetDefaultPagination(), err
}
if _p.Page == 0 {
_p.Page = constants.PaginationDefaultPage
}
if _p.Size == 0 {
_p.Size = constants.PaginationDefaultSize
}
return &_p, nil
}
func MustGetPagination(c *gin.Context) (p *entity.Pagination) {
p, err := GetPagination(c)
if err != nil || p == nil {
return GetDefaultPagination()
}
return p
}

View File

@@ -1,58 +0,0 @@
package controllers
import (
"encoding/json"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/gin-gonic/gin"
"go.mongodb.org/mongo-driver/bson"
)
// GetSorts Get entity.Sort from gin.Context
func GetSorts(c *gin.Context) (sorts []entity.Sort, err error) {
// bind
sortStr := c.Query(constants.SortQueryField)
if err := json.Unmarshal([]byte(sortStr), &sorts); err != nil {
return nil, err
}
return sorts, nil
}
// GetSortsOption Get entity.Sort from gin.Context
func GetSortsOption(c *gin.Context) (sort bson.D, err error) {
sorts, err := GetSorts(c)
if err != nil {
return nil, err
}
if sorts == nil || len(sorts) == 0 {
return bson.D{{"_id", -1}}, nil
}
return SortsToOption(sorts)
}
func MustGetSortOption(c *gin.Context) (sort bson.D) {
sort, err := GetSortsOption(c)
if err != nil {
return nil
}
return sort
}
// SortsToOption Translate entity.Sort to bson.D
func SortsToOption(sorts []entity.Sort) (sort bson.D, err error) {
sort = bson.D{}
for _, s := range sorts {
switch s.Direction {
case constants.ASCENDING:
sort = append(sort, bson.E{Key: s.Key, Value: 1})
case constants.DESCENDING:
sort = append(sort, bson.E{Key: s.Key, Value: -1})
}
}
if len(sort) == 0 {
sort = bson.D{{"_id", -1}}
}
return sort, nil
}

View File

@@ -1,54 +0,0 @@
package controllers
import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
http2 "net/http"
)
type WsWriter struct {
io.Writer
io.Closer
conn *websocket.Conn
}
func (w *WsWriter) Write(data []byte) (n int, err error) {
logger.Infof("websocket write: %s", string(data))
err = w.conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *WsWriter) Close() (err error) {
return w.conn.Close()
}
func (w *WsWriter) CloseWithText(text string) {
_ = w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, text))
}
func (w *WsWriter) CloseWithError(err error) {
_ = w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, err.Error()))
}
func NewWsWriter(c *gin.Context) (writer *WsWriter, err error) {
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http2.Request) bool {
return true
},
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Errorf("websocket open connection error: %v", err)
}
return &WsWriter{
conn: conn,
}, nil
}