diff --git a/core/controllers/spider.go b/core/controllers/spider.go index a4ef5931..44ade81c 100644 --- a/core/controllers/spider.go +++ b/core/controllers/spider.go @@ -214,20 +214,18 @@ func DeleteSpiderById(_ *gin.Context, params *DeleteByIdParams) (response *Respo return GetErrorResponse[models.Spider](err) } + // Delete spider directory synchronously to prevent goroutine leaks if !s.GitId.IsZero() { - go func() { - // delete spider directory - fsSvc, err := getSpiderFsSvcById(s.Id) - if err != nil { - logger.Errorf("failed to get spider fs service: %v", err) - return - } + // delete spider directory + fsSvc, err := getSpiderFsSvcById(s.Id) + if err != nil { + logger.Errorf("failed to get spider fs service: %v", err) + } else { err = fsSvc.Delete(".") if err != nil { logger.Errorf("failed to delete spider directory: %v", err) - return } - }() + } } return GetDataResponse(models.Spider{}) @@ -323,34 +321,39 @@ func DeleteSpiderList(_ *gin.Context, params *DeleteSpiderListParams) (response return GetErrorResponse[models.Spider](err) } - // Delete spider directories - go func() { - wg := sync.WaitGroup{} - wg.Add(len(spiders)) - for i := range spiders { - go func(s *models.Spider) { - defer wg.Done() - - // Skip spider with git - if !s.GitId.IsZero() { - return - } - - // Delete spider directory - fsSvc, err := getSpiderFsSvcById(s.Id) - if err != nil { - logger.Errorf("failed to get spider fs service: %v", err) - return - } - err = fsSvc.Delete(".") - if err != nil { - logger.Errorf("failed to delete spider directory: %v", err) - return - } - }(&spiders[i]) + // Delete spider directories synchronously to prevent goroutine leaks + wg := sync.WaitGroup{} + semaphore := make(chan struct{}, 5) // Limit concurrent operations + + for i := range spiders { + // Skip spider with git + if !spiders[i].GitId.IsZero() { + continue } - wg.Wait() - }() + + wg.Add(1) + semaphore <- struct{}{} // Acquire semaphore + + func(s *models.Spider) { + defer func() { + <-semaphore // Release semaphore + wg.Done() + }() + + // Delete spider directory + fsSvc, err := getSpiderFsSvcById(s.Id) + if err != nil { + logger.Errorf("failed to get spider fs service: %v", err) + return + } + err = fsSvc.Delete(".") + if err != nil { + logger.Errorf("failed to delete spider directory: %v", err) + return + } + }(&spiders[i]) + } + wg.Wait() return GetDataResponse(models.Spider{}) } diff --git a/core/grpc/client/client.go b/core/grpc/client/client.go index 272d2747..806e6ee1 100644 --- a/core/grpc/client/client.go +++ b/core/grpc/client/client.go @@ -112,6 +112,9 @@ type GrpcClient struct { healthClient grpc_health_v1.HealthClient healthCheckEnabled bool healthCheckMux sync.RWMutex + + // Goroutine management + wg sync.WaitGroup } func (c *GrpcClient) Start() { @@ -131,11 +134,18 @@ func (c *GrpcClient) Start() { // Don't fatal here, let reconnection handle it } - // start monitoring after connection attempt - go c.monitorState() + // start monitoring after connection attempt with proper tracking + c.wg.Add(2) // Track both monitoring goroutines + go func() { + defer c.wg.Done() + c.monitorState() + }() // start health monitoring - go c.startHealthMonitor() + go func() { + defer c.wg.Done() + c.startHealthMonitor() + }() }) } @@ -157,6 +167,21 @@ func (c *GrpcClient) Stop() error { default: } + // Wait for goroutines to finish + done := make(chan struct{}) + go func() { + c.wg.Wait() + close(done) + }() + + // Give goroutines time to finish gracefully, then force stop + select { + case <-done: + c.Debugf("all goroutines stopped gracefully") + case <-time.After(10 * time.Second): + c.Warnf("some goroutines did not stop gracefully within timeout") + } + // Close connection if c.conn != nil { if err := c.conn.Close(); err != nil { diff --git a/core/grpc/server/task_service_server.go b/core/grpc/server/task_service_server.go index 859f3422..a8fc43cf 100644 --- a/core/grpc/server/task_service_server.go +++ b/core/grpc/server/task_service_server.go @@ -59,7 +59,10 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st svr.Infof("task stream opened: %s", taskId.Hex()) - // add stream + // Create a context based on client stream + ctx := stream.Context() + + // add stream and track cancellation function taskServiceMutex.Lock() svr.subs[taskId] = stream taskServiceMutex.Unlock() @@ -72,22 +75,34 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st svr.Infof("task stream closed: %s", taskId.Hex()) }() - // wait for stream to close with timeout protection - ctx := stream.Context() + // send periodic heartbeat to detect client disconnection and check for task completion + heartbeatTicker := time.NewTicker(10 * time.Second) // More frequent for faster completion detection + defer heartbeatTicker.Stop() - // Create a context with timeout to prevent indefinite hanging - timeoutCtx, cancel := context.WithTimeout(context.Background(), 24*time.Hour) - defer cancel() + for { + select { + case <-ctx.Done(): + // Stream context cancelled normally (client disconnected or task finished) + svr.Debugf("task stream context done: %s", taskId.Hex()) + return ctx.Err() - select { - case <-ctx.Done(): - // Stream context cancelled normally - svr.Debugf("task stream context done: %s", taskId.Hex()) - return ctx.Err() - case <-timeoutCtx.Done(): - // Timeout reached - this prevents indefinite hanging - svr.Warnf("task stream timeout reached for task: %s", taskId.Hex()) - return errors.New("stream timeout") + case <-heartbeatTicker.C: + // Check if task has finished and close stream if so + if svr.isTaskFinished(taskId) { + svr.Infof("task[%s] finished, closing stream", taskId.Hex()) + return nil + } + + // Check if the context is still valid + select { + case <-ctx.Done(): + svr.Debugf("task stream context cancelled during heartbeat check: %s", taskId.Hex()) + return ctx.Err() + default: + // Context is still valid, continue + svr.Debugf("task stream heartbeat check passed: %s", taskId.Hex()) + } + } } } @@ -471,6 +486,18 @@ func (svr *TaskServiceServer) Stop() error { return nil } +// isTaskFinished checks if a task has completed execution +func (svr TaskServiceServer) isTaskFinished(taskId primitive.ObjectID) bool { + task, err := service.NewModelService[models.Task]().GetById(taskId) + if err != nil { + svr.Debugf("error checking task[%s] status: %v", taskId.Hex(), err) + return false + } + + // Task is finished if it's not in pending or running state + return task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning +} + var _taskServiceServer *TaskServiceServer var _taskServiceServerOnce sync.Once diff --git a/core/task/handler/service_operations.go b/core/task/handler/service_operations.go index a28dedce..6e9f7468 100644 --- a/core/task/handler/service_operations.go +++ b/core/task/handler/service_operations.go @@ -11,6 +11,7 @@ import ( "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/grpc" "go.mongodb.org/mongo-driver/bson/primitive" + grpc2 "google.golang.org/grpc" ) // Service operations for task management @@ -92,10 +93,8 @@ func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) { return err } -// subscribeTask attempts to subscribe to task stream -func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() +// subscribeTaskWithContext attempts to subscribe to task stream with provided context +func (svc *Service) subscribeTaskWithContext(ctx context.Context, taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) { req := &grpc.TaskServiceSubscribeRequest{ TaskId: taskId.Hex(), } @@ -103,7 +102,13 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe if err != nil { return nil, fmt.Errorf("failed to get task client: %v", err) } - stream, err = taskClient.Subscribe(ctx, req) + + // Use call options to ensure proper cancellation behavior + opts := []grpc2.CallOption{ + grpc2.WaitForReady(false), // Don't wait for connection if not ready + } + + stream, err = taskClient.Subscribe(ctx, req, opts...) if err != nil { svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err) return nil, err diff --git a/core/task/handler/stream_manager.go b/core/task/handler/stream_manager.go index 3f7e128f..abd54d86 100644 --- a/core/task/handler/stream_manager.go +++ b/core/task/handler/stream_manager.go @@ -73,24 +73,30 @@ func (sm *StreamManager) Stop() { } func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error { - // Check if we're at the stream limit - streamCount := 0 - sm.streams.Range(func(key, value interface{}) bool { - streamCount++ - return true - }) + // Check if stream already exists + if _, exists := sm.streams.Load(taskId); exists { + sm.service.Debugf("stream already exists for task[%s], skipping", taskId.Hex()) + return nil + } + // Check if we're at the stream limit + streamCount := sm.getStreamCount() if streamCount >= sm.maxStreams { + sm.service.Warnf("stream limit reached (%d/%d), rejecting new stream for task[%s]", + streamCount, sm.maxStreams, taskId.Hex()) return fmt.Errorf("stream limit reached (%d)", sm.maxStreams) } - // Create new stream - stream, err := sm.service.subscribeTask(taskId) + // Create a context for this specific stream that can be cancelled + ctx, cancel := context.WithCancel(sm.ctx) + + // Create stream with the cancellable context + stream, err := sm.service.subscribeTaskWithContext(ctx, taskId) if err != nil { + cancel() // Clean up the context if stream creation fails return fmt.Errorf("failed to subscribe to task stream: %v", err) } - ctx, cancel := context.WithCancel(sm.ctx) taskStream := &TaskStream{ taskId: taskId, stream: stream, @@ -100,6 +106,8 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error { } sm.streams.Store(taskId, taskStream) + sm.service.Infof("created stream for task[%s], total streams: %d/%d", + taskId.Hex(), streamCount+1, sm.maxStreams) // Start listening for messages in a single goroutine per stream sm.wg.Add(1) @@ -111,11 +119,21 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error { func (sm *StreamManager) RemoveTaskStream(taskId primitive.ObjectID) { if value, ok := sm.streams.LoadAndDelete(taskId); ok { if ts, ok := value.(*TaskStream); ok { + sm.service.Debugf("stream removed, total streams: %d/%d", sm.getStreamCount(), sm.maxStreams) ts.Close() } } } +func (sm *StreamManager) getStreamCount() int { + streamCount := 0 + sm.streams.Range(func(key, value interface{}) bool { + streamCount++ + return true + }) + return streamCount +} + func (sm *StreamManager) streamListener(ts *TaskStream) { defer sm.wg.Done() defer func() { @@ -124,6 +142,7 @@ func (sm *StreamManager) streamListener(ts *TaskStream) { } ts.Close() sm.streams.Delete(ts.taskId) + sm.service.Debugf("stream listener finished cleanup for task[%s]", ts.taskId.Hex()) }() sm.service.Debugf("stream listener started for task[%s]", ts.taskId.Hex()) @@ -134,40 +153,67 @@ func (sm *StreamManager) streamListener(ts *TaskStream) { sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex()) return case <-sm.ctx.Done(): + sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex()) return default: - msg, err := ts.stream.Recv() + // Use a timeout wrapper to handle cases where Recv() might hang + resultChan := make(chan struct { + msg *grpc.TaskServiceSubscribeResponse + err error + }, 1) - if err != nil { - if errors.Is(err, io.EOF) { - sm.service.Debugf("stream EOF for task[%s]", ts.taskId.Hex()) - return - } - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return - } - sm.service.Errorf("stream error for task[%s]: %v", ts.taskId.Hex(), err) - return - } + // Start receive operation in a separate goroutine + go func() { + msg, err := ts.stream.Recv() + resultChan <- struct { + msg *grpc.TaskServiceSubscribeResponse + err error + }{msg, err} + }() - // Update last active time - ts.mu.Lock() - ts.lastActive = time.Now() - ts.mu.Unlock() - - // Send message to processor + // Wait for result, timeout, or cancellation select { - case sm.messageQueue <- &StreamMessage{ - taskId: ts.taskId, - msg: msg, - err: nil, - }: + case result := <-resultChan: + if result.err != nil { + if errors.Is(result.err, io.EOF) { + sm.service.Debugf("stream EOF for task[%s] - server closed stream", ts.taskId.Hex()) + return + } + if errors.Is(result.err, context.Canceled) || errors.Is(result.err, context.DeadlineExceeded) { + sm.service.Debugf("stream context cancelled for task[%s]", ts.taskId.Hex()) + return + } + sm.service.Debugf("stream error for task[%s]: %v - likely server closed", ts.taskId.Hex(), result.err) + return + } + + // Update last active time + ts.mu.Lock() + ts.lastActive = time.Now() + ts.mu.Unlock() + + // Send message to processor (non-blocking) + select { + case sm.messageQueue <- &StreamMessage{ + taskId: ts.taskId, + msg: result.msg, + err: nil, + }: + case <-ts.ctx.Done(): + return + case <-sm.ctx.Done(): + return + default: + sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex()) + } + case <-ts.ctx.Done(): + sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex()) return + case <-sm.ctx.Done(): + sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex()) return - default: - sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex()) } } } @@ -250,8 +296,17 @@ func (sm *StreamManager) cleanupInactiveStreams() { } func (ts *TaskStream) Close() { + // Cancel the context first - this should interrupt any blocking operations ts.cancel() + if ts.stream != nil { - _ = ts.stream.CloseSend() + // Try to close send direction + err := ts.stream.CloseSend() + if err != nil { + fmt.Printf("failed to close stream send for task[%s]: %v\n", ts.taskId.Hex(), err) + } + + // Note: The stream.Recv() should now fail with context.Canceled + // due to the cancelled context passed to subscribeTaskWithContext } }