mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-22 17:31:03 +01:00
239 lines
6.3 KiB
Go
239 lines
6.3 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"github.com/apex/log"
|
|
"github.com/crawlab-team/crawlab/core/constants"
|
|
"github.com/crawlab-team/crawlab/core/container"
|
|
"github.com/crawlab-team/crawlab/core/entity"
|
|
"github.com/crawlab-team/crawlab/core/errors"
|
|
"github.com/crawlab-team/crawlab/core/interfaces"
|
|
"github.com/crawlab-team/crawlab/core/models/delegate"
|
|
"github.com/crawlab-team/crawlab/core/models/models"
|
|
"github.com/crawlab-team/crawlab/core/models/service"
|
|
"github.com/crawlab-team/crawlab/core/notification"
|
|
"github.com/crawlab-team/crawlab/core/utils"
|
|
"github.com/crawlab-team/crawlab/db/mongo"
|
|
grpc "github.com/crawlab-team/crawlab/grpc"
|
|
"github.com/crawlab-team/crawlab/trace"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
mongo2 "go.mongodb.org/mongo-driver/mongo"
|
|
"io"
|
|
"strings"
|
|
)
|
|
|
|
type TaskServer struct {
|
|
grpc.UnimplementedTaskServiceServer
|
|
|
|
// dependencies
|
|
modelSvc service.ModelService
|
|
cfgSvc interfaces.NodeConfigService
|
|
statsSvc interfaces.TaskStatsService
|
|
|
|
// internals
|
|
server interfaces.GrpcServer
|
|
}
|
|
|
|
// Subscribe to task stream when a task runner in a node starts
|
|
func (svr TaskServer) Subscribe(stream grpc.TaskService_SubscribeServer) (err error) {
|
|
for {
|
|
msg, err := stream.Recv()
|
|
utils.LogDebug(msg.String())
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
if strings.HasSuffix(err.Error(), "context canceled") {
|
|
return nil
|
|
}
|
|
trace.PrintError(err)
|
|
continue
|
|
}
|
|
switch msg.Code {
|
|
case grpc.StreamMessageCode_INSERT_DATA:
|
|
err = svr.handleInsertData(msg)
|
|
case grpc.StreamMessageCode_INSERT_LOGS:
|
|
err = svr.handleInsertLogs(msg)
|
|
default:
|
|
err = errors.ErrorGrpcInvalidCode
|
|
log.Errorf("invalid stream message code: %d", msg.Code)
|
|
continue
|
|
}
|
|
if err != nil {
|
|
log.Errorf("grpc error[%d]: %v", msg.Code, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fetch tasks to be executed by a task handler
|
|
func (svr TaskServer) Fetch(ctx context.Context, request *grpc.Request) (response *grpc.Response, err error) {
|
|
nodeKey := request.GetNodeKey()
|
|
if nodeKey == "" {
|
|
return nil, trace.TraceError(errors.ErrorGrpcInvalidNodeKey)
|
|
}
|
|
n, err := svr.modelSvc.GetNodeByKey(nodeKey, nil)
|
|
if err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
var tid primitive.ObjectID
|
|
opts := &mongo.FindOptions{
|
|
Sort: bson.D{
|
|
{"p", 1},
|
|
{"_id", 1},
|
|
},
|
|
Limit: 1,
|
|
}
|
|
if err := mongo.RunTransactionWithContext(ctx, func(sc mongo2.SessionContext) (err error) {
|
|
// get task queue item assigned to this node
|
|
tid, err = svr.getTaskQueueItemIdAndDequeue(bson.M{"nid": n.Id}, opts, n.Id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !tid.IsZero() {
|
|
return nil
|
|
}
|
|
|
|
// get task queue item assigned to any node (random mode)
|
|
tid, err = svr.getTaskQueueItemIdAndDequeue(bson.M{"nid": nil}, opts, n.Id)
|
|
if !tid.IsZero() {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
return HandleSuccessWithData(tid)
|
|
}
|
|
|
|
func (svr TaskServer) SendNotification(ctx context.Context, request *grpc.Request) (response *grpc.Response, err error) {
|
|
svc := notification.GetService()
|
|
var t = new(models.Task)
|
|
if err := json.Unmarshal(request.Data, t); err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
t, err = svr.modelSvc.GetTaskById(t.Id)
|
|
if err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
td, err := json.Marshal(t)
|
|
if err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
var e bson.M
|
|
if err := json.Unmarshal(td, &e); err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
ts, err := svr.modelSvc.GetTaskStatById(t.Id)
|
|
if err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
settings, _, err := svc.GetSettingList(bson.M{
|
|
"enabled": true,
|
|
}, nil, nil)
|
|
if err != nil {
|
|
return nil, trace.TraceError(err)
|
|
}
|
|
for _, s := range settings {
|
|
switch s.TaskTrigger {
|
|
case constants.NotificationTriggerTaskFinish:
|
|
if t.Status != constants.TaskStatusPending && t.Status != constants.TaskStatusRunning {
|
|
_ = svc.Send(s, e)
|
|
}
|
|
case constants.NotificationTriggerTaskError:
|
|
if t.Status == constants.TaskStatusError || t.Status == constants.TaskStatusAbnormal {
|
|
_ = svc.Send(s, e)
|
|
}
|
|
case constants.NotificationTriggerTaskEmptyResults:
|
|
if t.Status != constants.TaskStatusPending && t.Status != constants.TaskStatusRunning {
|
|
if ts.ResultCount == 0 {
|
|
_ = svc.Send(s, e)
|
|
}
|
|
}
|
|
case constants.NotificationTriggerTaskNever:
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (svr TaskServer) handleInsertData(msg *grpc.StreamMessage) (err error) {
|
|
data, err := svr.deserialize(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var records []interface{}
|
|
for _, d := range data.Records {
|
|
res, ok := d[constants.TaskKey]
|
|
if ok {
|
|
switch res.(type) {
|
|
case string:
|
|
id, err := primitive.ObjectIDFromHex(res.(string))
|
|
if err == nil {
|
|
d[constants.TaskKey] = id
|
|
}
|
|
}
|
|
}
|
|
records = append(records, d)
|
|
}
|
|
return svr.statsSvc.InsertData(data.TaskId, records...)
|
|
}
|
|
|
|
func (svr TaskServer) handleInsertLogs(msg *grpc.StreamMessage) (err error) {
|
|
data, err := svr.deserialize(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return svr.statsSvc.InsertLogs(data.TaskId, data.Logs...)
|
|
}
|
|
|
|
func (svr TaskServer) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.FindOptions, nid primitive.ObjectID) (tid primitive.ObjectID, err error) {
|
|
var tq models.TaskQueueItem
|
|
if err := mongo.GetMongoCol(interfaces.ModelColNameTaskQueue).Find(query, opts).One(&tq); err != nil {
|
|
if err == mongo2.ErrNoDocuments {
|
|
return tid, nil
|
|
}
|
|
return tid, trace.TraceError(err)
|
|
}
|
|
t, err := svr.modelSvc.GetTaskById(tq.Id)
|
|
if err == nil {
|
|
t.NodeId = nid
|
|
_ = delegate.NewModelDelegate(t).Save()
|
|
}
|
|
_ = delegate.NewModelDelegate(&tq).Delete()
|
|
return tq.Id, nil
|
|
}
|
|
|
|
func (svr TaskServer) deserialize(msg *grpc.StreamMessage) (data entity.StreamMessageTaskData, err error) {
|
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
|
return data, trace.TraceError(err)
|
|
}
|
|
if data.TaskId.IsZero() {
|
|
return data, trace.TraceError(errors.ErrorGrpcInvalidType)
|
|
}
|
|
return data, nil
|
|
}
|
|
|
|
func NewTaskServer() (res *TaskServer, err error) {
|
|
// task server
|
|
svr := &TaskServer{}
|
|
|
|
// dependency injection
|
|
if err := container.GetContainer().Invoke(func(
|
|
modelSvc service.ModelService,
|
|
statsSvc interfaces.TaskStatsService,
|
|
cfgSvc interfaces.NodeConfigService,
|
|
) {
|
|
svr.modelSvc = modelSvc
|
|
svr.statsSvc = statsSvc
|
|
svr.cfgSvc = cfgSvc
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return svr, nil
|
|
}
|