From 8357dc6f301dfedac4e1eaee4c2003027a95ca37 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Fri, 22 Nov 2024 13:32:27 +0800 Subject: [PATCH] fix: runner cancel issue --- core/sys_exec/sys_exec.go | 17 ++++------ core/task/handler/runner.go | 26 +++++++++------ core/task/handler/runner_test.go | 57 ++++++++++++++++++++++++++------ core/utils/process.go | 44 ++++++++++++++++++++++-- 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/core/sys_exec/sys_exec.go b/core/sys_exec/sys_exec.go index fb22f32d..3a320759 100644 --- a/core/sys_exec/sys_exec.go +++ b/core/sys_exec/sys_exec.go @@ -1,7 +1,7 @@ package sys_exec import ( - "github.com/crawlab-team/crawlab/trace" + "github.com/apex/log" "github.com/shirou/gopsutil/process" "os/exec" ) @@ -17,19 +17,15 @@ func KillProcess(cmd *exec.Cmd, opts *KillProcessOptions) error { return err } - // kill function - killFunc := func(p *process.Process) error { - return killProcessRecursive(p, opts.Force) - } - - // without timeout - return killFunc(p) + // kill process + return killProcessRecursive(p, opts.Force) } func killProcessRecursive(p *process.Process, force bool) (err error) { // children processes cps, err := p.Children() if err != nil { + log.Errorf("failed to get children processes: %v", err) return killProcess(p, force) } @@ -40,7 +36,7 @@ func killProcessRecursive(p *process.Process, force bool) (err error) { } } - return nil + return killProcess(p, force) } func killProcess(p *process.Process, force bool) (err error) { @@ -50,7 +46,8 @@ func killProcess(p *process.Process, force bool) (err error) { err = p.Terminate() } if err != nil { - return trace.TraceError(err) + log.Errorf("failed to kill process (force: %v): %v", force, err) + return err } return nil } diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 51f1a397..9937a874 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -6,8 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/crawlab-team/crawlab/core/fs" - "github.com/hashicorp/go-multierror" "io" "net/http" "os" @@ -17,6 +15,9 @@ import ( "sync" "time" + "github.com/crawlab-team/crawlab/core/fs" + "github.com/hashicorp/go-multierror" + "github.com/crawlab-team/crawlab/core/models/models" "github.com/apex/log" @@ -187,19 +188,22 @@ func (r *Runner) Cancel(force bool) (err error) { return err } - // Wait for process to be killed and goroutines to stop - ticker := time.NewTicker(time.Second) + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetCancelTimeout()) + defer cancel() + + // Wait for process to be killed with context + ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() + for { select { - case <-ticker.C: - if utils.ProcessIdExists(r.pid) { - continue - } - return nil - case <-time.After(r.svc.GetCancelTimeout()): - // timeout + case <-ctx.Done(): return fmt.Errorf("timeout waiting for task to stop") + case <-ticker.C: + if !utils.ProcessIdExists(r.pid) { + return nil + } } } } diff --git a/core/task/handler/runner_test.go b/core/task/handler/runner_test.go index 1323526d..b17113f6 100644 --- a/core/task/handler/runner_test.go +++ b/core/task/handler/runner_test.go @@ -1,14 +1,17 @@ package handler import ( + "bufio" "encoding/json" "fmt" - "github.com/apex/log" - "github.com/crawlab-team/crawlab/core/utils" "io" + "runtime" "testing" "time" + "github.com/apex/log" + "github.com/crawlab-team/crawlab/core/utils" + "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/models/models" "github.com/crawlab-team/crawlab/core/models/service" @@ -38,7 +41,12 @@ func setupTest(t *testing.T) *Runner { Type: "test", Mode: "test", NodeId: primitive.NewObjectID(), - Cmd: "python script.py", + } + switch runtime.GOOS { + case "windows": + task.Cmd = "ping -n 10 127.0.0.1" + default: // linux and darwin (macOS) + task.Cmd = "sleep 10" } taskId, err := service.NewModelService[models.Task]().InsertOne(*task) require.NoError(t, err) @@ -119,21 +127,50 @@ func TestRunner_Cancel(t *testing.T) { // Setup runner := setupTest(t) - // Start a long-running command - runner.t.Cmd = "sleep 10" + // Create pipes for stdout + pr, pw := io.Pipe() + runner.cmd.Stdout = pw + runner.cmd.Stderr = pw + + // Start the command err := runner.cmd.Start() assert.NoError(t, err) + log.Infof("started process with PID: %d", runner.cmd.Process.Pid) runner.pid = runner.cmd.Process.Pid + // Read and print command output + go func() { + scanner := bufio.NewScanner(pr) + for scanner.Scan() { + log.Info(scanner.Text()) + } + }() + + // Wait a bit longer on Windows for the process to start properly + waitTime := 100 * time.Millisecond + if runtime.GOOS == "windows" { + waitTime = 1 * time.Second + } + time.Sleep(waitTime) + + // Verify process exists before attempting to cancel + if !utils.ProcessIdExists(runner.pid) { + t.Fatalf("Process with PID %d was not started successfully", runner.pid) + } + // Test cancel go func() { err = runner.Cancel(true) assert.NoError(t, err) }() - // Verify process was killed - // Wait a short time for the process to be killed - time.Sleep(100 * time.Millisecond) - exists := utils.ProcessIdExists(runner.pid) - assert.False(t, exists) + // Wait for process to be killed, with shorter timeout + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if !utils.ProcessIdExists(runner.pid) { + return // Process was killed + } + time.Sleep(100 * time.Millisecond) + } + t.Errorf("Process with PID %d was not killed within timeout", runner.pid) } diff --git a/core/utils/process.go b/core/utils/process.go index 70e868bd..a30095d1 100644 --- a/core/utils/process.go +++ b/core/utils/process.go @@ -1,6 +1,8 @@ package utils import ( + "errors" + "github.com/apex/log" "github.com/crawlab-team/crawlab/trace" "github.com/shirou/gopsutil/process" "os/exec" @@ -9,8 +11,42 @@ import ( ) func ProcessIdExists(pid int) (ok bool) { - ok, _ = process.PidExists(int32(pid)) + //// Find process by pid + //p, err := os.FindProcess(pid) + //if err != nil { + // // Process not found + // return false + //} + // + //// Check if process exists + //err = p.Signal(syscall.Signal(0)) + //if err == nil { + // // Process exists + // return true + //} + // + //// Process not found + //return false + + ok, err := process.PidExists(int32(pid)) + if err != nil { + log.Errorf("error checking if process exists: %v", err) + } return ok + + //processIds, err := process.Pids() + //if err != nil { + // log.Errorf("error getting process pids: %v", err) + // return false + //} + // + //for _, _pid := range processIds { + // if int(_pid) == pid { + // return true + // } + //} + // + //return false } func ListProcess(text string) (lines []string, err error) { @@ -24,7 +60,8 @@ func ListProcess(text string) (lines []string, err error) { func listProcessWindow(text string) (lines []string, err error) { cmd := exec.Command("tasklist", "/fi", text) out, err := cmd.CombinedOutput() - _, ok := err.(*exec.ExitError) + var exitError *exec.ExitError + ok := errors.As(err, &exitError) if !ok { return nil, trace.TraceError(err) } @@ -35,7 +72,8 @@ func listProcessWindow(text string) (lines []string, err error) { func listProcessLinuxMac(text string) (lines []string, err error) { cmd := exec.Command("ps", "aux") out, err := cmd.CombinedOutput() - _, ok := err.(*exec.ExitError) + var exitError *exec.ExitError + ok := errors.As(err, &exitError) if !ok { return nil, trace.TraceError(err) }