refactor: improve goroutine management and context handling in task and stream operations; ensure graceful shutdown and prevent leaks

This commit is contained in:
Marvin Zhang
2025-08-07 00:16:46 +08:00
parent 784ffc8b52
commit 44dd68918f
5 changed files with 209 additions and 94 deletions

View File

@@ -214,20 +214,18 @@ func DeleteSpiderById(_ *gin.Context, params *DeleteByIdParams) (response *Respo
return GetErrorResponse[models.Spider](err)
}
// Delete spider directory synchronously to prevent goroutine leaks
if !s.GitId.IsZero() {
go func() {
// delete spider directory
fsSvc, err := getSpiderFsSvcById(s.Id)
if err != nil {
logger.Errorf("failed to get spider fs service: %v", err)
return
}
// delete spider directory
fsSvc, err := getSpiderFsSvcById(s.Id)
if err != nil {
logger.Errorf("failed to get spider fs service: %v", err)
} else {
err = fsSvc.Delete(".")
if err != nil {
logger.Errorf("failed to delete spider directory: %v", err)
return
}
}()
}
}
return GetDataResponse(models.Spider{})
@@ -323,34 +321,39 @@ func DeleteSpiderList(_ *gin.Context, params *DeleteSpiderListParams) (response
return GetErrorResponse[models.Spider](err)
}
// Delete spider directories
go func() {
wg := sync.WaitGroup{}
wg.Add(len(spiders))
for i := range spiders {
go func(s *models.Spider) {
defer wg.Done()
// Skip spider with git
if !s.GitId.IsZero() {
return
}
// Delete spider directory
fsSvc, err := getSpiderFsSvcById(s.Id)
if err != nil {
logger.Errorf("failed to get spider fs service: %v", err)
return
}
err = fsSvc.Delete(".")
if err != nil {
logger.Errorf("failed to delete spider directory: %v", err)
return
}
}(&spiders[i])
// Delete spider directories synchronously to prevent goroutine leaks
wg := sync.WaitGroup{}
semaphore := make(chan struct{}, 5) // Limit concurrent operations
for i := range spiders {
// Skip spider with git
if !spiders[i].GitId.IsZero() {
continue
}
wg.Wait()
}()
wg.Add(1)
semaphore <- struct{}{} // Acquire semaphore
func(s *models.Spider) {
defer func() {
<-semaphore // Release semaphore
wg.Done()
}()
// Delete spider directory
fsSvc, err := getSpiderFsSvcById(s.Id)
if err != nil {
logger.Errorf("failed to get spider fs service: %v", err)
return
}
err = fsSvc.Delete(".")
if err != nil {
logger.Errorf("failed to delete spider directory: %v", err)
return
}
}(&spiders[i])
}
wg.Wait()
return GetDataResponse(models.Spider{})
}

View File

