refactor: improve goroutine management and context handling in task and stream operations; ensure graceful shutdown and prevent leaks

This commit is contained in:
Marvin Zhang
2025-08-07 00:16:46 +08:00
parent 784ffc8b52
commit 44dd68918f
5 changed files with 209 additions and 94 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/grpc"
"go.mongodb.org/mongo-driver/bson/primitive"
grpc2 "google.golang.org/grpc"
)
// Service operations for task management
@@ -92,10 +93,8 @@ func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) {
return err
}
// subscribeTask attempts to subscribe to task stream
func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// subscribeTaskWithContext attempts to subscribe to task stream with provided context
func (svc *Service) subscribeTaskWithContext(ctx context.Context, taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
req := &grpc.TaskServiceSubscribeRequest{
TaskId: taskId.Hex(),
}
@@ -103,7 +102,13 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe
if err != nil {
return nil, fmt.Errorf("failed to get task client: %v", err)
}
stream, err = taskClient.Subscribe(ctx, req)
// Use call options to ensure proper cancellation behavior
opts := []grpc2.CallOption{
grpc2.WaitForReady(false), // Don't wait for connection if not ready
}
stream, err = taskClient.Subscribe(ctx, req, opts...)
if err != nil {
svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
return nil, err

View File

@@ -73,24 +73,30 @@ func (sm *StreamManager) Stop() {
}
func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
// Check if we're at the stream limit
streamCount := 0
sm.streams.Range(func(key, value interface{}) bool {
streamCount++
return true
})
// Check if stream already exists
if _, exists := sm.streams.Load(taskId); exists {
sm.service.Debugf("stream already exists for task[%s], skipping", taskId.Hex())
return nil
}
// Check if we're at the stream limit
streamCount := sm.getStreamCount()
if streamCount >= sm.maxStreams {
sm.service.Warnf("stream limit reached (%d/%d), rejecting new stream for task[%s]",
streamCount, sm.maxStreams, taskId.Hex())
return fmt.Errorf("stream limit reached (%d)", sm.maxStreams)
}
// Create new stream
stream, err := sm.service.subscribeTask(taskId)
// Create a context for this specific stream that can be cancelled
ctx, cancel := context.WithCancel(sm.ctx)
// Create stream with the cancellable context
stream, err := sm.service.subscribeTaskWithContext(ctx, taskId)
if err != nil {
cancel() // Clean up the context if stream creation fails
return fmt.Errorf("failed to subscribe to task stream: %v", err)
}
ctx, cancel := context.WithCancel(sm.ctx)
taskStream := &TaskStream{
taskId: taskId,
stream: stream,
@@ -100,6 +106,8 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
}
sm.streams.Store(taskId, taskStream)
sm.service.Infof("created stream for task[%s], total streams: %d/%d",
taskId.Hex(), streamCount+1, sm.maxStreams)
// Start listening for messages in a single goroutine per stream
sm.wg.Add(1)
@@ -111,11 +119,21 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
func (sm *StreamManager) RemoveTaskStream(taskId primitive.ObjectID) {
if value, ok := sm.streams.LoadAndDelete(taskId); ok {
if ts, ok := value.(*TaskStream); ok {
sm.service.Debugf("stream removed, total streams: %d/%d", sm.getStreamCount(), sm.maxStreams)
ts.Close()
}
}
}
func (sm *StreamManager) getStreamCount() int {
streamCount := 0
sm.streams.Range(func(key, value interface{}) bool {
streamCount++
return true
})
return streamCount
}
func (sm *StreamManager) streamListener(ts *TaskStream) {
defer sm.wg.Done()
defer func() {
@@ -124,6 +142,7 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
}
ts.Close()
sm.streams.Delete(ts.taskId)
sm.service.Debugf("stream listener finished cleanup for task[%s]", ts.taskId.Hex())
}()
sm.service.Debugf("stream listener started for task[%s]", ts.taskId.Hex())
@@ -134,40 +153,67 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
return
case <-sm.ctx.Done():
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
return
default:
msg, err := ts.stream.Recv()
// Use a timeout wrapper to handle cases where Recv() might hang
resultChan := make(chan struct {
msg *grpc.TaskServiceSubscribeResponse
err error
}, 1)
if err != nil {
if errors.Is(err, io.EOF) {
sm.service.Debugf("stream EOF for task[%s]", ts.taskId.Hex())
return
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
sm.service.Errorf("stream error for task[%s]: %v", ts.taskId.Hex(), err)
return
}
// Start receive operation in a separate goroutine
go func() {
msg, err := ts.stream.Recv()
resultChan <- struct {
msg *grpc.TaskServiceSubscribeResponse
err error
}{msg, err}
}()
// Update last active time
ts.mu.Lock()
ts.lastActive = time.Now()
ts.mu.Unlock()
// Send message to processor
// Wait for result, timeout, or cancellation
select {
case sm.messageQueue <- &StreamMessage{
taskId: ts.taskId,
msg: msg,
err: nil,
}:
case result := <-resultChan:
if result.err != nil {
if errors.Is(result.err, io.EOF) {
sm.service.Debugf("stream EOF for task[%s] - server closed stream", ts.taskId.Hex())
return
}
if errors.Is(result.err, context.Canceled) || errors.Is(result.err, context.DeadlineExceeded) {
sm.service.Debugf("stream context cancelled for task[%s]", ts.taskId.Hex())
return
}
sm.service.Debugf("stream error for task[%s]: %v - likely server closed", ts.taskId.Hex(), result.err)
return
}
// Update last active time
ts.mu.Lock()
ts.lastActive = time.Now()
ts.mu.Unlock()
// Send message to processor (non-blocking)
select {
case sm.messageQueue <- &StreamMessage{
taskId: ts.taskId,
msg: result.msg,
err: nil,
}:
case <-ts.ctx.Done():
return
case <-sm.ctx.Done():
return
default:
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
}
case <-ts.ctx.Done():
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
return
case <-sm.ctx.Done():
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
return
default:
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
}
}
}
@@ -250,8 +296,17 @@ func (sm *StreamManager) cleanupInactiveStreams() {
}
func (ts *TaskStream) Close() {
// Cancel the context first - this should interrupt any blocking operations
ts.cancel()
if ts.stream != nil {
_ = ts.stream.CloseSend()
// Try to close send direction
err := ts.stream.CloseSend()
if err != nil {
fmt.Printf("failed to close stream send for task[%s]: %v\n", ts.taskId.Hex(), err)
}
// Note: The stream.Recv() should now fail with context.Canceled
// due to the cancelled context passed to subscribeTaskWithContext
}
}