mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
refactor: improve goroutine management and context handling in task and stream operations; ensure graceful shutdown and prevent leaks
This commit is contained in:
@@ -214,20 +214,18 @@ func DeleteSpiderById(_ *gin.Context, params *DeleteByIdParams) (response *Respo
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// Delete spider directory synchronously to prevent goroutine leaks
|
||||
if !s.GitId.IsZero() {
|
||||
go func() {
|
||||
// delete spider directory
|
||||
fsSvc, err := getSpiderFsSvcById(s.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to get spider fs service: %v", err)
|
||||
return
|
||||
}
|
||||
// delete spider directory
|
||||
fsSvc, err := getSpiderFsSvcById(s.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to get spider fs service: %v", err)
|
||||
} else {
|
||||
err = fsSvc.Delete(".")
|
||||
if err != nil {
|
||||
logger.Errorf("failed to delete spider directory: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return GetDataResponse(models.Spider{})
|
||||
@@ -323,34 +321,39 @@ func DeleteSpiderList(_ *gin.Context, params *DeleteSpiderListParams) (response
|
||||
return GetErrorResponse[models.Spider](err)
|
||||
}
|
||||
|
||||
// Delete spider directories
|
||||
go func() {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(spiders))
|
||||
for i := range spiders {
|
||||
go func(s *models.Spider) {
|
||||
defer wg.Done()
|
||||
|
||||
// Skip spider with git
|
||||
if !s.GitId.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete spider directory
|
||||
fsSvc, err := getSpiderFsSvcById(s.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to get spider fs service: %v", err)
|
||||
return
|
||||
}
|
||||
err = fsSvc.Delete(".")
|
||||
if err != nil {
|
||||
logger.Errorf("failed to delete spider directory: %v", err)
|
||||
return
|
||||
}
|
||||
}(&spiders[i])
|
||||
// Delete spider directories synchronously to prevent goroutine leaks
|
||||
wg := sync.WaitGroup{}
|
||||
semaphore := make(chan struct{}, 5) // Limit concurrent operations
|
||||
|
||||
for i := range spiders {
|
||||
// Skip spider with git
|
||||
if !spiders[i].GitId.IsZero() {
|
||||
continue
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{} // Acquire semaphore
|
||||
|
||||
func(s *models.Spider) {
|
||||
defer func() {
|
||||
<-semaphore // Release semaphore
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Delete spider directory
|
||||
fsSvc, err := getSpiderFsSvcById(s.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to get spider fs service: %v", err)
|
||||
return
|
||||
}
|
||||
err = fsSvc.Delete(".")
|
||||
if err != nil {
|
||||
logger.Errorf("failed to delete spider directory: %v", err)
|
||||
return
|
||||
}
|
||||
}(&spiders[i])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return GetDataResponse(models.Spider{})
|
||||
}
|
||||
|
||||
@@ -112,6 +112,9 @@ type GrpcClient struct {
|
||||
healthClient grpc_health_v1.HealthClient
|
||||
healthCheckEnabled bool
|
||||
healthCheckMux sync.RWMutex
|
||||
|
||||
// Goroutine management
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Start() {
|
||||
@@ -131,11 +134,18 @@ func (c *GrpcClient) Start() {
|
||||
// Don't fatal here, let reconnection handle it
|
||||
}
|
||||
|
||||
// start monitoring after connection attempt
|
||||
go c.monitorState()
|
||||
// start monitoring after connection attempt with proper tracking
|
||||
c.wg.Add(2) // Track both monitoring goroutines
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
c.monitorState()
|
||||
}()
|
||||
|
||||
// start health monitoring
|
||||
go c.startHealthMonitor()
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
c.startHealthMonitor()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -157,6 +167,21 @@ func (c *GrpcClient) Stop() error {
|
||||
default:
|
||||
}
|
||||
|
||||
// Wait for goroutines to finish
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give goroutines time to finish gracefully, then force stop
|
||||
select {
|
||||
case <-done:
|
||||
c.Debugf("all goroutines stopped gracefully")
|
||||
case <-time.After(10 * time.Second):
|
||||
c.Warnf("some goroutines did not stop gracefully within timeout")
|
||||
}
|
||||
|
||||
// Close connection
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
|
||||
@@ -59,7 +59,10 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
|
||||
|
||||
svr.Infof("task stream opened: %s", taskId.Hex())
|
||||
|
||||
// add stream
|
||||
// Create a context based on client stream
|
||||
ctx := stream.Context()
|
||||
|
||||
// add stream and track cancellation function
|
||||
taskServiceMutex.Lock()
|
||||
svr.subs[taskId] = stream
|
||||
taskServiceMutex.Unlock()
|
||||
@@ -72,22 +75,34 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
|
||||
svr.Infof("task stream closed: %s", taskId.Hex())
|
||||
}()
|
||||
|
||||
// wait for stream to close with timeout protection
|
||||
ctx := stream.Context()
|
||||
// send periodic heartbeat to detect client disconnection and check for task completion
|
||||
heartbeatTicker := time.NewTicker(10 * time.Second) // More frequent for faster completion detection
|
||||
defer heartbeatTicker.Stop()
|
||||
|
||||
// Create a context with timeout to prevent indefinite hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Stream context cancelled normally (client disconnected or task finished)
|
||||
svr.Debugf("task stream context done: %s", taskId.Hex())
|
||||
return ctx.Err()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Stream context cancelled normally
|
||||
svr.Debugf("task stream context done: %s", taskId.Hex())
|
||||
return ctx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
// Timeout reached - this prevents indefinite hanging
|
||||
svr.Warnf("task stream timeout reached for task: %s", taskId.Hex())
|
||||
return errors.New("stream timeout")
|
||||
case <-heartbeatTicker.C:
|
||||
// Check if task has finished and close stream if so
|
||||
if svr.isTaskFinished(taskId) {
|
||||
svr.Infof("task[%s] finished, closing stream", taskId.Hex())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the context is still valid
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
svr.Debugf("task stream context cancelled during heartbeat check: %s", taskId.Hex())
|
||||
return ctx.Err()
|
||||
default:
|
||||
// Context is still valid, continue
|
||||
svr.Debugf("task stream heartbeat check passed: %s", taskId.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -471,6 +486,18 @@ func (svr *TaskServiceServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTaskFinished checks if a task has completed execution
|
||||
func (svr TaskServiceServer) isTaskFinished(taskId primitive.ObjectID) bool {
|
||||
task, err := service.NewModelService[models.Task]().GetById(taskId)
|
||||
if err != nil {
|
||||
svr.Debugf("error checking task[%s] status: %v", taskId.Hex(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Task is finished if it's not in pending or running state
|
||||
return task.Status != constants.TaskStatusPending && task.Status != constants.TaskStatusRunning
|
||||
}
|
||||
|
||||
var _taskServiceServer *TaskServiceServer
|
||||
var _taskServiceServerOnce sync.Once
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
grpc2 "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Service operations for task management
|
||||
@@ -92,10 +93,8 @@ func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// subscribeTask attempts to subscribe to task stream
|
||||
func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// subscribeTaskWithContext attempts to subscribe to task stream with provided context
|
||||
func (svc *Service) subscribeTaskWithContext(ctx context.Context, taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
req := &grpc.TaskServiceSubscribeRequest{
|
||||
TaskId: taskId.Hex(),
|
||||
}
|
||||
@@ -103,7 +102,13 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get task client: %v", err)
|
||||
}
|
||||
stream, err = taskClient.Subscribe(ctx, req)
|
||||
|
||||
// Use call options to ensure proper cancellation behavior
|
||||
opts := []grpc2.CallOption{
|
||||
grpc2.WaitForReady(false), // Don't wait for connection if not ready
|
||||
}
|
||||
|
||||
stream, err = taskClient.Subscribe(ctx, req, opts...)
|
||||
if err != nil {
|
||||
svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
|
||||
return nil, err
|
||||
|
||||
@@ -73,24 +73,30 @@ func (sm *StreamManager) Stop() {
|
||||
}
|
||||
|
||||
func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
|
||||
// Check if we're at the stream limit
|
||||
streamCount := 0
|
||||
sm.streams.Range(func(key, value interface{}) bool {
|
||||
streamCount++
|
||||
return true
|
||||
})
|
||||
// Check if stream already exists
|
||||
if _, exists := sm.streams.Load(taskId); exists {
|
||||
sm.service.Debugf("stream already exists for task[%s], skipping", taskId.Hex())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we're at the stream limit
|
||||
streamCount := sm.getStreamCount()
|
||||
if streamCount >= sm.maxStreams {
|
||||
sm.service.Warnf("stream limit reached (%d/%d), rejecting new stream for task[%s]",
|
||||
streamCount, sm.maxStreams, taskId.Hex())
|
||||
return fmt.Errorf("stream limit reached (%d)", sm.maxStreams)
|
||||
}
|
||||
|
||||
// Create new stream
|
||||
stream, err := sm.service.subscribeTask(taskId)
|
||||
// Create a context for this specific stream that can be cancelled
|
||||
ctx, cancel := context.WithCancel(sm.ctx)
|
||||
|
||||
// Create stream with the cancellable context
|
||||
stream, err := sm.service.subscribeTaskWithContext(ctx, taskId)
|
||||
if err != nil {
|
||||
cancel() // Clean up the context if stream creation fails
|
||||
return fmt.Errorf("failed to subscribe to task stream: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(sm.ctx)
|
||||
taskStream := &TaskStream{
|
||||
taskId: taskId,
|
||||
stream: stream,
|
||||
@@ -100,6 +106,8 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
|
||||
}
|
||||
|
||||
sm.streams.Store(taskId, taskStream)
|
||||
sm.service.Infof("created stream for task[%s], total streams: %d/%d",
|
||||
taskId.Hex(), streamCount+1, sm.maxStreams)
|
||||
|
||||
// Start listening for messages in a single goroutine per stream
|
||||
sm.wg.Add(1)
|
||||
@@ -111,11 +119,21 @@ func (sm *StreamManager) AddTaskStream(taskId primitive.ObjectID) error {
|
||||
func (sm *StreamManager) RemoveTaskStream(taskId primitive.ObjectID) {
|
||||
if value, ok := sm.streams.LoadAndDelete(taskId); ok {
|
||||
if ts, ok := value.(*TaskStream); ok {
|
||||
sm.service.Debugf("stream removed, total streams: %d/%d", sm.getStreamCount(), sm.maxStreams)
|
||||
ts.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StreamManager) getStreamCount() int {
|
||||
streamCount := 0
|
||||
sm.streams.Range(func(key, value interface{}) bool {
|
||||
streamCount++
|
||||
return true
|
||||
})
|
||||
return streamCount
|
||||
}
|
||||
|
||||
func (sm *StreamManager) streamListener(ts *TaskStream) {
|
||||
defer sm.wg.Done()
|
||||
defer func() {
|
||||
@@ -124,6 +142,7 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
|
||||
}
|
||||
ts.Close()
|
||||
sm.streams.Delete(ts.taskId)
|
||||
sm.service.Debugf("stream listener finished cleanup for task[%s]", ts.taskId.Hex())
|
||||
}()
|
||||
|
||||
sm.service.Debugf("stream listener started for task[%s]", ts.taskId.Hex())
|
||||
@@ -134,40 +153,67 @@ func (sm *StreamManager) streamListener(ts *TaskStream) {
|
||||
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
case <-sm.ctx.Done():
|
||||
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
default:
|
||||
msg, err := ts.stream.Recv()
|
||||
// Use a timeout wrapper to handle cases where Recv() might hang
|
||||
resultChan := make(chan struct {
|
||||
msg *grpc.TaskServiceSubscribeResponse
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
sm.service.Debugf("stream EOF for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
sm.service.Errorf("stream error for task[%s]: %v", ts.taskId.Hex(), err)
|
||||
return
|
||||
}
|
||||
// Start receive operation in a separate goroutine
|
||||
go func() {
|
||||
msg, err := ts.stream.Recv()
|
||||
resultChan <- struct {
|
||||
msg *grpc.TaskServiceSubscribeResponse
|
||||
err error
|
||||
}{msg, err}
|
||||
}()
|
||||
|
||||
// Update last active time
|
||||
ts.mu.Lock()
|
||||
ts.lastActive = time.Now()
|
||||
ts.mu.Unlock()
|
||||
|
||||
// Send message to processor
|
||||
// Wait for result, timeout, or cancellation
|
||||
select {
|
||||
case sm.messageQueue <- &StreamMessage{
|
||||
taskId: ts.taskId,
|
||||
msg: msg,
|
||||
err: nil,
|
||||
}:
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
if errors.Is(result.err, io.EOF) {
|
||||
sm.service.Debugf("stream EOF for task[%s] - server closed stream", ts.taskId.Hex())
|
||||
return
|
||||
}
|
||||
if errors.Is(result.err, context.Canceled) || errors.Is(result.err, context.DeadlineExceeded) {
|
||||
sm.service.Debugf("stream context cancelled for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
}
|
||||
sm.service.Debugf("stream error for task[%s]: %v - likely server closed", ts.taskId.Hex(), result.err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update last active time
|
||||
ts.mu.Lock()
|
||||
ts.lastActive = time.Now()
|
||||
ts.mu.Unlock()
|
||||
|
||||
// Send message to processor (non-blocking)
|
||||
select {
|
||||
case sm.messageQueue <- &StreamMessage{
|
||||
taskId: ts.taskId,
|
||||
msg: result.msg,
|
||||
err: nil,
|
||||
}:
|
||||
case <-ts.ctx.Done():
|
||||
return
|
||||
case <-sm.ctx.Done():
|
||||
return
|
||||
default:
|
||||
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
|
||||
}
|
||||
|
||||
case <-ts.ctx.Done():
|
||||
sm.service.Debugf("stream listener stopped for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
|
||||
case <-sm.ctx.Done():
|
||||
sm.service.Debugf("stream manager shutdown, stopping listener for task[%s]", ts.taskId.Hex())
|
||||
return
|
||||
default:
|
||||
sm.service.Warnf("message queue full, dropping message for task[%s]", ts.taskId.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -250,8 +296,17 @@ func (sm *StreamManager) cleanupInactiveStreams() {
|
||||
}
|
||||
|
||||
func (ts *TaskStream) Close() {
|
||||
// Cancel the context first - this should interrupt any blocking operations
|
||||
ts.cancel()
|
||||
|
||||
if ts.stream != nil {
|
||||
_ = ts.stream.CloseSend()
|
||||
// Try to close send direction
|
||||
err := ts.stream.CloseSend()
|
||||
if err != nil {
|
||||
fmt.Printf("failed to close stream send for task[%s]: %v\n", ts.taskId.Hex(), err)
|
||||
}
|
||||
|
||||
// Note: The stream.Recv() should now fail with context.Canceled
|
||||
// due to the cancelled context passed to subscribeTaskWithContext
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user