Files
crawlab/core/models/service/base_service.go
2024-06-14 15:59:48 +08:00

421 lines
10 KiB
Go

package service
import (
"encoding/json"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/errors"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/delegate"
models2 "github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/db/mongo"
"github.com/crawlab-team/go-trace"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"reflect"
"strings"
"sync"
"time"
)
type BaseService struct {
id interfaces.ModelId
col *mongo.Col
}
func (svc *BaseService) GetModelId() (id interfaces.ModelId) {
return svc.id
}
func (svc *BaseService) SetModelId(id interfaces.ModelId) {
svc.id = id
}
func (svc *BaseService) GetCol() (col *mongo.Col) {
return svc.col
}
func (svc *BaseService) SetCol(col *mongo.Col) {
svc.col = col
}
func (svc *BaseService) GetById(id primitive.ObjectID) (res interfaces.Model, err error) {
// find result
fr := svc.findId(id)
// bind
return NewBasicBinder(svc.id, fr).Bind()
}
func (svc *BaseService) Get(query bson.M, opts *mongo.FindOptions) (res interfaces.Model, err error) {
// find result
fr := svc.find(query, opts)
// bind
return NewBasicBinder(svc.id, fr).Bind()
}
func (svc *BaseService) GetList(query bson.M, opts *mongo.FindOptions) (l interfaces.List, err error) {
// find result
tic := time.Now()
log.Debugf("baseService.GetMany -> svc.find:start")
log.Debugf("baseService.GetMany -> svc.id: %v", svc.id)
log.Debugf("baseService.GetMany -> svc.col.GetName(): %v", svc.col.GetName())
log.Debugf("baseService.GetMany -> query: %v", query)
log.Debugf("baseService.GetMany -> opts: %v", opts)
fr := svc.find(query, opts)
log.Debugf("baseService.GetMany -> svc.find:end. elapsed: %d ms", time.Now().Sub(tic).Milliseconds())
// bind
return NewListBinder(svc.id, fr).Bind()
}
func (svc *BaseService) DeleteById(id primitive.ObjectID, args ...interface{}) (err error) {
return svc.deleteId(id, args...)
}
func (svc *BaseService) Delete(query bson.M, args ...interface{}) (err error) {
return svc.delete(query)
}
func (svc *BaseService) DeleteList(query bson.M, args ...interface{}) (err error) {
return svc.deleteList(query)
}
func (svc *BaseService) ForceDeleteList(query bson.M, args ...interface{}) (err error) {
return svc.forceDeleteList(query)
}
func (svc *BaseService) UpdateById(id primitive.ObjectID, update bson.M, args ...interface{}) (err error) {
return svc.updateId(id, update)
}
func (svc *BaseService) Update(query bson.M, update bson.M, fields []string, args ...interface{}) (err error) {
return svc.update(query, update, fields)
}
func (svc *BaseService) UpdateDoc(query bson.M, doc interfaces.Model, fields []string, args ...interface{}) (err error) {
return svc.update(query, doc, fields)
}
func (svc *BaseService) Insert(u interfaces.User, docs ...interface{}) (err error) {
log.Debugf("baseService.Insert -> svc.col.GetName(): %v", svc.col.GetName())
log.Debugf("baseService.Insert -> docs: %v", docs)
return svc.insert(u, docs...)
}
func (svc *BaseService) Count(query bson.M) (total int, err error) {
return svc.count(query)
}
func (svc *BaseService) findId(id primitive.ObjectID) (fr *mongo.FindResult) {
if svc.col == nil {
return mongo.NewFindResultWithError(constants.ErrMissingCol)
}
return svc.col.FindId(id)
}
func (svc *BaseService) find(query bson.M, opts *mongo.FindOptions) (fr *mongo.FindResult) {
if svc.col == nil {
return mongo.NewFindResultWithError(constants.ErrMissingCol)
}
return svc.col.Find(query, opts)
}
func (svc *BaseService) deleteId(id primitive.ObjectID, args ...interface{}) (err error) {
if svc.col == nil {
return trace.TraceError(constants.ErrMissingCol)
}
fr := svc.findId(id)
doc, err := NewBasicBinder(svc.id, fr).Bind()
if err != nil {
return err
}
return delegate.NewModelDelegate(doc, svc._getUserFromArgs(args...)).Delete()
}
func (svc *BaseService) delete(query bson.M, args ...interface{}) (err error) {
if svc.col == nil {
return trace.TraceError(constants.ErrMissingCol)
}
var doc models2.BaseModel
if err := svc.find(query, nil).One(&doc); err != nil {
return err
}
return svc.deleteId(doc.Id, svc._getUserFromArgs(args...))
}
func (svc *BaseService) deleteList(query bson.M, args ...interface{}) (err error) {
if svc.col == nil {
return trace.TraceError(constants.ErrMissingCol)
}
fr := svc.find(query, nil)
list, err := NewListBinder(svc.id, fr).Bind()
if err != nil {
return err
}
for _, doc := range list.GetModels() {
if err := delegate.NewModelDelegate(doc, svc._getUserFromArgs(args...)).Delete(); err != nil {
return err
}
}
return nil
}
func (svc *BaseService) forceDeleteList(query bson.M, args ...interface{}) (err error) {
return svc.col.Delete(query)
}
func (svc *BaseService) count(query bson.M) (total int, err error) {
if svc.col == nil {
return total, trace.TraceError(constants.ErrMissingCol)
}
return svc.col.Count(query)
}
func (svc *BaseService) update(query bson.M, update interface{}, fields []string, args ...interface{}) (err error) {
update, err = svc._getUpdateBsonM(update, fields)
if err != nil {
return err
}
return svc._update(query, update, svc._getUserFromArgs(args...))
}
func (svc *BaseService) updateId(id primitive.ObjectID, update interface{}, args ...interface{}) (err error) {
update, err = svc._getUpdateBsonM(update, nil)
if err != nil {
return err
}
return svc._updateById(id, update, svc._getUserFromArgs(args...))
}
func (svc *BaseService) insert(u interfaces.User, docs ...interface{}) (err error) {
// validate col
if svc.col == nil {
return trace.TraceError(constants.ErrMissingCol)
}
// iterate docs
for i, doc := range docs {
switch doc.(type) {
case map[string]interface{}:
// doc type: map[string]interface{}, need to handle _id
d := doc.(map[string]interface{})
vId, ok := d["_id"]
if !ok {
// _id not exists
d["_id"] = primitive.NewObjectID()
} else {
// _id exists
switch vId.(type) {
case string:
// _id type: string
sId, ok := vId.(string)
if ok {
d["_id"], err = primitive.ObjectIDFromHex(sId)
if err != nil {
return trace.TraceError(err)
}
}
case primitive.ObjectID:
// _id type: primitive.ObjectID
// do nothing
default:
return trace.TraceError(errors.ErrorModelInvalidType)
}
}
}
docs[i] = doc
}
// perform insert
ids, err := svc.col.InsertMany(docs)
if err != nil {
return err
}
// upsert artifacts
query := bson.M{
"_id": bson.M{
"$in": ids,
},
}
fr := svc.col.Find(query, nil)
list, err := NewListBinder(svc.id, fr).Bind()
for _, doc := range list.GetModels() {
// upsert artifact when performing model delegate save
if err := delegate.NewModelDelegate(doc, u).Save(); err != nil {
return err
}
}
return nil
}
func (svc *BaseService) _update(query bson.M, update interface{}, args ...interface{}) (err error) {
// ids of query
var ids []primitive.ObjectID
list, err := NewListBinder(svc.id, svc.find(query, nil)).Bind()
if err != nil {
return err
}
for _, doc := range list.GetModels() {
ids = append(ids, doc.GetId())
}
// update model objects
if err := svc.col.Update(query, update); err != nil {
return err
}
// update artifacts
u := svc._getUserFromArgs(args...)
return mongo.GetMongoCol(interfaces.ModelColNameArtifact).Update(query, svc._getUpdateArtifactUpdate(u))
}
func (svc *BaseService) _updateById(id primitive.ObjectID, update interface{}, args ...interface{}) (err error) {
// update model object
if err := svc.col.UpdateId(id, update); err != nil {
return err
}
// update artifact
u := svc._getUserFromArgs(args...)
return mongo.GetMongoCol(interfaces.ModelColNameArtifact).UpdateId(id, svc._getUpdateArtifactUpdate(u))
}
func (svc *BaseService) _getUpdateBsonM(update interface{}, fields []string) (res bson.M, err error) {
switch update.(type) {
case interfaces.Model:
// convert to bson.M
var updateBsonM bson.M
bytes, err := json.Marshal(&update)
if err != nil {
return nil, err
}
if err := json.Unmarshal(bytes, &updateBsonM); err != nil {
return nil, err
}
return svc._getUpdateBsonM(updateBsonM, fields)
case bson.M:
// convert to bson.M
updateBsonM := update.(bson.M)
// filter fields if not nil
if fields != nil {
// fields map
fieldsMap := map[string]bool{}
for _, f := range fields {
fieldsMap[f] = true
}
// remove unselected fields
for k := range updateBsonM {
if _, ok := fieldsMap[k]; !ok {
delete(updateBsonM, k)
}
}
}
// normalize update bson.M
if !svc._containsDollar(updateBsonM) {
if _, ok := updateBsonM["$set"]; !ok {
updateBsonM = bson.M{
"$set": updateBsonM,
}
}
}
return updateBsonM, nil
}
v := reflect.ValueOf(update)
switch v.Kind() {
case reflect.Struct:
if v.CanAddr() {
update = v.Addr().Interface()
return svc._getUpdateBsonM(update, fields)
}
return nil, errors.ErrorModelInvalidType
default:
return nil, errors.ErrorModelInvalidType
}
}
func (svc *BaseService) _getUpdateArtifactUpdate(u interfaces.User) (res bson.M) {
var uid primitive.ObjectID
if u != nil {
uid = u.GetId()
}
return bson.M{
"$set": bson.M{
"_sys.update_ts": time.Now(),
"_sys.update_uid": uid,
},
}
}
func (svc *BaseService) _getUserFromArgs(args ...interface{}) (u interfaces.User) {
return utils.GetUserFromArgs(args...)
}
func (svc *BaseService) _containsDollar(updateBsonM bson.M) (ok bool) {
for k := range updateBsonM {
if strings.HasPrefix(k, "$") {
return true
}
}
return false
}
func NewBaseService(id interfaces.ModelId, opts ...BaseServiceOption) (svc2 interfaces.ModelBaseService) {
// service
svc := &BaseService{
id: id,
}
// apply options
for _, opt := range opts {
opt(svc)
}
// get collection name if not set
if svc.GetCol() == nil {
colName := models2.GetModelColName(id)
svc.SetCol(mongo.GetMongoCol(colName))
}
return svc
}
var store = sync.Map{}
func GetBaseService(id interfaces.ModelId) (svc interfaces.ModelBaseService) {
res, ok := store.Load(id)
if ok {
svc, ok = res.(interfaces.ModelBaseService)
if ok {
return svc
}
}
svc = NewBaseService(id)
store.Store(id, svc)
return svc
}
func GetBaseServiceByColName(id interfaces.ModelId, colName string) (svc interfaces.ModelBaseService) {
res, ok := store.Load(colName)
if ok {
svc, ok = res.(interfaces.ModelBaseService)
if ok {
return svc
}
}
col := mongo.GetMongoCol(colName)
svc = NewBaseService(id, WithBaseServiceCol(col))
store.Store(colName, svc)
return svc
}