refactor: optimized code

This commit is contained in:
Marvin Zhang
2024-11-24 23:14:26 +08:00
parent 9281f44853
commit 601db5a567
6 changed files with 108 additions and 57 deletions

View File

@@ -4,6 +4,10 @@ import (
"context"
"encoding/json"
"errors"
"io"
"strings"
"sync"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/interfaces"
@@ -19,9 +23,6 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
mongo2 "go.mongodb.org/mongo-driver/mongo"
"io"
"strings"
"sync"
)
var taskServiceMutex = sync.Mutex{}
@@ -67,38 +68,66 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
}
// Connect to task stream when a task runner in a node starts
// Connect handles the bidirectional streaming connection from task runners in nodes.
// It receives messages containing task data and logs, processes them, and handles any errors.
func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err error) {
// spider id and task id to track which spider/task this connection belongs to
var spiderId primitive.ObjectID
var taskId primitive.ObjectID
// continuously receive messages from the stream
for {
// receive next message from stream
msg, err := stream.Recv()
if err == io.EOF {
// stream has ended normally
return nil
}
if err != nil {
// handle graceful context cancellation
if strings.HasSuffix(err.Error(), "context canceled") {
return nil
}
// log other stream receive errors and continue
log.Errorf("error receiving stream message: %v", err)
continue
}
// validate task id
taskId, err := primitive.ObjectIDFromHex(msg.TaskId)
if err != nil {
log.Errorf("invalid task id: %s", msg.TaskId)
continue
// validate and parse the task ID from the message if not already set
if taskId.IsZero() {
taskId, err = primitive.ObjectIDFromHex(msg.TaskId)
if err != nil {
log.Errorf("invalid task id: %s", msg.TaskId)
continue
}
}
// get spider id if not already set
// this only needs to be done once per connection
if spiderId.IsZero() {
t, err := service.NewModelService[models.Task]().GetById(taskId)
if err != nil {
log.Errorf("error getting spider[%s]: %v", taskId.Hex(), err)
continue
}
spiderId = t.SpiderId
}
// handle different message types based on the code
switch msg.Code {
case grpc.TaskServiceConnectCode_INSERT_DATA:
err = svr.handleInsertData(taskId, msg)
// handle scraped data insertion
err = svr.handleInsertData(taskId, spiderId, msg)
case grpc.TaskServiceConnectCode_INSERT_LOGS:
// handle task log insertion
err = svr.handleInsertLogs(taskId, msg)
default:
err = errors.New("invalid stream message code")
// invalid message code received
log.Errorf("invalid stream message code: %d", msg.Code)
continue
}
if err != nil {
// log any errors from handlers
log.Errorf("grpc error[%d]: %v", msg.Code, err)
}
}
@@ -170,7 +199,7 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
taskId, err := primitive.ObjectIDFromHex(request.TaskId)
if err != nil {
log.Errorf("invalid task id: %s", request.TaskId)
return nil, trace.TraceError(err)
return nil, err
}
// arguments
@@ -179,31 +208,32 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
// task
task, err := service.NewModelService[models.Task]().GetById(taskId)
if err != nil {
log.Errorf("task not found: %s", request.TaskId)
return nil, trace.TraceError(err)
log.Errorf("error getting task[%s]: %v", request.TaskId, err)
return nil, err
}
args = append(args, task)
// task stat
taskStat, err := service.NewModelService[models.TaskStat]().GetById(task.Id)
if err != nil {
log.Errorf("task stat not found for task: %s", request.TaskId)
return nil, trace.TraceError(err)
log.Errorf("error getting task stat for task[%s]: %v", request.TaskId, err)
return nil, err
}
args = append(args, taskStat)
// spider
spider, err := service.NewModelService[models.Spider]().GetById(task.SpiderId)
if err != nil {
log.Errorf("spider not found for task: %s", request.TaskId)
return nil, trace.TraceError(err)
log.Errorf("error getting spider[%s]: %v", task.SpiderId.Hex(), err)
return nil, err
}
args = append(args, spider)
// node
node, err := service.NewModelService[models.Node]().GetById(task.NodeId)
if err != nil {
return nil, trace.TraceError(err)
log.Errorf("error getting node[%s]: %v", task.NodeId.Hex(), err)
return nil, err
}
args = append(args, node)
@@ -212,8 +242,8 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
if !task.ScheduleId.IsZero() {
schedule, err = service.NewModelService[models.Schedule]().GetById(task.ScheduleId)
if err != nil {
log.Errorf("schedule not found for task: %s", request.TaskId)
return nil, trace.TraceError(err)
log.Errorf("error getting schedule[%s]: %v", task.ScheduleId.Hex(), err)
return nil, err
}
args = append(args, schedule)
}
@@ -226,7 +256,8 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
},
}, nil)
if err != nil {
return nil, trace.TraceError(err)
log.Errorf("error getting notification settings: %v", err)
return nil, err
}
// notification service
@@ -268,11 +299,16 @@ func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stre
return stream, ok
}
func (svr TaskServiceServer) handleInsertData(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
func (svr TaskServiceServer) handleInsertData(taskId, spiderId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
var records []map[string]interface{}
err = json.Unmarshal(msg.Data, &records)
if err != nil {
return trace.TraceError(err)
log.Errorf("error unmarshalling data: %v", err)
return err
}
for i := range records {
records[i][constants.TaskKey] = taskId
records[i][constants.SpiderKey] = spiderId
}
return svr.statsSvc.InsertData(taskId, records...)
}
@@ -281,7 +317,8 @@ func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *gr
var logs []string
err = json.Unmarshal(msg.Data, &logs)
if err != nil {
return trace.TraceError(err)
log.Errorf("error unmarshalling logs: %v", err)
return err
}
return svr.statsSvc.InsertLogs(taskId, logs...)
}