@@ -112,6 +112,9 @@ type GrpcClient struct {
healthClient grpc_health_v1.HealthClient
healthCheckEnabled bool
healthCheckMux sync.RWMutex
// Goroutine management
wg sync.WaitGroup
}
func (c *GrpcClient) Start() {
@@ -131,11 +134,18 @@ func (c *GrpcClient) Start() {
// Don't fatal here, let reconnection handle it
}
// start monitoring after connection attempt
go c.monitorState()
// start monitoring after connection attempt with proper tracking
c.wg.Add(2) // Track both monitoring goroutines
go func() {
defer c.wg.Done()
c.monitorState()
}()
// start health monitoring
go c.startHealthMonitor()
go func() {
defer c.wg.Done()
c.startHealthMonitor()
}()
})
}
@@ -157,6 +167,21 @@ func (c *GrpcClient) Stop() error {
default:
}
// Wait for goroutines to finish
done := make(chan struct{})
go func() {
c.wg.Wait()
close(done)
}()
// Give goroutines time to finish gracefully, then force stop
select {
case <-done:
c.Debugf("all goroutines stopped gracefully")
case <-time.After(10 * time.Second):
c.Warnf("some goroutines did not stop gracefully within timeout")
}
// Close connection
if c.conn != nil {
if err := c.conn.Close(); err != nil {

View File

@@ -59,7 +59,10 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
svr.Infof("task stream opened: %s", taskId.Hex())
// add stream
// Create a context based on client stream
ctx := stream.Context()
// add stream and track cancellation function
taskServiceMutex.Lock()
svr.subs[taskId] = stream
taskServiceMutex.Unlock()
@@ -72,22 +75,34 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
svr.Infof("task stream closed: %s", taskId.Hex())
}()
// wait for stream to close with timeout protection
ctx := stream.Context()
// send periodic heartbeat to detect client disconnection and check for task completion
heartbeatTicker := time.NewTicker(10 * time.Second) // More frequent for faster completion detection
defer heartbeatTicker.Stop()
// Create a context with timeout to prevent indefinite hanging
timeoutCtx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
defer cancel()
for {
select {
case <-ctx.Done():
// Stream context cancelled normally (client disconnected or task finished)
svr.Debugf("task stream context done: %s", taskId.Hex())
return ctx.Err()
select {
case <-ctx.Done():
// Stream context cancelled normally
svr.Debugf("task stream context done: %s", taskId.Hex())
return ctx.Err()
case <-timeoutCtx.Done():
// Timeout reached - this prevents indefinite hanging
svr.Warnf("task stream timeout reached for task: %s", taskId.Hex())
return errors.New("stream timeout")
case <-heartbeatTicker.C:
// Check if task has finished and close stream if so
if svr.isTaskFinished(taskId) {
svr.Infof("task[%s] finished, closing stream", taskId.Hex())
return nil
}
// Check if the context is still valid
select {
case <-ctx.Done():
svr.Debugf("task stream context cancelled during heartbeat check: %s", taskId.Hex())
return ctx.Err()
default:
// Context is still valid, continue
svr.Debugf("task stream heartbeat check passed: %s", taskId.Hex())
}
}
}
}
@@ -471,6 +486,18 @@ func (svr *TaskServiceServer) Stop() error {
return nil
}
// isTaskFinished checks if a task has completed execution
func (svr TaskServiceServer) isTaskFinished(taskId primitive.ObjectID) bool {
task, err := service.NewModelService[models.Task]().GetById(taskId)
if err != nil {
svr.Debugf("error checking task[%s] status: %v", taskId.Hex(), err)
return false
}
// Task is finished if it's not in pending or running state
return task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning
}
var _taskServiceServer *TaskServiceServer
var _taskServiceServerOnce sync.Once

View File

@@ -11,6 +11,7 @@ import (
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/grpc"
"go.mongodb.org/mongo-driver/bson/primitive"
grpc2 "google.golang.org/grpc"
)
// Service operations for task management
@@ -92,10 +93,8 @@ func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) {
return err
}
// subscribeTask attempts to subscribe to task stream
func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// subscribeTaskWithContext attempts to subscribe to task stream with provided context
func (svc *Service) subscribeTaskWithContext(ctx context.Context, taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
req := &grpc.TaskServiceSubscribeRequest{
TaskId: taskId.Hex(),
}
@@ -103,7 +102,13 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe
if err != nil {
return nil, fmt.Errorf("failed to get task client: %v", err)
}
stream, err = taskClient.Subscribe(ctx, req)
// Use call options to ensure proper cancellation behavior
opts := []grpc2.CallOption{
grpc2.WaitForReady(false), // Don't wait for connection if not ready
}
stream, err = taskClient.Subscribe(ctx, req, opts...)
if err != nil {
svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
return nil, err

View File

@@ -73,24 +73,30 @@ func (sm *StreamManager) Stop() {
}
func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
// Check if we're at the stream limit
streamCount := 0
sm.streams.Range(func(key, value interface{}) bool {
streamCount++
return true
})
// Check if stream already exists
if _, exists := sm.streams.Load(taskId); exists {
sm.service.Debugf("stream already exists for task[%s], skipping", taskId.Hex())
return nil
}
// Check if we're at the stream limit
streamCount := sm.getStreamCount()
if streamCount >= sm.maxStreams {
sm.service.Warnf("stream limit reached (%d/%d), rejecting new stream for task[%s]",
streamCount, sm.maxStreams, taskId.Hex())
return fmt.Errorf("stream limit reached (%d)", sm.maxStreams)
}
// Create new stream
stream, err := sm.service.subscribeTask(taskId)
// Create a context for this specific stream that can be cancelled
ctx, cancel := context.WithCancel(sm.ctx)
// Create stream with the cancellable context
stream, err := sm.service.subscribeTaskWithContext(ctx, taskId)
if err != nil {
cancel() // Clean up the context if stream creation fails
return fmt.Errorf("failed to subscribe to task stream: %v", err)
}
ctx, cancel := context.WithCancel(sm.ctx)
taskStream := &TaskStream{
taskId: taskId,
stream: stream,
@@ -100,6 +106,8 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
}
sm.streams.Store(taskId, taskStream)
sm.service.Infof("created stream for task[%s], total streams: %d/%d",
taskId.Hex(), streamCount+1, sm.maxStreams)
// Start listening for messages in a single goroutine per stream
sm.wg.Add(1)
@@ -111,11 +119,21 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
func (sm *StreamManager) RemoveTaskStream(taskId primitive.ObjectID) {
if value, ok := sm.streams.LoadAndDelete(taskId); ok {
if ts, ok := value.(*TaskStream); ok {
sm.service.Debugf("stream removed, total streams: %d/%d", sm.getStreamCount(), sm.maxStreams)
ts.Close()
}
}
}
func (sm *StreamManager) getStreamCount() int {
streamCount := 0
sm.streams.Range(func(key, value interface{}) bool {
streamCount++
return true
})
return streamCount
}
func (sm *StreamManager) streamListener(ts *TaskStream) {
defer sm.wg.Done()
defer func() {
@@ -124,6 +142,7 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
}
ts.Close()
sm.streams.Delete(ts.taskId)
sm.service.Debugf("stream listener finished cleanup for task[%s]", ts.taskId.Hex())
}()
sm.service.Debugf("stream listener started for task[%s]", ts.taskId.Hex())
@@ -134,40 +153,67 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
return
case <-sm.ctx.Done():
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
return
default:
msg, err := ts.stream.Recv()
// Use a timeout wrapper to handle cases where Recv() might hang
resultChan := make(chan struct {
msg *grpc.TaskServiceSubscribeResponse
err error
}, 1)
if err != nil {
if errors.Is(err, io.EOF) {
sm.service.Debugf("stream EOF for task[%s]", ts.taskId.Hex())
return
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
sm.service.Errorf("stream error for task[%s]: %v", ts.taskId.Hex(), err)
return
}
// Start receive operation in a separate goroutine
go func() {
msg, err := ts.stream.Recv()
resultChan <- struct {
msg *grpc.TaskServiceSubscribeResponse
err error
}{msg, err}
}()
// Update last active time
ts.mu.Lock()
ts.lastActive = time.Now()
ts.mu.Unlock()
// Send message to processor
// Wait for result, timeout, or cancellation
select {
case sm.messageQueue <- &StreamMessage{
taskId: ts.taskId,
msg: msg,
err: nil,
}:
case result := <-resultChan:
if result.err != nil {
if errors.Is(result.err, io.EOF) {
sm.service.Debugf("stream EOF for task[%s] - server closed stream", ts.taskId.Hex())
return
}
if errors.Is(result.err, context.Canceled) || errors.Is(result.err, context.DeadlineExceeded) {
sm.service.Debugf("stream context cancelled for task[%s]", ts.taskId.Hex())
return
}
sm.service.Debugf("stream error for task[%s]: %v - likely server closed", ts.taskId.Hex(), result.err)
return
}
// Update last active time
ts.mu.Lock()
ts.lastActive = time.Now()
ts.mu.Unlock()
// Send message to processor (non-blocking)
select {
case sm.messageQueue <- &StreamMessage{
taskId: ts.taskId,
msg: result.msg,
err: nil,
}:
case <-ts.ctx.Done():
return
case <-sm.ctx.Done():
return
default:
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
}
case <-ts.ctx.Done():
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
return
case <-sm.ctx.Done():
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
return
default:
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
}
}
}
@@ -250,8 +296,17 @@ func (sm *StreamManager) cleanupInactiveStreams() {
}
func (ts *TaskStream) Close() {
// Cancel the context first - this should interrupt any blocking operations
ts.cancel()
if ts.stream != nil {
_ = ts.stream.CloseSend()
// Try to close send direction
err := ts.stream.CloseSend()
if err != nil {
fmt.Printf("failed to close stream send for task[%s]: %v\n", ts.taskId.Hex(), err)
}
// Note: The stream.Recv() should now fail with context.Canceled
// due to the cancelled context passed to subscribeTaskWithContext
}
}