diff --git a/core/node/service/master_service.go b/core/node/service/master_service.go index dd3caeeb..759730df 100644 --- a/core/node/service/master_service.go +++ b/core/node/service/master_service.go @@ -339,7 +339,7 @@ func newMasterService() *MasterService { systemSvc: system.GetSystemService(), healthSvc: GetHealthService(), nodeMonitoringSvc: NewNodeMonitoringService(cfgSvc), - taskReconciliationSvc: NewTaskReconciliationService(server), + taskReconciliationSvc: NewTaskReconciliationService(server, handler.GetTaskHandlerService()), Logger: utils.NewLogger("MasterService"), } } diff --git a/core/node/service/task_reconciliation_service.go b/core/node/service/task_reconciliation_service.go index fd57490b..e3ab0f76 100644 --- a/core/node/service/task_reconciliation_service.go +++ b/core/node/service/task_reconciliation_service.go @@ -12,6 +12,7 @@ import ( "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/core/models/models" "github.com/crawlab-team/crawlab/core/models/service" + "github.com/crawlab-team/crawlab/core/task/handler" "github.com/crawlab-team/crawlab/core/utils" "github.com/crawlab-team/crawlab/grpc" "go.mongodb.org/mongo-driver/bson" @@ -20,7 +21,8 @@ import ( // TaskReconciliationService handles task status reconciliation for node disconnection scenarios type TaskReconciliationService struct { - server *server.GrpcServer + server *server.GrpcServer + taskHandlerSvc *handler.Service // access to task handlers and their status caches interfaces.Logger } @@ -63,7 +65,36 @@ func (svc *TaskReconciliationService) HandleTasksForOfflineNode(node *models.Nod } } +// triggerWorkerStatusSync triggers synchronization of cached status from worker to database +func (svc *TaskReconciliationService) triggerWorkerStatusSync(task *models.Task) error { + // Check if we have access to task handler service (only on worker nodes) + if svc.taskHandlerSvc == nil { + return fmt.Errorf("task handler service not available - not on worker node") + } + + // Get the task runner for this task + taskRunner := svc.taskHandlerSvc.GetTaskRunner(task.Id) + if taskRunner == nil { + return fmt.Errorf("no active task runner found for task %s", task.Id.Hex()) + } + + // Cast to concrete Runner type to access status cache methods + runner, ok := taskRunner.(*handler.Runner) + if !ok { + return fmt.Errorf("task runner is not of expected type for task %s", task.Id.Hex()) + } + + // Trigger sync of pending status updates + if err := runner.SyncPendingStatusUpdates(); err != nil { + return fmt.Errorf("failed to sync pending status updates: %w", err) + } + + svc.Infof("successfully triggered status sync for task[%s]", task.Id.Hex()) + return nil +} + // HandleNodeReconnection reconciles tasks that were marked as disconnected when the node comes back online +// Now leverages worker-side status cache for more accurate reconciliation func (svc *TaskReconciliationService) HandleNodeReconnection(node *models.Node) { // Find all disconnected tasks on this node query := bson.M{ @@ -86,6 +117,11 @@ func (svc *TaskReconciliationService) HandleNodeReconnection(node *models.Node) // For each disconnected task, try to get its actual status from the worker node for _, task := range disconnectedTasks { + // First, try to trigger status sync from worker cache if we're on the worker node + if err := svc.triggerWorkerStatusSync(&task); err != nil { + svc.Debugf("could not trigger worker status sync for task[%s]: %v", task.Id.Hex(), err) + } + actualStatus, err := svc.GetActualTaskStatusFromWorker(node, &task) if err != nil { svc.Warnf("failed to get actual status for task[%s] from reconnected node[%s]: %v", task.Id.Hex(), node.Key, err) @@ -121,19 +157,57 @@ func (svc *TaskReconciliationService) HandleNodeReconnection(node *models.Node) } // GetActualTaskStatusFromWorker queries the worker node to get the actual status of a task +// Now prioritizes worker-side status cache over heuristics func (svc *TaskReconciliationService) GetActualTaskStatusFromWorker(node *models.Node, task *models.Task) (status string, err error) { - // First, try to get the actual process status from the worker + // First priority: get status from worker-side task runner cache + cachedStatus, err := svc.getStatusFromWorkerCache(task) + if err == nil && cachedStatus != "" { + svc.Debugf("retrieved cached status for task[%s]: %s", task.Id.Hex(), cachedStatus) + return cachedStatus, nil + } + + // Second priority: query process status from worker actualProcessStatus, err := svc.queryProcessStatusFromWorker(node, task) if err != nil { svc.Warnf("failed to query process status from worker node[%s] for task[%s]: %v", node.Key, task.Id.Hex(), err) - // Fall back to heuristic detection - return svc.detectTaskStatusFromHeuristics(task) + // Return error instead of falling back to unreliable heuristics + return "", fmt.Errorf("unable to determine actual task status: %w", err) } // Synchronize task status with actual process status return svc.syncTaskStatusWithProcess(task, actualProcessStatus) } +// getStatusFromWorkerCache retrieves task status from worker-side task runner cache +func (svc *TaskReconciliationService) getStatusFromWorkerCache(task *models.Task) (string, error) { + // Check if we have access to task handler service (only on worker nodes) + if svc.taskHandlerSvc == nil { + return "", fmt.Errorf("task handler service not available - not on worker node") + } + + // Get the task runner for this task + taskRunner := svc.taskHandlerSvc.GetTaskRunner(task.Id) + if taskRunner == nil { + return "", fmt.Errorf("no active task runner found for task %s", task.Id.Hex()) + } + + // Cast to concrete Runner type to access status cache methods + runner, ok := taskRunner.(*handler.Runner) + if !ok { + return "", fmt.Errorf("task runner is not of expected type for task %s", task.Id.Hex()) + } + + // Get cached status from the runner + cachedSnapshot := runner.GetCachedTaskStatus() + if cachedSnapshot == nil { + return "", fmt.Errorf("no cached status available for task %s", task.Id.Hex()) + } + + svc.Infof("retrieved cached status for task[%s]: %s (cached at %v)", + task.Id.Hex(), cachedSnapshot.Status, cachedSnapshot.Timestamp) + return cachedSnapshot.Status, nil +} + // queryProcessStatusFromWorker directly queries the worker node for the actual process status func (svc *TaskReconciliationService) queryProcessStatusFromWorker(node *models.Node, task *models.Task) (processStatus string, err error) { // Check if there's an active stream for this task @@ -142,7 +216,7 @@ func (svc *TaskReconciliationService) queryProcessStatusFromWorker(node *models. // Check if the node is still connected via subscription nodeStream, nodeConnected := svc.server.NodeSvr.GetSubscribeStream(node.Id) if !nodeConnected { - return svc.inferProcessStatusFromLocalState(task, hasActiveStream) + return "", fmt.Errorf("node[%s] is not connected", node.Key) } // Query the worker for actual process status @@ -150,102 +224,41 @@ func (svc *TaskReconciliationService) queryProcessStatusFromWorker(node *models. // Send a process status query to the worker actualStatus, err := svc.requestProcessStatusFromWorker(nodeStream, task, 5*time.Second) if err != nil { - svc.Warnf("failed to get process status from worker: %v", err) - return svc.inferProcessStatusFromLocalState(task, hasActiveStream) + return "", fmt.Errorf("failed to get process status from worker: %w", err) } return actualStatus, nil } - return svc.inferProcessStatusFromLocalState(task, hasActiveStream) + // If we can't query the worker directly, return error + if hasActiveStream { + return constants.TaskStatusRunning, nil // Task likely still running if stream exists + } + return "", fmt.Errorf("unable to determine process status for task[%s] on node[%s]", task.Id.Hex(), node.Key) } // requestProcessStatusFromWorker sends a status query request to the worker node func (svc *TaskReconciliationService) requestProcessStatusFromWorker(nodeStream grpc.NodeService_SubscribeServer, task *models.Task, timeout time.Duration) (string, error) { // Check if task has a valid PID if task.Pid <= 0 { - return svc.inferProcessStatusFromLocalState(task, false) + return "", fmt.Errorf("task[%s] has invalid PID: %d", task.Id.Hex(), task.Pid) } // Get the node for this task node, err := service.NewModelService[models.Node]().GetById(task.NodeId) if err != nil { - svc.Warnf("failed to get node[%s] for task[%s]: %v", task.NodeId.Hex(), task.Id.Hex(), err) - _, hasActiveStream := svc.server.TaskSvr.GetSubscribeStream(task.Id) - return svc.inferProcessStatusFromLocalState(task, hasActiveStream) + return "", fmt.Errorf("failed to get node[%s] for task[%s]: %w", task.NodeId.Hex(), task.Id.Hex(), err) } - // Attempt to query worker directly (future implementation) - // This will return an error until worker discovery infrastructure is built + // Attempt to query worker directly workerStatus, err := svc.queryWorkerProcessStatus(node, task, timeout) if err != nil { - svc.Debugf("direct worker query not available, falling back to heuristics: %v", err) - - // Fallback to heuristic detection - _, hasActiveStream := svc.server.TaskSvr.GetSubscribeStream(task.Id) - return svc.inferProcessStatusFromLocalState(task, hasActiveStream) + return "", fmt.Errorf("worker process status query failed: %w", err) } svc.Infof("successfully queried worker process status for task[%s]: %s", task.Id.Hex(), workerStatus) return workerStatus, nil } -// inferProcessStatusFromLocalState uses local information to infer process status -func (svc *TaskReconciliationService) inferProcessStatusFromLocalState(task *models.Task, hasActiveStream bool) (string, error) { - // Check if task has been updated recently (within last 30 seconds) - isRecentlyUpdated := time.Since(task.UpdatedAt) < 30*time.Second - - switch { - case hasActiveStream && isRecentlyUpdated: - // Active stream and recent updates = likely running - return constants.TaskStatusRunning, nil - - case !hasActiveStream && isRecentlyUpdated: - // No stream but recent updates = likely just finished - if task.Error != "" { - return constants.TaskStatusError, nil - } - return constants.TaskStatusFinished, nil - - case !hasActiveStream && !isRecentlyUpdated: - // No stream and stale = process likely finished or failed - return svc.checkFinalTaskState(task), nil - - case hasActiveStream && !isRecentlyUpdated: - // Stream exists but no recent updates - could be a long-running task - // Don't assume abnormal - the task might be legitimately running without frequent updates - return constants.TaskStatusRunning, nil - - default: - // Fallback - return constants.TaskStatusError, nil - } -} - -// checkFinalTaskState determines the final state of a task without active streams -func (svc *TaskReconciliationService) checkFinalTaskState(task *models.Task) string { - // Check the current task status and error state - switch task.Status { - case constants.TaskStatusFinished, constants.TaskStatusError, constants.TaskStatusCancelled, constants.TaskStatusAbnormal: - // Already in a final state - return task.Status - case constants.TaskStatusRunning: - // Running status but no stream = process likely completed - if task.Error != "" { - return constants.TaskStatusError - } - return constants.TaskStatusFinished - case constants.TaskStatusPending, constants.TaskStatusAssigned: - // Never started running but lost connection - return constants.TaskStatusError - case constants.TaskStatusNodeDisconnected: - // Task is marked as disconnected - keep this status since we can't determine final state - // Don't assume abnormal until we can actually verify the process state - return constants.TaskStatusNodeDisconnected - default: - return constants.TaskStatusError - } -} - // mapProcessStatusToTaskStatus converts gRPC process status to task status func (svc *TaskReconciliationService) mapProcessStatusToTaskStatus(processStatus grpc.ProcessStatus, exitCode int32, task *models.Task) string { switch processStatus { @@ -277,10 +290,9 @@ func (svc *TaskReconciliationService) mapProcessStatusToTaskStatus(processStatus case grpc.ProcessStatus_PROCESS_UNKNOWN: fallthrough default: - // Unknown status - use heuristic detection - _, hasActiveStream := svc.server.TaskSvr.GetSubscribeStream(task.Id) - status, _ := svc.inferProcessStatusFromLocalState(task, hasActiveStream) - return status + // Unknown status - return error instead of using heuristics + svc.Warnf("unknown process status %v for task[%s]", processStatus, task.Id.Hex()) + return constants.TaskStatusError } } @@ -379,13 +391,6 @@ func (svc *TaskReconciliationService) updateTaskStatusReliably(task *models.Task }, backoff.WithMaxRetries(backoff.NewConstantBackOff(500*time.Millisecond), 3)) } -// detectTaskStatusFromHeuristics provides fallback detection when worker communication fails -func (svc *TaskReconciliationService) detectTaskStatusFromHeuristics(task *models.Task) (string, error) { - // Use improved heuristic detection - _, hasActiveStream := svc.server.TaskSvr.GetSubscribeStream(task.Id) - return svc.inferProcessStatusFromLocalState(task, hasActiveStream) -} - // StartPeriodicReconciliation starts a background service to periodically reconcile task status func (svc *TaskReconciliationService) StartPeriodicReconciliation() { go svc.runPeriodicReconciliation() @@ -470,102 +475,11 @@ func (svc *TaskReconciliationService) ForceReconcileTask(taskId primitive.Object return svc.reconcileTaskStatus(task) } -// detectTaskStatusFromActivity analyzes task activity to determine its actual status -func (svc *TaskReconciliationService) detectTaskStatusFromActivity(task *models.Task, hasActiveStream bool) (string, error) { - // Check if task has been updated recently (within last 30 seconds) - if time.Since(task.UpdatedAt) < 30*time.Second { - // Task was recently updated, likely still active - if hasActiveStream { - return constants.TaskStatusRunning, nil - } - // Recently updated but no stream - check if it finished - return svc.checkTaskCompletion(task), nil - } - - // Task hasn't been updated recently - if !hasActiveStream { - // No stream and no recent activity - likely finished or failed - return svc.checkTaskCompletion(task), nil - } - - // Has stream but no recent updates - might be stuck - return constants.TaskStatusRunning, nil -} - -// checkTaskCompletion determines if a task completed successfully or failed -func (svc *TaskReconciliationService) checkTaskCompletion(task *models.Task) string { - // Refresh task from database to get latest status - latestTask, err := service.NewModelService[models.Task]().GetById(task.Id) - if err != nil { - svc.Warnf("failed to refresh task[%s] from database: %v", task.Id.Hex(), err) - return constants.TaskStatusError - } - - // If task status was already updated to a final state, return that - switch latestTask.Status { - case constants.TaskStatusFinished, constants.TaskStatusError, constants.TaskStatusCancelled: - return latestTask.Status - case constants.TaskStatusAbnormal: - // Abnormal status is also final - keep it - return latestTask.Status - case constants.TaskStatusRunning: - // Task shows as running but has no active stream - need to determine actual status - if latestTask.Error != "" { - return constants.TaskStatusError - } - return constants.TaskStatusFinished - case constants.TaskStatusPending, constants.TaskStatusAssigned: - // Tasks that never started running but lost connection - mark as error - return constants.TaskStatusError - case constants.TaskStatusNodeDisconnected: - // Node disconnected status should be handled by reconnection logic - // Keep the disconnected status since we don't know the actual final state - return constants.TaskStatusNodeDisconnected - default: - // Unknown status - mark as error - svc.Warnf("task[%s] has unknown status: %s", task.Id.Hex(), latestTask.Status) - return constants.TaskStatusError - } -} - -// inferTaskStatusFromStream provides a fallback status inference based on stream presence -func (svc *TaskReconciliationService) inferTaskStatusFromStream(taskId primitive.ObjectID, hasActiveStream bool) string { - if !hasActiveStream { - // No active stream could mean: - // 1. Task finished successfully - // 2. Task failed and stream was closed - // 3. Worker disconnected ungracefully - // - // To determine which, we should check the task in the database - task, err := service.NewModelService[models.Task]().GetById(taskId) - if err != nil { - // If we can't find the task, assume it's in an error state - return constants.TaskStatusError - } - - // If the task was last seen running and now has no stream, - // it likely finished or errored - switch task.Status { - case constants.TaskStatusRunning: - // Task was running but stream is gone - likely finished - return constants.TaskStatusFinished - case constants.TaskStatusPending, constants.TaskStatusAssigned: - // Task never started running - likely error - return constants.TaskStatusError - default: - // Return the last known status - return task.Status - } - } - - // Stream exists, so task is likely still running - return constants.TaskStatusRunning -} - -func NewTaskReconciliationService(server *server.GrpcServer) *TaskReconciliationService { +func NewTaskReconciliationService(server *server.GrpcServer, taskHandlerSvc *handler.Service) *TaskReconciliationService { return &TaskReconciliationService{ - server: server, - Logger: utils.NewLogger("TaskReconciliationService"), + server: server, + taskHandlerSvc: taskHandlerSvc, + Logger: utils.NewLogger("TaskReconciliationService"), } } @@ -636,7 +550,13 @@ func GetTaskReconciliationService() *TaskReconciliationService { taskReconciliationServiceOnce.Do(func() { // Get the server from gRPC server singleton grpcServer := server.GetGrpcServer() - taskReconciliationService = NewTaskReconciliationService(grpcServer) + // Try to get task handler service (will be nil on master nodes) + var taskHandlerSvc *handler.Service + if !utils.IsMaster() { + // Only worker nodes have task handler service + taskHandlerSvc = handler.GetTaskHandlerService() + } + taskReconciliationService = NewTaskReconciliationService(grpcServer, taskHandlerSvc) }) return taskReconciliationService } diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index f5fef6b8..980723a7 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -79,6 +79,12 @@ func newTaskRunner(id primitive.ObjectID, svc *Service) (r *Runner, err error) { r.ctx, r.cancel = context.WithCancel(svc.ctx) r.done = make(chan struct{}) + // Initialize status cache for disconnection resilience + if err := r.initStatusCache(); err != nil { + r.Errorf("error initializing status cache: %v", err) + errs.Errors = append(errs.Errors, err) + } + // initialize task runner if err := r.Init(); err != nil { r.Errorf("error initializing task runner: %v", err) @@ -138,16 +144,21 @@ type Runner struct { // circuit breaker for log connections to prevent cascading failures logConnHealthy bool // tracks if log connection is healthy - logConnMutex sync.RWMutex // mutex for log connection health state - lastLogSendFailure time.Time // last time log send failed - logCircuitOpenTime time.Time // when circuit breaker was opened - logFailureCount int // consecutive log send failures + logConnMutex sync.RWMutex // mutex for log connection health state + lastLogSendFailure time.Time // last time log send failed + logCircuitOpenTime time.Time // when circuit breaker was opened + logFailureCount int // consecutive log send failures logCircuitOpenDuration time.Duration // how long to keep circuit open after failures // configurable timeouts for robust task execution ipcTimeout time.Duration // timeout for IPC operations healthCheckInterval time.Duration // interval for health checks connHealthInterval time.Duration // interval for connection health checks + + // status cache for disconnection resilience + statusCache *TaskStatusCache // local status cache that survives disconnections + pendingUpdates []TaskStatusSnapshot // status updates to sync when reconnected + statusCacheMutex sync.RWMutex // mutex for status cache operations } // Init initializes the task runner by updating the task status and establishing gRPC connections @@ -268,6 +279,9 @@ func (r *Runner) Run() (err error) { if r.ipcChan != nil { close(r.ipcChan) } + + // 6. Clean up status cache for completed tasks + r.cleanupStatusCache() }() // wait for process to finish @@ -499,6 +513,9 @@ func (r *Runner) updateTask(status string, e error) (err error) { } if r.t != nil && status != "" { + // Cache status locally first (always succeeds) + r.cacheTaskStatus(status, e) + // update task status r.t.Status = status if e != nil { @@ -507,18 +524,22 @@ func (r *Runner) updateTask(status string, e error) (err error) { if utils.IsMaster() { err = service.NewModelService[models.Task]().ReplaceById(r.t.Id, *r.t) if err != nil { - return err + r.Warnf("failed to update task in database, but cached locally: %v", err) + // Don't return error - the status is cached and will be synced later } } else { err = client.NewModelService[models.Task]().ReplaceById(r.t.Id, *r.t) if err != nil { - return err + r.Warnf("failed to update task in database, but cached locally: %v", err) + // Don't return error - the status is cached and will be synced later } } - // update stats - r.updateTaskStat(status) - r.updateSpiderStat(status) + // update stats (only if database update succeeded) + if err == nil { + r.updateTaskStat(status) + r.updateSpiderStat(status) + } // send notification go r.sendNotification() @@ -703,7 +724,7 @@ func (r *Runner) reconnectWithRetry() error { r.lastConnCheck = time.Now() r.connRetryAttempts = 0 r.Infof("successfully reconnected to task service after %d attempts", attempt+1) - + // Reset log circuit breaker when connection is restored r.logConnMutex.Lock() if !r.logConnHealthy { @@ -712,7 +733,14 @@ func (r *Runner) reconnectWithRetry() error { r.Logger.Info("log circuit breaker reset after successful reconnection") } r.logConnMutex.Unlock() - + + // Sync pending status updates after successful reconnection + go func() { + if err := r.syncPendingStatusUpdates(); err != nil { + r.Errorf("failed to sync pending status updates after reconnection: %v", err) + } + }() + return nil } diff --git a/core/task/handler/runner_status_cache.go b/core/task/handler/runner_status_cache.go new file mode 100644 index 00000000..f555766d --- /dev/null +++ b/core/task/handler/runner_status_cache.go @@ -0,0 +1,203 @@ +package handler + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/crawlab-team/crawlab/core/models/client" + "github.com/crawlab-team/crawlab/core/models/models" + "github.com/crawlab-team/crawlab/core/models/service" + "github.com/crawlab-team/crawlab/core/utils" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// TaskStatusSnapshot represents a point-in-time status of a task for caching +type TaskStatusSnapshot struct { + TaskId primitive.ObjectID `json:"task_id"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + Pid int `json:"pid,omitempty"` + Timestamp time.Time `json:"timestamp"` + StartedAt *time.Time `json:"started_at,omitempty"` + EndedAt *time.Time `json:"ended_at,omitempty"` +} + +// TaskStatusCache manages local task status storage for disconnection resilience +type TaskStatusCache struct { + mu sync.RWMutex + snapshots map[primitive.ObjectID]*TaskStatusSnapshot + filePath string + dirty bool // tracks if cache needs to be persisted +} + +func (r *Runner) initStatusCache() error { + cacheDir := filepath.Join(utils.GetWorkspace(), ".crawlab", "task_cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + return fmt.Errorf("failed to create cache directory: %w", err) + } + + r.statusCache = &TaskStatusCache{ + snapshots: make(map[primitive.ObjectID]*TaskStatusSnapshot), + filePath: filepath.Join(cacheDir, fmt.Sprintf("task_%s.json", r.tid.Hex())), + dirty: false, + } + + if err := r.loadStatusCache(); err != nil { + r.Warnf("failed to load existing status cache: %v", err) + } + + r.pendingUpdates = make([]TaskStatusSnapshot, 0) + return nil +} + +func (r *Runner) loadStatusCache() error { + if _, err := os.Stat(r.statusCache.filePath); os.IsNotExist(err) { + return nil + } + data, err := os.ReadFile(r.statusCache.filePath) + if err != nil { + return fmt.Errorf("failed to read cache file: %w", err) + } + var snapshots []TaskStatusSnapshot + if err := json.Unmarshal(data, &snapshots); err != nil { + return fmt.Errorf("failed to unmarshal cache data: %w", err) + } + r.statusCache.mu.Lock() + defer r.statusCache.mu.Unlock() + for _, snapshot := range snapshots { + r.statusCache.snapshots[snapshot.TaskId] = &snapshot + } + r.Debugf("loaded %d task status snapshots from cache", len(snapshots)) + return nil +} + +func (r *Runner) persistStatusCache() error { + r.statusCache.mu.RLock() + if !r.statusCache.dirty { + r.statusCache.mu.RUnlock() + return nil + } + snapshots := make([]TaskStatusSnapshot, 0, len(r.statusCache.snapshots)) + for _, snapshot := range r.statusCache.snapshots { + snapshots = append(snapshots, *snapshot) + } + r.statusCache.mu.RUnlock() + data, err := json.Marshal(snapshots) + if err != nil { + return fmt.Errorf("failed to marshal cache data: %w", err) + } + if err := os.WriteFile(r.statusCache.filePath, data, 0644); err != nil { + return fmt.Errorf("failed to write cache file: %w", err) + } + r.statusCache.mu.Lock() + r.statusCache.dirty = false + r.statusCache.mu.Unlock() + return nil +} + +func (r *Runner) cacheTaskStatus(status string, err error) { + snapshot := &TaskStatusSnapshot{ + TaskId: r.tid, + Status: status, + Pid: r.pid, + Timestamp: time.Now(), + } + if err != nil { + snapshot.Error = err.Error() + } + // Store in cache + r.statusCache.mu.Lock() + r.statusCache.snapshots[r.tid] = snapshot + r.statusCache.dirty = true + r.statusCache.mu.Unlock() + // Add to pending updates for sync when reconnected + r.statusCacheMutex.Lock() + r.pendingUpdates = append(r.pendingUpdates, *snapshot) + r.statusCacheMutex.Unlock() + go func() { + if err := r.persistStatusCache(); err != nil { + r.Errorf("failed to persist status cache: %v", err) + } + }() + r.Debugf("cached task status: %s (pid: %d)", status, r.pid) +} + +func (r *Runner) syncPendingStatusUpdates() error { + r.statusCacheMutex.Lock() + pendingCount := len(r.pendingUpdates) + if pendingCount == 0 { + r.statusCacheMutex.Unlock() + return nil + } + updates := make([]TaskStatusSnapshot, pendingCount) + copy(updates, r.pendingUpdates) + r.pendingUpdates = r.pendingUpdates[:0] + r.statusCacheMutex.Unlock() + r.Infof("syncing %d pending status updates to master node", pendingCount) + for _, update := range updates { + if err := r.syncStatusUpdate(update); err != nil { + r.Errorf("failed to sync status update for task %s: %v", update.TaskId.Hex(), err) + r.statusCacheMutex.Lock() + r.pendingUpdates = append(r.pendingUpdates, update) + r.statusCacheMutex.Unlock() + return err + } + } + r.Infof("successfully synced %d status updates", pendingCount) + return nil +} + +func (r *Runner) syncStatusUpdate(snapshot TaskStatusSnapshot) error { + task, err := r.svc.GetTaskById(snapshot.TaskId) + if err != nil { + return fmt.Errorf("failed to get task %s: %w", snapshot.TaskId.Hex(), err) + } + if task.UpdatedAt.After(snapshot.Timestamp) { + r.Debugf("skipping status sync for task %s - database is newer", snapshot.TaskId.Hex()) + return nil + } + task.Status = snapshot.Status + task.Error = snapshot.Error + task.Pid = snapshot.Pid + if utils.IsMaster() { + err = service.NewModelService[models.Task]().ReplaceById(task.Id, *task) + } else { + err = client.NewModelService[models.Task]().ReplaceById(task.Id, *task) + } + if err != nil { + return fmt.Errorf("failed to update task in database: %w", err) + } + r.Debugf("synced status update for task %s: %s", snapshot.TaskId.Hex(), snapshot.Status) + return nil +} + +func (r *Runner) getCachedTaskStatus() *TaskStatusSnapshot { + r.statusCache.mu.RLock() + defer r.statusCache.mu.RUnlock() + if snapshot, exists := r.statusCache.snapshots[r.tid]; exists { + return snapshot + } + return nil +} + +func (r *Runner) cleanupStatusCache() { + if r.statusCache != nil && r.statusCache.filePath != "" { + if err := os.Remove(r.statusCache.filePath); err != nil && !os.IsNotExist(err) { + r.Warnf("failed to remove status cache file: %v", err) + } + } +} + +// GetCachedTaskStatus retrieves the cached status for this task (public method for external access) +func (r *Runner) GetCachedTaskStatus() *TaskStatusSnapshot { + return r.getCachedTaskStatus() +} + +// SyncPendingStatusUpdates syncs all pending status updates to the master node (public method for external access) +func (r *Runner) SyncPendingStatusUpdates() error { + return r.syncPendingStatusUpdates() +} diff --git a/core/task/handler/service_operations.go b/core/task/handler/service_operations.go index f9e44089..1b6c1b2a 100644 --- a/core/task/handler/service_operations.go +++ b/core/task/handler/service_operations.go @@ -280,3 +280,12 @@ func (svc *Service) deleteRunner(taskId primitive.ObjectID) { svc.Debugf("delete runner: taskId[%v]", taskId) svc.runners.Delete(taskId) } + +// GetTaskRunner returns the task runner for the given task ID (public method for external access) +func (svc *Service) GetTaskRunner(taskId primitive.ObjectID) interfaces.TaskRunner { + r, err := svc.getRunner(taskId) + if err != nil { + return nil + } + return r +}