Files
crawlab/core/grpc/server/task_service_server.go
Marvin Zhang 8c2c23d9b6 feat: Update gRPC service definitions and implement CheckProcess method
- Downgraded protoc-gen-go-grpc and protoc versions for compatibility.
- Added CheckProcess method to TaskService with corresponding request and response types.
- Updated Subscribe and Connect methods to use new generic client stream types.
- Refactored server and client implementations for Subscribe and Connect methods.
- Ensured backward compatibility by maintaining existing method signatures where applicable.
- Added necessary handler for CheckProcess in the service descriptor.
2025-09-17 10:37:03 +08:00

601 lines
17 KiB
Go

package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
mongo3 "github.com/crawlab-team/crawlab/core/mongo"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
nodeconfig "github.com/crawlab-team/crawlab/core/node/config"
"github.com/crawlab-team/crawlab/core/notification"
"github.com/crawlab-team/crawlab/core/task/stats"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/grpc"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
mongo2 "go.mongodb.org/mongo-driver/mongo"
)
var taskServiceMutex = sync.Mutex{}
type TaskServiceServer struct {
grpc.UnimplementedTaskServiceServer
// dependencies
cfgSvc interfaces.NodeConfigService
statsSvc *stats.Service
// internals
subs map[primitive.ObjectID]grpc.TaskService_SubscribeServer
// cleanup mechanism
cleanupCtx context.Context
cleanupCancel context.CancelFunc
interfaces.Logger
}
func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, stream grpc.TaskService_SubscribeServer) (err error) {
// task id
taskId, err := primitive.ObjectIDFromHex(req.TaskId)
if err != nil {
return errors.New("invalid task id")
}
// validate stream
if stream == nil {
return errors.New("invalid stream")
}
svr.Infof("task stream opened: %s", taskId.Hex())
// Create a context based on client stream
ctx := stream.Context()
// add stream and track cancellation function
taskServiceMutex.Lock()
svr.subs[taskId] = stream
taskServiceMutex.Unlock()
// ensure cleanup on exit
defer func() {
taskServiceMutex.Lock()
delete(svr.subs, taskId)
taskServiceMutex.Unlock()
svr.Infof("task stream closed: %s", taskId.Hex())
}()
// send periodic heartbeat to detect client disconnection and check for task completion
heartbeatTicker := time.NewTicker(10 * time.Second) // More frequent for faster completion detection
defer heartbeatTicker.Stop()
for {
select {
case <-ctx.Done():
// Stream context cancelled normally (client disconnected or task finished)
svr.Debugf("task stream context done: %s", taskId.Hex())
return ctx.Err()
case <-heartbeatTicker.C:
// Check if task has finished and close stream if so
if svr.isTaskFinished(taskId) {
svr.Infof("task[%s] finished, closing stream", taskId.Hex())
return nil
}
// Check if the context is still valid
select {
case <-ctx.Done():
svr.Debugf("task stream context cancelled during heartbeat check: %s", taskId.Hex())
return ctx.Err()
default:
// Context is still valid, continue
svr.Debugf("task stream heartbeat check passed: %s", taskId.Hex())
}
}
}
}
// Connect to task stream when a task runner in a node starts
// Connect handles the bidirectional streaming connection from task runners in nodes.
// It receives messages containing task data and logs, processes them, and handles any errors.
func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err error) {
// spider id and task id to track which spider/task this connection belongs to
var spiderId primitive.ObjectID
var taskId primitive.ObjectID
// Add timeout protection for the entire connection
ctx := stream.Context()
// Log connection start
svr.Debugf("task connect stream started")
defer func() {
if taskId != primitive.NilObjectID {
svr.Debugf("task connect stream ended for task: %s", taskId.Hex())
} else {
svr.Debugf("task connect stream ended")
}
}()
// continuously receive messages from the stream
for {
// Check context cancellation before each receive
select {
case <-ctx.Done():
svr.Debugf("task connect stream context cancelled")
return ctx.Err()
default:
}
// receive next message from stream with timeout
msg, err := stream.Recv()
if err == io.EOF {
// stream has ended normally
svr.Debugf("task connect stream ended normally (EOF)")
return nil
}
if err != nil {
// handle graceful context cancellation
if strings.HasSuffix(err.Error(), "context canceled") ||
strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "transport is closing") {
svr.Debugf("task connect stream cancelled gracefully: %v", err)
return nil
}
// log other stream receive errors
svr.Errorf("error receiving stream message: %v", err)
// Return error instead of continuing to prevent infinite error loops
return err
}
// validate and parse the task ID from the message if not already set
if taskId.IsZero() {
taskId, err = primitive.ObjectIDFromHex(msg.TaskId)
if err != nil {
svr.Errorf("invalid task id: %s", msg.TaskId)
continue
}
svr.Debugf("task connect stream set task id: %s", taskId.Hex())
}
// get spider id if not already set
// this only needs to be done once per connection
if spiderId.IsZero() {
t, err := service.NewModelService[models.Task]().GetById(taskId)
if err != nil {
svr.Errorf("error getting spider[%s]: %v", taskId.Hex(), err)
continue
}
spiderId = t.SpiderId
}
// handle different message types based on the code
switch msg.Code {
case grpc.TaskServiceConnectCode_INSERT_DATA:
// handle scraped data insertion
err = svr.handleInsertData(taskId, spiderId, msg)
case grpc.TaskServiceConnectCode_INSERT_LOGS:
// handle task log insertion
err = svr.handleInsertLogs(taskId, msg)
case grpc.TaskServiceConnectCode_PING:
// handle connection health check ping - no action needed, just acknowledge
svr.Debugf("received ping from task[%s]", taskId.Hex())
err = nil
default:
// invalid message code received
svr.Errorf("invalid stream message code: %d", msg.Code)
continue
}
if err != nil {
// log any errors from handlers
svr.Errorf("grpc error[%d]: %v", msg.Code, err)
}
}
}
// FetchTask tasks to be executed by a task handler
func (svr TaskServiceServer) FetchTask(ctx context.Context, request *grpc.TaskServiceFetchTaskRequest) (response *grpc.TaskServiceFetchTaskResponse, err error) {
nodeKey := request.GetNodeKey()
if nodeKey == "" {
err = fmt.Errorf("invalid node key")
svr.Errorf("error fetching task: %v", err)
return nil, err
}
n, err := service.NewModelService[models.Node]().GetOne(bson.M{"key": nodeKey}, nil)
if err != nil {
svr.Errorf("error getting node[%s]: %v", nodeKey, err)
return nil, err
}
var tid primitive.ObjectID
opts := &mongo3.FindOptions{
Sort: bson.D{
{Key: "priority", Value: 1},
{Key: "_id", Value: 1},
},
Limit: 1,
}
if err := mongo3.RunTransactionWithContext(ctx, func(sc mongo2.SessionContext) (err error) {
// fetch task for the given node
t, err := service.NewModelService[models.Task]().GetOne(bson.M{
"node_id": n.Id,
"status": constants.TaskStatusPending,
}, opts)
if err == nil {
tid = t.Id
t.Status = constants.TaskStatusAssigned
return svr.saveTask(t)
} else if !errors.Is(err, mongo2.ErrNoDocuments) {
svr.Errorf("error fetching task for node[%s]: %v", nodeKey, err)
return err
}
// fetch task for any node
t, err = service.NewModelService[models.Task]().GetOne(bson.M{
"node_id": primitive.NilObjectID,
"status": constants.TaskStatusPending,
}, opts)
if err == nil {
tid = t.Id
t.NodeId = n.Id
t.Status = constants.TaskStatusAssigned
return svr.saveTask(t)
} else if !errors.Is(err, mongo2.ErrNoDocuments) {
svr.Errorf("error fetching task for any node: %v", err)
return err
}
// no task found
return nil
}); err != nil {
return nil, err
}
return &grpc.TaskServiceFetchTaskResponse{TaskId: tid.Hex()}, nil
}
func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.TaskServiceSendNotificationRequest) (response *grpc.TaskServiceSendNotificationResponse, err error) {
if !utils.IsPro() {
return &grpc.TaskServiceSendNotificationResponse{
Code: grpc.TaskServiceSendNotificationResponseCode_NOTIFICATION_DISABLED,
Message: "Notification service is disabled (Pro version required)",
}, nil
}
// task id
taskId, err := primitive.ObjectIDFromHex(request.TaskId)
if err != nil {
svr.Errorf("invalid task id: %s", request.TaskId)
return nil, err
}
// arguments
var args []any
// task
task, err := service.NewModelService[models.Task]().GetById(taskId)
if err != nil {
svr.Errorf("error getting task[%s]: %v", request.TaskId, err)
return nil, err
}
args = append(args, task)
// task stat
taskStat, err := service.NewModelService[models.TaskStat]().GetById(task.Id)
if err != nil {
svr.Errorf("error getting task stat for task[%s]: %v", request.TaskId, err)
return nil, err
}
args = append(args, taskStat)
// spider
spider, err := service.NewModelService[models.Spider]().GetById(task.SpiderId)
if err != nil {
svr.Errorf("error getting spider[%s]: %v", task.SpiderId.Hex(), err)
return nil, err
}
args = append(args, spider)
// node
node, err := service.NewModelService[models.Node]().GetById(task.NodeId)
if err != nil {
svr.Errorf("error getting node[%s]: %v", task.NodeId.Hex(), err)
return nil, err
}
args = append(args, node)
// schedule
var schedule *models.Schedule
if !task.ScheduleId.IsZero() {
schedule, err = service.NewModelService[models.Schedule]().GetById(task.ScheduleId)
if err != nil {
svr.Errorf("error getting schedule[%s]: %v", task.ScheduleId.Hex(), err)
return nil, err
}
args = append(args, schedule)
}
// settings
settings, err := service.NewModelService[models.NotificationSetting]().GetMany(bson.M{
"enabled": true,
"trigger": bson.M{
"$regex": constants.NotificationTriggerPatternTask,
},
}, nil)
if err != nil {
svr.Errorf("error getting notification settings: %v", err)
return nil, err
}
// notification service
svc := notification.GetNotificationService()
for _, s := range settings {
// compatible with old settings
trigger := s.Trigger
if trigger == "" {
trigger = s.TaskTrigger
}
// send notification
switch trigger {
case constants.NotificationTriggerTaskFinish:
if task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning {
go svc.Send(&s, args...)
}
case constants.NotificationTriggerTaskError:
if task.Status == constants.TaskStatusError || task.Status == constants.TaskStatusAbnormal {
go svc.Send(&s, args...)
}
case constants.NotificationTriggerTaskEmptyResults:
if task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning {
if taskStat.ResultCount == 0 {
go svc.Send(&s, args...)
}
}
}
}
return &grpc.TaskServiceSendNotificationResponse{
Code: grpc.TaskServiceSendNotificationResponseCode_NOTIFICATION_SUCCESS,
Message: "Notification sent successfully",
}, nil
}
func (svr TaskServiceServer) CheckProcess(_ context.Context, request *grpc.TaskServiceCheckProcessRequest) (response *grpc.TaskServiceCheckProcessResponse, err error) {
// Validate request
_, err = primitive.ObjectIDFromHex(request.TaskId)
if err != nil {
svr.Errorf("invalid task id: %s", request.TaskId)
return &grpc.TaskServiceCheckProcessResponse{
TaskId: request.TaskId,
Pid: request.Pid,
Status: grpc.ProcessStatus_PROCESS_UNKNOWN,
ErrorMessage: "invalid task id",
}, nil
}
pid := int(request.Pid)
if pid <= 0 {
return &grpc.TaskServiceCheckProcessResponse{
TaskId: request.TaskId,
Pid: request.Pid,
Status: grpc.ProcessStatus_PROCESS_NOT_FOUND,
ErrorMessage: "invalid process id",
}, nil
}
// Check if process exists
processExists := utils.ProcessIdExists(pid)
if !processExists {
return &grpc.TaskServiceCheckProcessResponse{
TaskId: request.TaskId,
Pid: request.Pid,
Status: grpc.ProcessStatus_PROCESS_NOT_FOUND,
ExitCode: -1,
}, nil
}
// Get process details using gopsutil
processStatus, exitCode, errMsg := svr.getProcessDetails(pid)
return &grpc.TaskServiceCheckProcessResponse{
TaskId: request.TaskId,
Pid: request.Pid,
Status: processStatus,
ExitCode: int32(exitCode),
ErrorMessage: errMsg,
}, nil
}
// getProcessDetails queries the process details using gopsutil
func (svr TaskServiceServer) getProcessDetails(pid int) (status grpc.ProcessStatus, exitCode int, errorMessage string) {
// Import the gopsutil process package
processLib, err := utils.GetProcesses()
if err != nil {
return grpc.ProcessStatus_PROCESS_UNKNOWN, -1, fmt.Sprintf("failed to get processes: %v", err)
}
// Find the specific process
for _, p := range processLib {
if int(p.Pid) == pid {
// Get process status
processStatus, err := p.Status()
if err != nil {
return grpc.ProcessStatus_PROCESS_UNKNOWN, -1, fmt.Sprintf("failed to get process status: %v", err)
}
// Map process status to our enum
switch strings.ToLower(processStatus) {
case "running", "sleep", "disk-sleep":
return grpc.ProcessStatus_PROCESS_RUNNING, 0, ""
case "zombie":
return grpc.ProcessStatus_PROCESS_ZOMBIE, 0, "process is zombie"
case "stopped", "tracing-stop":
return grpc.ProcessStatus_PROCESS_FINISHED, 0, "process stopped"
default:
return grpc.ProcessStatus_PROCESS_UNKNOWN, -1, fmt.Sprintf("unknown process status: %s", processStatus)
}
}
}
// Process not found
return grpc.ProcessStatus_PROCESS_NOT_FOUND, -1, "process not found"
}
func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeServer, ok bool) {
taskServiceMutex.Lock()
defer taskServiceMutex.Unlock()
stream, ok = svr.subs[taskId]
return stream, ok
}
// cleanupStaleStreams periodically checks for and removes stale streams
func (svr TaskServiceServer) cleanupStaleStreams() {
ticker := time.NewTicker(10 * time.Minute) // Check every 10 minutes
defer ticker.Stop()
for {
select {
case <-svr.cleanupCtx.Done():
svr.Debugf("stream cleanup routine shutting down")
return
case <-ticker.C:
svr.performStreamCleanup()
}
}
}
// performStreamCleanup checks each stream and removes those that are no longer active
func (svr TaskServiceServer) performStreamCleanup() {
taskServiceMutex.Lock()
defer taskServiceMutex.Unlock()
var staleTaskIds []primitive.ObjectID
for taskId, stream := range svr.subs {
// Check if stream context is still active
select {
case <-stream.Context().Done():
// Stream is done, mark for removal
staleTaskIds = append(staleTaskIds, taskId)
default:
// Stream is still active, continue
}
}
// Remove stale streams
for _, taskId := range staleTaskIds {
delete(svr.subs, taskId)
svr.Infof("cleaned up stale stream for task: %s", taskId.Hex())
}
if len(staleTaskIds) > 0 {
svr.Infof("cleaned up %d stale streams", len(staleTaskIds))
}
}
func (svr TaskServiceServer) handleInsertData(taskId, spiderId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
var records []map[string]interface{}
err = json.Unmarshal(msg.Data, &records)
if err != nil {
svr.Errorf("error unmarshalling data: %v", err)
return err
}
for i := range records {
records[i][constants.TaskKey] = taskId
records[i][constants.SpiderKey] = spiderId
}
return svr.statsSvc.InsertData(taskId, records...)
}
func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
var logs []string
err = json.Unmarshal(msg.Data, &logs)
if err != nil {
svr.Errorf("error unmarshalling logs: %v", err)
return err
}
return svr.statsSvc.InsertLogs(taskId, logs...)
}
func (svr TaskServiceServer) saveTask(t *models.Task) (err error) {
t.SetUpdated(t.CreatedBy)
return service.NewModelService[models.Task]().ReplaceById(t.Id, *t)
}
func newTaskServiceServer() *TaskServiceServer {
ctx, cancel := context.WithCancel(context.Background())
server := &TaskServiceServer{
cfgSvc: nodeconfig.GetNodeConfigService(),
subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer),
statsSvc: stats.GetTaskStatsService(),
cleanupCtx: ctx,
cleanupCancel: cancel,
Logger: utils.NewLogger("GrpcTaskServiceServer"),
}
// Start the cleanup routine
go server.cleanupStaleStreams()
return server
}
// Stop gracefully shuts down the task service server
func (svr TaskServiceServer) Stop() error {
svr.Infof("stopping task service server...")
// Cancel cleanup routine
if svr.cleanupCancel != nil {
svr.cleanupCancel()
}
// Clean up all remaining streams
taskServiceMutex.Lock()
streamCount := len(svr.subs)
for taskId := range svr.subs {
delete(svr.subs, taskId)
}
taskServiceMutex.Unlock()
if streamCount > 0 {
svr.Infof("cleaned up %d remaining streams on shutdown", streamCount)
}
svr.Infof("task service server stopped")
return nil
}
// isTaskFinished checks if a task has completed execution
func (svr TaskServiceServer) isTaskFinished(taskId primitive.ObjectID) bool {
task, err := service.NewModelService[models.Task]().GetById(taskId)
if err != nil {
svr.Debugf("error checking task[%s] status: %v", taskId.Hex(), err)
return false
}
// Task is finished if it's not in pending or running state
return task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning
}
var _taskServiceServer *TaskServiceServer
var _taskServiceServerOnce sync.Once
func GetTaskServiceServer() *TaskServiceServer {
_taskServiceServerOnce.Do(func() {
_taskServiceServer = newTaskServiceServer()
})
return _taskServiceServer
}