From 9560da66b5ef6c0e0e7cfbf9c7236c3132d2e15a Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Fri, 22 Nov 2024 17:43:59 +0800 Subject: [PATCH] test: fix test cases issues --- core/go.mod | 4 - core/go.sum | 9 -- core/sys_exec/sys_exec.go | 58 ----------- core/sys_exec/sys_exec_darwin.go | 30 ------ core/sys_exec/sys_exec_linux.go | 30 ------ core/sys_exec/sys_exec_windows.go | 18 ---- core/task/handler/runner.go | 13 ++- core/task/handler/runner_test.go | 41 +++++--- core/utils/encrypt_test.go | 4 - core/utils/file.go | 153 +----------------------------- core/utils/file_test.go | 143 ++++++---------------------- core/utils/process.go | 135 ++++++++++++++------------ core/utils/process_test.go | 20 ++++ 13 files changed, 161 insertions(+), 497 deletions(-) delete mode 100644 core/sys_exec/sys_exec.go delete mode 100644 core/sys_exec/sys_exec_darwin.go delete mode 100644 core/sys_exec/sys_exec_linux.go delete mode 100644 core/sys_exec/sys_exec_windows.go create mode 100644 core/utils/process_test.go diff --git a/core/go.mod b/core/go.mod index 7f2b9268..c826bd91 100644 --- a/core/go.mod +++ b/core/go.mod @@ -32,7 +32,6 @@ require ( github.com/pkg/errors v0.9.1 github.com/robfig/cron/v3 v3.0.0 github.com/shirou/gopsutil v3.21.11+incompatible - github.com/smartystreets/goconvey v1.6.4 github.com/spf13/cobra v1.3.0 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 @@ -80,13 +79,11 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.13.0 // indirect - github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/klauspost/compress v1.17.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect @@ -104,7 +101,6 @@ require ( github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect - github.com/smartystreets/assertions v1.0.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect diff --git a/core/go.sum b/core/go.sum index e9811fce..07391107 100644 --- a/core/go.sum +++ b/core/go.sum @@ -300,8 +300,6 @@ github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pf github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM= github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 h1:z53tR0945TRRQO/fLEVPI6SMv7ZflF0TEaTAoU7tOzg= @@ -363,8 +361,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= @@ -491,12 +487,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/skeema/knownhosts v1.2.2 h1:Iug2P4fLmDw9f41PB6thxUkNUkJzB5i+1/exaj40L3A= github.com/skeema/knownhosts v1.2.2/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/assertions v1.0.0 h1:UVQPSSmc3qtTi+zPPkCXvZX9VvW/xT/NsRvKfwY81a8= github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= github.com/smartystreets/go-aws-auth v0.0.0-20180515143844-0c1422d1fdb9/go.mod h1:SnhjPscd9TpLiy1LpzGSKh3bXCfxxXuqd9xmQJy3slM= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4SauJk4cUOwJs= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= @@ -860,7 +852,6 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/core/sys_exec/sys_exec.go b/core/sys_exec/sys_exec.go deleted file mode 100644 index cb472091..00000000 --- a/core/sys_exec/sys_exec.go +++ /dev/null @@ -1,58 +0,0 @@ -package sys_exec - -import ( - "errors" - "github.com/apex/log" - "github.com/shirou/gopsutil/process" - "os/exec" -) - -type KillProcessOptions struct { - Force bool -} - -func KillProcess(cmd *exec.Cmd, opts *KillProcessOptions) error { - // process - p, err := process.NewProcess(int32(cmd.Process.Pid)) - if err != nil { - return err - } - - // kill process - return killProcessRecursive(p, opts.Force) -} - -func killProcessRecursive(p *process.Process, force bool) (err error) { - // children processes - cps, err := p.Children() - if err != nil { - if !errors.Is(err, process.ErrorNoChildren) { - log.Errorf("failed to get children processes: %v", err) - } else if errors.Is(err, process.ErrorProcessNotRunning) { - return nil - } - return killProcess(p, force) - } - - // iterate children processes - for _, cp := range cps { - if err := killProcessRecursive(cp, force); err != nil { - return err - } - } - - return killProcess(p, force) -} - -func killProcess(p *process.Process, force bool) (err error) { - if force { - err = p.Kill() - } else { - err = p.Terminate() - } - if err != nil { - log.Errorf("failed to kill process (force: %v): %v", force, err) - return err - } - return nil -} diff --git a/core/sys_exec/sys_exec_darwin.go b/core/sys_exec/sys_exec_darwin.go deleted file mode 100644 index b6db18c2..00000000 --- a/core/sys_exec/sys_exec_darwin.go +++ /dev/null @@ -1,30 +0,0 @@ -//go:build darwin -// +build darwin - -package sys_exec - -import ( - "errors" - "os/exec" - "strings" - "syscall" -) - -func BuildCmd(cmdStr string) (cmd *exec.Cmd, err error) { - if cmdStr == "" { - return nil, errors.New("command string is empty") - } - args := strings.Split(cmdStr, " ") - return exec.Command(args[0], args[1:]...), nil -} - -func SetPgid(cmd *exec.Cmd) { - if cmd == nil { - return - } - if cmd.SysProcAttr == nil { - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - } else { - cmd.SysProcAttr.Setpgid = true - } -} diff --git a/core/sys_exec/sys_exec_linux.go b/core/sys_exec/sys_exec_linux.go deleted file mode 100644 index 33af73e3..00000000 --- a/core/sys_exec/sys_exec_linux.go +++ /dev/null @@ -1,30 +0,0 @@ -//go:build linux -// +build linux - -package sys_exec - -import ( - "errors" - "os/exec" - "strings" - "syscall" -) - -func BuildCmd(cmdStr string) (cmd *exec.Cmd, err error) { - if cmdStr == "" { - return nil, errors.New("command string is empty") - } - args := strings.Split(cmdStr, " ") - return exec.Command(args[0], args[1:]...), nil -} - -func SetPgid(cmd *exec.Cmd) { - if cmd == nil { - return - } - if cmd.SysProcAttr == nil { - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - } else { - cmd.SysProcAttr.Setpgid = true - } -} diff --git a/core/sys_exec/sys_exec_windows.go b/core/sys_exec/sys_exec_windows.go deleted file mode 100644 index 1f1afd5a..00000000 --- a/core/sys_exec/sys_exec_windows.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build windows -// +build windows - -package sys_exec - -import ( - "errors" - "os/exec" - "strings" -) - -func BuildCmd(cmdStr string) (cmd *exec.Cmd, err error) { - if cmdStr == "" { - return nil, errors.New("command string is empty") - } - args := strings.Split(cmdStr, " ") - return exec.Command(args[0], args[1:]...), nil -} diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 39eeea03..474a2eb3 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -27,7 +27,6 @@ import ( "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/core/models/client" "github.com/crawlab-team/crawlab/core/models/service" - "github.com/crawlab-team/crawlab/core/sys_exec" "github.com/crawlab-team/crawlab/core/utils" "github.com/crawlab-team/crawlab/grpc" "github.com/crawlab-team/crawlab/trace" @@ -180,13 +179,12 @@ func (r *Runner) Cancel(force bool) (err error) { r.cancel() // Kill process - err = sys_exec.KillProcess(r.cmd, &sys_exec.KillProcessOptions{ - Force: force, - }) + err = utils.KillProcess(r.cmd, force) if err != nil { log.Errorf("kill process error: %v", err) return err } + log.Debugf("attempt to kill process[%d]", r.pid) // Create a context with timeout ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetCancelTimeout()) @@ -236,7 +234,7 @@ func (r *Runner) configureCmd() (err error) { } // get cmd instance - r.cmd, err = sys_exec.BuildCmd(cmdStr) + r.cmd, err = utils.BuildCmd(cmdStr) if err != nil { log.Errorf("error building command: %v", err) return err @@ -511,28 +509,33 @@ func (r *Runner) getHttpRequestHeaders() (headers map[string]string) { func (r *Runner) wait() (err error) { // start a goroutine to wait for process to finish go func() { + log.Debugf("waiting for process[%d] to finish", r.pid) err = r.cmd.Wait() if err != nil { var exitError *exec.ExitError if !errors.As(err, &exitError) { r.ch <- constants.TaskSignalError + log.Debugf("process[%d] exited with error: %v", r.pid, err) return } exitCode := exitError.ExitCode() if exitCode == -1 { // cancel error r.ch <- constants.TaskSignalCancel + log.Debugf("process[%d] cancelled", r.pid) return } // standard error r.err = err r.ch <- constants.TaskSignalError + log.Debugf("process[%d] exited with error: %v", r.pid, err) return } // success r.ch <- constants.TaskSignalFinish + log.Debugf("process[%d] exited successfully", r.pid) }() // declare task status diff --git a/core/task/handler/runner_test.go b/core/task/handler/runner_test.go index f38825f3..bb7c1e15 100644 --- a/core/task/handler/runner_test.go +++ b/core/task/handler/runner_test.go @@ -2,6 +2,7 @@ package handler import ( "bufio" + "context" "encoding/json" "fmt" "io" @@ -46,7 +47,7 @@ func setupTest(t *testing.T) *Runner { case "windows": task.Cmd = "ping -n 10 127.0.0.1" default: // linux and darwin (macOS) - task.Cmd = "sleep 10" + task.Cmd = "ping -c 10 127.0.0.1" } taskId, err := service.NewModelService[models.Task]().InsertOne(*task) require.NoError(t, err) @@ -146,12 +147,18 @@ func TestRunner_Cancel(t *testing.T) { } }() - // Wait a bit longer on Windows for the process to start properly - waitTime := 100 * time.Millisecond - if runtime.GOOS == "windows" { - waitTime = 1 * time.Second - } - time.Sleep(waitTime) + // Wait for process to finish + go func() { + err = runner.cmd.Wait() + if err != nil { + log.Errorf("process[%d] exited with error: %v", runner.pid, err) + return + } + log.Infof("process[%d] exited successfully", runner.pid) + }() + + // Wait for a certain period for the process to start properly + time.Sleep(1 * time.Second) // Verify process exists before attempting to cancel if !utils.ProcessIdExists(runner.pid) { @@ -162,17 +169,25 @@ func TestRunner_Cancel(t *testing.T) { go func() { err = runner.Cancel(true) assert.NoError(t, err) + log.Infof("process[%d] cancelled", runner.pid) }() // Wait for process to be killed, with shorter timeout - deadline := time.Now().Add(5 * time.Second) - for time.Now().Before(deadline) { - if !utils.ProcessIdExists(runner.pid) { - return // Process was killed + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + t.Fatalf("Process with PID %d was not killed within timeout", runner.pid) + case <-ticker.C: + exists := utils.ProcessIdExists(runner.pid) + if !exists { + return // Exit the test when process is killed + } } - time.Sleep(100 * time.Millisecond) } - t.Errorf("Process with PID %d was not killed within timeout", runner.pid) } func TestRunner_HandleIPCData(t *testing.T) { diff --git a/core/utils/encrypt_test.go b/core/utils/encrypt_test.go index d393fa5b..e23adb34 100644 --- a/core/utils/encrypt_test.go +++ b/core/utils/encrypt_test.go @@ -1,7 +1,6 @@ package utils import ( - "fmt" "github.com/stretchr/testify/require" "testing" ) @@ -12,9 +11,6 @@ func TestEncryptAesPassword(t *testing.T) { require.Nil(t, err) decryptedText, err := DecryptAES(encryptedText) require.Nil(t, err) - fmt.Println(fmt.Sprintf("plainText: %s", plainText)) - fmt.Println(fmt.Sprintf("encryptedText: %s", encryptedText)) - fmt.Println(fmt.Sprintf("decryptedText: %s", decryptedText)) require.Equal(t, decryptedText, plainText) require.NotEqual(t, decryptedText, encryptedText) } diff --git a/core/utils/file.go b/core/utils/file.go index 9e09d4d5..b75501d6 100644 --- a/core/utils/file.go +++ b/core/utils/file.go @@ -41,9 +41,7 @@ func IsDir(path string) bool { return s.IsDir() } -// ListDir Add: 增加error类型作为第二返回值 -// 在其他函数如 /task/log/file_driver.go中的 *FileLogDriver.cleanup()函数调用时 -// 可以通过判断err是否为nil来判断是否有错误发生 +// ListDir returns a list of files metadata in the directory func ListDir(path string) ([]fs.FileInfo, error) { list, err := os.ReadDir(path) if err != nil { @@ -65,153 +63,8 @@ func ListDir(path string) ([]fs.FileInfo, error) { return res, nil } -func DeCompress(srcFile *os.File, dstPath string) error { - // 如果保存路径不存在,创建一个 - if !Exists(dstPath) { - if err := os.MkdirAll(dstPath, os.ModePerm); err != nil { - debug.PrintStack() - return err - } - } - - // 读取zip文件 - zipFile, err := zip.OpenReader(srcFile.Name()) - if err != nil { - log.Errorf("Unzip File Error:" + err.Error()) - debug.PrintStack() - return err - } - defer Close(zipFile) - - // 遍历zip内所有文件和目录 - for _, innerFile := range zipFile.File { - // 获取该文件数据 - info := innerFile.FileInfo() - - // 如果是目录,则创建一个 - if info.IsDir() { - err = os.MkdirAll(filepath.Join(dstPath, innerFile.Name), os.ModeDir|os.ModePerm) - if err != nil { - log.Errorf("Unzip File Error : " + err.Error()) - debug.PrintStack() - return err - } - continue - } - - // 如果文件目录不存在,则创建一个 - dirPath := filepath.Join(dstPath, filepath.Dir(innerFile.Name)) - if !Exists(dirPath) { - if err = os.MkdirAll(dirPath, os.ModeDir|os.ModePerm); err != nil { - log.Errorf("Unzip File Error : " + err.Error()) - debug.PrintStack() - return err - } - } - - // 打开该文件 - srcFile, err := innerFile.Open() - if err != nil { - log.Errorf("Unzip File Error : " + err.Error()) - debug.PrintStack() - continue - } - - // 创建新文件 - newFilePath := filepath.Join(dstPath, innerFile.Name) - newFile, err := os.OpenFile(newFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, info.Mode()) - if err != nil { - log.Errorf("Unzip File Error : " + err.Error()) - debug.PrintStack() - continue - } - - // 拷贝该文件到新文件中 - if _, err := io.Copy(newFile, srcFile); err != nil { - debug.PrintStack() - return err - } - - // 关闭该文件 - if err := srcFile.Close(); err != nil { - debug.PrintStack() - return err - } - - // 关闭新文件 - if err := newFile.Close(); err != nil { - debug.PrintStack() - return err - } - } - return nil -} - -// Compress 压缩文件 -// files 文件数组,可以是不同dir下的文件或者文件夹 -// dest 压缩文件存放地址 -func Compress(files []*os.File, dest string) error { - d, _ := os.Create(dest) - defer Close(d) - w := zip.NewWriter(d) - defer Close(w) - for _, file := range files { - if err := _Compress(file, "", w); err != nil { - return err - } - } - return nil -} - -func _Compress(file *os.File, prefix string, zw *zip.Writer) error { - info, err := file.Stat() - if err != nil { - debug.PrintStack() - return err - } - if info.IsDir() { - prefix = prefix + "/" + info.Name() - fileInfos, err := file.Readdir(-1) - if err != nil { - debug.PrintStack() - return err - } - for _, fi := range fileInfos { - f, err := os.Open(file.Name() + "/" + fi.Name()) - if err != nil { - debug.PrintStack() - return err - } - err = _Compress(f, prefix, zw) - if err != nil { - debug.PrintStack() - return err - } - } - } else { - header, err := zip.FileInfoHeader(info) - if err != nil { - debug.PrintStack() - return err - } - header.Name = prefix + "/" + header.Name - writer, err := zw.CreateHeader(header) - if err != nil { - debug.PrintStack() - return err - } - _, err = io.Copy(writer, file) - Close(file) - if err != nil { - debug.PrintStack() - return err - } - } - return nil -} - -func ZipDirectory(dir, zipfile string) error { - zipFile, err := os.Create(zipfile) +func ZipDirectory(dir, filePath string) error { + zipFile, err := os.Create(filePath) if err != nil { return err } diff --git a/core/utils/file_test.go b/core/utils/file_test.go index 4af32d0d..6e15ff02 100644 --- a/core/utils/file_test.go +++ b/core/utils/file_test.go @@ -1,129 +1,40 @@ package utils import ( - "archive/zip" - . "github.com/smartystreets/goconvey/convey" - "io" - "log" - "os" - "runtime/debug" "testing" + + "github.com/stretchr/testify/assert" ) func TestExists(t *testing.T) { - var pathString = "../config" - var wrongPathString = "test" + // Test cases + pathString := "../config" + wrongPathString := "test" - Convey("Test path or file is Exists or not", t, func() { - res := Exists(pathString) - Convey("The result should be true", func() { - So(res, ShouldEqual, true) - }) - wrongRes := Exists(wrongPathString) - Convey("The result should be false", func() { - So(wrongRes, ShouldEqual, false) - }) - }) + // Test existing path + res := Exists(pathString) + assert.True(t, res, "Expected existing path to return true") + + // Test non-existing path + wrongRes := Exists(wrongPathString) + assert.False(t, wrongRes, "Expected non-existing path to return false") } func TestIsDir(t *testing.T) { - var pathString = "../config" - var fileString = "../config/config.go" - var wrongString = "test" - - Convey("Test path is folder or not", t, func() { - res := IsDir(pathString) - So(res, ShouldEqual, true) - fileRes := IsDir(fileString) - So(fileRes, ShouldEqual, false) - wrongRes := IsDir(wrongString) - So(wrongRes, ShouldEqual, false) - }) -} - -func TestCompress(t *testing.T) { - err := os.Mkdir("testCompress", os.ModePerm) - if err != nil { - t.Error("create testCompress failed") - } - var pathString = "testCompress" - var files []*os.File - var disPath = "testCompress" - file, err := os.Open(pathString) - if err != nil { - t.Error("open source path failed") - } - files = append(files, file) - Convey("Verify dispath is valid path", t, func() { - er := Compress(files, disPath) - Convey("err should be nil", func() { - So(er, ShouldEqual, nil) - }) - }) - _ = os.RemoveAll("testCompress") - -} -func Zip(zipFile string, fileList []string) error { - // 创建 zip 包文件 - fw, err := os.Create(zipFile) - if err != nil { - log.Fatal() - } - defer Close(fw) - - // 实例化新的 zip.Writer - zw := zip.NewWriter(fw) - defer Close(zw) - - for _, fileName := range fileList { - fr, err := os.Open(fileName) - if err != nil { - return err - } - fi, err := fr.Stat() - if err != nil { - return err - } - // 写入文件的头信息 - fh, err := zip.FileInfoHeader(fi) - if err != nil { - return err - } - w, err := zw.CreateHeader(fh) - if err != nil { - return err - } - // 写入文件内容 - _, err = io.Copy(w, fr) - if err != nil { - return err - } - } - return nil -} - -func TestDeCompress(t *testing.T) { - err := os.Mkdir("testDeCompress", os.ModePerm) - if err != nil { - t.Error(err) - - } - err = Zip("demo.zip", []string{}) - if err != nil { - t.Error("create zip file failed") - } - tmpFile, err := os.OpenFile("demo.zip", os.O_RDONLY, 0777) - if err != nil { - debug.PrintStack() - t.Error("open demo.zip failed") - } - var dstPath = "./testDeCompress" - Convey("Test DeCopmress func", t, func() { - - err := DeCompress(tmpFile, dstPath) - So(err, ShouldEqual, nil) - }) - _ = os.RemoveAll("testDeCompress") - _ = os.Remove("demo.zip") + // Test cases + pathString := "../config" + fileString := "../config/config.go" + wrongString := "test" + // Test directory path + res := IsDir(pathString) + assert.True(t, res, "Expected directory path to return true") + + // Test file path + fileRes := IsDir(fileString) + assert.False(t, fileRes, "Expected file path to return false") + + // Test non-existing path + wrongRes := IsDir(wrongString) + assert.False(t, wrongRes, "Expected non-existing path to return false") } diff --git a/core/utils/process.go b/core/utils/process.go index a30095d1..c2967ebe 100644 --- a/core/utils/process.go +++ b/core/utils/process.go @@ -3,85 +3,100 @@ package utils import ( "errors" "github.com/apex/log" - "github.com/crawlab-team/crawlab/trace" "github.com/shirou/gopsutil/process" "os/exec" "runtime" "strings" ) -func ProcessIdExists(pid int) (ok bool) { - //// Find process by pid - //p, err := os.FindProcess(pid) - //if err != nil { - // // Process not found - // return false - //} - // - //// Check if process exists - //err = p.Signal(syscall.Signal(0)) - //if err == nil { - // // Process exists - // return true - //} - // - //// Process not found - //return false +func BuildCmd(cmdStr string) (cmd *exec.Cmd, err error) { + if cmdStr == "" { + return nil, errors.New("command string is empty") + } + args := strings.Split(cmdStr, " ") + return exec.Command(args[0], args[1:]...), nil +} - ok, err := process.PidExists(int32(pid)) +func ProcessIdExists(pid int) (exists bool) { + if runtime.GOOS == "windows" { + return processIdExistsWindows(pid) + } else { + return processIdExistsLinuxMac(pid) + } +} + +func processIdExistsWindows(pid int) (exists bool) { + exists, err := process.PidExists(int32(pid)) if err != nil { log.Errorf("error checking if process exists: %v", err) } - return ok - - //processIds, err := process.Pids() - //if err != nil { - // log.Errorf("error getting process pids: %v", err) - // return false - //} - // - //for _, _pid := range processIds { - // if int(_pid) == pid { - // return true - // } - //} - // - //return false + return exists } -func ListProcess(text string) (lines []string, err error) { - if runtime.GOOS == "windows" { - return listProcessWindow(text) - } else { - return listProcessLinuxMac(text) +func processIdExistsLinuxMac(pid int) (exists bool) { + exists, err := process.PidExists(int32(pid)) + if err != nil { + log.Errorf("error checking if process exists: %v", err) } + return exists } -func listProcessWindow(text string) (lines []string, err error) { - cmd := exec.Command("tasklist", "/fi", text) - out, err := cmd.CombinedOutput() - var exitError *exec.ExitError - ok := errors.As(err, &exitError) - if !ok { - return nil, trace.TraceError(err) +func GetProcesses() (processes []*process.Process, err error) { + processes, err = process.Processes() + if err != nil { + log.Errorf("error getting processes: %v", err) + return nil, err } - lines = strings.Split(string(out), "\n") - return lines, nil + return processes, nil } -func listProcessLinuxMac(text string) (lines []string, err error) { - cmd := exec.Command("ps", "aux") - out, err := cmd.CombinedOutput() - var exitError *exec.ExitError - ok := errors.As(err, &exitError) - if !ok { - return nil, trace.TraceError(err) +type KillProcessOptions struct { + Force bool +} + +func KillProcess(cmd *exec.Cmd, force bool) error { + // process + p, err := process.NewProcess(int32(cmd.Process.Pid)) + if err != nil { + log.Errorf("failed to get process: %v", err) + return err } - _lines := strings.Split(string(out), "\n") - for _, l := range _lines { - if strings.Contains(l, text) { - lines = append(lines, l) + + // kill process + return killProcessRecursive(p, force) +} + +func killProcessRecursive(p *process.Process, force bool) (err error) { + // children processes + cps, err := p.Children() + if err != nil { + if !errors.Is(err, process.ErrorNoChildren) { + log.Errorf("failed to get children processes: %v", err) + } else if errors.Is(err, process.ErrorProcessNotRunning) { + return nil + } + return killProcess(p, force) + } + + // iterate children processes + for _, cp := range cps { + if err := killProcessRecursive(cp, force); err != nil { + return err } } - return lines, nil + + return killProcess(p, force) +} + +func killProcess(p *process.Process, force bool) (err error) { + if force { + err = p.Kill() + } else { + err = p.Terminate() + } + if err != nil { + log.Errorf("failed to kill process (force: %v): %v", force, err) + return err + } + return nil } diff --git a/core/utils/process_test.go b/core/utils/process_test.go new file mode 100644 index 00000000..f0b3cd67 --- /dev/null +++ b/core/utils/process_test.go @@ -0,0 +1,20 @@ +package utils + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProcessIdExists(t *testing.T) { + t.Run("existing process", func(t *testing.T) { + currentPid := os.Getpid() + assert.True(t, ProcessIdExists(currentPid), "should detect current process") + }) + + t.Run("non-existent process", func(t *testing.T) { + invalidPid := 99999999 + assert.False(t, ProcessIdExists(invalidPid), "should not detect non-existent process") + }) +}