Files
crawlab/core/controllers/base.go
Marvin Zhang d6badb533d feat: enhance API routing and OpenAPI documentation support
- Introduced a new OpenAPI wrapper using Fizz for improved API documentation
- Refactored base controller to support more flexible route handling
- Added dynamic route registration with OpenAPI metadata
- Implemented generic response types for consistent API responses
- Updated router initialization to support OpenAPI documentation endpoint
- Improved route and resource naming utilities
- Migrated existing controllers to use new routing and response mechanisms
2025-03-11 23:45:06 +08:00

302 lines
7.1 KiB
Go

package controllers
import (
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/mongo"
"github.com/gin-gonic/gin"
"github.com/juju/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
mongo2 "go.mongodb.org/mongo-driver/mongo"
)
type Action struct {
Method string
Path string
HandlerFunc interface{}
}
type BaseController[T any] struct {
modelSvc *service.ModelService[T]
actions []Action
}
type GetByIdParams struct {
Id string `path:"id" description:"The ID of the item to get"`
}
// GetAllParams represents parameters for GetAll method
type GetAllParams struct {
Query bson.M `json:"query"`
Sort bson.D `json:"sort"`
}
// GetListParams represents parameters for GetList with pagination
type GetListParams struct {
Query bson.M `json:"query"`
Sort bson.D `json:"sort"`
Pagination *entity.Pagination `json:"pagination"`
All bool `query:"all" description:"Whether to get all items"`
}
func (ctr *BaseController[T]) GetById(_ *gin.Context, params *GetByIdParams) (response *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(params.Id)
if err != nil {
return nil, errors.BadRequestf("invalid id: %s", params.Id)
}
model, err := ctr.modelSvc.GetById(id)
if err != nil {
return nil, err
}
return GetSuccessDataResponse(*model)
}
func (ctr *BaseController[T]) GetList(c *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
// get all if query field "all" is set true
all := params.All || MustGetFilterAll(c)
// Prepare parameters
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
if all {
allParams := &GetAllParams{
Query: query,
Sort: sort,
}
return ctr.GetAll(c, allParams)
}
// get list with pagination
pagination := MustGetPagination(c)
listParams := &GetListParams{
Query: query,
Sort: sort,
Pagination: pagination,
}
return ctr.GetWithPagination(c, listParams)
}
func (ctr *BaseController[T]) Post(c *gin.Context) (response *Response[T], err error) {
var model T
if err := c.ShouldBindJSON(&model); err != nil {
return GetErrorDataResponse[T](err)
}
u := GetUserFromContext(c)
m := any(&model).(interfaces.Model)
m.SetId(primitive.NewObjectID())
m.SetCreated(u.Id)
m.SetUpdated(u.Id)
col := ctr.modelSvc.GetCol()
res, err := col.GetCollection().InsertOne(col.GetContext(), m)
if err != nil {
return GetErrorDataResponse[T](err)
}
result, err := ctr.modelSvc.GetById(res.InsertedID.(primitive.ObjectID))
if err != nil {
return GetErrorDataResponse[T](err)
}
return GetSuccessDataResponse(*result)
}
func (ctr *BaseController[T]) PutById(c *gin.Context) (response *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
if err != nil {
return GetErrorDataResponse[T](err)
}
var model T
if err := c.ShouldBindJSON(&model); err != nil {
return GetErrorDataResponse[T](err)
}
u := GetUserFromContext(c)
m := any(&model).(interfaces.Model)
m.SetUpdated(u.Id)
if err := ctr.modelSvc.ReplaceById(id, model); err != nil {
return GetErrorDataResponse[T](err)
}
result, err := ctr.modelSvc.GetById(id)
if err != nil {
return GetErrorDataResponse[T](err)
}
return GetSuccessDataResponse(*result)
}
func (ctr *BaseController[T]) PatchList(c *gin.Context) (res *Response[T], err error) {
type Payload struct {
Ids []primitive.ObjectID `json:"ids"`
Update bson.M `json:"update"`
}
var payload Payload
if err := c.ShouldBindJSON(&payload); err != nil {
return GetErrorDataResponse[T](err)
}
// query
query := bson.M{
"_id": bson.M{
"$in": payload.Ids,
},
}
// update
if err := ctr.modelSvc.UpdateMany(query, bson.M{"$set": payload.Update}); err != nil {
return GetErrorDataResponse[T](err)
}
// Return an empty response with success status
var emptyModel T
return GetSuccessDataResponse(emptyModel)
}
func (ctr *BaseController[T]) DeleteById(c *gin.Context) (res *Response[T], err error) {
id, err := primitive.ObjectIDFromHex(c.Param("id"))
if err != nil {
return GetErrorDataResponse[T](err)
}
if err := ctr.modelSvc.DeleteById(id); err != nil {
return GetErrorDataResponse[T](err)
}
var emptyModel T
return GetSuccessDataResponse(emptyModel)
}
func (ctr *BaseController[T]) DeleteList(c *gin.Context) (res *Response[T], err error) {
type Payload struct {
Ids []string `json:"ids"`
}
var payload Payload
if err := c.ShouldBindJSON(&payload); err != nil {
return GetErrorDataResponse[T](err)
}
var ids []primitive.ObjectID
for _, id := range payload.Ids {
objectId, err := primitive.ObjectIDFromHex(id)
if err != nil {
return GetErrorDataResponse[T](err)
}
ids = append(ids, objectId)
}
if err := ctr.modelSvc.DeleteMany(bson.M{
"_id": bson.M{
"$in": ids,
},
}); err != nil {
return GetErrorDataResponse[T](err)
}
var emptyModel T
return GetSuccessDataResponse(emptyModel)
}
// GetAll retrieves all items based on filter and sort
func (ctr *BaseController[T]) GetAll(_ *gin.Context, params *GetAllParams) (response *ListResponse[T], err error) {
query := params.Query
sort := params.Sort
if sort == nil {
sort = bson.D{{"_id", -1}}
}
models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{
Sort: sort,
})
if err != nil {
return nil, err
}
total, err := ctr.modelSvc.Count(query)
if err != nil {
return nil, err
}
return GetSuccessListResponse(models, total)
}
// GetWithPagination retrieves items with pagination
func (ctr *BaseController[T]) GetWithPagination(_ *gin.Context, params *GetListParams) (response *ListResponse[T], err error) {
// params
pagination := params.Pagination
query := params.Query
sort := params.Sort
if pagination == nil {
pagination = GetDefaultPagination()
}
// get list
models, err := ctr.modelSvc.GetMany(query, &mongo.FindOptions{
Sort: sort,
Skip: pagination.Size * (pagination.Page - 1),
Limit: pagination.Size,
})
if err != nil {
if errors.Is(err, mongo2.ErrNoDocuments) {
return GetSuccessListResponse[T](nil, 0)
} else {
return nil, err
}
}
// total count
total, err := ctr.modelSvc.Count(query)
if err != nil {
return nil, err
}
// response
return GetSuccessListResponse(models, total)
}
// getAll is kept for backward compatibility
func (ctr *BaseController[T]) getAll(c *gin.Context) (response *ListResponse[T], err error) {
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
params := &GetAllParams{
Query: query,
Sort: sort,
}
return ctr.GetAll(c, params)
}
// getList is kept for backward compatibility
func (ctr *BaseController[T]) getList(c *gin.Context) (response *ListResponse[T], err error) {
// params
pagination := MustGetPagination(c)
query := MustGetFilterQuery(c)
sort := MustGetSortOption(c)
params := &GetListParams{
Query: query,
Sort: sort,
Pagination: pagination,
}
return ctr.GetWithPagination(c, params)
}
func NewController[T any](actions ...Action) *BaseController[T] {
ctr := &BaseController[T]{
modelSvc: service.NewModelService[T](),
actions: actions,
}
return ctr
}