mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-28 17:50:56 +01:00
refactor: improve goroutine management and context handling in task and stream operations; ensure graceful shutdown and prevent leaks
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
grpc2 "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Service operations for task management
|
||||
@@ -92,10 +93,8 @@ func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// subscribeTask attempts to subscribe to task stream
|
||||
func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// subscribeTaskWithContext attempts to subscribe to task stream with provided context
|
||||
func (svc *Service) subscribeTaskWithContext(ctx context.Context, taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
req := &grpc.TaskServiceSubscribeRequest{
|
||||
TaskId: taskId.Hex(),
|
||||
}
|
||||
@@ -103,7 +102,13 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get task client: %v", err)
|
||||
}
|
||||
stream, err = taskClient.Subscribe(ctx, req)
|
||||
|
||||
// Use call options to ensure proper cancellation behavior
|
||||
opts := []grpc2.CallOption{
|
||||
grpc2.WaitForReady(false), // Don't wait for connection if not ready
|
||||
}
|
||||
|
||||
stream, err = taskClient.Subscribe(ctx, req, opts...)
|
||||
if err != nil {
|
||||
svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
|
||||
return nil, err
|
||||
|
||||
@@ -73,24 +73,30 @@ func (sm *StreamManager) Stop() {
|
||||
}
|
||||
|
||||
func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
|
||||
// Check if we're at the stream limit
|
||||
streamCount := 0
|
||||
sm.streams.Range(func(key, value interface{}) bool {
|
||||
streamCount++
|
||||
return true
|
||||
})
|
||||
// Check if stream already exists
|
||||
if _, exists := sm.streams.Load(taskId); exists {
|
||||
sm.service.Debugf("stream already exists for task[%s], skipping", taskId.Hex())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we're at the stream limit
|
||||
streamCount := sm.getStreamCount()
|
||||
if streamCount >= sm.maxStreams {
|
||||
sm.service.Warnf("stream limit reached (%d/%d), rejecting new stream for task[%s]",
|
||||
streamCount, sm.maxStreams, taskId.Hex())
|
||||
return fmt.Errorf("stream limit reached (%d)", sm.maxStreams)
|
||||
}
|
||||
|
||||
// Create new stream
|
||||
stream, err := sm.service.subscribeTask(taskId)
|
||||
// Create a context for this specific stream that can be cancelled
|
||||
ctx, cancel := context.WithCancel(sm.ctx)
|
||||
|
||||
// Create stream with the cancellable context
|
||||
stream, err := sm.service.subscribeTaskWithContext(ctx, taskId)
|
||||
if err != nil {
|
||||
cancel() // Clean up the context if stream creation fails
|
||||
return fmt.Errorf("failed to subscribe to task stream: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(sm.ctx)
|
||||
taskStream := &TaskStream{
|
||||
taskId: taskId,
|
||||
stream: stream,
|
||||
@@ -100,6 +106,8 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
|
||||
}
|
||||
|
||||
sm.streams.Store(taskId, taskStream)
|
||||
sm.service.Infof("created stream for task[%s], total streams: %d/%d",
|
||||
taskId.Hex(), streamCount+1, sm.maxStreams)
|
||||
|
||||
// Start listening for messages in a single goroutine per stream
|
||||
sm.wg.Add(1)
|
||||
@@ -111,11 +119,21 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
|
||||
func (sm *StreamManager) RemoveTaskStream(taskId primitive.ObjectID) {
|
||||
if value, ok := sm.streams.LoadAndDelete(taskId); ok {
|
||||
if ts, ok := value.(*TaskStream); ok {
|
||||
sm.service.Debugf("stream removed, total streams: %d/%d", sm.getStreamCount(), sm.maxStreams)
|
||||
ts.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StreamManager) getStreamCount() int {
|
||||
streamCount := 0
|
||||
sm.streams.Range(func(key, value interface{}) bool {
|
||||
streamCount++
|
||||
return true
|
||||
})
|
||||
return streamCount
|
||||
}
|
||||
|
||||
func (sm *StreamManager) streamListener(ts *TaskStream) {
|
||||
defer sm.wg.Done()
|
||||
defer func() {
|
||||
@@ -124,6 +142,7 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
|
||||
}
|
||||
ts.Close()
|
||||
sm.streams.Delete(ts.taskId)
|
||||
sm.service.Debugf("stream listener finished cleanup for task[%s]", ts.taskId.Hex())
|
||||
}()
|
||||
|
||||
sm.service.Debugf("stream listener started for task[%s]", ts.taskId.Hex())
|
||||
@@ -134,40 +153,67 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
|
||||
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
case <-sm.ctx.Done():
|
||||
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
default:
|
||||
msg, err := ts.stream.Recv()
|
||||
// Use a timeout wrapper to handle cases where Recv() might hang
|
||||
resultChan := make(chan struct {
|
||||
msg *grpc.TaskServiceSubscribeResponse
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
sm.service.Debugf("stream EOF for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
sm.service.Errorf("stream error for task[%s]: %v", ts.taskId.Hex(), err)
|
||||
return
|
||||
}
|
||||
// Start receive operation in a separate goroutine
|
||||
go func() {
|
||||
msg, err := ts.stream.Recv()
|
||||
resultChan <- struct {
|
||||
msg *grpc.TaskServiceSubscribeResponse
|
||||
err error
|
||||
}{msg, err}
|
||||
}()
|
||||
|
||||
// Update last active time
|
||||
ts.mu.Lock()
|
||||
ts.lastActive = time.Now()
|
||||
ts.mu.Unlock()
|
||||
|
||||
// Send message to processor
|
||||
// Wait for result, timeout, or cancellation
|
||||
select {
|
||||
case sm.messageQueue <- &StreamMessage{
|
||||
taskId: ts.taskId,
|
||||
msg: msg,
|
||||
err: nil,
|
||||
}:
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
if errors.Is(result.err, io.EOF) {
|
||||
sm.service.Debugf("stream EOF for task[%s] - server closed stream", ts.taskId.Hex())
|
||||
return
|
||||
}
|
||||
if errors.Is(result.err, context.Canceled) || errors.Is(result.err, context.DeadlineExceeded) {
|
||||
sm.service.Debugf("stream context cancelled for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
}
|
||||
sm.service.Debugf("stream error for task[%s]: %v - likely server closed", ts.taskId.Hex(), result.err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update last active time
|
||||
ts.mu.Lock()
|
||||
ts.lastActive = time.Now()
|
||||
ts.mu.Unlock()
|
||||
|
||||
// Send message to processor (non-blocking)
|
||||
select {
|
||||
case sm.messageQueue <- &StreamMessage{
|
||||
taskId: ts.taskId,
|
||||
msg: result.msg,
|
||||
err: nil,
|
||||
}:
|
||||
case <-ts.ctx.Done():
|
||||
return
|
||||
case <-sm.ctx.Done():
|
||||
return
|
||||
default:
|
||||
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
|
||||
}
|
||||
|
||||
case <-ts.ctx.Done():
|
||||
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
|
||||
case <-sm.ctx.Done():
|
||||
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
default:
|
||||
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -250,8 +296,17 @@ func (sm *StreamManager) cleanupInactiveStreams() {
|
||||
}
|
||||
|
||||
func (ts *TaskStream) Close() {
|
||||
// Cancel the context first - this should interrupt any blocking operations
|
||||
ts.cancel()
|
||||
|
||||
if ts.stream != nil {
|
||||
_ = ts.stream.CloseSend()
|
||||
// Try to close send direction
|
||||
err := ts.stream.CloseSend()
|
||||
if err != nil {
|
||||
fmt.Printf("failed to close stream send for task[%s]: %v\n", ts.taskId.Hex(), err)
|
||||
}
|
||||
|
||||
// Note: The stream.Recv() should now fail with context.Canceled
|
||||
// due to the cancelled context passed to subscribeTaskWithContext
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user