diff --git a/core/apps/api.go b/core/apps/api.go index 047ca535..39c13c54 100644 --- a/core/apps/api.go +++ b/core/apps/api.go @@ -21,17 +21,26 @@ func init() { type Api struct { // internals - app *gin.Engine - ln net.Listener - srv *http.Server + app *gin.Engine + ln net.Listener + srv *http.Server + initialized bool } func (app *Api) Init() { + // skip if already initialized + if app.initialized { + return + } + // initialize middlewares _ = app.initModuleWithApp("middlewares", middlewares.InitMiddlewares) // initialize routes _ = app.initModuleWithApp("routes", controllers.InitRoutes) + + // set initialized + app.initialized = true } func (app *Api) Start() { @@ -90,9 +99,11 @@ func (app *Api) initModuleWithApp(name string, fn func(app *gin.Engine) error) ( } func newApi() *Api { - return &Api{ + api := &Api{ app: gin.New(), } + api.Init() + return api } var api *Api diff --git a/core/constants/ipc.go b/core/constants/ipc.go new file mode 100644 index 00000000..1b88b72a --- /dev/null +++ b/core/constants/ipc.go @@ -0,0 +1,6 @@ +package constants + +const ( + IPCMessageData = "data" // IPCMessageData is the message type identifier for data messages + IPCMessageLog = "log" // IPCMessageLog is the message type identifier for log messages +) diff --git a/core/entity/ipc.go b/core/entity/ipc.go new file mode 100644 index 00000000..6446891e --- /dev/null +++ b/core/entity/ipc.go @@ -0,0 +1,8 @@ +package entity + +// IPCMessage defines the structure for messages exchanged between parent and child processes +type IPCMessage struct { + Type string `json:"type"` // message type identifier + Payload interface{} `json:"payload"` // message content + IPC bool `json:"ipc"` // Add this field to explicitly mark IPC messages +} diff --git a/core/grpc/server/task_service_server.go b/core/grpc/server/task_service_server.go index e9f69be6..2dec2bf5 100644 --- a/core/grpc/server/task_service_server.go +++ b/core/grpc/server/task_service_server.go @@ -4,6 +4,10 @@ import ( "context" "encoding/json" "errors" + "io" + "strings" + "sync" + "github.com/apex/log" "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/interfaces" @@ -19,9 +23,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" mongo2 "go.mongodb.org/mongo-driver/mongo" - "io" - "strings" - "sync" ) var taskServiceMutex = sync.Mutex{} @@ -67,38 +68,66 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st } // Connect to task stream when a task runner in a node starts +// Connect handles the bidirectional streaming connection from task runners in nodes. +// It receives messages containing task data and logs, processes them, and handles any errors. func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err error) { + // spider id and task id to track which spider/task this connection belongs to + var spiderId primitive.ObjectID + var taskId primitive.ObjectID + + // continuously receive messages from the stream for { + // receive next message from stream msg, err := stream.Recv() if err == io.EOF { + // stream has ended normally return nil } if err != nil { + // handle graceful context cancellation if strings.HasSuffix(err.Error(), "context canceled") { return nil } + // log other stream receive errors and continue log.Errorf("error receiving stream message: %v", err) continue } - // validate task id - taskId, err := primitive.ObjectIDFromHex(msg.TaskId) - if err != nil { - log.Errorf("invalid task id: %s", msg.TaskId) - continue + // validate and parse the task ID from the message if not already set + if taskId.IsZero() { + taskId, err = primitive.ObjectIDFromHex(msg.TaskId) + if err != nil { + log.Errorf("invalid task id: %s", msg.TaskId) + continue + } } + // get spider id if not already set + // this only needs to be done once per connection + if spiderId.IsZero() { + t, err := service.NewModelService[models.Task]().GetById(taskId) + if err != nil { + log.Errorf("error getting spider[%s]: %v", taskId.Hex(), err) + continue + } + spiderId = t.SpiderId + } + + // handle different message types based on the code switch msg.Code { case grpc.TaskServiceConnectCode_INSERT_DATA: - err = svr.handleInsertData(taskId, msg) + // handle scraped data insertion + err = svr.handleInsertData(taskId, spiderId, msg) case grpc.TaskServiceConnectCode_INSERT_LOGS: + // handle task log insertion err = svr.handleInsertLogs(taskId, msg) default: - err = errors.New("invalid stream message code") + // invalid message code received log.Errorf("invalid stream message code: %d", msg.Code) continue } if err != nil { + // log any errors from handlers log.Errorf("grpc error[%d]: %v", msg.Code, err) } } @@ -170,7 +199,7 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T taskId, err := primitive.ObjectIDFromHex(request.TaskId) if err != nil { log.Errorf("invalid task id: %s", request.TaskId) - return nil, trace.TraceError(err) + return nil, err } // arguments @@ -179,31 +208,32 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T // task task, err := service.NewModelService[models.Task]().GetById(taskId) if err != nil { - log.Errorf("task not found: %s", request.TaskId) - return nil, trace.TraceError(err) + log.Errorf("error getting task[%s]: %v", request.TaskId, err) + return nil, err } args = append(args, task) // task stat taskStat, err := service.NewModelService[models.TaskStat]().GetById(task.Id) if err != nil { - log.Errorf("task stat not found for task: %s", request.TaskId) - return nil, trace.TraceError(err) + log.Errorf("error getting task stat for task[%s]: %v", request.TaskId, err) + return nil, err } args = append(args, taskStat) // spider spider, err := service.NewModelService[models.Spider]().GetById(task.SpiderId) if err != nil { - log.Errorf("spider not found for task: %s", request.TaskId) - return nil, trace.TraceError(err) + log.Errorf("error getting spider[%s]: %v", task.SpiderId.Hex(), err) + return nil, err } args = append(args, spider) // node node, err := service.NewModelService[models.Node]().GetById(task.NodeId) if err != nil { - return nil, trace.TraceError(err) + log.Errorf("error getting node[%s]: %v", task.NodeId.Hex(), err) + return nil, err } args = append(args, node) @@ -212,8 +242,8 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T if !task.ScheduleId.IsZero() { schedule, err = service.NewModelService[models.Schedule]().GetById(task.ScheduleId) if err != nil { - log.Errorf("schedule not found for task: %s", request.TaskId) - return nil, trace.TraceError(err) + log.Errorf("error getting schedule[%s]: %v", task.ScheduleId.Hex(), err) + return nil, err } args = append(args, schedule) } @@ -226,7 +256,8 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T }, }, nil) if err != nil { - return nil, trace.TraceError(err) + log.Errorf("error getting notification settings: %v", err) + return nil, err } // notification service @@ -268,11 +299,16 @@ func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stre return stream, ok } -func (svr TaskServiceServer) handleInsertData(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) { +func (svr TaskServiceServer) handleInsertData(taskId, spiderId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) { var records []map[string]interface{} err = json.Unmarshal(msg.Data, &records) if err != nil { - return trace.TraceError(err) + log.Errorf("error unmarshalling data: %v", err) + return err + } + for i := range records { + records[i][constants.TaskKey] = taskId + records[i][constants.SpiderKey] = spiderId } return svr.statsSvc.InsertData(taskId, records...) } @@ -281,7 +317,8 @@ func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *gr var logs []string err = json.Unmarshal(msg.Data, &logs) if err != nil { - return trace.TraceError(err) + log.Errorf("error unmarshalling logs: %v", err) + return err } return svr.statsSvc.InsertLogs(taskId, logs...) } diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 474a2eb3..5e22f468 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -62,10 +62,10 @@ type Runner struct { logBatchSize int // number of log lines to batch before sending // IPC (Inter-Process Communication) - stdinPipe io.WriteCloser // pipe for writing to child process - stdoutPipe io.ReadCloser // pipe for reading from child process - ipcChan chan IPCMessage // channel for sending IPC messages - ipcHandler func(IPCMessage) // callback for handling received IPC messages + stdinPipe io.WriteCloser // pipe for writing to child process + stdoutPipe io.ReadCloser // pipe for reading from child process + ipcChan chan entity.IPCMessage // channel for sending IPC messages + ipcHandler func(entity.IPCMessage) // callback for handling received IPC messages // goroutine management ctx context.Context // context for controlling goroutine lifecycle @@ -74,18 +74,6 @@ type Runner struct { wg sync.WaitGroup // wait group for goroutine synchronization } -const ( - IPCMessageData = "data" // IPCMessageData is the message type identifier for data messages - IPCMessageLog = "log" // IPCMessageLog is the message type identifier for log messages -) - -// IPCMessage defines the structure for messages exchanged between parent and child processes -type IPCMessage struct { - Type string `json:"type"` // message type identifier - Payload interface{} `json:"payload"` // message content - IPC bool `json:"ipc"` // Add this field to explicitly mark IPC messages -} - // Init initializes the task runner by updating the task status and establishing gRPC connections func (r *Runner) Init() (err error) { // update task @@ -162,10 +150,11 @@ func (r *Runner) Run() (err error) { // Ensure cleanup when Run() exits defer func() { - r.cancel() // Cancel context to stop all goroutines - r.wg.Wait() // Wait for all goroutines to finish - close(r.done) // Signal that everything is done - close(r.ipcChan) // Close IPC channel + _ = r.conn.CloseSend() // Close gRPC connection + r.cancel() // Cancel context to stop all goroutines + r.wg.Wait() // Wait for all goroutines to finish + close(r.done) // Signal that everything is done + close(r.ipcChan) // Close IPC channel }() // wait for process to finish @@ -258,7 +247,7 @@ func (r *Runner) configureCmd() (err error) { } // Initialize IPC channel - r.ipcChan = make(chan IPCMessage) + r.ipcChan = make(chan entity.IPCMessage) return nil } @@ -782,7 +771,7 @@ func (r *Runner) handleIPC() { // msgType: type of message being sent // payload: data to be sent to the child process func (r *Runner) SendToChild(msgType string, payload interface{}) { - r.ipcChan <- IPCMessage{ + r.ipcChan <- entity.IPCMessage{ Type: msgType, Payload: payload, IPC: true, // Explicitly mark as IPC message @@ -790,7 +779,7 @@ func (r *Runner) SendToChild(msgType string, payload interface{}) { } // SetIPCHandler sets the handler for incoming IPC messages -func (r *Runner) SetIPCHandler(handler func(IPCMessage)) { +func (r *Runner) SetIPCHandler(handler func(entity.IPCMessage)) { r.ipcHandler = handler } @@ -811,7 +800,7 @@ func (r *Runner) startIPCReader() { } line := scanner.Text() - var ipcMsg IPCMessage + var ipcMsg entity.IPCMessage err := json.Unmarshal([]byte(line), &ipcMsg) if err == nil && ipcMsg.IPC { // Only handle as IPC if it's valid JSON AND has IPC flag set @@ -819,7 +808,7 @@ func (r *Runner) startIPCReader() { r.ipcHandler(ipcMsg) } else { // Default handler (insert data) - if ipcMsg.Type == "" || ipcMsg.Type == IPCMessageData { + if ipcMsg.Type == "" || ipcMsg.Type == constants.IPCMessageData { r.handleIPCInsertDataMessage(ipcMsg) } else { log.Warnf("no IPC handler set for message: %v", ipcMsg) @@ -834,7 +823,7 @@ func (r *Runner) startIPCReader() { } // handleIPCInsertDataMessage converts the IPC message payload to JSON and sends it to the master node -func (r *Runner) handleIPCInsertDataMessage(ipcMsg IPCMessage) { +func (r *Runner) handleIPCInsertDataMessage(ipcMsg entity.IPCMessage) { // Validate message if ipcMsg.Payload == nil { log.Errorf("empty payload in IPC message") diff --git a/core/task/stats/service.go b/core/task/stats/service.go index 6130201f..75d64741 100644 --- a/core/task/stats/service.go +++ b/core/task/stats/service.go @@ -63,11 +63,11 @@ func (svc *Service) InsertData(taskId primitive.ObjectID, records ...map[string] count++ } } else { - var records2 []interface{} + var recordsToInsert []interface{} for _, record := range records { - records2 = append(records2, svc.normalizeRecord(item, record)) + recordsToInsert = append(recordsToInsert, svc.normalizeRecord(item, record)) } - _, err = mongo.GetMongoCol(tableName).InsertMany(records2) + _, err = mongo.GetMongoCol(tableName).InsertMany(recordsToInsert) if err != nil { log2.Errorf("failed to insert data: %v", err) return err