mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
feat: added ipc to task runner
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
@@ -34,33 +35,54 @@ import (
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// Runner represents a task execution handler that manages the lifecycle of a running task
|
||||
type Runner struct {
|
||||
// dependencies
|
||||
svc *Service // task handler service
|
||||
fsSvc interfaces.FsService // task fs service
|
||||
|
||||
// settings
|
||||
subscribeTimeout time.Duration
|
||||
bufferSize int
|
||||
subscribeTimeout time.Duration // maximum time to wait for task subscription
|
||||
bufferSize int // buffer size for reading process output
|
||||
|
||||
// internals
|
||||
cmd *exec.Cmd // process command instance
|
||||
pid int // process id
|
||||
tid primitive.ObjectID // task id
|
||||
t *models.Task // task model.Task
|
||||
s *models.Spider // spider model.Spider
|
||||
ch chan constants.TaskSignal // channel to communicate between Service and Runner
|
||||
err error // standard process error
|
||||
cwd string // working directory
|
||||
c *client2.GrpcClient // grpc client
|
||||
conn grpc.TaskService_ConnectClient // grpc task service stream client
|
||||
t *models.Task // task model instance
|
||||
s *models.Spider // spider model instance
|
||||
ch chan constants.TaskSignal // channel for task status communication
|
||||
err error // captures any process execution errors
|
||||
cwd string // current working directory for task
|
||||
c *client2.GrpcClient // gRPC client for communication
|
||||
conn grpc.TaskService_ConnectClient // gRPC stream connection for task service
|
||||
|
||||
// log internals
|
||||
scannerStdout *bufio.Reader
|
||||
scannerStderr *bufio.Reader
|
||||
logBatchSize int
|
||||
// log handling
|
||||
scannerStdout *bufio.Reader // reader for process stdout
|
||||
scannerStderr *bufio.Reader // reader for process stderr
|
||||
logBatchSize int // number of log lines to batch before sending
|
||||
|
||||
// IPC (Inter-Process Communication)
|
||||
stdinPipe io.WriteCloser // pipe for writing to child process
|
||||
stdoutPipe io.ReadCloser // pipe for reading from child process
|
||||
ipcChan chan IPCMessage // channel for sending IPC messages
|
||||
ipcHandler func(IPCMessage) // callback for handling received IPC messages
|
||||
|
||||
// goroutine management
|
||||
ctx context.Context // context for controlling goroutine lifecycle
|
||||
cancel context.CancelFunc // function to cancel the context
|
||||
done chan struct{} // channel to signal completion
|
||||
wg sync.WaitGroup // wait group for goroutine synchronization
|
||||
}
|
||||
|
||||
// IPCMessage defines the structure for messages exchanged between parent and child processes
|
||||
type IPCMessage struct {
|
||||
Type string `json:"type"` // message type identifier
|
||||
Payload interface{} `json:"payload"` // message content
|
||||
IPC bool `json:"ipc"` // Add this field to explicitly mark IPC messages
|
||||
}
|
||||
|
||||
// Init initializes the task runner by updating the task status and establishing gRPC connections
|
||||
func (r *Runner) Init() (err error) {
|
||||
// update task
|
||||
if err := r.updateTask("", nil); err != nil {
|
||||
@@ -83,6 +105,8 @@ func (r *Runner) Init() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run executes the task and manages its lifecycle, including file synchronization, process execution,
|
||||
// and status monitoring. Returns an error if the task execution fails.
|
||||
func (r *Runner) Run() (err error) {
|
||||
// log task started
|
||||
log.Infof("task[%s] started", r.tid.Hex())
|
||||
@@ -106,17 +130,11 @@ func (r *Runner) Run() (err error) {
|
||||
// configure environment variables
|
||||
r.configureEnv()
|
||||
|
||||
// configure logging
|
||||
r.configureLogging()
|
||||
|
||||
// start process
|
||||
if err := r.cmd.Start(); err != nil {
|
||||
return r.updateTask(constants.TaskStatusError, err)
|
||||
}
|
||||
|
||||
// start logging
|
||||
go r.startLogging()
|
||||
|
||||
// process id
|
||||
if r.cmd.Process == nil {
|
||||
return r.updateTask(constants.TaskStatusError, constants.ErrNotExists)
|
||||
@@ -135,6 +153,20 @@ func (r *Runner) Run() (err error) {
|
||||
// start health check
|
||||
go r.startHealthCheck()
|
||||
|
||||
// Start IPC reader
|
||||
go r.startIPCReader()
|
||||
|
||||
// Start IPC handler
|
||||
go r.handleIPC()
|
||||
|
||||
// Ensure cleanup when Run() exits
|
||||
defer func() {
|
||||
r.cancel() // Cancel context to stop all goroutines
|
||||
r.wg.Wait() // Wait for all goroutines to finish
|
||||
close(r.done) // Signal that everything is done
|
||||
close(r.ipcChan) // Close IPC channel
|
||||
}()
|
||||
|
||||
// declare task status
|
||||
status := ""
|
||||
|
||||
@@ -166,8 +198,13 @@ func (r *Runner) Run() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cancel terminates the running task. If force is true, the process will be killed immediately
|
||||
// without waiting for graceful shutdown.
|
||||
func (r *Runner) Cancel(force bool) (err error) {
|
||||
// kill process
|
||||
// Signal goroutines to stop
|
||||
r.cancel()
|
||||
|
||||
// Kill process
|
||||
opts := &sys_exec.KillProcessOptions{
|
||||
Timeout: r.svc.GetCancelTimeout(),
|
||||
Force: force,
|
||||
@@ -176,18 +213,26 @@ func (r *Runner) Cancel(force bool) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// make sure the process does not exist
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
timeout := time.After(r.svc.GetCancelTimeout())
|
||||
// Wait for process to be killed and goroutines to stop
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return fmt.Errorf("task process %d still exists", r.pid)
|
||||
case <-ticker.C:
|
||||
if exists, _ := process.PidExists(int32(r.pid)); exists {
|
||||
return fmt.Errorf("task process %d still exists", r.pid)
|
||||
p, err := os.FindProcess(r.pid)
|
||||
if err != nil {
|
||||
// process not exists, exit
|
||||
return nil
|
||||
}
|
||||
err = p.Signal(syscall.Signal(0))
|
||||
if err == nil {
|
||||
// process still exists, continue
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
case <-time.After(r.svc.GetCancelTimeout()):
|
||||
// timeout
|
||||
return fmt.Errorf("timeout waiting for task to stop")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,6 +245,8 @@ func (r *Runner) GetTaskId() (id primitive.ObjectID) {
|
||||
return r.tid
|
||||
}
|
||||
|
||||
// configureCmd builds and configures the command to be executed, including setting up IPC pipes
|
||||
// and processing command parameters
|
||||
func (r *Runner) configureCmd() (err error) {
|
||||
var cmdStr string
|
||||
|
||||
@@ -221,71 +268,62 @@ func (r *Runner) configureCmd() (err error) {
|
||||
r.cmd, err = sys_exec.BuildCmd(cmdStr)
|
||||
if err != nil {
|
||||
log.Errorf("error building command: %v", err)
|
||||
trace.PrintError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// set working directory
|
||||
r.cmd.Dir = r.cwd
|
||||
|
||||
// Configure pipes for IPC
|
||||
r.stdinPipe, err = r.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
// Add stdout pipe for IPC
|
||||
r.stdoutPipe, err = r.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
// Initialize IPC channel
|
||||
r.ipcChan = make(chan IPCMessage)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) configureLogging() {
|
||||
// set stdout reader
|
||||
stdout, _ := r.cmd.StdoutPipe()
|
||||
r.scannerStdout = bufio.NewReaderSize(stdout, r.bufferSize)
|
||||
|
||||
// set stderr reader
|
||||
stderr, _ := r.cmd.StderrPipe()
|
||||
r.scannerStderr = bufio.NewReaderSize(stderr, r.bufferSize)
|
||||
}
|
||||
|
||||
func (r *Runner) startLogging() {
|
||||
// start reading stdout
|
||||
go r.startLoggingReaderStdout()
|
||||
|
||||
// start reading stderr
|
||||
go r.startLoggingReaderStderr()
|
||||
}
|
||||
|
||||
func (r *Runner) startLoggingReaderStdout() {
|
||||
for {
|
||||
line, err := r.scannerStdout.ReadString(byte('\n'))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
r.writeLogLines([]string{line})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) startLoggingReaderStderr() {
|
||||
for {
|
||||
line, err := r.scannerStderr.ReadString(byte('\n'))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
r.writeLogLines([]string{line})
|
||||
}
|
||||
}
|
||||
|
||||
// startHealthCheck periodically verifies that the process is still running
|
||||
// If the process disappears unexpectedly, it signals a task lost condition
|
||||
func (r *Runner) startHealthCheck() {
|
||||
r.wg.Add(1)
|
||||
defer r.wg.Done()
|
||||
|
||||
if r.cmd.ProcessState == nil || r.cmd.ProcessState.Exited() {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
exists, _ := process.PidExists(int32(r.pid))
|
||||
if !exists {
|
||||
// process lost
|
||||
r.ch <- constants.TaskSignalLost
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
exists, _ := process.PidExists(int32(r.pid))
|
||||
if !exists {
|
||||
// process lost
|
||||
r.ch <- constants.TaskSignalLost
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// configureEnv sets up the environment variables for the task process, including:
|
||||
// - Node.js paths
|
||||
// - Crawlab-specific variables
|
||||
// - Global environment variables from the system
|
||||
func (r *Runner) configureEnv() {
|
||||
// By default, add Node.js's global node_modules to PATH
|
||||
envPath := os.Getenv("PATH")
|
||||
@@ -312,6 +350,9 @@ func (r *Runner) configureEnv() {
|
||||
for _, env := range envs {
|
||||
r.cmd.Env = append(r.cmd.Env, env.Key+"="+env.Value)
|
||||
}
|
||||
|
||||
// Add environment variable for child processes to identify they're running under Crawlab
|
||||
r.cmd.Env = append(r.cmd.Env, "CRAWLAB_PARENT_PID="+fmt.Sprint(os.Getpid()))
|
||||
}
|
||||
|
||||
func (r *Runner) createHttpRequest(method, path string) (*http.Response, error) {
|
||||
@@ -337,6 +378,11 @@ func (r *Runner) createHttpRequest(method, path string) (*http.Response, error)
|
||||
return http.DefaultClient.Do(req)
|
||||
}
|
||||
|
||||
// syncFiles synchronizes files between master and worker nodes:
|
||||
// 1. Gets file list from master
|
||||
// 2. Compares with local files
|
||||
// 3. Downloads new/modified files
|
||||
// 4. Deletes files that no longer exist on master
|
||||
func (r *Runner) syncFiles() (err error) {
|
||||
workingDir := ""
|
||||
if !r.s.GitId.IsZero() {
|
||||
@@ -440,6 +486,7 @@ func (r *Runner) syncFiles() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// downloadFile downloads a file from the master node and saves it to the local file system
|
||||
func (r *Runner) downloadFile(path string, filePath string, fileInfo *entity.FsFileInfo) error {
|
||||
resp, err := r.createHttpRequest("GET", "/download?path="+path)
|
||||
if err != nil {
|
||||
@@ -478,14 +525,17 @@ func (r *Runner) downloadFile(path string, filePath string, fileInfo *entity.FsF
|
||||
return nil
|
||||
}
|
||||
|
||||
// getHttpRequestHeaders returns the headers for HTTP requests to the master node
|
||||
func (r *Runner) getHttpRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
headers["Authorization"] = utils.GetAuthKey()
|
||||
return headers
|
||||
}
|
||||
|
||||
// wait for process to finish and send task signal (constants.TaskSignal)
|
||||
// to task runner's channel (Runner.ch) according to exit code
|
||||
// wait monitors the process execution and sends appropriate signals based on the exit status:
|
||||
// - TaskSignalFinish for successful completion
|
||||
// - TaskSignalCancel for cancellation
|
||||
// - TaskSignalError for execution errors
|
||||
func (r *Runner) wait() {
|
||||
// wait for process to finish
|
||||
if err := r.cmd.Wait(); err != nil {
|
||||
@@ -512,7 +562,8 @@ func (r *Runner) wait() {
|
||||
r.ch <- constants.TaskSignalFinish
|
||||
}
|
||||
|
||||
// updateTask update and get updated info of task (Runner.t)
|
||||
// updateTask updates the task status and related statistics in the database
|
||||
// If running on a worker node, updates are sent to the master
|
||||
func (r *Runner) updateTask(status string, e error) (err error) {
|
||||
if r.t != nil && status != "" {
|
||||
// update task status
|
||||
@@ -549,14 +600,17 @@ func (r *Runner) updateTask(status string, e error) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initConnection establishes a gRPC connection to the task service
|
||||
func (r *Runner) initConnection() (err error) {
|
||||
r.conn, err = r.c.TaskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
log.Errorf("error connecting to task service: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeLogLines marshals log lines to JSON and sends them to the task service
|
||||
func (r *Runner) writeLogLines(lines []string) {
|
||||
linesBytes, err := json.Marshal(lines)
|
||||
if err != nil {
|
||||
@@ -574,6 +628,9 @@ func (r *Runner) writeLogLines(lines []string) {
|
||||
}
|
||||
}
|
||||
|
||||
// _updateTaskStat updates task statistics based on the current status:
|
||||
// - For running tasks: sets start time and wait duration
|
||||
// - For completed tasks: sets end time and calculates durations
|
||||
func (r *Runner) _updateTaskStat(status string) {
|
||||
ts, err := client.NewModelService[models.TaskStat]().GetById(r.tid)
|
||||
if err != nil {
|
||||
@@ -610,6 +667,7 @@ func (r *Runner) _updateTaskStat(status string) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendNotification sends a notification to the task service
|
||||
func (r *Runner) sendNotification() {
|
||||
req := &grpc.TaskServiceSendNotificationRequest{
|
||||
NodeKey: r.svc.GetNodeConfigService().GetNodeKey(),
|
||||
@@ -623,6 +681,10 @@ func (r *Runner) sendNotification() {
|
||||
}
|
||||
}
|
||||
|
||||
// _updateSpiderStat updates spider statistics based on task completion:
|
||||
// - Updates last task ID
|
||||
// - Increments task counts
|
||||
// - Updates duration metrics
|
||||
func (r *Runner) _updateSpiderStat(status string) {
|
||||
// task stat
|
||||
ts, err := client.NewModelService[models.TaskStat]().GetById(r.tid)
|
||||
@@ -677,6 +739,7 @@ func (r *Runner) _updateSpiderStat(status string) {
|
||||
}
|
||||
}
|
||||
|
||||
// configureCwd sets the working directory for the task based on the spider's configuration
|
||||
func (r *Runner) configureCwd() {
|
||||
workspacePath := utils.GetWorkspace()
|
||||
if r.s.GitId.IsZero() {
|
||||
@@ -688,6 +751,162 @@ func (r *Runner) configureCwd() {
|
||||
}
|
||||
}
|
||||
|
||||
// handleIPC processes incoming IPC messages from the child process
|
||||
// Messages are converted to JSON and written to the child process's stdin
|
||||
func (r *Runner) handleIPC() {
|
||||
for msg := range r.ipcChan {
|
||||
// Convert message to JSON
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error marshaling IPC message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Write to child process's stdin
|
||||
_, err = fmt.Fprintln(r.stdinPipe, string(jsonData))
|
||||
if err != nil {
|
||||
log.Errorf("error writing to child process: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendToChild sends a message to the child process through the IPC channel
|
||||
// msgType: type of message being sent
|
||||
// payload: data to be sent to the child process
|
||||
func (r *Runner) SendToChild(msgType string, payload interface{}) {
|
||||
r.ipcChan <- IPCMessage{
|
||||
Type: msgType,
|
||||
Payload: payload,
|
||||
IPC: true, // Explicitly mark as IPC message
|
||||
}
|
||||
}
|
||||
|
||||
// SetIPCHandler sets the handler for incoming IPC messages
|
||||
func (r *Runner) SetIPCHandler(handler func(IPCMessage)) {
|
||||
r.ipcHandler = handler
|
||||
}
|
||||
|
||||
// startIPCReader continuously reads IPC messages from the child process's stdout
|
||||
// Messages are parsed and either handled by the IPC handler or written to logs
|
||||
func (r *Runner) startIPCReader() {
|
||||
r.wg.Add(1)
|
||||
defer r.wg.Done()
|
||||
|
||||
scanner := bufio.NewScanner(r.stdoutPipe)
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
default:
|
||||
if !scanner.Scan() {
|
||||
return
|
||||
}
|
||||
line := scanner.Text()
|
||||
|
||||
var ipcMsg IPCMessage
|
||||
if err := json.Unmarshal([]byte(line), &ipcMsg); err == nil && ipcMsg.IPC {
|
||||
// Only handle as IPC if it's valid JSON AND has IPC flag set
|
||||
if r.ipcHandler != nil {
|
||||
r.ipcHandler(ipcMsg)
|
||||
} else {
|
||||
// Default handler (insert data)
|
||||
if ipcMsg.Type == "" || ipcMsg.Type == "insert_data" {
|
||||
r.handleIPCInsertDataMessage(ipcMsg)
|
||||
} else {
|
||||
log.Warnf("no IPC handler set for message: %v", ipcMsg)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Everything else is treated as logs
|
||||
r.writeLogLines([]string{line})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleIPCInsertDataMessage converts the IPC message payload to JSON and sends it to the master node
|
||||
func (r *Runner) handleIPCInsertDataMessage(ipcMsg IPCMessage) {
|
||||
// Validate message
|
||||
if ipcMsg.Payload == nil {
|
||||
log.Errorf("empty payload in IPC message")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert payload to data to be inserted
|
||||
var records []map[string]interface{}
|
||||
|
||||
switch payload := ipcMsg.Payload.(type) {
|
||||
case []interface{}: // Handle array of objects
|
||||
records = make([]map[string]interface{}, 0, len(payload))
|
||||
for i, item := range payload {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
records = append(records, itemMap)
|
||||
} else {
|
||||
log.Errorf("invalid record at index %d: %v", i, item)
|
||||
continue
|
||||
}
|
||||
}
|
||||
case []map[string]interface{}: // Handle direct array of maps
|
||||
records = payload
|
||||
case map[string]interface{}: // Handle single object
|
||||
records = []map[string]interface{}{payload}
|
||||
case interface{}: // Handle generic interface
|
||||
if itemMap, ok := payload.(map[string]interface{}); ok {
|
||||
records = []map[string]interface{}{itemMap}
|
||||
} else {
|
||||
log.Errorf("invalid payload type: %T", payload)
|
||||
return
|
||||
}
|
||||
default:
|
||||
log.Errorf("unsupported payload type: %T, value: %v", payload, ipcMsg.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate records
|
||||
if len(records) == 0 {
|
||||
log.Warnf("no valid records to insert")
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal data with error handling
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
log.Errorf("error marshaling records: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate connection
|
||||
if r.conn == nil {
|
||||
log.Errorf("gRPC connection not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
// Send IPC message to master with context and timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create gRPC message
|
||||
grpcMsg := &grpc.TaskServiceConnectRequest{
|
||||
Code: grpc.TaskServiceConnectCode_INSERT_DATA,
|
||||
TaskId: r.tid.Hex(),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// Use context for sending
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Errorf("timeout sending IPC message")
|
||||
return
|
||||
default:
|
||||
if err := r.conn.Send(grpcMsg); err != nil {
|
||||
log.Errorf("error sending IPC message: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTaskRunner creates a new task runner instance with the specified task ID
|
||||
// It initializes all necessary components and establishes required connections
|
||||
func newTaskRunner(id primitive.ObjectID, svc *Service) (r2 *Runner, err error) {
|
||||
// validate options
|
||||
if id.IsZero() {
|
||||
@@ -722,6 +941,10 @@ func newTaskRunner(id primitive.ObjectID, svc *Service) (r2 *Runner, err error)
|
||||
// grpc client
|
||||
r.c = client2.GetGrpcClient()
|
||||
|
||||
// Initialize context and done channel
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
r.done = make(chan struct{})
|
||||
|
||||
// initialize task runner
|
||||
if err := r.Init(); err != nil {
|
||||
return r, err
|
||||
|
||||
140
core/task/handler/runner_test.go
Normal file
140
core/task/handler/runner_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) *Runner {
|
||||
// Create a test spider
|
||||
spider := &models.Spider{
|
||||
Name: "Test Spider",
|
||||
}
|
||||
spiderId, err := service.NewModelService[models.Spider]().InsertOne(*spider)
|
||||
require.NoError(t, err)
|
||||
spider.Id = spiderId
|
||||
|
||||
// Create a test task
|
||||
task := &models.Task{
|
||||
SpiderId: spiderId,
|
||||
Status: constants.TaskStatusPending,
|
||||
Type: "test",
|
||||
Mode: "test",
|
||||
NodeId: primitive.NewObjectID(),
|
||||
Cmd: "python script.py",
|
||||
}
|
||||
taskId, err := service.NewModelService[models.Task]().InsertOne(*task)
|
||||
require.NoError(t, err)
|
||||
task.Id = taskId
|
||||
|
||||
// Create a test runner
|
||||
svc := newTaskHandlerService()
|
||||
runner, _ := newTaskRunner(task.Id, svc)
|
||||
err = runner.updateTask("", nil)
|
||||
require.Nil(t, err)
|
||||
_ = runner.Init()
|
||||
err = runner.configureCmd()
|
||||
require.Nil(t, err)
|
||||
|
||||
return runner
|
||||
}
|
||||
|
||||
func TestRunner_HandleIPC(t *testing.T) {
|
||||
// Setup test data
|
||||
runner := setupTest(t)
|
||||
|
||||
// Create a pipe for testing
|
||||
pr, pw := io.Pipe()
|
||||
runner.stdoutPipe = pr
|
||||
|
||||
// Start IPC reader
|
||||
go runner.startIPCReader()
|
||||
|
||||
// Create test message
|
||||
testMsg := IPCMessage{
|
||||
Type: "test_type",
|
||||
Payload: map[string]interface{}{"key": "value"},
|
||||
IPC: true,
|
||||
}
|
||||
|
||||
// Create a channel to signal that the message was handled
|
||||
handled := make(chan bool)
|
||||
runner.SetIPCHandler(func(msg IPCMessage) {
|
||||
assert.Equal(t, testMsg.Type, msg.Type)
|
||||
assert.Equal(t, testMsg.Payload, msg.Payload)
|
||||
handled <- true
|
||||
})
|
||||
|
||||
// Convert message to JSON and write to pipe
|
||||
go func() {
|
||||
jsonData, err := json.Marshal(testMsg)
|
||||
if err != nil {
|
||||
t.Errorf("failed to marshal test message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write message followed by newline
|
||||
_, err = fmt.Fprintln(pw, string(jsonData))
|
||||
if err != nil {
|
||||
t.Errorf("failed to write to pipe: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-handled:
|
||||
// Message was handled successfully
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for IPC message to be handled")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
pw.Close()
|
||||
pr.Close()
|
||||
}
|
||||
|
||||
func TestRunner_Cancel(t *testing.T) {
|
||||
// Setup
|
||||
runner := setupTest(t)
|
||||
|
||||
// Start a long-running command
|
||||
runner.t.Cmd = "sleep 10"
|
||||
err := runner.cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
runner.pid = runner.cmd.Process.Pid
|
||||
|
||||
// Test cancel
|
||||
err = runner.Cancel(true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify process was killed
|
||||
// Wait a short time for the process to be killed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
process, err := os.FindProcess(runner.pid)
|
||||
require.NoError(t, err)
|
||||
err = process.Signal(syscall.Signal(0))
|
||||
assert.Error(t, err) // Process should not exist
|
||||
}
|
||||
|
||||
// Helper function to create a temporary workspace for testing
|
||||
func createTestWorkspace(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp("", "crawlab-test-*")
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(dir)
|
||||
})
|
||||
return dir
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
scrapy>=2.9.0
|
||||
pymongo
|
||||
bs4
|
||||
crawlab-sdk>=0.6.0
|
||||
crawlab-sdk>=0.7.0rc1
|
||||
crawlab-demo<=0.1.0
|
||||
selenium
|
||||
pyopenssl
|
||||
|
||||
Reference in New Issue
Block a user