From 6912b9250147ed9224d6f7b040b3d2802613c946 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Thu, 7 Aug 2025 15:40:48 +0800 Subject: [PATCH] refactor: enhance context handling across task runner and service components; ensure proper cancellation chains and prevent goroutine leaks --- core/task/handler/runner.go | 18 ++++--- core/task/handler/runner_config.go | 67 ++++++++++++++++++++----- core/task/handler/runner_ipc.go | 4 +- core/task/handler/service.go | 7 ++- core/task/handler/service_operations.go | 6 +-- core/task/handler/stream_manager.go | 3 +- core/task/handler/worker_pool.go | 3 +- core/task/scheduler/service.go | 4 +- 8 files changed, 83 insertions(+), 29 deletions(-) diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 2c57eab5..a741db34 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -72,8 +72,8 @@ func newTaskRunner(id primitive.ObjectID, svc *Service) (r *Runner, err error) { } } - // Initialize context and done channel - r.ctx, r.cancel = context.WithCancel(context.Background()) + // Initialize context and done channel - use service context for proper cancellation chain + r.ctx, r.cancel = context.WithCancel(svc.ctx) r.done = make(chan struct{}) // initialize task runner @@ -297,7 +297,7 @@ func (r *Runner) Cancel(force bool) (err error) { force = true } else { // Wait for graceful termination with shorter timeout - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(r.ctx, 15*time.Second) defer cancel() ticker := time.NewTicker(500 * time.Millisecond) @@ -329,7 +329,7 @@ forceKill: } // Wait for process to be killed with timeout - ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetCancelTimeout()) + ctx, cancel := context.WithTimeout(r.ctx, r.svc.GetCancelTimeout()) defer cancel() ticker := time.NewTicker(100 * time.Millisecond) @@ -731,11 +731,17 @@ func (r *Runner) sendNotification() { r.Errorf("failed to get task client: %v", err) return } - ctx, cancel := context.WithTimeout(r.ctx, 10*time.Second) + + // Use independent context for async notification - prevents cancellation due to task lifecycle + // This ensures notifications are sent even if the task runner is being cleaned up + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + _, err = taskClient.SendNotification(ctx, req) if err != nil { - r.Errorf("error sending notification: %v", err) + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + r.Errorf("error sending notification: %v", err) + } return } } diff --git a/core/task/handler/runner_config.go b/core/task/handler/runner_config.go index 627c65af..6d887b47 100644 --- a/core/task/handler/runner_config.go +++ b/core/task/handler/runner_config.go @@ -23,27 +23,40 @@ func (r *Runner) configurePythonPath() { pyenvBinPath := pyenvRoot + "/bin" // Configure global pyenv path - _ = os.Setenv("PYENV_ROOT", pyenvRoot) - _ = os.Setenv("PATH", pyenvShimsPath+":"+os.Getenv("PATH")) - _ = os.Setenv("PATH", pyenvBinPath+":"+os.Getenv("PATH")) + r.cmd.Env = append(r.cmd.Env, "PYENV_ROOT="+pyenvRoot) + + // Update PATH with pyenv paths + currentPath := r.getEnvFromCmd("PATH") + if currentPath == "" { + currentPath = os.Getenv("PATH") + } + newPath := pyenvBinPath + ":" + pyenvShimsPath + ":" + currentPath + r.setEnvInCmd("PATH", newPath) } // configureNodePath sets up the Node.js environment paths, handling both nvm and default installations func (r *Runner) configureNodePath() { // Configure nvm-based Node.js paths - envPath := os.Getenv("PATH") + currentPath := r.getEnvFromCmd("PATH") + if currentPath == "" { + currentPath = os.Getenv("PATH") + } // Configure global node_modules path nodePath := utils.GetNodeModulesPath() - if !strings.Contains(envPath, nodePath) { - _ = os.Setenv("PATH", nodePath+":"+envPath) + if !strings.Contains(currentPath, nodePath) { + currentPath = nodePath + ":" + currentPath + r.setEnvInCmd("PATH", currentPath) } - _ = os.Setenv("NODE_PATH", nodePath) + r.cmd.Env = append(r.cmd.Env, "NODE_PATH="+nodePath) // Configure global node_bin path nodeBinPath := utils.GetNodeBinPath() - if !strings.Contains(envPath, nodeBinPath) { - _ = os.Setenv("PATH", nodeBinPath+":"+os.Getenv("PATH")) + // Get the updated PATH after the node_modules path was added + updatedPath := r.getEnvFromCmd("PATH") + if !strings.Contains(updatedPath, nodeBinPath) { + newPath := nodeBinPath + ":" + updatedPath + r.setEnvInCmd("PATH", newPath) } } @@ -51,7 +64,7 @@ func (r *Runner) configureGoPath() { // Configure global go path goPath := utils.GetGoPath() if goPath != "" { - _ = os.Setenv("GOPATH", goPath) + r.cmd.Env = append(r.cmd.Env, "GOPATH="+goPath) } } @@ -60,6 +73,9 @@ func (r *Runner) configureGoPath() { // - Crawlab-specific variables // - Global environment variables from the system func (r *Runner) configureEnv() { + // Default envs - initialize first so configuration functions can modify them + r.cmd.Env = os.Environ() + // Configure Python path r.configurePythonPath() @@ -69,9 +85,6 @@ func (r *Runner) configureEnv() { // Configure Go path r.configureGoPath() - // Default envs - r.cmd.Env = os.Environ() - // Remove CRAWLAB_ prefixed environment variables for i := 0; i < len(r.cmd.Env); i++ { env := r.cmd.Env[i] @@ -177,3 +190,31 @@ func (r *Runner) configureCmd() (err error) { return nil } + +// getEnvFromCmd retrieves an environment variable value from r.cmd.Env +func (r *Runner) getEnvFromCmd(key string) string { + prefix := key + "=" + for _, env := range r.cmd.Env { + if after, ok := strings.CutPrefix(env, prefix); ok { + return after + } + } + return "" +} + +// setEnvInCmd sets or updates an environment variable in r.cmd.Env +func (r *Runner) setEnvInCmd(key, value string) { + envVar := key + "=" + value + prefix := key + "=" + + // Check if the environment variable already exists and update it + for i, env := range r.cmd.Env { + if strings.HasPrefix(env, prefix) { + r.cmd.Env[i] = envVar + return + } + } + + // If not found, append it + r.cmd.Env = append(r.cmd.Env, envVar) +} diff --git a/core/task/handler/runner_ipc.go b/core/task/handler/runner_ipc.go index 667c2d95..56908ea6 100644 --- a/core/task/handler/runner_ipc.go +++ b/core/task/handler/runner_ipc.go @@ -145,8 +145,8 @@ func (r *Runner) handleIPCInsertDataMessage(ipcMsg entity.IPCMessage) { return } - // Send IPC message to master with context and timeout - ctx, cancel := context.WithTimeout(context.Background(), r.ipcTimeout) + // Send IPC message to master with context and timeout - use runner's context + ctx, cancel := context.WithTimeout(r.ctx, r.ipcTimeout) defer cancel() // Create gRPC message diff --git a/core/task/handler/service.go b/core/task/handler/service.go index 6f1abf33..33f3494a 100644 --- a/core/task/handler/service.go +++ b/core/task/handler/service.go @@ -376,6 +376,10 @@ func (svc *Service) updateNodeStatus() (err error) { // set available runners n.CurrentRunners = svc.getRunnerCount() + // Log goroutine count for leak monitoring + currentGoroutines := runtime.NumGoroutine() + svc.Debugf("Node status update - runners: %d, goroutines: %d", n.CurrentRunners, currentGoroutines) + // save node n.SetUpdated(n.CreatedBy) if svc.cfgSvc.IsMaster() { @@ -391,7 +395,8 @@ func (svc *Service) updateNodeStatus() (err error) { } func (svc *Service) fetchTask() (tid primitive.ObjectID, err error) { - ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout) + // Use service context with timeout for fetch operation + ctx, cancel := context.WithTimeout(svc.ctx, svc.fetchTimeout) defer cancel() taskClient, err := svc.c.GetTaskClient() if err != nil { diff --git a/core/task/handler/service_operations.go b/core/task/handler/service_operations.go index 6e9f7468..f9e44089 100644 --- a/core/task/handler/service_operations.go +++ b/core/task/handler/service_operations.go @@ -58,7 +58,7 @@ func (svc *Service) executeTask(taskId primitive.ObjectID) (err error) { // add runner to pool svc.addRunner(taskId, r) - // Ensure cleanup always happens + // Ensure cleanup always happens - CRITICAL for preventing goroutine leaks defer func() { if rec := recover(); rec != nil { svc.Errorf("task[%s] panic recovered: %v", taskId.Hex(), rec) @@ -163,8 +163,8 @@ func (svc *Service) cancelTask(taskId primitive.ObjectID, force bool) (err error return nil } - // Attempt cancellation with timeout - cancelCtx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second) + // Attempt cancellation with timeout - use service context + cancelCtx, cancelFunc := context.WithTimeout(svc.ctx, 30*time.Second) defer cancelFunc() cancelDone := make(chan error, 1) diff --git a/core/task/handler/stream_manager.go b/core/task/handler/stream_manager.go index dbc1169d..d0e8e73a 100644 --- a/core/task/handler/stream_manager.go +++ b/core/task/handler/stream_manager.go @@ -41,7 +41,8 @@ type StreamMessage struct { } func NewStreamManager(service *Service) *StreamManager { - ctx, cancel := context.WithCancel(context.Background()) + // Use service context for proper cancellation chain + ctx, cancel := context.WithCancel(service.ctx) return &StreamManager{ ctx: ctx, cancel: cancel, diff --git a/core/task/handler/worker_pool.go b/core/task/handler/worker_pool.go index 84aa7ebd..1bbb26dc 100644 --- a/core/task/handler/worker_pool.go +++ b/core/task/handler/worker_pool.go @@ -24,7 +24,8 @@ type TaskWorkerPool struct { } func NewTaskWorkerPool(workers int, service *Service) *TaskWorkerPool { - ctx, cancel := context.WithCancel(context.Background()) + // Use service context for proper cancellation chain + ctx, cancel := context.WithCancel(service.ctx) // Use a more generous queue size to handle task bursts // Queue size is workers * 5 to allow for better buffering queueSize := workers * 5 diff --git a/core/task/scheduler/service.go b/core/task/scheduler/service.go index 6b42da37..192fc837 100644 --- a/core/task/scheduler/service.go +++ b/core/task/scheduler/service.go @@ -172,8 +172,8 @@ func (svc *Service) cancelOnWorker(t *models.Task, by primitive.ObjectID, force return svc.SaveTask(t, by) } - // send cancel request with timeout - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // send cancel request with timeout - use service context + ctx, cancel := context.WithTimeout(svc.ctx, 30*time.Second) defer cancel() // Create a channel to handle the send operation