feat: enhance task reconciliation with worker-side status caching and synchronization

This commit is contained in:
Marvin Zhang
2025-09-17 11:03:35 +08:00
parent 8c2c23d9b6
commit afa5fab4c1
5 changed files with 355 additions and 195 deletions

View File

@@ -79,6 +79,12 @@ func newTaskRunner(id primitive.ObjectID, svc *Service) (r *Runner, err error) {
r.ctx, r.cancel = context.WithCancel(svc.ctx)
r.done = make(chan struct{})
// Initialize status cache for disconnection resilience
if err := r.initStatusCache(); err != nil {
r.Errorf("error initializing status cache: %v", err)
errs.Errors = append(errs.Errors, err)
}
// initialize task runner
if err := r.Init(); err != nil {
r.Errorf("error initializing task runner: %v", err)
@@ -138,16 +144,21 @@ type Runner struct {
// circuit breaker for log connections to prevent cascading failures
logConnHealthy bool // tracks if log connection is healthy
logConnMutex sync.RWMutex // mutex for log connection health state
lastLogSendFailure time.Time // last time log send failed
logCircuitOpenTime time.Time // when circuit breaker was opened
logFailureCount int // consecutive log send failures
logConnMutex sync.RWMutex // mutex for log connection health state
lastLogSendFailure time.Time // last time log send failed
logCircuitOpenTime time.Time // when circuit breaker was opened
logFailureCount int // consecutive log send failures
logCircuitOpenDuration time.Duration // how long to keep circuit open after failures
// configurable timeouts for robust task execution
ipcTimeout time.Duration // timeout for IPC operations
healthCheckInterval time.Duration // interval for health checks
connHealthInterval time.Duration // interval for connection health checks
// status cache for disconnection resilience
statusCache *TaskStatusCache // local status cache that survives disconnections
pendingUpdates []TaskStatusSnapshot // status updates to sync when reconnected
statusCacheMutex sync.RWMutex // mutex for status cache operations
}
// Init initializes the task runner by updating the task status and establishing gRPC connections
@@ -268,6 +279,9 @@ func (r *Runner) Run() (err error) {
if r.ipcChan != nil {
close(r.ipcChan)
}
// 6. Clean up status cache for completed tasks
r.cleanupStatusCache()
}()
// wait for process to finish
@@ -499,6 +513,9 @@ func (r *Runner) updateTask(status string, e error) (err error) {
}
if r.t != nil && status != "" {
// Cache status locally first (always succeeds)
r.cacheTaskStatus(status, e)
// update task status
r.t.Status = status
if e != nil {
@@ -507,18 +524,22 @@ func (r *Runner) updateTask(status string, e error) (err error) {
if utils.IsMaster() {
err = service.NewModelService[models.Task]().ReplaceById(r.t.Id, *r.t)
if err != nil {
return err
r.Warnf("failed to update task in database, but cached locally: %v", err)
// Don't return error - the status is cached and will be synced later
}
} else {
err = client.NewModelService[models.Task]().ReplaceById(r.t.Id, *r.t)
if err != nil {
return err
r.Warnf("failed to update task in database, but cached locally: %v", err)
// Don't return error - the status is cached and will be synced later
}
}
// update stats
r.updateTaskStat(status)
r.updateSpiderStat(status)
// update stats (only if database update succeeded)
if err == nil {
r.updateTaskStat(status)
r.updateSpiderStat(status)
}
// send notification
go r.sendNotification()
@@ -703,7 +724,7 @@ func (r *Runner) reconnectWithRetry() error {
r.lastConnCheck = time.Now()
r.connRetryAttempts = 0
r.Infof("successfully reconnected to task service after %d attempts", attempt+1)
// Reset log circuit breaker when connection is restored
r.logConnMutex.Lock()
if !r.logConnHealthy {
@@ -712,7 +733,14 @@ func (r *Runner) reconnectWithRetry() error {
r.Logger.Info("log circuit breaker reset after successful reconnection")
}
r.logConnMutex.Unlock()
// Sync pending status updates after successful reconnection
go func() {
if err := r.syncPendingStatusUpdates(); err != nil {
r.Errorf("failed to sync pending status updates after reconnection: %v", err)
}
}()
return nil
}

View File

@@ -0,0 +1,203 @@
package handler
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/crawlab-team/crawlab/core/models/client"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/utils"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// TaskStatusSnapshot represents a point-in-time status of a task for caching
type TaskStatusSnapshot struct {
TaskId primitive.ObjectID `json:"task_id"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Pid int `json:"pid,omitempty"`
Timestamp time.Time `json:"timestamp"`
StartedAt *time.Time `json:"started_at,omitempty"`
EndedAt *time.Time `json:"ended_at,omitempty"`
}
// TaskStatusCache manages local task status storage for disconnection resilience
type TaskStatusCache struct {
mu sync.RWMutex
snapshots map[primitive.ObjectID]*TaskStatusSnapshot
filePath string
dirty bool // tracks if cache needs to be persisted
}
func (r *Runner) initStatusCache() error {
cacheDir := filepath.Join(utils.GetWorkspace(), ".crawlab", "task_cache")
if err := os.MkdirAll(cacheDir, 0755); err != nil {
return fmt.Errorf("failed to create cache directory: %w", err)
}
r.statusCache = &TaskStatusCache{
snapshots: make(map[primitive.ObjectID]*TaskStatusSnapshot),
filePath: filepath.Join(cacheDir, fmt.Sprintf("task_%s.json", r.tid.Hex())),
dirty: false,
}
if err := r.loadStatusCache(); err != nil {
r.Warnf("failed to load existing status cache: %v", err)
}
r.pendingUpdates = make([]TaskStatusSnapshot, 0)
return nil
}
func (r *Runner) loadStatusCache() error {
if _, err := os.Stat(r.statusCache.filePath); os.IsNotExist(err) {
return nil
}
data, err := os.ReadFile(r.statusCache.filePath)
if err != nil {
return fmt.Errorf("failed to read cache file: %w", err)
}
var snapshots []TaskStatusSnapshot
if err := json.Unmarshal(data, &snapshots); err != nil {
return fmt.Errorf("failed to unmarshal cache data: %w", err)
}
r.statusCache.mu.Lock()
defer r.statusCache.mu.Unlock()
for _, snapshot := range snapshots {
r.statusCache.snapshots[snapshot.TaskId] = &snapshot
}
r.Debugf("loaded %d task status snapshots from cache", len(snapshots))
return nil
}
func (r *Runner) persistStatusCache() error {
r.statusCache.mu.RLock()
if !r.statusCache.dirty {
r.statusCache.mu.RUnlock()
return nil
}
snapshots := make([]TaskStatusSnapshot, 0, len(r.statusCache.snapshots))
for _, snapshot := range r.statusCache.snapshots {
snapshots = append(snapshots, *snapshot)
}
r.statusCache.mu.RUnlock()
data, err := json.Marshal(snapshots)
if err != nil {
return fmt.Errorf("failed to marshal cache data: %w", err)
}
if err := os.WriteFile(r.statusCache.filePath, data, 0644); err != nil {
return fmt.Errorf("failed to write cache file: %w", err)
}
r.statusCache.mu.Lock()
r.statusCache.dirty = false
r.statusCache.mu.Unlock()
return nil
}
func (r *Runner) cacheTaskStatus(status string, err error) {
snapshot := &TaskStatusSnapshot{
TaskId: r.tid,
Status: status,
Pid: r.pid,
Timestamp: time.Now(),
}
if err != nil {
snapshot.Error = err.Error()
}
// Store in cache
r.statusCache.mu.Lock()
r.statusCache.snapshots[r.tid] = snapshot
r.statusCache.dirty = true
r.statusCache.mu.Unlock()
// Add to pending updates for sync when reconnected
r.statusCacheMutex.Lock()
r.pendingUpdates = append(r.pendingUpdates, *snapshot)
r.statusCacheMutex.Unlock()
go func() {
if err := r.persistStatusCache(); err != nil {
r.Errorf("failed to persist status cache: %v", err)
}
}()
r.Debugf("cached task status: %s (pid: %d)", status, r.pid)
}
func (r *Runner) syncPendingStatusUpdates() error {
r.statusCacheMutex.Lock()
pendingCount := len(r.pendingUpdates)
if pendingCount == 0 {
r.statusCacheMutex.Unlock()
return nil
}
updates := make([]TaskStatusSnapshot, pendingCount)
copy(updates, r.pendingUpdates)
r.pendingUpdates = r.pendingUpdates[:0]
r.statusCacheMutex.Unlock()
r.Infof("syncing %d pending status updates to master node", pendingCount)
for _, update := range updates {
if err := r.syncStatusUpdate(update); err != nil {
r.Errorf("failed to sync status update for task %s: %v", update.TaskId.Hex(), err)
r.statusCacheMutex.Lock()
r.pendingUpdates = append(r.pendingUpdates, update)
r.statusCacheMutex.Unlock()
return err
}
}
r.Infof("successfully synced %d status updates", pendingCount)
return nil
}
func (r *Runner) syncStatusUpdate(snapshot TaskStatusSnapshot) error {
task, err := r.svc.GetTaskById(snapshot.TaskId)
if err != nil {
return fmt.Errorf("failed to get task %s: %w", snapshot.TaskId.Hex(), err)
}
if task.UpdatedAt.After(snapshot.Timestamp) {
r.Debugf("skipping status sync for task %s - database is newer", snapshot.TaskId.Hex())
return nil
}
task.Status = snapshot.Status
task.Error = snapshot.Error
task.Pid = snapshot.Pid
if utils.IsMaster() {
err = service.NewModelService[models.Task]().ReplaceById(task.Id, *task)
} else {
err = client.NewModelService[models.Task]().ReplaceById(task.Id, *task)
}
if err != nil {
return fmt.Errorf("failed to update task in database: %w", err)
}
r.Debugf("synced status update for task %s: %s", snapshot.TaskId.Hex(), snapshot.Status)
return nil
}
func (r *Runner) getCachedTaskStatus() *TaskStatusSnapshot {
r.statusCache.mu.RLock()
defer r.statusCache.mu.RUnlock()
if snapshot, exists := r.statusCache.snapshots[r.tid]; exists {
return snapshot
}
return nil
}
func (r *Runner) cleanupStatusCache() {
if r.statusCache != nil && r.statusCache.filePath != "" {
if err := os.Remove(r.statusCache.filePath); err != nil && !os.IsNotExist(err) {
r.Warnf("failed to remove status cache file: %v", err)
}
}
}
// GetCachedTaskStatus retrieves the cached status for this task (public method for external access)
func (r *Runner) GetCachedTaskStatus() *TaskStatusSnapshot {
return r.getCachedTaskStatus()
}
// SyncPendingStatusUpdates syncs all pending status updates to the master node (public method for external access)
func (r *Runner) SyncPendingStatusUpdates() error {
return r.syncPendingStatusUpdates()
}

View File

@@ -280,3 +280,12 @@ func (svc *Service) deleteRunner(taskId primitive.ObjectID) {
svc.Debugf("delete runner: taskId[%v]", taskId)
svc.runners.Delete(taskId)
}
// GetTaskRunner returns the task runner for the given task ID (public method for external access)
func (svc *Service) GetTaskRunner(taskId primitive.ObjectID) interfaces.TaskRunner {
r, err := svc.getRunner(taskId)
if err != nil {
return nil
}
return r
}