From 7d3082c3b4ebd52877d6c649bd413c72a25caae9 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Thu, 21 Nov 2024 18:24:29 +0800 Subject: [PATCH] feat: added ipc to task runner --- core/task/handler/runner.go | 377 ++++++++++++++---- core/task/handler/runner_test.go | 140 +++++++ .../install/python/requirements.txt | 2 +- 3 files changed, 441 insertions(+), 78 deletions(-) create mode 100644 core/task/handler/runner_test.go diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index f83202e3..702579a6 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "sync" + "syscall" "time" "github.com/crawlab-team/crawlab/core/models/models" @@ -34,33 +35,54 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +// Runner represents a task execution handler that manages the lifecycle of a running task type Runner struct { // dependencies svc *Service // task handler service fsSvc interfaces.FsService // task fs service // settings - subscribeTimeout time.Duration - bufferSize int + subscribeTimeout time.Duration // maximum time to wait for task subscription + bufferSize int // buffer size for reading process output // internals cmd *exec.Cmd // process command instance pid int // process id tid primitive.ObjectID // task id - t *models.Task // task model.Task - s *models.Spider // spider model.Spider - ch chan constants.TaskSignal // channel to communicate between Service and Runner - err error // standard process error - cwd string // working directory - c *client2.GrpcClient // grpc client - conn grpc.TaskService_ConnectClient // grpc task service stream client + t *models.Task // task model instance + s *models.Spider // spider model instance + ch chan constants.TaskSignal // channel for task status communication + err error // captures any process execution errors + cwd string // current working directory for task + c *client2.GrpcClient // gRPC client for communication + conn grpc.TaskService_ConnectClient // gRPC stream connection for task service - // log internals - scannerStdout *bufio.Reader - scannerStderr *bufio.Reader - logBatchSize int + // log handling + scannerStdout *bufio.Reader // reader for process stdout + scannerStderr *bufio.Reader // reader for process stderr + logBatchSize int // number of log lines to batch before sending + + // IPC (Inter-Process Communication) + stdinPipe io.WriteCloser // pipe for writing to child process + stdoutPipe io.ReadCloser // pipe for reading from child process + ipcChan chan IPCMessage // channel for sending IPC messages + ipcHandler func(IPCMessage) // callback for handling received IPC messages + + // goroutine management + ctx context.Context // context for controlling goroutine lifecycle + cancel context.CancelFunc // function to cancel the context + done chan struct{} // channel to signal completion + wg sync.WaitGroup // wait group for goroutine synchronization } +// IPCMessage defines the structure for messages exchanged between parent and child processes +type IPCMessage struct { + Type string `json:"type"` // message type identifier + Payload interface{} `json:"payload"` // message content + IPC bool `json:"ipc"` // Add this field to explicitly mark IPC messages +} + +// Init initializes the task runner by updating the task status and establishing gRPC connections func (r *Runner) Init() (err error) { // update task if err := r.updateTask("", nil); err != nil { @@ -83,6 +105,8 @@ func (r *Runner) Init() (err error) { return nil } +// Run executes the task and manages its lifecycle, including file synchronization, process execution, +// and status monitoring. Returns an error if the task execution fails. func (r *Runner) Run() (err error) { // log task started log.Infof("task[%s] started", r.tid.Hex()) @@ -106,17 +130,11 @@ func (r *Runner) Run() (err error) { // configure environment variables r.configureEnv() - // configure logging - r.configureLogging() - // start process if err := r.cmd.Start(); err != nil { return r.updateTask(constants.TaskStatusError, err) } - // start logging - go r.startLogging() - // process id if r.cmd.Process == nil { return r.updateTask(constants.TaskStatusError, constants.ErrNotExists) @@ -135,6 +153,20 @@ func (r *Runner) Run() (err error) { // start health check go r.startHealthCheck() + // Start IPC reader + go r.startIPCReader() + + // Start IPC handler + go r.handleIPC() + + // Ensure cleanup when Run() exits + defer func() { + r.cancel() // Cancel context to stop all goroutines + r.wg.Wait() // Wait for all goroutines to finish + close(r.done) // Signal that everything is done + close(r.ipcChan) // Close IPC channel + }() + // declare task status status := "" @@ -166,8 +198,13 @@ func (r *Runner) Run() (err error) { return err } +// Cancel terminates the running task. If force is true, the process will be killed immediately +// without waiting for graceful shutdown. func (r *Runner) Cancel(force bool) (err error) { - // kill process + // Signal goroutines to stop + r.cancel() + + // Kill process opts := &sys_exec.KillProcessOptions{ Timeout: r.svc.GetCancelTimeout(), Force: force, @@ -176,18 +213,26 @@ func (r *Runner) Cancel(force bool) (err error) { return err } - // make sure the process does not exist - ticker := time.NewTicker(1 * time.Second) - timeout := time.After(r.svc.GetCancelTimeout()) + // Wait for process to be killed and goroutines to stop + ticker := time.NewTicker(time.Second) + defer ticker.Stop() for { select { - case <-timeout: - return fmt.Errorf("task process %d still exists", r.pid) case <-ticker.C: - if exists, _ := process.PidExists(int32(r.pid)); exists { - return fmt.Errorf("task process %d still exists", r.pid) + p, err := os.FindProcess(r.pid) + if err != nil { + // process not exists, exit + return nil + } + err = p.Signal(syscall.Signal(0)) + if err == nil { + // process still exists, continue + continue } return nil + case <-time.After(r.svc.GetCancelTimeout()): + // timeout + return fmt.Errorf("timeout waiting for task to stop") } } } @@ -200,6 +245,8 @@ func (r *Runner) GetTaskId() (id primitive.ObjectID) { return r.tid } +// configureCmd builds and configures the command to be executed, including setting up IPC pipes +// and processing command parameters func (r *Runner) configureCmd() (err error) { var cmdStr string @@ -221,71 +268,62 @@ func (r *Runner) configureCmd() (err error) { r.cmd, err = sys_exec.BuildCmd(cmdStr) if err != nil { log.Errorf("error building command: %v", err) - trace.PrintError(err) return err } // set working directory r.cmd.Dir = r.cwd + // Configure pipes for IPC + r.stdinPipe, err = r.cmd.StdinPipe() + if err != nil { + return trace.TraceError(err) + } + + // Add stdout pipe for IPC + r.stdoutPipe, err = r.cmd.StdoutPipe() + if err != nil { + return trace.TraceError(err) + } + + // Initialize IPC channel + r.ipcChan = make(chan IPCMessage) + return nil } -func (r *Runner) configureLogging() { - // set stdout reader - stdout, _ := r.cmd.StdoutPipe() - r.scannerStdout = bufio.NewReaderSize(stdout, r.bufferSize) - - // set stderr reader - stderr, _ := r.cmd.StderrPipe() - r.scannerStderr = bufio.NewReaderSize(stderr, r.bufferSize) -} - -func (r *Runner) startLogging() { - // start reading stdout - go r.startLoggingReaderStdout() - - // start reading stderr - go r.startLoggingReaderStderr() -} - -func (r *Runner) startLoggingReaderStdout() { - for { - line, err := r.scannerStdout.ReadString(byte('\n')) - if err != nil { - break - } - line = strings.TrimSuffix(line, "\n") - r.writeLogLines([]string{line}) - } -} - -func (r *Runner) startLoggingReaderStderr() { - for { - line, err := r.scannerStderr.ReadString(byte('\n')) - if err != nil { - break - } - line = strings.TrimSuffix(line, "\n") - r.writeLogLines([]string{line}) - } -} - +// startHealthCheck periodically verifies that the process is still running +// If the process disappears unexpectedly, it signals a task lost condition func (r *Runner) startHealthCheck() { + r.wg.Add(1) + defer r.wg.Done() + if r.cmd.ProcessState == nil || r.cmd.ProcessState.Exited() { return } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for { - exists, _ := process.PidExists(int32(r.pid)) - if !exists { - // process lost - r.ch <- constants.TaskSignalLost + select { + case <-r.ctx.Done(): return + case <-ticker.C: + exists, _ := process.PidExists(int32(r.pid)) + if !exists { + // process lost + r.ch <- constants.TaskSignalLost + return + } } - time.Sleep(1 * time.Second) } } +// configureEnv sets up the environment variables for the task process, including: +// - Node.js paths +// - Crawlab-specific variables +// - Global environment variables from the system func (r *Runner) configureEnv() { // By default, add Node.js's global node_modules to PATH envPath := os.Getenv("PATH") @@ -312,6 +350,9 @@ func (r *Runner) configureEnv() { for _, env := range envs { r.cmd.Env = append(r.cmd.Env, env.Key+"="+env.Value) } + + // Add environment variable for child processes to identify they're running under Crawlab + r.cmd.Env = append(r.cmd.Env, "CRAWLAB_PARENT_PID="+fmt.Sprint(os.Getpid())) } func (r *Runner) createHttpRequest(method, path string) (*http.Response, error) { @@ -337,6 +378,11 @@ func (r *Runner) createHttpRequest(method, path string) (*http.Response, error) return http.DefaultClient.Do(req) } +// syncFiles synchronizes files between master and worker nodes: +// 1. Gets file list from master +// 2. Compares with local files +// 3. Downloads new/modified files +// 4. Deletes files that no longer exist on master func (r *Runner) syncFiles() (err error) { workingDir := "" if !r.s.GitId.IsZero() { @@ -440,6 +486,7 @@ func (r *Runner) syncFiles() (err error) { return err } +// downloadFile downloads a file from the master node and saves it to the local file system func (r *Runner) downloadFile(path string, filePath string, fileInfo *entity.FsFileInfo) error { resp, err := r.createHttpRequest("GET", "/download?path="+path) if err != nil { @@ -478,14 +525,17 @@ func (r *Runner) downloadFile(path string, filePath string, fileInfo *entity.FsF return nil } +// getHttpRequestHeaders returns the headers for HTTP requests to the master node func (r *Runner) getHttpRequestHeaders() (headers map[string]string) { headers = make(map[string]string) headers["Authorization"] = utils.GetAuthKey() return headers } -// wait for process to finish and send task signal (constants.TaskSignal) -// to task runner's channel (Runner.ch) according to exit code +// wait monitors the process execution and sends appropriate signals based on the exit status: +// - TaskSignalFinish for successful completion +// - TaskSignalCancel for cancellation +// - TaskSignalError for execution errors func (r *Runner) wait() { // wait for process to finish if err := r.cmd.Wait(); err != nil { @@ -512,7 +562,8 @@ func (r *Runner) wait() { r.ch <- constants.TaskSignalFinish } -// updateTask update and get updated info of task (Runner.t) +// updateTask updates the task status and related statistics in the database +// If running on a worker node, updates are sent to the master func (r *Runner) updateTask(status string, e error) (err error) { if r.t != nil && status != "" { // update task status @@ -549,14 +600,17 @@ func (r *Runner) updateTask(status string, e error) (err error) { return nil } +// initConnection establishes a gRPC connection to the task service func (r *Runner) initConnection() (err error) { r.conn, err = r.c.TaskClient.Connect(context.Background()) if err != nil { - return trace.TraceError(err) + log.Errorf("error connecting to task service: %v", err) + return err } return nil } +// writeLogLines marshals log lines to JSON and sends them to the task service func (r *Runner) writeLogLines(lines []string) { linesBytes, err := json.Marshal(lines) if err != nil { @@ -574,6 +628,9 @@ func (r *Runner) writeLogLines(lines []string) { } } +// _updateTaskStat updates task statistics based on the current status: +// - For running tasks: sets start time and wait duration +// - For completed tasks: sets end time and calculates durations func (r *Runner) _updateTaskStat(status string) { ts, err := client.NewModelService[models.TaskStat]().GetById(r.tid) if err != nil { @@ -610,6 +667,7 @@ func (r *Runner) _updateTaskStat(status string) { } } +// sendNotification sends a notification to the task service func (r *Runner) sendNotification() { req := &grpc.TaskServiceSendNotificationRequest{ NodeKey: r.svc.GetNodeConfigService().GetNodeKey(), @@ -623,6 +681,10 @@ func (r *Runner) sendNotification() { } } +// _updateSpiderStat updates spider statistics based on task completion: +// - Updates last task ID +// - Increments task counts +// - Updates duration metrics func (r *Runner) _updateSpiderStat(status string) { // task stat ts, err := client.NewModelService[models.TaskStat]().GetById(r.tid) @@ -677,6 +739,7 @@ func (r *Runner) _updateSpiderStat(status string) { } } +// configureCwd sets the working directory for the task based on the spider's configuration func (r *Runner) configureCwd() { workspacePath := utils.GetWorkspace() if r.s.GitId.IsZero() { @@ -688,6 +751,162 @@ func (r *Runner) configureCwd() { } } +// handleIPC processes incoming IPC messages from the child process +// Messages are converted to JSON and written to the child process's stdin +func (r *Runner) handleIPC() { + for msg := range r.ipcChan { + // Convert message to JSON + jsonData, err := json.Marshal(msg) + if err != nil { + log.Errorf("error marshaling IPC message: %v", err) + continue + } + + // Write to child process's stdin + _, err = fmt.Fprintln(r.stdinPipe, string(jsonData)) + if err != nil { + log.Errorf("error writing to child process: %v", err) + } + } +} + +// SendToChild sends a message to the child process through the IPC channel +// msgType: type of message being sent +// payload: data to be sent to the child process +func (r *Runner) SendToChild(msgType string, payload interface{}) { + r.ipcChan <- IPCMessage{ + Type: msgType, + Payload: payload, + IPC: true, // Explicitly mark as IPC message + } +} + +// SetIPCHandler sets the handler for incoming IPC messages +func (r *Runner) SetIPCHandler(handler func(IPCMessage)) { + r.ipcHandler = handler +} + +// startIPCReader continuously reads IPC messages from the child process's stdout +// Messages are parsed and either handled by the IPC handler or written to logs +func (r *Runner) startIPCReader() { + r.wg.Add(1) + defer r.wg.Done() + + scanner := bufio.NewScanner(r.stdoutPipe) + for { + select { + case <-r.ctx.Done(): + return + default: + if !scanner.Scan() { + return + } + line := scanner.Text() + + var ipcMsg IPCMessage + if err := json.Unmarshal([]byte(line), &ipcMsg); err == nil && ipcMsg.IPC { + // Only handle as IPC if it's valid JSON AND has IPC flag set + if r.ipcHandler != nil { + r.ipcHandler(ipcMsg) + } else { + // Default handler (insert data) + if ipcMsg.Type == "" || ipcMsg.Type == "insert_data" { + r.handleIPCInsertDataMessage(ipcMsg) + } else { + log.Warnf("no IPC handler set for message: %v", ipcMsg) + } + } + } else { + // Everything else is treated as logs + r.writeLogLines([]string{line}) + } + } + } +} + +// handleIPCInsertDataMessage converts the IPC message payload to JSON and sends it to the master node +func (r *Runner) handleIPCInsertDataMessage(ipcMsg IPCMessage) { + // Validate message + if ipcMsg.Payload == nil { + log.Errorf("empty payload in IPC message") + return + } + + // Convert payload to data to be inserted + var records []map[string]interface{} + + switch payload := ipcMsg.Payload.(type) { + case []interface{}: // Handle array of objects + records = make([]map[string]interface{}, 0, len(payload)) + for i, item := range payload { + if itemMap, ok := item.(map[string]interface{}); ok { + records = append(records, itemMap) + } else { + log.Errorf("invalid record at index %d: %v", i, item) + continue + } + } + case []map[string]interface{}: // Handle direct array of maps + records = payload + case map[string]interface{}: // Handle single object + records = []map[string]interface{}{payload} + case interface{}: // Handle generic interface + if itemMap, ok := payload.(map[string]interface{}); ok { + records = []map[string]interface{}{itemMap} + } else { + log.Errorf("invalid payload type: %T", payload) + return + } + default: + log.Errorf("unsupported payload type: %T, value: %v", payload, ipcMsg.Payload) + return + } + + // Validate records + if len(records) == 0 { + log.Warnf("no valid records to insert") + return + } + + // Marshal data with error handling + data, err := json.Marshal(records) + if err != nil { + log.Errorf("error marshaling records: %v", err) + return + } + + // Validate connection + if r.conn == nil { + log.Errorf("gRPC connection not initialized") + return + } + + // Send IPC message to master with context and timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Create gRPC message + grpcMsg := &grpc.TaskServiceConnectRequest{ + Code: grpc.TaskServiceConnectCode_INSERT_DATA, + TaskId: r.tid.Hex(), + Data: data, + } + + // Use context for sending + select { + case <-ctx.Done(): + log.Errorf("timeout sending IPC message") + return + default: + if err := r.conn.Send(grpcMsg); err != nil { + log.Errorf("error sending IPC message: %v", err) + return + } + } +} + +// newTaskRunner creates a new task runner instance with the specified task ID +// It initializes all necessary components and establishes required connections func newTaskRunner(id primitive.ObjectID, svc *Service) (r2 *Runner, err error) { // validate options if id.IsZero() { @@ -722,6 +941,10 @@ func newTaskRunner(id primitive.ObjectID, svc *Service) (r2 *Runner, err error) // grpc client r.c = client2.GetGrpcClient() + // Initialize context and done channel + r.ctx, r.cancel = context.WithCancel(context.Background()) + r.done = make(chan struct{}) + // initialize task runner if err := r.Init(); err != nil { return r, err diff --git a/core/task/handler/runner_test.go b/core/task/handler/runner_test.go new file mode 100644 index 00000000..613a8bd7 --- /dev/null +++ b/core/task/handler/runner_test.go @@ -0,0 +1,140 @@ +package handler + +import ( + "encoding/json" + "fmt" + "io" + "os" + "syscall" + "testing" + "time" + + "github.com/crawlab-team/crawlab/core/constants" + "github.com/crawlab-team/crawlab/core/models/models" + "github.com/crawlab-team/crawlab/core/models/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +func setupTest(t *testing.T) *Runner { + // Create a test spider + spider := &models.Spider{ + Name: "Test Spider", + } + spiderId, err := service.NewModelService[models.Spider]().InsertOne(*spider) + require.NoError(t, err) + spider.Id = spiderId + + // Create a test task + task := &models.Task{ + SpiderId: spiderId, + Status: constants.TaskStatusPending, + Type: "test", + Mode: "test", + NodeId: primitive.NewObjectID(), + Cmd: "python script.py", + } + taskId, err := service.NewModelService[models.Task]().InsertOne(*task) + require.NoError(t, err) + task.Id = taskId + + // Create a test runner + svc := newTaskHandlerService() + runner, _ := newTaskRunner(task.Id, svc) + err = runner.updateTask("", nil) + require.Nil(t, err) + _ = runner.Init() + err = runner.configureCmd() + require.Nil(t, err) + + return runner +} + +func TestRunner_HandleIPC(t *testing.T) { + // Setup test data + runner := setupTest(t) + + // Create a pipe for testing + pr, pw := io.Pipe() + runner.stdoutPipe = pr + + // Start IPC reader + go runner.startIPCReader() + + // Create test message + testMsg := IPCMessage{ + Type: "test_type", + Payload: map[string]interface{}{"key": "value"}, + IPC: true, + } + + // Create a channel to signal that the message was handled + handled := make(chan bool) + runner.SetIPCHandler(func(msg IPCMessage) { + assert.Equal(t, testMsg.Type, msg.Type) + assert.Equal(t, testMsg.Payload, msg.Payload) + handled <- true + }) + + // Convert message to JSON and write to pipe + go func() { + jsonData, err := json.Marshal(testMsg) + if err != nil { + t.Errorf("failed to marshal test message: %v", err) + return + } + + // Write message followed by newline + _, err = fmt.Fprintln(pw, string(jsonData)) + if err != nil { + t.Errorf("failed to write to pipe: %v", err) + return + } + }() + + select { + case <-handled: + // Message was handled successfully + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for IPC message to be handled") + } + + // Clean up + pw.Close() + pr.Close() +} + +func TestRunner_Cancel(t *testing.T) { + // Setup + runner := setupTest(t) + + // Start a long-running command + runner.t.Cmd = "sleep 10" + err := runner.cmd.Start() + assert.NoError(t, err) + runner.pid = runner.cmd.Process.Pid + + // Test cancel + err = runner.Cancel(true) + assert.NoError(t, err) + + // Verify process was killed + // Wait a short time for the process to be killed + time.Sleep(100 * time.Millisecond) + + process, err := os.FindProcess(runner.pid) + require.NoError(t, err) + err = process.Signal(syscall.Signal(0)) + assert.Error(t, err) // Process should not exist +} + +// Helper function to create a temporary workspace for testing +func createTestWorkspace(t *testing.T) string { + dir, err := os.MkdirTemp("", "crawlab-test-*") + assert.NoError(t, err) + t.Cleanup(func() { + os.RemoveAll(dir) + }) + return dir +} diff --git a/docker/base-image/install/python/requirements.txt b/docker/base-image/install/python/requirements.txt index 97902d7d..86646d6d 100644 --- a/docker/base-image/install/python/requirements.txt +++ b/docker/base-image/install/python/requirements.txt @@ -1,7 +1,7 @@ scrapy>=2.9.0 pymongo bs4 -crawlab-sdk>=0.6.0 +crawlab-sdk>=0.7.0rc1 crawlab-demo<=0.1.0 selenium pyopenssl