diff --git a/core/notification/service.go b/core/notification/service.go index a49b68bc..ff472bde 100644 --- a/core/notification/service.go +++ b/core/notification/service.go @@ -26,11 +26,26 @@ type Service struct { func (svc *Service) Send(s *models.NotificationSetting, args ...any) { title := s.Title + // Use bounded goroutine pool to prevent unlimited goroutine creation + const maxWorkers = 5 wg := sync.WaitGroup{} - wg.Add(len(s.ChannelIds)) + semaphore := make(chan struct{}, maxWorkers) + for _, chId := range s.ChannelIds { + wg.Add(1) + + // Acquire semaphore + semaphore <- struct{}{} + go func(chId primitive.ObjectID) { - defer wg.Done() + defer func() { + <-semaphore // Release semaphore + wg.Done() + if r := recover(); r != nil { + svc.Errorf("[NotificationService] channel handler panic for %s: %v", chId.Hex(), r) + } + }() + ch, err := service.NewModelService[models.NotificationChannel]().GetById(chId) if err != nil { svc.Errorf("[NotificationService] get channel error: %v", err) @@ -62,8 +77,8 @@ func (svc *Service) SendMail(s *models.NotificationSetting, ch *models.Notificat svc.Errorf("[NotificationService] send mail error: %v", err) } - // save request - go svc.saveRequest(r, err) + // save request synchronously to avoid unbounded goroutines + svc.saveRequest(r, err) } func (svc *Service) SendIM(ch *models.NotificationChannel, title, content string) { @@ -76,8 +91,8 @@ func (svc *Service) SendIM(ch *models.NotificationChannel, title, content string svc.Errorf("[NotificationService] send mobile notification error: %v", err) } - // save request - go svc.saveRequest(r, err) + // save request synchronously to avoid unbounded goroutines + svc.saveRequest(r, err) } func (svc *Service) SendTestMessage(locale string, ch *models.NotificationChannel, toMail []string) (err error) { @@ -126,8 +141,8 @@ func (svc *Service) SendTestMessage(locale string, ch *models.NotificationChanne return fmt.Errorf("unsupported notification channel type: %s", ch.Type) } - // Save request - go svc.saveRequest(r, err) + // Save request synchronously to avoid unbounded goroutines + svc.saveRequest(r, err) return err } @@ -361,21 +376,21 @@ func (svc *Service) geContentWithVariables(template string, variables []entity.N func (svc *Service) getVariableData(args ...any) (vd VariableData) { for _, arg := range args { - switch arg.(type) { + switch arg := arg.(type) { case *models.Task: - vd.Task = arg.(*models.Task) + vd.Task = arg case *models.TaskStat: - vd.TaskStat = arg.(*models.TaskStat) + vd.TaskStat = arg case *models.Spider: - vd.Spider = arg.(*models.Spider) + vd.Spider = arg case *models.Node: - vd.Node = arg.(*models.Node) + vd.Node = arg case *models.Schedule: - vd.Schedule = arg.(*models.Schedule) + vd.Schedule = arg case *models.NotificationAlert: - vd.Alert = arg.(*models.NotificationAlert) + vd.Alert = arg case *models.Metric: - vd.Metric = arg.(*models.Metric) + vd.Metric = arg } } return vd @@ -383,7 +398,7 @@ func (svc *Service) getVariableData(args ...any) (vd VariableData) { func (svc *Service) parseTemplateVariables(template string) (variables []entity.NotificationVariable) { // regex pattern - regex := regexp.MustCompile("\\$\\{(\\w+):(\\w+)}") + regex := regexp.MustCompile(`\$\{(\w+):(\w+)\}`) // find all matches matches := regex.FindAllStringSubmatch(template, -1) @@ -500,21 +515,42 @@ func (svc *Service) SendNodeNotification(node *models.Node) { return } + // Use bounded goroutine pool for node notifications + const maxWorkers = 3 + var wg sync.WaitGroup + semaphore := make(chan struct{}, maxWorkers) + for _, s := range settings { - // send notification - switch s.Trigger { - case constants.NotificationTriggerNodeStatusChange: - go svc.Send(&s, args...) - case constants.NotificationTriggerNodeOnline: - if node.Status == constants.NodeStatusOnline { - go svc.Send(&s, args...) + wg.Add(1) + + // Acquire semaphore + semaphore <- struct{}{} + + go func(setting models.NotificationSetting) { + defer func() { + <-semaphore // Release semaphore + wg.Done() + if r := recover(); r != nil { + svc.Errorf("[NotificationService] node notification panic for setting %s: %v", setting.Id.Hex(), r) + } + }() + + // send notification + switch setting.Trigger { + case constants.NotificationTriggerNodeStatusChange: + svc.Send(&setting, args...) + case constants.NotificationTriggerNodeOnline: + if node.Status == constants.NodeStatusOnline { + svc.Send(&setting, args...) + } + case constants.NotificationTriggerNodeOffline: + if node.Status == constants.NodeStatusOffline { + svc.Send(&setting, args...) + } } - case constants.NotificationTriggerNodeOffline: - if node.Status == constants.NodeStatusOffline { - go svc.Send(&s, args...) - } - } + }(s) } + wg.Wait() } func (svc *Service) createRequestMail(s *models.NotificationSetting, ch *models.NotificationChannel, title, content string) (res *models.NotificationRequest, err error) { diff --git a/core/schedule/logger.go b/core/schedule/logger.go index 0eabb42d..d09538e5 100644 --- a/core/schedule/logger.go +++ b/core/schedule/logger.go @@ -4,7 +4,6 @@ import ( "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/core/utils" "github.com/robfig/cron/v3" - "strings" ) type CronLogger struct { @@ -12,21 +11,11 @@ type CronLogger struct { } func (l *CronLogger) Info(msg string, keysAndValues ...interface{}) { - p := l.getPlaceholder(len(keysAndValues)) - l.Infof("cron: %s %s", msg, p) + l.Infof("%s %v", msg, keysAndValues) } func (l *CronLogger) Error(err error, msg string, keysAndValues ...interface{}) { - p := l.getPlaceholder(len(keysAndValues)) - l.Errorf("cron: %s %v %s", msg, err, p) -} - -func (l *CronLogger) getPlaceholder(n int) (s string) { - var arr []string - for i := 0; i < n; i++ { - arr = append(arr, "%v") - } - return strings.Join(arr, " ") + l.Errorf("%s %v %v", msg, err, keysAndValues) } func NewCronLogger() cron.Logger { diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 7832cc93..2c1c16fc 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -1752,9 +1752,7 @@ func (r *Runner) checkForZombieProcesses() { // Process exists, check if it's a zombie if proc, err := process.NewProcess(int32(r.pid)); err == nil { if status, err := proc.Status(); err == nil { - // Status returns a string, check if it indicates zombie - statusStr := string(status) - if statusStr == "Z" || statusStr == "zombie" { + if status == "Z" || status == "zombie" { r.Warnf("detected zombie process %d for task %s", r.pid, r.tid.Hex()) go r.cleanupOrphanedProcesses() } diff --git a/core/task/handler/service.go b/core/task/handler/service.go index 26b8aefa..89abc230 100644 --- a/core/task/handler/service.go +++ b/core/task/handler/service.go @@ -45,9 +45,147 @@ type Service struct { fetchTicker *time.Ticker reportTicker *time.Ticker + // worker pool for bounded task execution + workerPool *TaskWorkerPool + maxWorkers int + + // stream manager for leak-free stream handling + streamManager *StreamManager + interfaces.Logger } +// StreamManager manages task streams without goroutine leaks +type StreamManager struct { + streams sync.Map // map[primitive.ObjectID]*TaskStream + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + service *Service + messageQueue chan *StreamMessage + maxStreams int +} + +// TaskStream represents a single task's stream +type TaskStream struct { + taskId primitive.ObjectID + stream grpc.TaskService_SubscribeClient + ctx context.Context + cancel context.CancelFunc + lastActive time.Time + mu sync.RWMutex +} + +// StreamMessage represents a message from a stream +type StreamMessage struct { + taskId primitive.ObjectID + msg *grpc.TaskServiceSubscribeResponse + err error +} + +// taskRequest represents a task execution request +type taskRequest struct { + taskId primitive.ObjectID +} + +// TaskWorkerPool manages a bounded pool of workers for task execution +type TaskWorkerPool struct { + workers int + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + taskQueue chan taskRequest + service *Service +} + +func NewTaskWorkerPool(workers int, service *Service) *TaskWorkerPool { + ctx, cancel := context.WithCancel(context.Background()) + // Use a more generous queue size to handle task bursts + // Queue size is workers * 5 to allow for better buffering + queueSize := workers * 5 + if queueSize < 50 { + queueSize = 50 // Minimum queue size + } + + return &TaskWorkerPool{ + workers: workers, + ctx: ctx, + cancel: cancel, + taskQueue: make(chan taskRequest, queueSize), + service: service, + } +} + +func (pool *TaskWorkerPool) Start() { + for i := 0; i < pool.workers; i++ { + pool.wg.Add(1) + go pool.worker(i) + } +} + +func (pool *TaskWorkerPool) Stop() { + pool.cancel() + close(pool.taskQueue) + pool.wg.Wait() +} + +func (pool *TaskWorkerPool) SubmitTask(taskId primitive.ObjectID) error { + req := taskRequest{ + taskId: taskId, + } + + select { + case pool.taskQueue <- req: + pool.service.Debugf("task[%s] queued for parallel execution, queue usage: %d/%d", + taskId.Hex(), len(pool.taskQueue), cap(pool.taskQueue)) + return nil // Return immediately - task will execute in parallel + case <-pool.ctx.Done(): + return fmt.Errorf("worker pool is shutting down") + default: + queueLen := len(pool.taskQueue) + queueCap := cap(pool.taskQueue) + pool.service.Warnf("task queue is full (%d/%d), consider increasing task.workers configuration", + queueLen, queueCap) + return fmt.Errorf("task queue is full (%d/%d), consider increasing task.workers configuration", + queueLen, queueCap) + } +} + +func (pool *TaskWorkerPool) worker(workerID int) { + defer pool.wg.Done() + defer func() { + if r := recover(); r != nil { + pool.service.Errorf("worker %d panic recovered: %v", workerID, r) + } + }() + + pool.service.Debugf("worker %d started", workerID) + + for { + select { + case <-pool.ctx.Done(): + pool.service.Debugf("worker %d shutting down", workerID) + return + case req, ok := <-pool.taskQueue: + if !ok { + pool.service.Debugf("worker %d: task queue closed", workerID) + return + } + + // Execute task asynchronously - each worker handles one task at a time + // but multiple workers can process different tasks simultaneously + pool.service.Debugf("worker %d processing task[%s]", workerID, req.taskId.Hex()) + err := pool.service.executeTask(req.taskId) + if err != nil { + pool.service.Errorf("worker %d failed to execute task[%s]: %v", + workerID, req.taskId.Hex(), err) + } else { + pool.service.Debugf("worker %d completed task[%s]", workerID, req.taskId.Hex()) + } + } + } +} + func (svc *Service) Start() { // Initialize context for graceful shutdown svc.ctx, svc.cancel = context.WithCancel(context.Background()) @@ -59,7 +197,14 @@ func (svc *Service) Start() { svc.fetchTicker = time.NewTicker(svc.fetchInterval) svc.reportTicker = time.NewTicker(svc.reportInterval) - // Start goroutine monitoring + // Initialize and start worker pool + svc.workerPool = NewTaskWorkerPool(svc.maxWorkers, svc) + svc.workerPool.Start() + + // Initialize and start stream manager + svc.streamManager.Start() + + // Start goroutine monitoring (adds to WaitGroup internally) svc.startGoroutineMonitoring() // Start background goroutines with WaitGroup tracking @@ -67,9 +212,10 @@ func (svc *Service) Start() { go svc.reportStatus() go svc.fetchAndRunTasks() - svc.Infof("Task handler service started") + queueSize := cap(svc.workerPool.taskQueue) + svc.Infof("Task handler service started with %d workers and queue size %d", svc.maxWorkers, queueSize) - // Start the stuck task cleanup routine + // Start the stuck task cleanup routine (adds to WaitGroup internally) svc.startStuckTaskCleanup() } @@ -89,6 +235,16 @@ func (svc *Service) Stop() { svc.cancel() } + // Stop worker pool first + if svc.workerPool != nil { + svc.workerPool.Stop() + } + + // Stop stream manager + if svc.streamManager != nil { + svc.streamManager.Stop() + } + // Stop tickers to prevent new tasks if svc.fetchTicker != nil { svc.fetchTicker.Stop() @@ -119,7 +275,9 @@ func (svc *Service) Stop() { } func (svc *Service) startGoroutineMonitoring() { + svc.wg.Add(1) // Track goroutine monitoring in WaitGroup go func() { + defer svc.wg.Done() defer func() { if r := recover(); r != nil { svc.Errorf("[TaskHandler] goroutine monitoring panic: %v", r) @@ -217,7 +375,7 @@ func (svc *Service) processFetchCycle() error { return fmt.Errorf("no task available") } - // run task + // run task - now using worker pool instead of unlimited goroutines if err := svc.runTask(tid); err != nil { // Handle task error t, getErr := svc.GetTaskById(tid) @@ -432,65 +590,64 @@ func (svc *Service) runTask(taskId primitive.ObjectID) (err error) { return err } + // Use worker pool for bounded task execution + return svc.workerPool.SubmitTask(taskId) +} + +// executeTask is the actual task execution logic called by worker pool +func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) { + // attempt to get runner from pool + _, ok := svc.runners.Load(taskId) + if ok { + err = fmt.Errorf("task[%s] already exists", taskId.Hex()) + svc.Errorf("execute task error: %v", err) + return err + } + // create a new task runner r, err := newTaskRunner(taskId, svc) if err != nil { err = fmt.Errorf("failed to create task runner: %v", err) - svc.Errorf("run task error: %v", err) + svc.Errorf("execute task error: %v", err) return err } // add runner to pool svc.addRunner(taskId, r) - // create a goroutine to run task with proper cleanup - go func() { - defer func() { - if rec := recover(); rec != nil { - svc.Errorf("task[%s] panic recovered: %v", taskId.Hex(), rec) - } - // Always cleanup runner from pool - svc.deleteRunner(taskId) - }() - - // Create task-specific context for better cancellation control - taskCtx, taskCancel := context.WithCancel(svc.ctx) - defer taskCancel() - - // get subscription stream with retry logic - stopCh := make(chan struct{}, 1) - stream, err := svc.subscribeTaskWithRetry(taskCtx, r.GetTaskId(), 3) - if err == nil { - // create a goroutine to handle stream messages - go svc.handleStreamMessages(r.GetTaskId(), stream, stopCh) - } else { - svc.Errorf("failed to subscribe task[%s] after retries: %v", r.GetTaskId().Hex(), err) - svc.Warnf("task[%s] will not be able to receive stream messages", r.GetTaskId().Hex()) - } - - // run task process (blocking) error or finish after task runner ends - if err := r.Run(); err != nil { - switch { - case errors.Is(err, constants.ErrTaskError): - svc.Errorf("task[%s] finished with error: %v", r.GetTaskId().Hex(), err) - case errors.Is(err, constants.ErrTaskCancelled): - svc.Infof("task[%s] cancelled", r.GetTaskId().Hex()) - default: - svc.Errorf("task[%s] finished with unknown error: %v", r.GetTaskId().Hex(), err) - } - } else { - svc.Infof("task[%s] finished successfully", r.GetTaskId().Hex()) - } - - // send stopCh signal to stream message handler - select { - case stopCh <- struct{}{}: - default: - // Channel already closed or full + // Ensure cleanup always happens + defer func() { + if rec := recover(); rec != nil { + svc.Errorf("task[%s] panic recovered: %v", taskId.Hex(), rec) } + // Always cleanup runner from pool and stream + svc.deleteRunner(taskId) + svc.streamManager.RemoveTaskStream(taskId) }() - return nil + // Add task to stream manager for cancellation support + if err := svc.streamManager.AddTaskStream(r.GetTaskId()); err != nil { + svc.Warnf("failed to add task[%s] to stream manager: %v", r.GetTaskId().Hex(), err) + svc.Warnf("task[%s] will not be able to receive cancellation messages", r.GetTaskId().Hex()) + } else { + svc.Debugf("task[%s] added to stream manager for cancellation support", r.GetTaskId().Hex()) + } + + // run task process (blocking) error or finish after task runner ends + if err := r.Run(); err != nil { + switch { + case errors.Is(err, constants.ErrTaskError): + svc.Errorf("task[%s] finished with error: %v", r.GetTaskId().Hex(), err) + case errors.Is(err, constants.ErrTaskCancelled): + svc.Infof("task[%s] cancelled", r.GetTaskId().Hex()) + default: + svc.Errorf("task[%s] finished with unknown error: %v", r.GetTaskId().Hex(), err) + } + } else { + svc.Infof("task[%s] finished successfully", r.GetTaskId().Hex()) + } + + return err } // subscribeTask attempts to subscribe to task stream @@ -542,6 +699,77 @@ func (svc *Service) subscribeTaskWithRetry(ctx context.Context, taskId primitive return nil, fmt.Errorf("failed to subscribe after %d retries: %w", maxRetries, err) } +func (svc *Service) handleStreamMessagesSync(ctx context.Context, taskId primitive.ObjectID, stream grpc.TaskService_SubscribeClient) { + defer func() { + if r := recover(); r != nil { + svc.Errorf("handleStreamMessagesSync[%s] panic recovered: %v", taskId.Hex(), r) + } + // Ensure stream is properly closed + if stream != nil { + if err := stream.CloseSend(); err != nil { + svc.Debugf("task[%s] failed to close stream: %v", taskId.Hex(), err) + } + } + }() + + svc.Debugf("task[%s] starting synchronous stream message handling", taskId.Hex()) + + for { + select { + case <-ctx.Done(): + svc.Debugf("task[%s] stream handler stopped by context", taskId.Hex()) + return + case <-svc.ctx.Done(): + svc.Debugf("task[%s] stream handler stopped by service context", taskId.Hex()) + return + default: + // Set a reasonable timeout for stream receive + recvCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + + // Create a done channel to handle the recv operation + done := make(chan struct { + msg *grpc.TaskServiceSubscribeResponse + err error + }, 1) + + // Use a goroutine only for the blocking recv call, but ensure it's cleaned up + go func() { + defer cancel() + msg, err := stream.Recv() + select { + case done <- struct { + msg *grpc.TaskServiceSubscribeResponse + err error + }{msg, err}: + case <-recvCtx.Done(): + // Timeout occurred, abandon this receive + } + }() + + select { + case result := <-done: + cancel() // Clean up the context + if result.err != nil { + if errors.Is(result.err, io.EOF) { + svc.Infof("task[%s] received EOF, stream closed", taskId.Hex()) + return + } + svc.Errorf("task[%s] stream error: %v", taskId.Hex(), result.err) + return + } + svc.processStreamMessage(taskId, result.msg) + case <-recvCtx.Done(): + cancel() + // Timeout on receive - continue to next iteration + svc.Debugf("task[%s] stream receive timeout", taskId.Hex()) + case <-ctx.Done(): + cancel() + return + } + } + } +} + func (svc *Service) handleStreamMessages(taskId primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) { defer func() { if r := recover(); r != nil { @@ -714,12 +942,26 @@ func (svc *Service) stopAllRunners() { return true }) - // Cancel all runners with timeout + // Cancel all runners with bounded concurrency to prevent goroutine explosion + const maxConcurrentCancellations = 10 var wg sync.WaitGroup + semaphore := make(chan struct{}, maxConcurrentCancellations) + for _, taskId := range runnerIds { wg.Add(1) + + // Acquire semaphore to limit concurrent cancellations + semaphore <- struct{}{} + go func(tid primitive.ObjectID) { - defer wg.Done() + defer func() { + <-semaphore // Release semaphore + wg.Done() + if r := recover(); r != nil { + svc.Errorf("stopAllRunners panic for task[%s]: %v", tid.Hex(), r) + } + }() + if err := svc.cancelTask(tid, false); err != nil { svc.Errorf("failed to cancel task[%s]: %v", tid.Hex(), err) // Force cancel after timeout @@ -745,7 +987,15 @@ func (svc *Service) stopAllRunners() { } func (svc *Service) startStuckTaskCleanup() { + svc.wg.Add(1) // Track this goroutine in the WaitGroup go func() { + defer svc.wg.Done() // Ensure WaitGroup is decremented + defer func() { + if r := recover(); r != nil { + svc.Errorf("startStuckTaskCleanup panic recovered: %v", r) + } + }() + ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes defer ticker.Stop() @@ -817,6 +1067,222 @@ func (svc *Service) checkAndCleanupStuckTasks() { } } +func NewStreamManager(service *Service) *StreamManager { + ctx, cancel := context.WithCancel(context.Background()) + return &StreamManager{ + ctx: ctx, + cancel: cancel, + service: service, + messageQueue: make(chan *StreamMessage, 100), // Buffered channel for messages + maxStreams: 50, // Limit concurrent streams + } +} + +func (sm *StreamManager) Start() { + sm.wg.Add(2) + go sm.messageProcessor() + go sm.streamCleaner() +} + +func (sm *StreamManager) Stop() { + sm.cancel() + close(sm.messageQueue) + + // Close all active streams + sm.streams.Range(func(key, value interface{}) bool { + if ts, ok := value.(*TaskStream); ok { + ts.Close() + } + return true + }) + + sm.wg.Wait() +} + +func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error { + // Check if we're at the stream limit + streamCount := 0 + sm.streams.Range(func(key, value interface{}) bool { + streamCount++ + return true + }) + + if streamCount >= sm.maxStreams { + return fmt.Errorf("stream limit reached (%d)", sm.maxStreams) + } + + // Create new stream + stream, err := sm.service.subscribeTask(taskId) + if err != nil { + return fmt.Errorf("failed to subscribe to task stream: %v", err) + } + + ctx, cancel := context.WithCancel(sm.ctx) + taskStream := &TaskStream{ + taskId: taskId, + stream: stream, + ctx: ctx, + cancel: cancel, + lastActive: time.Now(), + } + + sm.streams.Store(taskId, taskStream) + + // Start listening for messages in a single goroutine per stream + sm.wg.Add(1) + go sm.streamListener(taskStream) + + return nil +} + +func (sm *StreamManager) RemoveTaskStream(taskId primitive.ObjectID) { + if value, ok := sm.streams.LoadAndDelete(taskId); ok { + if ts, ok := value.(*TaskStream); ok { + ts.Close() + } + } +} + +func (sm *StreamManager) streamListener(ts *TaskStream) { + defer sm.wg.Done() + defer func() { + if r := recover(); r != nil { + sm.service.Errorf("stream listener panic for task[%s]: %v", ts.taskId.Hex(), r) + } + ts.Close() + sm.streams.Delete(ts.taskId) + }() + + sm.service.Debugf("stream listener started for task[%s]", ts.taskId.Hex()) + + for { + select { + case <-ts.ctx.Done(): + sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex()) + return + case <-sm.ctx.Done(): + return + default: + msg, err := ts.stream.Recv() + + if err != nil { + if errors.Is(err, io.EOF) { + sm.service.Debugf("stream EOF for task[%s]", ts.taskId.Hex()) + return + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + sm.service.Errorf("stream error for task[%s]: %v", ts.taskId.Hex(), err) + return + } + + // Update last active time + ts.mu.Lock() + ts.lastActive = time.Now() + ts.mu.Unlock() + + // Send message to processor + select { + case sm.messageQueue <- &StreamMessage{ + taskId: ts.taskId, + msg: msg, + err: nil, + }: + case <-ts.ctx.Done(): + return + case <-sm.ctx.Done(): + return + default: + sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex()) + } + } + } +} + +func (sm *StreamManager) messageProcessor() { + defer sm.wg.Done() + defer func() { + if r := recover(); r != nil { + sm.service.Errorf("message processor panic: %v", r) + } + }() + + sm.service.Debugf("stream message processor started") + + for { + select { + case <-sm.ctx.Done(): + sm.service.Debugf("stream message processor shutting down") + return + case msg, ok := <-sm.messageQueue: + if !ok { + return + } + sm.processMessage(msg) + } + } +} + +func (sm *StreamManager) processMessage(streamMsg *StreamMessage) { + if streamMsg.err != nil { + sm.service.Errorf("stream message error for task[%s]: %v", streamMsg.taskId.Hex(), streamMsg.err) + return + } + + // Process the actual message + sm.service.processStreamMessage(streamMsg.taskId, streamMsg.msg) +} + +func (sm *StreamManager) streamCleaner() { + defer sm.wg.Done() + defer func() { + if r := recover(); r != nil { + sm.service.Errorf("stream cleaner panic: %v", r) + } + }() + + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-sm.ctx.Done(): + return + case <-ticker.C: + sm.cleanupInactiveStreams() + } + } +} + +func (sm *StreamManager) cleanupInactiveStreams() { + now := time.Now() + inactiveThreshold := 10 * time.Minute + + sm.streams.Range(func(key, value interface{}) bool { + taskId := key.(primitive.ObjectID) + ts := value.(*TaskStream) + + ts.mu.RLock() + lastActive := ts.lastActive + ts.mu.RUnlock() + + if now.Sub(lastActive) > inactiveThreshold { + sm.service.Debugf("cleaning up inactive stream for task[%s]", taskId.Hex()) + sm.RemoveTaskStream(taskId) + } + + return true + }) +} + +func (ts *TaskStream) Close() { + ts.cancel() + if ts.stream != nil { + _ = ts.stream.CloseSend() + } +} + func newTaskHandlerService() *Service { // service svc := &Service{ @@ -824,6 +1290,7 @@ func newTaskHandlerService() *Service { fetchTimeout: 15 * time.Second, reportInterval: 5 * time.Second, cancelTimeout: 60 * time.Second, + maxWorkers: utils.GetTaskWorkers(), // Use configurable worker count mu: sync.RWMutex{}, runners: sync.Map{}, Logger: utils.NewLogger("TaskHandlerService"), @@ -835,6 +1302,9 @@ func newTaskHandlerService() *Service { // grpc client svc.c = grpcclient.GetGrpcClient() + // initialize stream manager + svc.streamManager = NewStreamManager(svc) + return svc } diff --git a/core/task/scheduler/service.go b/core/task/scheduler/service.go index 8f75337d..6b42da37 100644 --- a/core/task/scheduler/service.go +++ b/core/task/scheduler/service.go @@ -28,16 +28,63 @@ type Service struct { // settings interval time.Duration - // internals + // internals for lifecycle management + ctx context.Context + cancel context.CancelFunc + stopped bool + mu sync.RWMutex + wg sync.WaitGroup + interfaces.Logger } func (svc *Service) Start() { + // Initialize context for graceful shutdown + svc.ctx, svc.cancel = context.WithCancel(context.Background()) + + // Start background goroutines with proper tracking + svc.wg.Add(2) go svc.initTaskStatus() go svc.cleanupTasks() + + svc.Infof("Task scheduler service started") utils.DefaultWait() } +func (svc *Service) Stop() { + svc.mu.Lock() + if svc.stopped { + svc.mu.Unlock() + return + } + svc.stopped = true + svc.mu.Unlock() + + svc.Infof("Stopping task scheduler service...") + + // Cancel context to signal all goroutines to stop + if svc.cancel != nil { + svc.cancel() + } + + // Wait for all background goroutines to finish + done := make(chan struct{}) + go func() { + svc.wg.Wait() + close(done) + }() + + // Give goroutines time to finish gracefully, then force stop + select { + case <-done: + svc.Infof("All goroutines stopped gracefully") + case <-time.After(30 * time.Second): + svc.Warnf("Some goroutines did not stop gracefully within timeout") + } + + svc.Infof("Task scheduler service stopped") +} + func (svc *Service) Enqueue(t *models.Task, by primitive.ObjectID) (t2 *models.Task, err error) { // set task status t.Status = constants.TaskStatusPending @@ -192,6 +239,13 @@ func (svc *Service) SaveTask(t *models.Task, by primitive.ObjectID) (err error) // initTaskStatus initialize task status of existing tasks func (svc *Service) initTaskStatus() { + defer svc.wg.Done() + defer func() { + if r := recover(); r != nil { + svc.Errorf("initTaskStatus panic recovered: %v", r) + } + }() + // set status of running tasks as TaskStatusAbnormal runningTasks, err := service.NewModelService[models.Task]().GetMany(bson.M{ "status": bson.M{ @@ -209,15 +263,37 @@ func (svc *Service) initTaskStatus() { svc.Errorf("failed to get running tasks: %v", err) return } + + // Use bounded worker pool for task updates + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) // Limit concurrent updates + for _, t := range runningTasks { - go func(t *models.Task) { - t.Status = constants.TaskStatusAbnormal - if err := svc.SaveTask(t, primitive.NilObjectID); err != nil { - svc.Errorf("failed to set task status as TaskStatusAbnormal: %s", t.Id.Hex()) - return - } - }(&t) + select { + case <-svc.ctx.Done(): + svc.Infof("initTaskStatus stopped by context") + return + case semaphore <- struct{}{}: + wg.Add(1) + go func(task models.Task) { + defer func() { + <-semaphore + wg.Done() + if r := recover(); r != nil { + svc.Errorf("task status update panic for task %s: %v", task.Id.Hex(), r) + } + }() + + task.Status = constants.TaskStatusAbnormal + if err := svc.SaveTask(&task, primitive.NilObjectID); err != nil { + svc.Errorf("failed to set task status as TaskStatusAbnormal: %s", task.Id.Hex()) + } + }(t) + } } + + wg.Wait() + svc.Infof("initTaskStatus completed") } func (svc *Service) isMasterNode(t *models.Task) (ok bool, err error) { @@ -239,41 +315,67 @@ func (svc *Service) isMasterNode(t *models.Task) (ok bool, err error) { } func (svc *Service) cleanupTasks() { + defer svc.wg.Done() + defer func() { + if r := recover(); r != nil { + svc.Errorf("cleanupTasks panic recovered: %v", r) + } + }() + + ticker := time.NewTicker(30 * time.Minute) + defer ticker.Stop() + for { - // task stats over 30 days ago - taskStats, err := service.NewModelService[models.TaskStat]().GetMany(bson.M{ - "created_at": bson.M{ - "$lt": time.Now().Add(-30 * 24 * time.Hour), - }, - }, nil) - if err != nil { - time.Sleep(30 * time.Minute) - continue + select { + case <-svc.ctx.Done(): + svc.Infof("cleanupTasks stopped by context") + return + case <-ticker.C: + svc.performCleanup() + } + } +} + +func (svc *Service) performCleanup() { + defer func() { + if r := recover(); r != nil { + svc.Errorf("performCleanup panic recovered: %v", r) + } + }() + + // task stats over 30 days ago + taskStats, err := service.NewModelService[models.TaskStat]().GetMany(bson.M{ + "created_at": bson.M{ + "$lt": time.Now().Add(-30 * 24 * time.Hour), + }, + }, nil) + if err != nil { + svc.Errorf("failed to get old task stats: %v", err) + return + } + + // task ids + var ids []primitive.ObjectID + for _, ts := range taskStats { + ids = append(ids, ts.Id) + } + + if len(ids) > 0 { + // remove tasks + if err := service.NewModelService[models.Task]().DeleteMany(bson.M{ + "_id": bson.M{"$in": ids}, + }); err != nil { + svc.Warnf("failed to remove tasks: %v", err) } - // task ids - var ids []primitive.ObjectID - for _, ts := range taskStats { - ids = append(ids, ts.Id) + // remove task stats + if err := service.NewModelService[models.TaskStat]().DeleteMany(bson.M{ + "_id": bson.M{"$in": ids}, + }); err != nil { + svc.Warnf("failed to remove task stats: %v", err) } - if len(ids) > 0 { - // remove tasks - if err := service.NewModelService[models.Task]().DeleteMany(bson.M{ - "_id": bson.M{"$in": ids}, - }); err != nil { - svc.Warnf("failed to remove tasks: %v", err) - } - - // remove task stats - if err := service.NewModelService[models.TaskStat]().DeleteMany(bson.M{ - "_id": bson.M{"$in": ids}, - }); err != nil { - svc.Warnf("failed to remove task stats: %v", err) - } - } - - time.Sleep(30 * time.Minute) + svc.Infof("cleaned up %d old tasks", len(ids)) } } diff --git a/core/utils/config.go b/core/utils/config.go index e1c57911..806af447 100644 --- a/core/utils/config.go +++ b/core/utils/config.go @@ -28,7 +28,8 @@ const ( DefaultApiAllowHeaders = "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With" DefaultApiPort = 8080 DefaultApiPath = "/api" - DefaultNodeMaxRunners = 0 // 0 means no limit + DefaultNodeMaxRunners = 0 // 0 means no limit + DefaultTaskWorkers = 30 // Default number of task workers DefaultInstallRoot = "/app/install" DefaultInstallEnvs = "" MetadataConfigDirName = ".crawlab" @@ -241,6 +242,13 @@ func GetNodeMaxRunners() int { return DefaultNodeMaxRunners } +func GetTaskWorkers() int { + if res := viper.GetInt("task.workers"); res != 0 { + return res + } + return DefaultTaskWorkers +} + func GetMetadataConfigPath() string { var homeDirPath, err = homedir.Dir() if err != nil { diff --git a/frontend/crawlab-ui/src/components/core/metric/MetricMonitoringDetail.vue b/frontend/crawlab-ui/src/components/core/metric/MetricMonitoringDetail.vue index 940646d2..c47341ee 100644 --- a/frontend/crawlab-ui/src/components/core/metric/MetricMonitoringDetail.vue +++ b/frontend/crawlab-ui/src/components/core/metric/MetricMonitoringDetail.vue @@ -30,8 +30,8 @@ const t = translate; const { activeId } = useDetail(props.ns); const timeRange = ref(props.defaultTimeRange || '1h'); -const timeRanges = ['1h', '24h', '7d', '30d']; -const timeUnits = ['5m', '1h', '6h', '1d']; +const timeRanges = ['15m', '1h', '24h', '7d', '30d']; +const timeUnits = ['1m', '5m', '1h', '6h', '1d']; const timeRangeOptions = computed(() => { return timeRanges.map((value, index) => { const label = t('components.metric.timeRanges.' + value); diff --git a/frontend/crawlab-ui/src/i18n/lang/en/components/metric.ts b/frontend/crawlab-ui/src/i18n/lang/en/components/metric.ts index 4bde5e85..a2c7df42 100644 --- a/frontend/crawlab-ui/src/i18n/lang/en/components/metric.ts +++ b/frontend/crawlab-ui/src/i18n/lang/en/components/metric.ts @@ -37,6 +37,7 @@ const metric: LComponentsMetric = { y: 'yr', }, timeRanges: { + '15m': 'Past 15 Minutes', '1h': 'Past 1 Hour', '24h': 'Past 24 Hours', '7d': 'Past 7 Days', diff --git a/frontend/crawlab-ui/src/i18n/lang/zh/components/metric.ts b/frontend/crawlab-ui/src/i18n/lang/zh/components/metric.ts index a97c17e8..6a91a425 100644 --- a/frontend/crawlab-ui/src/i18n/lang/zh/components/metric.ts +++ b/frontend/crawlab-ui/src/i18n/lang/zh/components/metric.ts @@ -37,6 +37,7 @@ const metric: LComponentsMetric = { y: '年', }, timeRanges: { + '15m': '过去 15 分钟', '1h': '过去 1 小时', '24h': '过去 24 小时', '7d': '过去 7 天', diff --git a/frontend/crawlab-ui/src/interfaces/i18n/components/metric.d.ts b/frontend/crawlab-ui/src/interfaces/i18n/components/metric.d.ts index 84daecea..1262d0b1 100644 --- a/frontend/crawlab-ui/src/interfaces/i18n/components/metric.d.ts +++ b/frontend/crawlab-ui/src/interfaces/i18n/components/metric.d.ts @@ -37,6 +37,7 @@ interface LComponentsMetric { y: string; }; timeRanges: { + '15m': string; '1h': string; '24h': string; '7d': string;