mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-02-01 18:20:17 +01:00
chore: update Go version and dependencies
- Updated Go version in go.work and backend/go.mod to 1.23.7 - Updated various dependencies in go.sum and backend/go.sum - Refactored models to remove generic type parameters from BaseModel - Introduced new utility functions for consistent API responses - Removed unused utility files from controllers
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
@@ -23,28 +24,32 @@ type BaseController[T any] struct {
|
||||
actions []Action
|
||||
}
|
||||
|
||||
type GetByIdParams struct {
|
||||
Id string `path:"id" description:"The ID of the item to get"`
|
||||
}
|
||||
|
||||
// GetAllParams represents parameters for GetAll method
|
||||
type GetAllParams struct {
|
||||
Query bson.M `json:"query"`
|
||||
Sort bson.D `json:"sort"`
|
||||
}
|
||||
|
||||
// GetListParams represents parameters for GetList with pagination
|
||||
type GetListParams struct {
|
||||
Query bson.M `json:"query"`
|
||||
Sort bson.D `json:"sort"`
|
||||
Pagination *entity.Pagination `json:"pagination"`
|
||||
All bool `query:"all" description:"Whether to get all items"`
|
||||
Conditions string `query:"conditions" description:"Filter conditions. Format: [{\"key\":\"name\",\"op\":\"eq\",\"value\":\"test\"}]"`
|
||||
Sort bson.D `query:"sort" description:"Sort options"`
|
||||
Page int `query:"page" default:"1" description:"Page number"`
|
||||
Size int `query:"size" default:"10" description:"Page size"`
|
||||
All bool `query:"all" default:"false" description:"Whether to get all items"`
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) GetList(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
|
||||
// get all if query field "all" is set true
|
||||
if params.All {
|
||||
return ctr.GetAll(params)
|
||||
}
|
||||
|
||||
return ctr.GetWithPagination(params)
|
||||
}
|
||||
|
||||
type GetByIdParams struct {
|
||||
Id string `path:"id" description:"The ID of the item to get"`
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) GetById(_ *gin.Context, params *GetByIdParams) (response *Response[T], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
return nil, errors.BadRequestf("invalid id: %s", params.Id)
|
||||
return GetErrorResponse[T](errors.BadRequestf("invalid id format"))
|
||||
}
|
||||
|
||||
model, err := ctr.modelSvc.GetById(id)
|
||||
@@ -52,143 +57,132 @@ func (ctr *BaseController[T]) GetById(_ *gin.Context, params *GetByIdParams) (re
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return GetSuccessDataResponse(*model)
|
||||
return GetDataResponse(*model)
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) GetList(c *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
|
||||
// get all if query field "all" is set true
|
||||
all := params.All || MustGetFilterAll(c)
|
||||
|
||||
// Prepare parameters
|
||||
query := MustGetFilterQuery(c)
|
||||
sort := MustGetSortOption(c)
|
||||
|
||||
if all {
|
||||
allParams := &GetAllParams{
|
||||
Query: query,
|
||||
Sort: sort,
|
||||
}
|
||||
return ctr.GetAll(c, allParams)
|
||||
}
|
||||
|
||||
// get list with pagination
|
||||
pagination := MustGetPagination(c)
|
||||
listParams := &GetListParams{
|
||||
Query: query,
|
||||
Sort: sort,
|
||||
Pagination: pagination,
|
||||
}
|
||||
return ctr.GetWithPagination(c, listParams)
|
||||
type PostParams[T any] struct {
|
||||
Data T `json:"data"`
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) Post(c *gin.Context) (response *Response[T], err error) {
|
||||
var model T
|
||||
if err := c.ShouldBindJSON(&model); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
}
|
||||
func (ctr *BaseController[T]) Post(c *gin.Context, params *PostParams[T]) (response *Response[T], err error) {
|
||||
u := GetUserFromContext(c)
|
||||
m := any(&model).(interfaces.Model)
|
||||
m := any(¶ms.Data).(interfaces.Model)
|
||||
m.SetId(primitive.NewObjectID())
|
||||
m.SetCreated(u.Id)
|
||||
m.SetUpdated(u.Id)
|
||||
col := ctr.modelSvc.GetCol()
|
||||
res, err := col.GetCollection().InsertOne(col.GetContext(), m)
|
||||
if err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
result, err := ctr.modelSvc.GetById(res.InsertedID.(primitive.ObjectID))
|
||||
if err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
return GetSuccessDataResponse(*result)
|
||||
return GetDataResponse(*result)
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) PutById(c *gin.Context) (response *Response[T], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
if err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
}
|
||||
type PutByIdParams[T any] struct {
|
||||
Id string `path:"id" description:"The ID of the item to update"`
|
||||
Data T `json:"data"`
|
||||
}
|
||||
|
||||
var model T
|
||||
if err := c.ShouldBindJSON(&model); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
func (ctr *BaseController[T]) PutById(c *gin.Context, params *PutByIdParams[T]) (response *Response[T], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
|
||||
}
|
||||
|
||||
u := GetUserFromContext(c)
|
||||
m := any(&model).(interfaces.Model)
|
||||
m := any(¶ms.Data).(interfaces.Model)
|
||||
m.SetUpdated(u.Id)
|
||||
if m.GetId().IsZero() {
|
||||
m.SetId(id)
|
||||
}
|
||||
|
||||
if err := ctr.modelSvc.ReplaceById(id, model); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
if err := ctr.modelSvc.ReplaceById(id, params.Data); err != nil {
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
result, err := ctr.modelSvc.GetById(id)
|
||||
if err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
return GetSuccessDataResponse(*result)
|
||||
return GetDataResponse(*result)
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) PatchList(c *gin.Context) (res *Response[T], err error) {
|
||||
type Payload struct {
|
||||
Ids []primitive.ObjectID `json:"ids"`
|
||||
Update bson.M `json:"update"`
|
||||
type PatchParams struct {
|
||||
Ids []string `json:"ids" description:"The IDs of the items to update" validate:"required"`
|
||||
Update bson.M `json:"update" description:"The update object" validate:"required"`
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) PatchList(c *gin.Context, params *PatchParams) (res *Response[T], err error) {
|
||||
var ids []primitive.ObjectID
|
||||
for _, id := range params.Ids {
|
||||
objectId, err := primitive.ObjectIDFromHex(id)
|
||||
if err != nil {
|
||||
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
|
||||
}
|
||||
ids = append(ids, objectId)
|
||||
}
|
||||
|
||||
var payload Payload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
}
|
||||
// Get user from context for updated_by
|
||||
u := GetUserFromContext(c)
|
||||
|
||||
// query
|
||||
query := bson.M{
|
||||
"_id": bson.M{
|
||||
"$in": payload.Ids,
|
||||
"$in": ids,
|
||||
},
|
||||
}
|
||||
|
||||
// Add updated_by and updated_ts to the update object
|
||||
updateObj := params.Update
|
||||
updateObj["updated_by"] = u.Id
|
||||
updateObj["updated_ts"] = time.Now()
|
||||
|
||||
// update
|
||||
if err := ctr.modelSvc.UpdateMany(query, bson.M{"$set": payload.Update}); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
if err := ctr.modelSvc.UpdateMany(query, bson.M{"$set": updateObj}); err != nil {
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
// Return an empty response with success status
|
||||
var emptyModel T
|
||||
return GetSuccessDataResponse(emptyModel)
|
||||
return GetDataResponse(emptyModel)
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) DeleteById(c *gin.Context) (res *Response[T], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
type DeleteByIdParams struct {
|
||||
Id string `path:"id" description:"The ID of the item to get"`
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) DeleteById(c *gin.Context, params *DeleteByIdParams) (res *Response[T], err error) {
|
||||
params.Id = c.Param("id")
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](errors.BadRequestf("invalid id format: %v", err))
|
||||
}
|
||||
|
||||
if err := ctr.modelSvc.DeleteById(id); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
var emptyModel T
|
||||
return GetSuccessDataResponse(emptyModel)
|
||||
return GetDataResponse(emptyModel)
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) DeleteList(c *gin.Context) (res *Response[T], err error) {
|
||||
type Payload struct {
|
||||
Ids []string `json:"ids"`
|
||||
}
|
||||
|
||||
var payload Payload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
}
|
||||
type DeleteListParams struct {
|
||||
Ids []string `json:"ids" description:"The IDs of the items to delete"`
|
||||
}
|
||||
|
||||
func (ctr *BaseController[T]) DeleteList(_ *gin.Context, params *DeleteListParams) (res *Response[T], err error) {
|
||||
var ids []primitive.ObjectID
|
||||
for _, id := range payload.Ids {
|
||||
for _, id := range params.Ids {
|
||||
objectId, err := primitive.ObjectIDFromHex(id)
|
||||
if err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
ids = append(ids, objectId)
|
||||
}
|
||||
@@ -198,16 +192,19 @@ func (ctr *BaseController[T]) DeleteList(c *gin.Context) (res *Response[T], err
|
||||
"$in": ids,
|
||||
},
|
||||
}); err != nil {
|
||||
return GetErrorDataResponse[T](err)
|
||||
return GetErrorResponse[T](err)
|
||||
}
|
||||
|
||||
var emptyModel T
|
||||
return GetSuccessDataResponse(emptyModel)
|
||||
return GetDataResponse(emptyModel)
|
||||
}
|
||||
|
||||
// GetAll retrieves all items based on filter and sort
|
||||
func (ctr *BaseController[T]) GetAll(_ *gin.Context, params *GetAllParams) (response *ListResponse[T], err error) {
|
||||
query := params.Query
|
||||
func (ctr *BaseController[T]) GetAll(params *GetListParams) (response *ListResponse[T], err error) {
|
||||
query, err := GetFilterQueryFromListParams(params)
|
||||
if err != nil {
|
||||
return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err))
|
||||
}
|
||||
sort := params.Sort
|
||||
if sort == nil {
|
||||
sort = bson.D{{"_id", -1}}
|
||||
@@ -225,29 +222,23 @@ func (ctr *BaseController[T]) GetAll(_ *gin.Context, params *GetAllParams) (resp
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return GetSuccessListResponse(models, total)
|
||||
return GetListResponse(models, total)
|
||||
}
|
||||
|
||||
// GetWithPagination retrieves items with pagination
|
||||
func (ctr *BaseController[T]) GetWithPagination(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
|
||||
// params
|
||||
pagination := params.Pagination
|
||||
query := params.Query
|
||||
sort := params.Sort
|
||||
|
||||
if pagination == nil {
|
||||
pagination = GetDefaultPagination()
|
||||
func (ctr *BaseController[T]) GetWithPagination(params *GetListParams) (response *ListResponse[T], err error) {
|
||||
query, err := GetFilterQueryFromListParams(params)
|
||||
if err != nil {
|
||||
return GetErrorListResponse[T](errors.BadRequestf("invalid request parameters: %v", err))
|
||||
}
|
||||
|
||||
// get list
|
||||
models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{
|
||||
Sort: sort,
|
||||
Skip: pagination.Size * (pagination.Page - 1),
|
||||
Limit: pagination.Size,
|
||||
Sort: params.Sort,
|
||||
Skip: params.Size * (params.Page - 1),
|
||||
Limit: params.Size,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
return GetSuccessListResponse[T](nil, 0)
|
||||
return GetListResponse[T](nil, 0)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
@@ -260,36 +251,7 @@ func (ctr *BaseController[T]) GetWithPagination(_ *gin.Context, params *GetListP
|
||||
}
|
||||
|
||||
// response
|
||||
return GetSuccessListResponse(models, total)
|
||||
}
|
||||
|
||||
// getAll is kept for backward compatibility
|
||||
func (ctr *BaseController[T]) getAll(c *gin.Context) (response *ListResponse[T], err error) {
|
||||
query := MustGetFilterQuery(c)
|
||||
sort := MustGetSortOption(c)
|
||||
|
||||
params := &GetAllParams{
|
||||
Query: query,
|
||||
Sort: sort,
|
||||
}
|
||||
|
||||
return ctr.GetAll(c, params)
|
||||
}
|
||||
|
||||
// getList is kept for backward compatibility
|
||||
func (ctr *BaseController[T]) getList(c *gin.Context) (response *ListResponse[T], err error) {
|
||||
// params
|
||||
pagination := MustGetPagination(c)
|
||||
query := MustGetFilterQuery(c)
|
||||
sort := MustGetSortOption(c)
|
||||
|
||||
params := &GetListParams{
|
||||
Query: query,
|
||||
Sort: sort,
|
||||
Pagination: pagination,
|
||||
}
|
||||
|
||||
return ctr.GetWithPagination(c, params)
|
||||
return GetListResponse(models, total)
|
||||
}
|
||||
|
||||
func NewController[T any](actions ...Action) *BaseController[T] {
|
||||
|
||||
@@ -4,48 +4,54 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/crawlab-team/crawlab/core/fs"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func GetBaseFileListDir(rootPath string, c *gin.Context) {
|
||||
path := c.Query("path")
|
||||
type GetBaseFileListDirParams struct {
|
||||
Path string `path:"path"`
|
||||
}
|
||||
|
||||
func GetBaseFileListDir(rootPath string, params *GetBaseFileListDirParams) (response *Response[[]interfaces.FsFileInfo], err error) {
|
||||
path := params.Path
|
||||
|
||||
fsSvc, err := fs.GetBaseFileFsSvc(rootPath)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[[]interfaces.FsFileInfo](err)
|
||||
}
|
||||
|
||||
files, err := fsSvc.List(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[[]interfaces.FsFileInfo](err)
|
||||
}
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, files)
|
||||
//HandleSuccessWithData(c, files)
|
||||
return GetDataResponse[[]interfaces.FsFileInfo](files)
|
||||
}
|
||||
|
||||
func GetBaseFileFile(rootPath string, c *gin.Context) {
|
||||
path := c.Query("path")
|
||||
type GetBaseFileFileParams struct {
|
||||
Path string `path:"path"`
|
||||
}
|
||||
|
||||
func GetBaseFileFile(rootPath string, params *GetBaseFileFileParams) (response *Response[string], err error) {
|
||||
path := params.Path
|
||||
|
||||
fsSvc, err := fs.GetBaseFileFsSvc(rootPath)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[string](err)
|
||||
}
|
||||
|
||||
data, err := fsSvc.GetFile(path)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[string](err)
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, string(data))
|
||||
return GetDataResponse[string](string(data))
|
||||
}
|
||||
|
||||
func GetBaseFileFileInfo(rootPath string, c *gin.Context) {
|
||||
|
||||
@@ -4,19 +4,30 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/controllers"
|
||||
"github.com/crawlab-team/crawlab/core/middlewares"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/crawlab-team/crawlab/core/user"
|
||||
"github.com/loopfz/gadgeto/tonic"
|
||||
"github.com/spf13/viper"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/wI2L/fizz"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -26,7 +37,12 @@ func init() {
|
||||
// TestModel is a simple struct to be used as a model in tests
|
||||
type TestModel models.TestModel
|
||||
|
||||
//type TestModel struct {
|
||||
// Name string `json:"name" bson:"name"`
|
||||
//}
|
||||
|
||||
var TestToken string
|
||||
var TestUserId primitive.ObjectID
|
||||
|
||||
// SetupTestDB sets up the test database
|
||||
func SetupTestDB() {
|
||||
@@ -50,11 +66,12 @@ func SetupTestDB() {
|
||||
panic(err)
|
||||
}
|
||||
TestToken = token
|
||||
TestUserId = u.Id
|
||||
}
|
||||
|
||||
// SetupRouter sets up the gin router for testing
|
||||
func SetupRouter() *gin.Engine {
|
||||
router := gin.Default()
|
||||
func SetupRouter() *fizz.Fizz {
|
||||
router := fizz.New()
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -77,7 +94,7 @@ func TestBaseController_GetById(t *testing.T) {
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.GET("/testmodels/:id", ctr.GetById)
|
||||
router.GET("/testmodels/:id", nil, tonic.Handler(ctr.GetById, 200))
|
||||
|
||||
// Create a test request
|
||||
req, _ := http.NewRequest("GET", "/testmodels/"+id.Hex(), nil)
|
||||
@@ -106,30 +123,34 @@ func TestBaseController_Post(t *testing.T) {
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.POST("/testmodels", ctr.Post)
|
||||
router.POST("/testmodels", nil, tonic.Handler(ctr.Post, 200))
|
||||
|
||||
// Create a test request
|
||||
testModel := TestModel{Name: "test"}
|
||||
jsonValue, _ := json.Marshal(testModel)
|
||||
requestBody := controllers.PostParams[TestModel]{
|
||||
Data: testModel,
|
||||
}
|
||||
jsonValue, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("POST", "/testmodels", bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response controllers.Response[TestModel]
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", response.Data.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", response.Data.Name)
|
||||
|
||||
// Check if the document was inserted into the database
|
||||
result, err := service.NewModelService[TestModel]().GetById(response.Data.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", result.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", result.Name)
|
||||
}
|
||||
|
||||
func TestBaseController_DeleteById(t *testing.T) {
|
||||
@@ -146,7 +167,7 @@ func TestBaseController_DeleteById(t *testing.T) {
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.DELETE("/testmodels/:id", ctr.DeleteById)
|
||||
router.DELETE("/testmodels/:id", nil, tonic.Handler(ctr.DeleteById, 200))
|
||||
|
||||
// Create a test request
|
||||
req, _ := http.NewRequest("DELETE", "/testmodels/"+id.Hex(), nil)
|
||||
@@ -163,3 +184,321 @@ func TestBaseController_DeleteById(t *testing.T) {
|
||||
_, err = service.NewModelService[TestModel]().GetById(id)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBaseController_GetList(t *testing.T) {
|
||||
SetupTestDB()
|
||||
defer CleanupTestDB()
|
||||
|
||||
// Insert test documents
|
||||
modelSvc := service.NewModelService[TestModel]()
|
||||
for i := 0; i < 15; i++ {
|
||||
_, err := modelSvc.InsertOne(TestModel{Name: fmt.Sprintf("test%d", i)})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Initialize the controller
|
||||
ctr := controllers.NewController[TestModel]()
|
||||
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.GET("/testmodels/list", nil, tonic.Handler(ctr.GetList, 200))
|
||||
|
||||
// Test case 1: Get with pagination
|
||||
t.Run("test_get_with_pagination", func(t *testing.T) {
|
||||
var testData = []struct {
|
||||
Page int
|
||||
ExpectedDataCount int
|
||||
ExpectedTotalCount int
|
||||
}{
|
||||
{1, 10, 15},
|
||||
{2, 5, 15},
|
||||
}
|
||||
for _, data := range testData {
|
||||
params := url.Values{}
|
||||
params.Add("page", strconv.Itoa(data.Page))
|
||||
params.Add("size", "10")
|
||||
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}
|
||||
req, _ := http.NewRequest("GET", requestUrl.String(), nil)
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response controllers.ListResponse[TestModel]
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data.ExpectedDataCount, len(response.Data))
|
||||
assert.Equal(t, data.ExpectedTotalCount, response.Total)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 2: Get all
|
||||
t.Run("test_get_all", func(t *testing.T) {
|
||||
params := url.Values{}
|
||||
params.Add("all", "true")
|
||||
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}
|
||||
req, _ := http.NewRequest("GET", requestUrl.String(), nil)
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response controllers.ListResponse[TestModel]
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 15, len(response.Data))
|
||||
assert.Equal(t, 15, response.Total)
|
||||
})
|
||||
|
||||
// Test case 3: Get with query filter
|
||||
t.Run("test_get_with_query_filter", func(t *testing.T) {
|
||||
cond := []entity.Condition{
|
||||
{Key: "name", Op: "eq", Value: "test1"},
|
||||
}
|
||||
condBytes, err := json.Marshal(cond)
|
||||
require.Nil(t, err)
|
||||
params := url.Values{}
|
||||
params.Add("conditions", string(condBytes))
|
||||
params.Add("page", "1")
|
||||
params.Add("size", "10")
|
||||
requestUrl := url.URL{Path: "/testmodels/list", RawQuery: params.Encode()}
|
||||
req, _ := http.NewRequest("GET", requestUrl.String(), nil)
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response controllers.ListResponse[TestModel]
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(response.Data))
|
||||
assert.Equal(t, 1, response.Total)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseController_PutById(t *testing.T) {
|
||||
SetupTestDB()
|
||||
defer CleanupTestDB()
|
||||
|
||||
// Insert a test document
|
||||
id, err := service.NewModelService[TestModel]().InsertOne(TestModel{Name: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Initialize the controller
|
||||
ctr := controllers.NewController[TestModel]()
|
||||
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.PUT("/testmodels/:id", nil, tonic.Handler(ctr.PutById, 200))
|
||||
|
||||
// Create a test request
|
||||
updatedModel := TestModel{Name: "updated"}
|
||||
requestParams := controllers.PutByIdParams[TestModel]{
|
||||
Data: updatedModel,
|
||||
}
|
||||
jsonValue, _ := json.Marshal(requestParams)
|
||||
req, _ := http.NewRequest("PUT", "/testmodels/"+id.Hex(), bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response controllers.Response[TestModel]
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "updated", response.Data.Name)
|
||||
|
||||
// Check if the document was updated in the database
|
||||
result, err := service.NewModelService[TestModel]().GetById(id)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "updated", result.Name)
|
||||
|
||||
// Test with invalid ID
|
||||
req, _ = http.NewRequest("PUT", "/testmodels/invalid-id", bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestBaseController_PatchList(t *testing.T) {
|
||||
SetupTestDB()
|
||||
defer CleanupTestDB()
|
||||
|
||||
// Insert test documents
|
||||
modelSvc := service.NewModelService[TestModel]()
|
||||
var ids []primitive.ObjectID
|
||||
for i := 0; i < 3; i++ {
|
||||
id, err := modelSvc.InsertOne(TestModel{Name: fmt.Sprintf("test%d", i)})
|
||||
assert.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Initialize the controller
|
||||
ctr := controllers.NewController[TestModel]()
|
||||
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.PATCH("/testmodels", nil, tonic.Handler(ctr.PatchList, 200))
|
||||
|
||||
// Create a test request
|
||||
t.Run("test_patch_list", func(t *testing.T) {
|
||||
var idStrings []string
|
||||
for _, id := range ids {
|
||||
idStrings = append(idStrings, id.Hex())
|
||||
}
|
||||
requestBody := controllers.PatchParams{
|
||||
Ids: idStrings,
|
||||
Update: bson.M{"name": "patched"},
|
||||
}
|
||||
jsonValue, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("PATCH", "/testmodels", bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Get the user ID
|
||||
userId := TestUserId
|
||||
|
||||
// Record time before the update
|
||||
beforeUpdate := time.Now()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Record time after the update
|
||||
afterUpdate := time.Now()
|
||||
|
||||
// Check if the documents were updated in the database
|
||||
for _, id := range ids {
|
||||
result, err := modelSvc.GetById(id)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "patched", result.Name)
|
||||
|
||||
// Verify updated_by is set to the current user's ID
|
||||
assert.Equal(t, userId, result.UpdatedBy)
|
||||
|
||||
// Verify updated_ts is set to a recent timestamp
|
||||
assert.GreaterOrEqual(t, result.UpdatedAt.UnixMilli(), beforeUpdate.UnixMilli())
|
||||
assert.LessOrEqual(t, result.UpdatedAt.UnixMilli(), afterUpdate.UnixMilli())
|
||||
}
|
||||
})
|
||||
|
||||
// Test with invalid ID
|
||||
t.Run("test_patch_list_with_invalid_id", func(t *testing.T) {
|
||||
requestBody := controllers.PatchParams{
|
||||
Ids: []string{"invalid-id"},
|
||||
Update: bson.M{"name": "patched"},
|
||||
}
|
||||
jsonValue, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("PATCH", "/testmodels", bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseController_DeleteList(t *testing.T) {
|
||||
SetupTestDB()
|
||||
defer CleanupTestDB()
|
||||
|
||||
// Insert test documents
|
||||
modelSvc := service.NewModelService[TestModel]()
|
||||
var ids []primitive.ObjectID
|
||||
for i := 0; i < 3; i++ {
|
||||
id, err := modelSvc.InsertOne(TestModel{Name: fmt.Sprintf("test%d", i)})
|
||||
assert.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Initialize the controller
|
||||
ctr := controllers.NewController[TestModel]()
|
||||
|
||||
// Set up the router
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.DELETE("/testmodels", nil, tonic.Handler(ctr.DeleteList, 200))
|
||||
|
||||
// Create a test request
|
||||
var idStrings []string
|
||||
for _, id := range ids {
|
||||
idStrings = append(idStrings, id.Hex())
|
||||
}
|
||||
requestBody := controllers.DeleteListParams{
|
||||
Ids: idStrings,
|
||||
}
|
||||
jsonValue, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("DELETE", "/testmodels", bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check if the documents were deleted from the database
|
||||
for _, id := range ids {
|
||||
_, err := modelSvc.GetById(id)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Test with invalid ID
|
||||
requestBody = controllers.DeleteListParams{
|
||||
Ids: []string{"invalid-id"},
|
||||
}
|
||||
jsonValue, _ = json.Marshal(requestBody)
|
||||
req, _ = http.NewRequest("DELETE", "/testmodels", bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
@@ -1,34 +1,31 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/juju/errors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
mongo2 "go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func GetProjectList(c *gin.Context) {
|
||||
// get all list
|
||||
all := MustGetFilterAll(c)
|
||||
if all {
|
||||
NewController[models.Project]().getAll(c)
|
||||
return
|
||||
func GetProjectList(c *gin.Context, params *GetListParams) (response *ListResponse[models.Project], err error) {
|
||||
if params.All {
|
||||
return NewController[models.Project]().GetAll(params)
|
||||
}
|
||||
|
||||
// params
|
||||
pagination := MustGetPagination(c)
|
||||
query := MustGetFilterQuery(c)
|
||||
sort := MustGetSortOption(c)
|
||||
query, err := GetFilterQueryFromListParams(params)
|
||||
if err != nil {
|
||||
return GetErrorListResponse[models.Project](errors.BadRequestf("invalid request parameters: %v", err))
|
||||
}
|
||||
|
||||
// get list
|
||||
projects, err := service.NewModelService[models.Project]().GetMany(query, &mongo.FindOptions{
|
||||
Sort: sort,
|
||||
Skip: pagination.Size * (pagination.Page - 1),
|
||||
Limit: pagination.Size,
|
||||
Sort: params.Sort,
|
||||
Skip: params.Size * (params.Page - 1),
|
||||
Limit: params.Size,
|
||||
})
|
||||
if err != nil {
|
||||
if err.Error() != mongo2.ErrNoDocuments.Error() {
|
||||
@@ -82,5 +79,5 @@ func GetProjectList(c *gin.Context) {
|
||||
projects[i].Spiders = cache[p.Id]
|
||||
}
|
||||
|
||||
HandleSuccessWithListData(c, projects, total)
|
||||
return GetListResponse[models.Project](projects, total)
|
||||
}
|
||||
|
||||
@@ -29,13 +29,13 @@ func GetGlobalFizzWrapper() *openapi.FizzWrapper {
|
||||
// NewRouterGroups initializes the router groups with their respective middleware
|
||||
func NewRouterGroups(app *gin.Engine) (groups *RouterGroups) {
|
||||
// Create OpenAPI wrapper
|
||||
wrapper := openapi.NewFizzWrapper(app)
|
||||
globalWrapper = openapi.NewFizzWrapper(app)
|
||||
|
||||
return &RouterGroups{
|
||||
AuthGroup: app.Group("/", middlewares.AuthorizationMiddleware()),
|
||||
SyncAuthGroup: app.Group("/", middlewares.SyncAuthorizationMiddleware()),
|
||||
AnonymousGroup: app.Group("/"),
|
||||
Wrapper: wrapper,
|
||||
Wrapper: globalWrapper,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,33 +232,6 @@ func InitRoutes(app *gin.Engine) (err error) {
|
||||
HandlerFunc: GetProjectList,
|
||||
},
|
||||
}...))
|
||||
RegisterController(groups.AuthGroup, "/schedules", NewController[models.Schedule]([]Action{
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "",
|
||||
HandlerFunc: PostSchedule,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPut,
|
||||
Path: "/:id",
|
||||
HandlerFunc: PutScheduleById,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "/:id/enable",
|
||||
HandlerFunc: PostScheduleEnable,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "/:id/disable",
|
||||
HandlerFunc: PostScheduleDisable,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "/:id/run",
|
||||
HandlerFunc: PostScheduleRun,
|
||||
},
|
||||
}...))
|
||||
RegisterController(groups.AuthGroup, "/spiders", NewController[models.Spider]([]Action{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
@@ -351,6 +324,35 @@ func InitRoutes(app *gin.Engine) (err error) {
|
||||
HandlerFunc: GetSpiderResults,
|
||||
},
|
||||
}...))
|
||||
groups.AnonymousGroup.GET("/openapi.json", GetOpenAPI)
|
||||
return
|
||||
RegisterController(groups.AuthGroup, "/schedules", NewController[models.Schedule]([]Action{
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "",
|
||||
HandlerFunc: PostSchedule,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPut,
|
||||
Path: "/:id",
|
||||
HandlerFunc: PutScheduleById,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "/:id/enable",
|
||||
HandlerFunc: PostScheduleEnable,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "/:id/disable",
|
||||
HandlerFunc: PostScheduleDisable,
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Path: "/:id/run",
|
||||
HandlerFunc: PostScheduleRun,
|
||||
},
|
||||
}...))
|
||||
RegisterController(groups.AuthGroup, "/tasks", NewController[models.Task]([]Action{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestRegisterController_Routes(t *testing.T) {
|
||||
func TestInitRoutes_ProjectsRoute(t *testing.T) {
|
||||
router := gin.Default()
|
||||
|
||||
controllers.InitRoutes(router)
|
||||
_ = controllers.InitRoutes(router)
|
||||
|
||||
// Check if the projects route is registered
|
||||
routes := router.Routes()
|
||||
|
||||
@@ -1,59 +1,55 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
mongo2 "github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/crawlab-team/crawlab/core/spider"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/fs"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
mongo2 "github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/crawlab-team/crawlab/core/spider"
|
||||
"github.com/crawlab-team/crawlab/core/spider/admin"
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/juju/errors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func GetSpiderById(c *gin.Context) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
// GetSpiderById handles getting a spider by ID
|
||||
func GetSpiderById(_ *gin.Context, params *GetByIdParams) (response *Response[models.Spider], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
|
||||
}
|
||||
s, err := service.NewModelService[models.Spider]().GetById(id)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
HandleErrorNotFound(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](errors.NotFoundf("spider not found"))
|
||||
}
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// stat
|
||||
s.Stat, err = service.NewModelService[models.SpiderStat]().GetById(s.Id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mongo.ErrNoDocuments) {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
}
|
||||
|
||||
// data collection (compatible to old version) # TODO: remove in the future
|
||||
// data collection (compatible to old version)
|
||||
if s.ColName == "" && !s.ColId.IsZero() {
|
||||
col, err := service.NewModelService[models.DataCollection]().GetById(s.ColId)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mongo.ErrNoDocuments) {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
} else {
|
||||
s.ColName = col.Name
|
||||
@@ -65,55 +61,51 @@ func GetSpiderById(c *gin.Context) {
|
||||
s.Git, err = service.NewModelService[models.Git]().GetById(s.GitId)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mongo.ErrNoDocuments) {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, s)
|
||||
return GetDataResponse(*s)
|
||||
}
|
||||
|
||||
func GetSpiderList(c *gin.Context, params *GetListParams) {
|
||||
// GetSpiderList handles getting a list of spiders with optional stats
|
||||
func GetSpiderList(c *gin.Context, params *GetListParams) (response *ListResponse[models.Spider], err error) {
|
||||
// get all list
|
||||
all := MustGetFilterAll(c)
|
||||
all := params.All
|
||||
if all {
|
||||
NewController[models.Spider]().getAll(c)
|
||||
return
|
||||
return NewController[models.Spider]().GetAll(params)
|
||||
}
|
||||
|
||||
// get list
|
||||
withStats := c.Query("stats")
|
||||
if withStats == "" {
|
||||
NewController[models.Spider]().GetList(c, params)
|
||||
return
|
||||
return NewController[models.Spider]().GetList(c, params)
|
||||
}
|
||||
|
||||
// get list with stats
|
||||
getSpiderListWithStats(c)
|
||||
return getSpiderListWithStats(params)
|
||||
}
|
||||
|
||||
func getSpiderListWithStats(c *gin.Context) {
|
||||
// params
|
||||
pagination := MustGetPagination(c)
|
||||
query := MustGetFilterQuery(c)
|
||||
sort := MustGetSortOption(c)
|
||||
|
||||
func getSpiderListWithStats(params *GetListParams) (response *ListResponse[models.Spider], err error) {
|
||||
query, err := GetFilterQueryFromListParams(params)
|
||||
if err != nil {
|
||||
return GetErrorListResponse[models.Spider](errors.BadRequestf("invalid request parameters: %v", err))
|
||||
}
|
||||
// get list
|
||||
spiders, err := service.NewModelService[models.Spider]().GetMany(query, &mongo2.FindOptions{
|
||||
Sort: sort,
|
||||
Skip: pagination.Size * (pagination.Page - 1),
|
||||
Limit: pagination.Size,
|
||||
Sort: params.Sort,
|
||||
Skip: params.Size * (params.Page - 1),
|
||||
Limit: params.Size,
|
||||
})
|
||||
if err != nil {
|
||||
if err.Error() != mongo.ErrNoDocuments.Error() {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
if !errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return GetErrorListResponse[models.Spider](err)
|
||||
}
|
||||
return
|
||||
return GetListResponse[models.Spider]([]models.Spider{}, 0)
|
||||
}
|
||||
if len(spiders) == 0 {
|
||||
HandleSuccessWithListData(c, []models.Spider{}, 0)
|
||||
return
|
||||
return GetListResponse[models.Spider]([]models.Spider{}, 0)
|
||||
}
|
||||
|
||||
// ids
|
||||
@@ -129,15 +121,13 @@ func getSpiderListWithStats(c *gin.Context) {
|
||||
// total count
|
||||
total, err := service.NewModelService[models.Spider]().Count(query)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// stat list
|
||||
spiderStats, err := service.NewModelService[models.SpiderStat]().GetMany(bson.M{"_id": bson.M{"$in": ids}}, nil)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// cache stat list to dict
|
||||
@@ -170,15 +160,13 @@ func getSpiderListWithStats(c *gin.Context) {
|
||||
}
|
||||
tasks, err = service.NewModelService[models.Task]().GetMany(queryTask, nil)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// task stats list
|
||||
taskStats, err := service.NewModelService[models.TaskStat]().GetMany(queryTask, nil)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// cache task stats to dict
|
||||
@@ -201,8 +189,7 @@ func getSpiderListWithStats(c *gin.Context) {
|
||||
if len(gitIds) > 0 && utils.IsPro() {
|
||||
gits, err = service.NewModelService[models.Git]().GetMany(bson.M{"_id": bson.M{"$in": gitIds}}, nil)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.Spider](err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,16 +227,12 @@ func getSpiderListWithStats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// response
|
||||
HandleSuccessWithListData(c, data, total)
|
||||
return GetListResponse(data, total)
|
||||
}
|
||||
|
||||
func PostSpider(c *gin.Context) {
|
||||
// bind
|
||||
var s models.Spider
|
||||
if err := c.ShouldBindJSON(&s); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
// PostSpider handles creating a new spider
|
||||
func PostSpider(c *gin.Context, params *PostParams[models.Spider]) (response *Response[models.Spider], err error) {
|
||||
s := params.Data
|
||||
|
||||
if s.Mode == "" {
|
||||
s.Mode = constants.RunTypeRandom
|
||||
@@ -266,8 +249,7 @@ func PostSpider(c *gin.Context) {
|
||||
s.SetUpdated(u.Id)
|
||||
id, err := service.NewModelService[models.Spider]().InsertOne(s)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
s.SetId(id)
|
||||
|
||||
@@ -278,20 +260,17 @@ func PostSpider(c *gin.Context) {
|
||||
st.SetUpdated(u.Id)
|
||||
_, err = service.NewModelService[models.SpiderStat]().InsertOne(st)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// create folder
|
||||
fsSvc, err := getSpiderFsSvcById(id)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
err = fsSvc.CreateDir(".")
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// create template if available
|
||||
@@ -299,63 +278,52 @@ func PostSpider(c *gin.Context) {
|
||||
if templateSvc := spider.GetSpiderTemplateRegistryService(); templateSvc != nil {
|
||||
err = templateSvc.CreateTemplate(s.Id)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, s)
|
||||
return GetDataResponse(s)
|
||||
}
|
||||
|
||||
func PutSpiderById(c *gin.Context) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
// PutSpiderById handles updating a spider by ID
|
||||
func PutSpiderById(c *gin.Context, params *PutByIdParams[models.Spider]) (response *Response[models.Spider], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// bind
|
||||
var s models.Spider
|
||||
if err := c.ShouldBindJSON(&s); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
|
||||
}
|
||||
|
||||
u := GetUserFromContext(c)
|
||||
|
||||
modelSvc := service.NewModelService[models.Spider]()
|
||||
|
||||
// save
|
||||
s.SetUpdated(u.Id)
|
||||
err = modelSvc.ReplaceById(id, s)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
params.Data.SetUpdated(u.Id)
|
||||
if params.Data.Id.IsZero() {
|
||||
params.Data.SetId(id)
|
||||
}
|
||||
|
||||
_s, err := modelSvc.GetById(id)
|
||||
err = modelSvc.ReplaceById(id, params.Data)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
s = *_s
|
||||
|
||||
HandleSuccessWithData(c, s)
|
||||
s, err := modelSvc.GetById(id)
|
||||
if err != nil {
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
return GetDataResponse(*s)
|
||||
}
|
||||
|
||||
func DeleteSpiderById(c *gin.Context) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
// DeleteSpiderById handles deleting a spider by ID
|
||||
func DeleteSpiderById(_ *gin.Context, params *DeleteByIdParams) (response *Response[models.Spider], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
|
||||
}
|
||||
|
||||
// spider
|
||||
s, err := service.NewModelService[models.Spider]().GetById(id)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](errors.NotFoundf("spider not found"))
|
||||
}
|
||||
|
||||
if err := mongo2.RunTransaction(func(context mongo.SessionContext) (err error) {
|
||||
@@ -416,14 +384,13 @@ func DeleteSpiderById(c *gin.Context) {
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
if !s.GitId.IsZero() {
|
||||
go func() {
|
||||
// delete spider directory
|
||||
fsSvc, err := getSpiderFsSvcById(id)
|
||||
fsSvc, err := getSpiderFsSvcById(s.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to get spider fs service: %v", err)
|
||||
return
|
||||
@@ -436,34 +403,39 @@ func DeleteSpiderById(c *gin.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
HandleSuccess(c)
|
||||
return GetDataResponse(models.Spider{})
|
||||
}
|
||||
|
||||
func DeleteSpiderList(c *gin.Context) {
|
||||
var payload struct {
|
||||
Ids []primitive.ObjectID `json:"ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
type DeleteSpiderListParams struct {
|
||||
Ids []string `json:"ids" validate:"required"`
|
||||
}
|
||||
|
||||
// DeleteSpiderList handles deleting multiple spiders
|
||||
func DeleteSpiderList(_ *gin.Context, params *DeleteSpiderListParams) (response *Response[models.Spider], err error) {
|
||||
var ids []primitive.ObjectID
|
||||
for _, id := range params.Ids {
|
||||
_id, err := primitive.ObjectIDFromHex(id)
|
||||
if err != nil {
|
||||
return GetErrorResponse[models.Spider](errors.BadRequestf("invalid id format"))
|
||||
}
|
||||
ids = append(ids, _id)
|
||||
}
|
||||
|
||||
// Fetch spiders before deletion
|
||||
spiders, err := service.NewModelService[models.Spider]().GetMany(bson.M{
|
||||
"_id": bson.M{
|
||||
"$in": payload.Ids,
|
||||
"$in": ids,
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mongo2.RunTransaction(func(context mongo.SessionContext) (err error) {
|
||||
// delete spiders
|
||||
if err := service.NewModelService[models.Spider]().DeleteMany(bson.M{
|
||||
"_id": bson.M{
|
||||
"$in": payload.Ids,
|
||||
"$in": ids,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
@@ -472,14 +444,14 @@ func DeleteSpiderList(c *gin.Context) {
|
||||
// delete spider stats
|
||||
if err := service.NewModelService[models.SpiderStat]().DeleteMany(bson.M{
|
||||
"_id": bson.M{
|
||||
"$in": payload.Ids,
|
||||
"$in": ids,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// related tasks
|
||||
tasks, err := service.NewModelService[models.Task]().GetMany(bson.M{"spider_id": bson.M{"$in": payload.Ids}}, nil)
|
||||
tasks, err := service.NewModelService[models.Task]().GetMany(bson.M{"spider_id": bson.M{"$in": ids}}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -521,8 +493,7 @@ func DeleteSpiderList(c *gin.Context) {
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// Delete spider directories
|
||||
@@ -554,25 +525,25 @@ func DeleteSpiderList(c *gin.Context) {
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
HandleSuccess(c)
|
||||
return GetDataResponse(models.Spider{})
|
||||
}
|
||||
|
||||
func GetSpiderListDir(c *gin.Context) {
|
||||
func GetSpiderListDir(c *gin.Context, params *GetBaseFileListDirParams) (response *Response[[]interfaces.FsFileInfo], err error) {
|
||||
rootPath, err := getSpiderRootPathByContext(c)
|
||||
if err != nil {
|
||||
HandleErrorForbidden(c, err)
|
||||
return
|
||||
}
|
||||
GetBaseFileListDir(rootPath, c)
|
||||
return GetBaseFileListDir(rootPath, params)
|
||||
}
|
||||
|
||||
func GetSpiderFile(c *gin.Context) {
|
||||
func GetSpiderFile(c *gin.Context, params *GetBaseFileFileParams) (response *Response[string], err error) {
|
||||
rootPath, err := getSpiderRootPathByContext(c)
|
||||
if err != nil {
|
||||
HandleErrorForbidden(c, err)
|
||||
return
|
||||
}
|
||||
GetBaseFileFile(rootPath, c)
|
||||
return GetBaseFileFile(rootPath, params)
|
||||
}
|
||||
|
||||
func GetSpiderFileInfo(c *gin.Context) {
|
||||
@@ -648,18 +619,16 @@ func PostSpiderExport(c *gin.Context) {
|
||||
PostBaseFileExport(rootPath, c)
|
||||
}
|
||||
|
||||
func PostSpiderRun(c *gin.Context) {
|
||||
func PostSpiderRun(c *gin.Context) (response *Response[[]primitive.ObjectID], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[[]primitive.ObjectID](errors.BadRequestf("invalid id format"))
|
||||
}
|
||||
|
||||
// options
|
||||
var opts interfaces.SpiderRunOptions
|
||||
if err := c.ShouldBindJSON(&opts); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[[]primitive.ObjectID](err)
|
||||
}
|
||||
|
||||
// user
|
||||
@@ -670,11 +639,11 @@ func PostSpiderRun(c *gin.Context) {
|
||||
// schedule tasks
|
||||
taskIds, err := admin.GetSpiderAdminService().Schedule(id, &opts)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[[]primitive.ObjectID](err)
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, taskIds)
|
||||
return GetDataResponse(taskIds)
|
||||
}
|
||||
|
||||
func GetSpiderResults(c *gin.Context) {
|
||||
@@ -733,17 +702,13 @@ func getSpiderFsSvcById(id primitive.ObjectID) (svc interfaces.FsService, err er
|
||||
}
|
||||
|
||||
func getSpiderRootPathByContext(c *gin.Context) (rootPath string, err error) {
|
||||
// spider id
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// spider
|
||||
s, err := service.NewModelService[models.Spider]().GetById(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return utils.GetSpiderRootPath(s)
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@ package controllers_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/loopfz/gadgeto/tonic"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/controllers"
|
||||
"github.com/crawlab-team/crawlab/core/middlewares"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
@@ -24,9 +26,9 @@ func TestCreateSpider(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.POST("/spiders", controllers.PostSpider)
|
||||
router.POST("/spiders", nil, tonic.Handler(controllers.PostSpider, 200))
|
||||
|
||||
payload := models.Spider{
|
||||
Name: "Test Spider",
|
||||
@@ -54,9 +56,9 @@ func TestGetSpiderById(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.GET("/spiders/:id", controllers.GetSpiderById)
|
||||
router.GET("/spiders/:id", nil, tonic.Handler(controllers.GetSpiderById, 200))
|
||||
|
||||
model := models.Spider{
|
||||
Name: "Test Spider",
|
||||
@@ -89,9 +91,9 @@ func TestUpdateSpiderById(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.PUT("/spiders/:id", controllers.PutSpiderById)
|
||||
router.PUT("/spiders/:id", nil, tonic.Handler(controllers.PutSpiderById, 200))
|
||||
|
||||
model := models.Spider{
|
||||
Name: "Test Spider",
|
||||
@@ -110,7 +112,10 @@ func TestUpdateSpiderById(t *testing.T) {
|
||||
ColName: "test_spider",
|
||||
}
|
||||
payload.SetId(id)
|
||||
jsonValue, _ := json.Marshal(payload)
|
||||
requestBody := controllers.PutByIdParams[models.Spider]{
|
||||
Data: payload,
|
||||
}
|
||||
jsonValue, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("PUT", "/spiders/"+spiderId, bytes.NewBuffer(jsonValue))
|
||||
req.Header.Set("Authorization", TestToken)
|
||||
resp := httptest.NewRecorder()
|
||||
@@ -136,9 +141,9 @@ func TestDeleteSpiderById(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.DELETE("/spiders/:id", controllers.DeleteSpiderById)
|
||||
router.DELETE("/spiders/:id", nil, tonic.Handler(controllers.DeleteSpiderById, 200))
|
||||
|
||||
model := models.Spider{
|
||||
Name: "Test Spider",
|
||||
@@ -186,9 +191,9 @@ func TestDeleteSpiderList(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.DELETE("/spiders", controllers.DeleteSpiderList)
|
||||
router.DELETE("/spiders", nil, tonic.Handler(controllers.DeleteSpiderList, 200))
|
||||
|
||||
modelList := []models.Spider{
|
||||
{
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
"regexp"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/juju/errors"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
@@ -15,34 +15,31 @@ import (
|
||||
mongo2 "go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func GetUserById(c *gin.Context) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
func GetUserById(c *gin.Context, params *GetByIdParams) (response *Response[models.User], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
getUserById(id, c)
|
||||
return getUserById(id)
|
||||
}
|
||||
|
||||
func GetUserList(c *gin.Context) {
|
||||
// params
|
||||
pagination := MustGetPagination(c)
|
||||
query := MustGetFilterQuery(c)
|
||||
sort := MustGetSortOption(c)
|
||||
|
||||
// get users
|
||||
func GetUserList(_ *gin.Context, params *GetListParams) (response *ListResponse[models.User], err error) {
|
||||
query, err := GetFilterQueryFromListParams(params)
|
||||
if err != nil {
|
||||
return GetErrorListResponse[models.User](err)
|
||||
}
|
||||
users, err := service.NewModelService[models.User]().GetMany(query, &mongo.FindOptions{
|
||||
Sort: sort,
|
||||
Skip: pagination.Size * (pagination.Page - 1),
|
||||
Limit: pagination.Size,
|
||||
Sort: params.Sort,
|
||||
Skip: params.Size * (params.Page - 1),
|
||||
Limit: params.Size,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
HandleSuccessWithListData(c, nil, 0)
|
||||
return GetListResponse[models.User](nil, 0)
|
||||
} else {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return GetErrorListResponse[models.User](err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get roles
|
||||
@@ -58,8 +55,7 @@ func GetUserList(c *gin.Context) {
|
||||
"_id": bson.M{"$in": roleIds},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.User](err)
|
||||
}
|
||||
rolesMap := make(map[primitive.ObjectID]models.Role)
|
||||
for _, role := range roles {
|
||||
@@ -80,149 +76,124 @@ func GetUserList(c *gin.Context) {
|
||||
// total count
|
||||
total, err := service.NewModelService[models.User]().Count(query)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorListResponse[models.User](err)
|
||||
}
|
||||
|
||||
// response
|
||||
HandleSuccessWithListData(c, users, total)
|
||||
return GetListResponse[models.User](users, total)
|
||||
}
|
||||
|
||||
func PostUser(c *gin.Context) {
|
||||
var payload struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Role string `json:"role"`
|
||||
RoleId primitive.ObjectID `json:"role_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
type PostUserParams struct {
|
||||
Username string `json:"username" validate:"required"`
|
||||
Password string `json:"password" validate:"required"`
|
||||
Role string `json:"role"`
|
||||
RoleId primitive.ObjectID `json:"role_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func PostUser(c *gin.Context, params *PostUserParams) (response *Response[models.User], err error) {
|
||||
// Validate email format
|
||||
if payload.Email != "" {
|
||||
if params.Email != "" {
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
if !emailRegex.MatchString(payload.Email) {
|
||||
HandleErrorBadRequest(c, fmt.Errorf("invalid email format"))
|
||||
return
|
||||
if !emailRegex.MatchString(params.Email) {
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("invalid email format"))
|
||||
}
|
||||
}
|
||||
|
||||
if !payload.RoleId.IsZero() {
|
||||
_, err := service.NewModelService[models.Role]().GetById(payload.RoleId)
|
||||
if !params.RoleId.IsZero() {
|
||||
_, err := service.NewModelService[models.Role]().GetById(params.RoleId)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("role not found: %v", err))
|
||||
}
|
||||
}
|
||||
u := GetUserFromContext(c)
|
||||
model := models.User{
|
||||
Username: payload.Username,
|
||||
Password: utils.EncryptMd5(payload.Password),
|
||||
Role: payload.Role,
|
||||
RoleId: payload.RoleId,
|
||||
Email: payload.Email,
|
||||
Username: params.Username,
|
||||
Password: utils.EncryptMd5(params.Password),
|
||||
Role: params.Role,
|
||||
RoleId: params.RoleId,
|
||||
Email: params.Email,
|
||||
}
|
||||
model.SetCreated(u.Id)
|
||||
model.SetUpdated(u.Id)
|
||||
id, err := service.NewModelService[models.User]().InsertOne(model)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
result, err := service.NewModelService[models.User]().GetById(id)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, result)
|
||||
return GetDataResponse[models.User](*result)
|
||||
}
|
||||
|
||||
func PutUserById(c *gin.Context) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
func PutUserById(c *gin.Context, params *PutByIdParams[models.User]) (response *Response[models.User], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
|
||||
}
|
||||
putUser(id, c)
|
||||
return putUser(id, GetUserFromContext(c).Id, params.Data)
|
||||
}
|
||||
|
||||
func PostUserChangePassword(c *gin.Context) {
|
||||
// get id
|
||||
type PostUserChangePasswordParams struct {
|
||||
Id string `path:"id"`
|
||||
Password string `json:"password" validate:"required"`
|
||||
}
|
||||
|
||||
func PostUserChangePassword(c *gin.Context, params *PostUserChangePasswordParams) (response *Response[models.User], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
|
||||
}
|
||||
|
||||
postUserChangePassword(id, c)
|
||||
return postUserChangePassword(id, GetUserFromContext(c).Id, params.Password)
|
||||
}
|
||||
|
||||
func DeleteUserById(c *gin.Context) {
|
||||
id, err := primitive.ObjectIDFromHex(c.Param("id"))
|
||||
func DeleteUserById(_ *gin.Context, params *DeleteByIdParams) (response *Response[models.User], err error) {
|
||||
id, err := primitive.ObjectIDFromHex(params.Id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
|
||||
}
|
||||
|
||||
user, err := service.NewModelService[models.User]().GetById(id)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
if user.RootAdmin {
|
||||
HandleErrorForbidden(c, errors.New("root admin cannot be deleted"))
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.New("root admin cannot be deleted"))
|
||||
}
|
||||
|
||||
if err := service.NewModelService[models.User]().DeleteById(id); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
HandleSuccess(c)
|
||||
return GetDataResponse[models.User](models.User{})
|
||||
}
|
||||
|
||||
func DeleteUserList(c *gin.Context) {
|
||||
type Payload struct {
|
||||
Ids []string `json:"ids"`
|
||||
}
|
||||
|
||||
var payload Payload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteUserList(_ *gin.Context, params *DeleteListParams) (response *Response[models.User], err error) {
|
||||
// Convert string IDs to ObjectIDs
|
||||
var ids []primitive.ObjectID
|
||||
for _, id := range payload.Ids {
|
||||
for _, id := range params.Ids {
|
||||
objectId, err := primitive.ObjectIDFromHex(id)
|
||||
if err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("invalid user id: %v", err))
|
||||
}
|
||||
ids = append(ids, objectId)
|
||||
}
|
||||
|
||||
// Check if root admin is in the list
|
||||
_, err := service.NewModelService[models.User]().GetOne(bson.M{
|
||||
_, err = service.NewModelService[models.User]().GetOne(bson.M{
|
||||
"_id": bson.M{
|
||||
"$in": ids,
|
||||
},
|
||||
"root_admin": true,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
HandleErrorForbidden(c, errors.New("root admin cannot be deleted"))
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.New("root admin cannot be deleted"))
|
||||
}
|
||||
if !errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
// Delete users
|
||||
@@ -231,34 +202,43 @@ func DeleteUserList(c *gin.Context) {
|
||||
"$in": ids,
|
||||
},
|
||||
}); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
HandleSuccess(c)
|
||||
return GetDataResponse[models.User](models.User{})
|
||||
}
|
||||
|
||||
func GetUserMe(c *gin.Context) {
|
||||
func GetUserMe(c *gin.Context) (response *Response[models.User], err error) {
|
||||
u := GetUserFromContext(c)
|
||||
getUserByIdWithRoutes(u.Id, c)
|
||||
return getUserByIdWithRoutes(u.Id)
|
||||
}
|
||||
|
||||
func PutUserMe(c *gin.Context) {
|
||||
type PutUserMeParams struct {
|
||||
Data models.User `json:"data"`
|
||||
}
|
||||
|
||||
func PutUserMe(c *gin.Context, params *PutUserMeParams) (response *Response[models.User], err error) {
|
||||
u := GetUserFromContext(c)
|
||||
putUser(u.Id, c)
|
||||
return putUser(u.Id, u.Id, params.Data)
|
||||
}
|
||||
|
||||
func PostUserMeChangePassword(c *gin.Context) {
|
||||
type PostUserMeChangePasswordParams struct {
|
||||
Password string `json:"password" validate:"required"`
|
||||
}
|
||||
|
||||
func PostUserMeChangePassword(c *gin.Context, params *PostUserMeChangePasswordParams) (response *Response[models.User], err error) {
|
||||
u := GetUserFromContext(c)
|
||||
postUserChangePassword(u.Id, c)
|
||||
return postUserChangePassword(u.Id, u.Id, params.Password)
|
||||
}
|
||||
|
||||
func getUserById(userId primitive.ObjectID, c *gin.Context) {
|
||||
func getUserById(userId primitive.ObjectID) (response *Response[models.User], err error) {
|
||||
// get user
|
||||
user, err := service.NewModelService[models.User]().GetById(userId)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("user not found: %v", err))
|
||||
}
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
// get role
|
||||
@@ -266,61 +246,58 @@ func getUserById(userId primitive.ObjectID, c *gin.Context) {
|
||||
if !user.RoleId.IsZero() {
|
||||
role, err := service.NewModelService[models.Role]().GetById(user.RoleId)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("role not found: %v", err))
|
||||
}
|
||||
user.Role = role.Name
|
||||
user.RootAdminRole = role.RootAdmin
|
||||
}
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, user)
|
||||
return GetDataResponse[models.User](*user)
|
||||
}
|
||||
|
||||
func getUserByIdWithRoutes(userId primitive.ObjectID, c *gin.Context) {
|
||||
func getUserByIdWithRoutes(userId primitive.ObjectID) (response *Response[models.User], err error) {
|
||||
if !utils.IsPro() {
|
||||
getUserById(userId, c)
|
||||
return
|
||||
return getUserById(userId)
|
||||
}
|
||||
|
||||
// get user
|
||||
user, err := service.NewModelService[models.User]().GetById(userId)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("user not found: %v", err))
|
||||
}
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
// get role
|
||||
if !user.RoleId.IsZero() {
|
||||
role, err := service.NewModelService[models.Role]().GetById(user.RoleId)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("role not found: %v", err))
|
||||
}
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
user.Role = role.Name
|
||||
user.RootAdminRole = role.RootAdmin
|
||||
user.Routes = role.Routes
|
||||
}
|
||||
|
||||
HandleSuccessWithData(c, user)
|
||||
return GetDataResponse[models.User](*user)
|
||||
}
|
||||
|
||||
func putUser(userId primitive.ObjectID, c *gin.Context) {
|
||||
// get payload
|
||||
var user models.User
|
||||
if err := c.ShouldBindJSON(&user); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
func putUser(userId, by primitive.ObjectID, user models.User) (response *Response[models.User], err error) {
|
||||
// model service
|
||||
modelSvc := service.NewModelService[models.User]()
|
||||
|
||||
// update user
|
||||
userDb, err := modelSvc.GetById(userId)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("user not found: %v", err))
|
||||
}
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
// if root admin, disallow changing username and role
|
||||
@@ -332,53 +309,34 @@ func putUser(userId primitive.ObjectID, c *gin.Context) {
|
||||
// disallow changing password
|
||||
user.Password = userDb.Password
|
||||
|
||||
// current user
|
||||
u := GetUserFromContext(c)
|
||||
|
||||
// update user
|
||||
user.SetUpdated(u.Id)
|
||||
user.SetUpdated(by)
|
||||
if user.Id.IsZero() {
|
||||
user.Id = userId
|
||||
}
|
||||
if err := modelSvc.ReplaceById(userId, user); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
// handle success
|
||||
HandleSuccess(c)
|
||||
return GetDataResponse[models.User](user)
|
||||
}
|
||||
|
||||
func postUserChangePassword(userId primitive.ObjectID, c *gin.Context) {
|
||||
// get payload
|
||||
var payload struct {
|
||||
Password string `json:"password"`
|
||||
func postUserChangePassword(userId, by primitive.ObjectID, password string) (response *Response[models.User], err error) {
|
||||
if len(password) < 5 {
|
||||
return GetErrorResponse[models.User](errors.BadRequestf("password must be at least 5 characters"))
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
HandleErrorBadRequest(c, err)
|
||||
return
|
||||
}
|
||||
if len(payload.Password) < 5 {
|
||||
HandleErrorBadRequest(c, errors.New("password must be at least 5 characters"))
|
||||
return
|
||||
}
|
||||
|
||||
// current user
|
||||
u := GetUserFromContext(c)
|
||||
|
||||
// update password
|
||||
userDb, err := service.NewModelService[models.User]().GetById(userId)
|
||||
if err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
userDb.SetUpdated(u.Id)
|
||||
userDb.Password = utils.EncryptMd5(payload.Password)
|
||||
userDb.SetUpdated(by)
|
||||
userDb.Password = utils.EncryptMd5(password)
|
||||
if err := service.NewModelService[models.User]().ReplaceById(userDb.Id, *userDb); err != nil {
|
||||
HandleErrorInternalServerError(c, err)
|
||||
return
|
||||
return GetErrorResponse[models.User](err)
|
||||
}
|
||||
|
||||
// handle success
|
||||
HandleSuccess(c)
|
||||
return GetDataResponse[models.User](models.User{})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package controllers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -12,7 +14,7 @@ import (
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
"github.com/crawlab-team/crawlab/core/user"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/loopfz/gadgeto/tonic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
@@ -36,9 +38,9 @@ func TestGetUserById_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
u.SetId(id)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.GET("/users/:id", controllers.GetUserById)
|
||||
router.GET("/users/:id", nil, tonic.Handler(controllers.GetUserById, 200))
|
||||
|
||||
// Test valid ID
|
||||
req, err := http.NewRequest(http.MethodGet, "/users/"+id.Hex(), nil)
|
||||
@@ -79,9 +81,9 @@ func TestGetUserList_Success(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.GET("/users", controllers.GetUserList)
|
||||
router.GET("/users", nil, tonic.Handler(controllers.GetUserList, 200))
|
||||
|
||||
// Test default pagination
|
||||
req, err := http.NewRequest(http.MethodGet, "/users", nil)
|
||||
@@ -108,9 +110,9 @@ func TestPostUser_Success(t *testing.T) {
|
||||
SetupTestDB()
|
||||
defer CleanupTestDB()
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.POST("/users", controllers.PostUser)
|
||||
router.POST("/users", nil, tonic.Handler(controllers.PostUser, 200))
|
||||
|
||||
// Test creating a new user with valid data
|
||||
reqBody := strings.NewReader(`{
|
||||
@@ -161,9 +163,9 @@ func TestPutUserById_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
u.SetId(id)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.PUT("/users/:id", controllers.PutUserById)
|
||||
router.PUT("/users/:id", nil, tonic.Handler(controllers.PutUserById, 200))
|
||||
|
||||
// Test case 1: Regular user update
|
||||
reqBody := strings.NewReader(`{
|
||||
@@ -214,9 +216,9 @@ func TestPostUserChangePassword_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
u.SetId(id)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.POST("/users/:id/change-password", controllers.PostUserChangePassword)
|
||||
router.POST("/users/:id/change-password", nil, tonic.Handler(controllers.PostUserChangePassword, 200))
|
||||
|
||||
// Add validation for minimum password length
|
||||
// Test case 1: Valid password
|
||||
@@ -252,9 +254,9 @@ func TestGetUserMe_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
u.SetId(id)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.GET("/users/me", controllers.GetUserMe)
|
||||
router.GET("/users/me", nil, tonic.Handler(controllers.GetUserMe, 200))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/users/me", nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -288,23 +290,26 @@ func TestPutUserMe_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create router
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.PUT("/users/me", controllers.PutUserMe)
|
||||
router.PUT("/users/me", nil, tonic.Handler(controllers.PutUserMe, 200))
|
||||
|
||||
// Test valid update
|
||||
reqBody := strings.NewReader(`{
|
||||
"username": "updateduser",
|
||||
"email": "updated@example.com"
|
||||
}`)
|
||||
req, err := http.NewRequest(http.MethodPut, "/users/me", reqBody)
|
||||
reqParams := controllers.PutUserMeParams{
|
||||
Data: models.User{
|
||||
Username: "updateduser",
|
||||
Email: "updated@example.com",
|
||||
},
|
||||
}
|
||||
jsonValue, _ := json.Marshal(reqParams)
|
||||
req, err := http.NewRequest(http.MethodPut, "/users/me", bytes.NewBuffer(jsonValue))
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equalf(t, http.StatusOK, w.Code, "response body: %s", w.Body.String())
|
||||
|
||||
// Verify the update
|
||||
updatedUser, err := modelSvc.GetById(id)
|
||||
@@ -338,9 +343,9 @@ func TestPostUserMeChangePassword_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create router
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.POST("/users/me/change-password", controllers.PostUserMeChangePassword)
|
||||
router.POST("/users/me/change-password", nil, tonic.Handler(controllers.PostUserMeChangePassword, 200))
|
||||
|
||||
// Test valid password change
|
||||
password := "newValidPassword123"
|
||||
@@ -388,9 +393,9 @@ func TestDeleteUserById_Success(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
u.SetId(id)
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.DELETE("/users/:id", controllers.DeleteUserById)
|
||||
router.DELETE("/users/:id", nil, tonic.Handler(controllers.DeleteUserById, 200))
|
||||
|
||||
// Test deleting normal user
|
||||
req, err := http.NewRequest(http.MethodDelete, "/users/"+id.Hex(), nil)
|
||||
@@ -423,7 +428,7 @@ func TestDeleteUserById_Success(t *testing.T) {
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String())
|
||||
|
||||
// Test deleting with invalid ID
|
||||
req, err = http.NewRequest(http.MethodDelete, "/users/invalid-id", nil)
|
||||
@@ -460,9 +465,9 @@ func TestDeleteUserList_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
router := gin.Default()
|
||||
router := SetupRouter()
|
||||
router.Use(middlewares.AuthorizationMiddleware())
|
||||
router.DELETE("/users", controllers.DeleteUserList)
|
||||
router.DELETE("/users", nil, tonic.Handler(controllers.DeleteUserList, 200))
|
||||
|
||||
// Test deleting normal users
|
||||
reqBody := strings.NewReader(fmt.Sprintf(`{"ids":["%s","%s"]}`,
|
||||
@@ -492,7 +497,7 @@ func TestDeleteUserList_Success(t *testing.T) {
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equalf(t, http.StatusBadRequest, w.Code, "response body: %s", w.Body.String())
|
||||
|
||||
// Test with mix of valid and invalid ids
|
||||
reqBody = strings.NewReader(fmt.Sprintf(`{"ids":["%s","invalid-id"]}`, normalUserIds[0].Hex()))
|
||||
|
||||
331
core/controllers/utils.go
Normal file
331
core/controllers/utils.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"github.com/crawlab-team/crawlab/trace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var logger = utils.NewLogger("Controllers")
|
||||
|
||||
func GetUserFromContext(c *gin.Context) (u *models.User) {
|
||||
value, ok := c.Get(constants.UserContextKey)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
u, ok = value.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func GetFilterQueryFromListParams(params *GetListParams) (q bson.M, err error) {
|
||||
if params.Conditions == "" {
|
||||
return nil, nil
|
||||
}
|
||||
conditions, err := GetFilterFromConditionString(params.Conditions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utils.FilterToQuery(conditions)
|
||||
}
|
||||
|
||||
func GetFilterQueryFromConditionString(condStr string) (q bson.M, err error) {
|
||||
conditions, err := GetFilterFromConditionString(condStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utils.FilterToQuery(conditions)
|
||||
}
|
||||
|
||||
func GetFilterFromConditionString(condStr string) (f *entity.Filter, err error) {
|
||||
var conditions []*entity.Condition
|
||||
if err := json.Unmarshal([]byte(condStr), &conditions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, cond := range conditions {
|
||||
v := reflect.ValueOf(cond.Value)
|
||||
switch v.Kind() {
|
||||
case reflect.String:
|
||||
item := cond.Value.(string)
|
||||
// attempt to convert object id
|
||||
id, err := primitive.ObjectIDFromHex(item)
|
||||
if err == nil {
|
||||
conditions[i].Value = id
|
||||
} else {
|
||||
conditions[i].Value = item
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
var items []interface{}
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
vItem := v.Index(i)
|
||||
item := vItem.Interface()
|
||||
|
||||
// string
|
||||
stringItem, ok := item.(string)
|
||||
if ok {
|
||||
id, err := primitive.ObjectIDFromHex(stringItem)
|
||||
if err == nil {
|
||||
items = append(items, id)
|
||||
} else {
|
||||
items = append(items, stringItem)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// default
|
||||
items = append(items, item)
|
||||
}
|
||||
conditions[i].Value = items
|
||||
default:
|
||||
return nil, errors.New("invalid type")
|
||||
}
|
||||
}
|
||||
|
||||
return &entity.Filter{
|
||||
IsOr: false,
|
||||
Conditions: conditions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetFilter Get entity.Filter from gin.Context
|
||||
func GetFilter(c *gin.Context) (f *entity.Filter, err error) {
|
||||
condStr := c.Query(constants.FilterQueryFieldConditions)
|
||||
return GetFilterFromConditionString(condStr)
|
||||
}
|
||||
|
||||
// GetFilterQuery Get bson.M from gin.Context
|
||||
func GetFilterQuery(c *gin.Context) (q bson.M, err error) {
|
||||
f, err := GetFilter(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO: implement logic OR
|
||||
|
||||
return utils.FilterToQuery(f)
|
||||
}
|
||||
|
||||
func MustGetFilterQuery(c *gin.Context) (q bson.M) {
|
||||
q, err := GetFilterQuery(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func getResultListQuery(c *gin.Context) (q mongo.ListQuery) {
|
||||
f, err := GetFilter(c)
|
||||
if err != nil {
|
||||
return q
|
||||
}
|
||||
for _, cond := range f.Conditions {
|
||||
q = append(q, mongo.ListQueryCondition{
|
||||
Key: cond.Key,
|
||||
Op: cond.Op,
|
||||
Value: utils.NormalizeObjectId(cond.Value),
|
||||
})
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func GetDefaultPagination() (p *entity.Pagination) {
|
||||
return &entity.Pagination{
|
||||
Page: constants.PaginationDefaultPage,
|
||||
Size: constants.PaginationDefaultSize,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPagination(c *gin.Context) (p *entity.Pagination, err error) {
|
||||
var _p entity.Pagination
|
||||
if err := c.ShouldBindQuery(&_p); err != nil {
|
||||
return GetDefaultPagination(), err
|
||||
}
|
||||
if _p.Page == 0 {
|
||||
_p.Page = constants.PaginationDefaultPage
|
||||
}
|
||||
if _p.Size == 0 {
|
||||
_p.Size = constants.PaginationDefaultSize
|
||||
}
|
||||
return &_p, nil
|
||||
}
|
||||
|
||||
func MustGetPagination(c *gin.Context) (p *entity.Pagination) {
|
||||
p, err := GetPagination(c)
|
||||
if err != nil || p == nil {
|
||||
return GetDefaultPagination()
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// GetSorts Get entity.Sort from gin.Context
|
||||
func GetSorts(c *gin.Context) (sorts []entity.Sort, err error) {
|
||||
// bind
|
||||
sortStr := c.Query(constants.SortQueryField)
|
||||
if err := json.Unmarshal([]byte(sortStr), &sorts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sorts, nil
|
||||
}
|
||||
|
||||
// GetSortsOption Get entity.Sort from gin.Context
|
||||
func GetSortsOption(c *gin.Context) (sort bson.D, err error) {
|
||||
sorts, err := GetSorts(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sorts == nil || len(sorts) == 0 {
|
||||
return bson.D{{"_id", -1}}, nil
|
||||
}
|
||||
|
||||
return SortsToOption(sorts)
|
||||
}
|
||||
|
||||
func MustGetSortOption(c *gin.Context) (sort bson.D) {
|
||||
sort, err := GetSortsOption(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return sort
|
||||
}
|
||||
|
||||
// SortsToOption Translate entity.Sort to bson.D
|
||||
func SortsToOption(sorts []entity.Sort) (sort bson.D, err error) {
|
||||
sort = bson.D{}
|
||||
for _, s := range sorts {
|
||||
switch s.Direction {
|
||||
case constants.ASCENDING:
|
||||
sort = append(sort, bson.E{Key: s.Key, Value: 1})
|
||||
case constants.DESCENDING:
|
||||
sort = append(sort, bson.E{Key: s.Key, Value: -1})
|
||||
}
|
||||
}
|
||||
if len(sort) == 0 {
|
||||
sort = bson.D{{"_id", -1}}
|
||||
}
|
||||
return sort, nil
|
||||
}
|
||||
|
||||
type Response[T any] struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Data T `json:"data"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type ListResponse[T any] struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Total int `json:"total"`
|
||||
Data []T `json:"data"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func GetDataResponse[T any](model T) (res *Response[T], err error) {
|
||||
return &Response[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetListResponse[T any](models []T, total int) (res *ListResponse[T], err error) {
|
||||
return &ListResponse[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: models,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetErrorResponse[T any](err error) (res *Response[T], err2 error) {
|
||||
return &Response[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageError,
|
||||
Error: err.Error(),
|
||||
}, err
|
||||
}
|
||||
|
||||
func GetErrorListResponse[T any](err error) (res *ListResponse[T], err2 error) {
|
||||
return &ListResponse[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageError,
|
||||
Error: err.Error(),
|
||||
}, err
|
||||
}
|
||||
|
||||
func handleError(statusCode int, c *gin.Context, err error) {
|
||||
if utils.IsDev() {
|
||||
trace.PrintError(err)
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, entity.Response{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageError,
|
||||
Error: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func HandleError(statusCode int, c *gin.Context, err error) {
|
||||
handleError(statusCode, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorBadRequest(c *gin.Context, err error) {
|
||||
HandleError(http.StatusBadRequest, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorForbidden(c *gin.Context, err error) {
|
||||
HandleError(http.StatusForbidden, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorUnauthorized(c *gin.Context, err error) {
|
||||
HandleError(http.StatusUnauthorized, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorNotFound(c *gin.Context, err error) {
|
||||
HandleError(http.StatusNotFound, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorInternalServerError(c *gin.Context, err error) {
|
||||
HandleError(http.StatusInternalServerError, c, err)
|
||||
}
|
||||
|
||||
func HandleSuccess(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
func HandleSuccessWithData(c *gin.Context, data interface{}) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func HandleSuccessWithListData(c *gin.Context, data interface{}, total int) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, entity.ListResponse{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: data,
|
||||
Total: total,
|
||||
})
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetUserFromContext(c *gin.Context) (u *models.User) {
|
||||
value, ok := c.Get(constants.UserContextKey)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
u, ok = value.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return u
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
errors2 "errors"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetFilter Get entity.Filter from gin.Context
|
||||
func GetFilter(c *gin.Context) (f *entity.Filter, err error) {
|
||||
// bind
|
||||
condStr := c.Query(constants.FilterQueryFieldConditions)
|
||||
var conditions []*entity.Condition
|
||||
if err := json.Unmarshal([]byte(condStr), &conditions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// attempt to convert object id
|
||||
for i, cond := range conditions {
|
||||
v := reflect.ValueOf(cond.Value)
|
||||
switch v.Kind() {
|
||||
case reflect.String:
|
||||
item := cond.Value.(string)
|
||||
id, err := primitive.ObjectIDFromHex(item)
|
||||
if err == nil {
|
||||
conditions[i].Value = id
|
||||
} else {
|
||||
conditions[i].Value = item
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
var items []interface{}
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
vItem := v.Index(i)
|
||||
item := vItem.Interface()
|
||||
|
||||
// string
|
||||
stringItem, ok := item.(string)
|
||||
if ok {
|
||||
id, err := primitive.ObjectIDFromHex(stringItem)
|
||||
if err == nil {
|
||||
items = append(items, id)
|
||||
} else {
|
||||
items = append(items, stringItem)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// default
|
||||
items = append(items, item)
|
||||
}
|
||||
conditions[i].Value = items
|
||||
default:
|
||||
return nil, errors2.New("invalid type")
|
||||
}
|
||||
}
|
||||
|
||||
return &entity.Filter{
|
||||
IsOr: false,
|
||||
Conditions: conditions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetFilterQuery Get bson.M from gin.Context
|
||||
func GetFilterQuery(c *gin.Context) (q bson.M, err error) {
|
||||
f, err := GetFilter(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO: implement logic OR
|
||||
|
||||
return utils.FilterToQuery(f)
|
||||
}
|
||||
|
||||
func MustGetFilterQuery(c *gin.Context) (q bson.M) {
|
||||
q, err := GetFilterQuery(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// GetFilterAll Get all from gin.Context
|
||||
func GetFilterAll(c *gin.Context) (res bool, err error) {
|
||||
resStr := c.Query(constants.FilterQueryFieldAll)
|
||||
switch strings.ToUpper(resStr) {
|
||||
case "1":
|
||||
return true, nil
|
||||
case "0":
|
||||
return false, nil
|
||||
case "Y":
|
||||
return true, nil
|
||||
case "N":
|
||||
return false, nil
|
||||
case "T":
|
||||
return true, nil
|
||||
case "F":
|
||||
return false, nil
|
||||
case "TRUE":
|
||||
return true, nil
|
||||
case "FALSE":
|
||||
return false, nil
|
||||
default:
|
||||
return false, errors2.New("invalid value")
|
||||
}
|
||||
}
|
||||
|
||||
func MustGetFilterAll(c *gin.Context) (res bool) {
|
||||
res, err := GetFilterAll(c)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func getResultListQuery(c *gin.Context) (q mongo.ListQuery) {
|
||||
f, err := GetFilter(c)
|
||||
if err != nil {
|
||||
return q
|
||||
}
|
||||
for _, cond := range f.Conditions {
|
||||
q = append(q, mongo.ListQueryCondition{
|
||||
Key: cond.Key,
|
||||
Op: cond.Op,
|
||||
Value: utils.NormalizeObjectId(cond.Value),
|
||||
})
|
||||
}
|
||||
return q
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"github.com/crawlab-team/crawlab/trace"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Response[T any] struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Data T `json:"data"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type ListResponse[T any] struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Total int `json:"total"`
|
||||
Data []T `json:"data"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func GetSuccessDataResponse[T any](model T) (res *Response[T], err error) {
|
||||
return &Response[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetSuccessListResponse[T any](models []T, total int) (res *ListResponse[T], err error) {
|
||||
return &ListResponse[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: models,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetErrorDataResponse[T any](err error) (res *Response[T], err2 error) {
|
||||
return &Response[T]{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageError,
|
||||
Error: err.Error(),
|
||||
}, err
|
||||
}
|
||||
|
||||
func handleError(statusCode int, c *gin.Context, err error) {
|
||||
if utils.IsDev() {
|
||||
trace.PrintError(err)
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, entity.Response{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageError,
|
||||
Error: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func HandleError(statusCode int, c *gin.Context, err error) {
|
||||
handleError(statusCode, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorBadRequest(c *gin.Context, err error) {
|
||||
HandleError(http.StatusBadRequest, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorForbidden(c *gin.Context, err error) {
|
||||
HandleError(http.StatusForbidden, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorUnauthorized(c *gin.Context, err error) {
|
||||
HandleError(http.StatusUnauthorized, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorNotFound(c *gin.Context, err error) {
|
||||
HandleError(http.StatusNotFound, c, err)
|
||||
}
|
||||
|
||||
func HandleErrorInternalServerError(c *gin.Context, err error) {
|
||||
HandleError(http.StatusInternalServerError, c, err)
|
||||
}
|
||||
|
||||
func HandleSuccess(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
func HandleSuccessWithData(c *gin.Context, data interface{}) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, entity.Response{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func HandleSuccessWithListData(c *gin.Context, data interface{}, total int) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, entity.ListResponse{
|
||||
Status: constants.HttpResponseStatusOk,
|
||||
Message: constants.HttpResponseMessageSuccess,
|
||||
Data: data,
|
||||
Total: total,
|
||||
})
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import "github.com/crawlab-team/crawlab/core/utils"
|
||||
|
||||
var logger = utils.NewLogger("Controllers")
|
||||
@@ -1,36 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetDefaultPagination() (p *entity.Pagination) {
|
||||
return &entity.Pagination{
|
||||
Page: constants.PaginationDefaultPage,
|
||||
Size: constants.PaginationDefaultSize,
|
||||
}
|
||||
}
|
||||
|
||||
func GetPagination(c *gin.Context) (p *entity.Pagination, err error) {
|
||||
var _p entity.Pagination
|
||||
if err := c.ShouldBindQuery(&_p); err != nil {
|
||||
return GetDefaultPagination(), err
|
||||
}
|
||||
if _p.Page == 0 {
|
||||
_p.Page = constants.PaginationDefaultPage
|
||||
}
|
||||
if _p.Size == 0 {
|
||||
_p.Size = constants.PaginationDefaultSize
|
||||
}
|
||||
return &_p, nil
|
||||
}
|
||||
|
||||
func MustGetPagination(c *gin.Context) (p *entity.Pagination) {
|
||||
p, err := GetPagination(c)
|
||||
if err != nil || p == nil {
|
||||
return GetDefaultPagination()
|
||||
}
|
||||
return p
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
// GetSorts Get entity.Sort from gin.Context
|
||||
func GetSorts(c *gin.Context) (sorts []entity.Sort, err error) {
|
||||
// bind
|
||||
sortStr := c.Query(constants.SortQueryField)
|
||||
if err := json.Unmarshal([]byte(sortStr), &sorts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sorts, nil
|
||||
}
|
||||
|
||||
// GetSortsOption Get entity.Sort from gin.Context
|
||||
func GetSortsOption(c *gin.Context) (sort bson.D, err error) {
|
||||
sorts, err := GetSorts(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sorts == nil || len(sorts) == 0 {
|
||||
return bson.D{{"_id", -1}}, nil
|
||||
}
|
||||
|
||||
return SortsToOption(sorts)
|
||||
}
|
||||
|
||||
func MustGetSortOption(c *gin.Context) (sort bson.D) {
|
||||
sort, err := GetSortsOption(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return sort
|
||||
}
|
||||
|
||||
// SortsToOption Translate entity.Sort to bson.D
|
||||
func SortsToOption(sorts []entity.Sort) (sort bson.D, err error) {
|
||||
sort = bson.D{}
|
||||
for _, s := range sorts {
|
||||
switch s.Direction {
|
||||
case constants.ASCENDING:
|
||||
sort = append(sort, bson.E{Key: s.Key, Value: 1})
|
||||
case constants.DESCENDING:
|
||||
sort = append(sort, bson.E{Key: s.Key, Value: -1})
|
||||
}
|
||||
}
|
||||
if len(sort) == 0 {
|
||||
sort = bson.D{{"_id", -1}}
|
||||
}
|
||||
return sort, nil
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
http2 "net/http"
|
||||
)
|
||||
|
||||
type WsWriter struct {
|
||||
io.Writer
|
||||
io.Closer
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func (w *WsWriter) Write(data []byte) (n int, err error) {
|
||||
logger.Infof("websocket write: %s", string(data))
|
||||
err = w.conn.WriteMessage(websocket.TextMessage, data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *WsWriter) Close() (err error) {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *WsWriter) CloseWithText(text string) {
|
||||
_ = w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, text))
|
||||
}
|
||||
|
||||
func (w *WsWriter) CloseWithError(err error) {
|
||||
_ = w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, err.Error()))
|
||||
}
|
||||
|
||||
func NewWsWriter(c *gin.Context) (writer *WsWriter, err error) {
|
||||
upgrader := websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http2.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Errorf("websocket open connection error: %v", err)
|
||||
}
|
||||
|
||||
return &WsWriter{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user