mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
feat: implement zombie process prevention and cleanup mechanisms in task runner
This commit is contained in:
@@ -15,11 +15,13 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/dependency"
|
||||
"github.com/crawlab-team/crawlab/core/fs"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/shirou/gopsutil/process"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
|
||||
@@ -219,6 +221,9 @@ func (r *Runner) Run() (err error) {
|
||||
// Start IPC handler
|
||||
go r.handleIPC()
|
||||
|
||||
// ZOMBIE PREVENTION: Start zombie process monitor
|
||||
go r.startZombieMonitor()
|
||||
|
||||
// Ensure cleanup when Run() exits
|
||||
defer func() {
|
||||
// 1. Signal all goroutines to stop
|
||||
@@ -336,6 +341,15 @@ func (r *Runner) configureCmd() (err error) {
|
||||
// set working directory
|
||||
r.cmd.Dir = r.cwd
|
||||
|
||||
// ZOMBIE PREVENTION: Set process group to enable proper cleanup of child processes
|
||||
if runtime.GOOS != "windows" {
|
||||
// Create new process group on Unix systems to ensure child processes can be killed together
|
||||
r.cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true, // Create new process group
|
||||
Pgid: 0, // Use process ID as process group ID
|
||||
}
|
||||
}
|
||||
|
||||
// Configure pipes for IPC and logs
|
||||
r.stdinPipe, err = r.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
@@ -727,6 +741,8 @@ func (r *Runner) wait() (err error) {
|
||||
case constants.TaskSignalLost:
|
||||
err = constants.ErrTaskLost
|
||||
status = constants.TaskStatusError
|
||||
// ZOMBIE PREVENTION: Clean up any remaining processes when task is lost
|
||||
go r.cleanupOrphanedProcesses()
|
||||
default:
|
||||
err = constants.ErrInvalidSignal
|
||||
status = constants.TaskStatusError
|
||||
@@ -1492,3 +1508,151 @@ func (r *Runner) GetConnectionStats() map[string]interface{} {
|
||||
"connection_exists": r.conn != nil,
|
||||
}
|
||||
}
|
||||
|
||||
// ZOMBIE PROCESS PREVENTION METHODS
|
||||
|
||||
// cleanupOrphanedProcesses attempts to clean up any orphaned processes related to this task
|
||||
func (r *Runner) cleanupOrphanedProcesses() {
|
||||
r.Warnf("cleaning up orphaned processes for task %s (PID: %d)", r.tid.Hex(), r.pid)
|
||||
|
||||
if r.pid <= 0 {
|
||||
r.Debugf("no PID to clean up")
|
||||
return
|
||||
}
|
||||
|
||||
// Try to kill the process group if it exists
|
||||
if runtime.GOOS != "windows" {
|
||||
r.killProcessGroup()
|
||||
}
|
||||
|
||||
// Force kill the main process if it still exists
|
||||
if utils.ProcessIdExists(r.pid) {
|
||||
r.Warnf("forcefully killing remaining process %d", r.pid)
|
||||
if r.cmd != nil && r.cmd.Process != nil {
|
||||
if err := utils.KillProcess(r.cmd, true); err != nil {
|
||||
r.Errorf("failed to force kill process: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scan for any remaining child processes and kill them
|
||||
r.scanAndKillChildProcesses()
|
||||
}
|
||||
|
||||
// killProcessGroup kills the entire process group on Unix systems
|
||||
func (r *Runner) killProcessGroup() {
|
||||
if r.pid <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.Debugf("attempting to kill process group for PID %d", r.pid)
|
||||
|
||||
// Kill the process group (negative PID kills the group)
|
||||
err := syscall.Kill(-r.pid, syscall.SIGTERM)
|
||||
if err != nil {
|
||||
r.Debugf("failed to send SIGTERM to process group: %v", err)
|
||||
// Try SIGKILL as last resort
|
||||
err = syscall.Kill(-r.pid, syscall.SIGKILL)
|
||||
if err != nil {
|
||||
r.Debugf("failed to send SIGKILL to process group: %v", err)
|
||||
}
|
||||
} else {
|
||||
r.Debugf("successfully sent SIGTERM to process group %d", r.pid)
|
||||
}
|
||||
}
|
||||
|
||||
// scanAndKillChildProcesses scans for and kills any remaining child processes
|
||||
func (r *Runner) scanAndKillChildProcesses() {
|
||||
r.Debugf("scanning for orphaned child processes of task %s", r.tid.Hex())
|
||||
|
||||
processes, err := utils.GetProcesses()
|
||||
if err != nil {
|
||||
r.Errorf("failed to get process list: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
taskIdEnv := "CRAWLAB_TASK_ID=" + r.tid.Hex()
|
||||
killedCount := 0
|
||||
|
||||
for _, proc := range processes {
|
||||
// Check if this process has our task ID in its environment
|
||||
if r.isTaskRelatedProcess(proc, taskIdEnv) {
|
||||
pid := int(proc.Pid)
|
||||
r.Warnf("found orphaned task process PID %d, killing it", pid)
|
||||
|
||||
// Kill the orphaned process
|
||||
if err := proc.Kill(); err != nil {
|
||||
r.Errorf("failed to kill orphaned process %d: %v", pid, err)
|
||||
} else {
|
||||
killedCount++
|
||||
r.Infof("successfully killed orphaned process %d", pid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if killedCount > 0 {
|
||||
r.Infof("cleaned up %d orphaned processes for task %s", killedCount, r.tid.Hex())
|
||||
} else {
|
||||
r.Debugf("no orphaned processes found for task %s", r.tid.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
// isTaskRelatedProcess checks if a process is related to this task
|
||||
func (r *Runner) isTaskRelatedProcess(proc *process.Process, taskIdEnv string) bool {
|
||||
// Get process environment variables
|
||||
environ, err := proc.Environ()
|
||||
if err != nil {
|
||||
// If we can't read environment, skip this process
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if this process has our task ID
|
||||
for _, env := range environ {
|
||||
if env == taskIdEnv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// startZombieMonitor starts a background goroutine to monitor for zombie processes
|
||||
func (r *Runner) startZombieMonitor() {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
|
||||
// Check for zombies every 5 minutes
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.checkForZombieProcesses()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// checkForZombieProcesses periodically checks for and cleans up zombie processes
|
||||
func (r *Runner) checkForZombieProcesses() {
|
||||
r.Debugf("checking for zombie processes related to task %s", r.tid.Hex())
|
||||
|
||||
// Check if our main process still exists and is in the expected state
|
||||
if r.pid > 0 && utils.ProcessIdExists(r.pid) {
|
||||
// Process exists, check if it's a zombie
|
||||
if proc, err := process.NewProcess(int32(r.pid)); err == nil {
|
||||
if status, err := proc.Status(); err == nil {
|
||||
// Status returns a string, check if it indicates zombie
|
||||
statusStr := string(status)
|
||||
if statusStr == "Z" || statusStr == "zombie" {
|
||||
r.Warnf("detected zombie process %d for task %s", r.pid, r.tid.Hex())
|
||||
go r.cleanupOrphanedProcesses()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
grpcclient "github.com/crawlab-team/crawlab/core/grpc/client"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
@@ -15,9 +19,6 @@ import (
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@@ -32,23 +33,82 @@ type Service struct {
|
||||
cancelTimeout time.Duration
|
||||
|
||||
// internals variables
|
||||
stopped bool
|
||||
mu sync.Mutex
|
||||
runners sync.Map // pool of task runners started
|
||||
syncLocks sync.Map // files sync locks map of task runners
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
stopped bool
|
||||
mu sync.RWMutex
|
||||
runners sync.Map // pool of task runners started
|
||||
wg sync.WaitGroup // track background goroutines
|
||||
|
||||
// tickers for cleanup
|
||||
fetchTicker *time.Ticker
|
||||
reportTicker *time.Ticker
|
||||
|
||||
interfaces.Logger
|
||||
}
|
||||
|
||||
func (svc *Service) Start() {
|
||||
// Initialize context for graceful shutdown
|
||||
svc.ctx, svc.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// wait for grpc client ready
|
||||
grpcclient.GetGrpcClient().WaitForReady()
|
||||
|
||||
// Initialize tickers
|
||||
svc.fetchTicker = time.NewTicker(svc.fetchInterval)
|
||||
svc.reportTicker = time.NewTicker(svc.reportInterval)
|
||||
|
||||
// Start background goroutines with WaitGroup tracking
|
||||
svc.wg.Add(2)
|
||||
go svc.reportStatus()
|
||||
go svc.fetchAndRunTasks()
|
||||
|
||||
svc.Infof("Task handler service started")
|
||||
}
|
||||
|
||||
func (svc *Service) Stop() {
|
||||
svc.mu.Lock()
|
||||
if svc.stopped {
|
||||
svc.mu.Unlock()
|
||||
return
|
||||
}
|
||||
svc.stopped = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
svc.Infof("Stopping task handler service...")
|
||||
|
||||
// Cancel context to signal all goroutines to stop
|
||||
if svc.cancel != nil {
|
||||
svc.cancel()
|
||||
}
|
||||
|
||||
// Stop tickers to prevent new tasks
|
||||
if svc.fetchTicker != nil {
|
||||
svc.fetchTicker.Stop()
|
||||
}
|
||||
if svc.reportTicker != nil {
|
||||
svc.reportTicker.Stop()
|
||||
}
|
||||
|
||||
// Cancel all running tasks gracefully
|
||||
svc.stopAllRunners()
|
||||
|
||||
// Wait for all background goroutines to finish
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
svc.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give goroutines time to finish gracefully, then force stop
|
||||
select {
|
||||
case <-done:
|
||||
svc.Infof("All goroutines stopped gracefully")
|
||||
case <-time.After(30 * time.Second):
|
||||
svc.Warnf("Some goroutines did not stop gracefully within timeout")
|
||||
}
|
||||
|
||||
svc.Infof("Task handler service stopped")
|
||||
}
|
||||
|
||||
func (svc *Service) Run(taskId primitive.ObjectID) (err error) {
|
||||
@@ -60,67 +120,95 @@ func (svc *Service) Cancel(taskId primitive.ObjectID, force bool) (err error) {
|
||||
}
|
||||
|
||||
func (svc *Service) fetchAndRunTasks() {
|
||||
ticker := time.NewTicker(svc.fetchInterval)
|
||||
for {
|
||||
if svc.stopped {
|
||||
return
|
||||
defer svc.wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
svc.Errorf("fetchAndRunTasks panic recovered: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// current node
|
||||
n, err := svc.GetCurrentNode()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// skip if node is not active or enabled
|
||||
if !n.Active || !n.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate if max runners is reached (max runners = 0 means no limit)
|
||||
if n.MaxRunners > 0 && svc.getRunnerCount() >= n.MaxRunners {
|
||||
continue
|
||||
}
|
||||
|
||||
// fetch task id
|
||||
tid, err := svc.fetchTask()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// skip if no task id
|
||||
if tid.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
// run task
|
||||
if err := svc.runTask(tid); err != nil {
|
||||
t, err := svc.GetTaskById(tid)
|
||||
if err != nil && t.Status != constants.TaskStatusCancelled {
|
||||
t.Error = err.Error()
|
||||
t.Status = constants.TaskStatusError
|
||||
t.SetUpdated(t.CreatedBy)
|
||||
_ = client.NewModelService[models.Task]().ReplaceById(t.Id, *t)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
case <-svc.ctx.Done():
|
||||
svc.Infof("fetchAndRunTasks stopped by context")
|
||||
return
|
||||
case <-svc.fetchTicker.C:
|
||||
// Use a separate context with timeout for each operation
|
||||
if err := svc.processFetchCycle(); err != nil {
|
||||
svc.Debugf("fetch cycle error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) reportStatus() {
|
||||
ticker := time.NewTicker(svc.reportInterval)
|
||||
for {
|
||||
if svc.stopped {
|
||||
return
|
||||
}
|
||||
func (svc *Service) processFetchCycle() error {
|
||||
// Check if stopped
|
||||
svc.mu.RLock()
|
||||
stopped := svc.stopped
|
||||
svc.mu.RUnlock()
|
||||
|
||||
if stopped {
|
||||
return fmt.Errorf("service stopped")
|
||||
}
|
||||
|
||||
// current node
|
||||
n, err := svc.GetCurrentNode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current node: %w", err)
|
||||
}
|
||||
|
||||
// skip if node is not active or enabled
|
||||
if !n.Active || !n.Enabled {
|
||||
return fmt.Errorf("node not active or enabled")
|
||||
}
|
||||
|
||||
// validate if max runners is reached (max runners = 0 means no limit)
|
||||
if n.MaxRunners > 0 && svc.getRunnerCount() >= n.MaxRunners {
|
||||
return fmt.Errorf("max runners reached")
|
||||
}
|
||||
|
||||
// fetch task id
|
||||
tid, err := svc.fetchTask()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch task: %w", err)
|
||||
}
|
||||
|
||||
// skip if no task id
|
||||
if tid.IsZero() {
|
||||
return fmt.Errorf("no task available")
|
||||
}
|
||||
|
||||
// run task
|
||||
if err := svc.runTask(tid); err != nil {
|
||||
// Handle task error
|
||||
t, getErr := svc.GetTaskById(tid)
|
||||
if getErr == nil && t.Status != constants.TaskStatusCancelled {
|
||||
t.Error = err.Error()
|
||||
t.Status = constants.TaskStatusError
|
||||
t.SetUpdated(t.CreatedBy)
|
||||
_ = client.NewModelService[models.Task]().ReplaceById(t.Id, *t)
|
||||
}
|
||||
return fmt.Errorf("failed to run task: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) reportStatus() {
|
||||
defer svc.wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
svc.Errorf("reportStatus panic recovered: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// update node status
|
||||
case <-svc.ctx.Done():
|
||||
svc.Infof("reportStatus stopped by context")
|
||||
return
|
||||
case <-svc.reportTicker.C:
|
||||
// Update node status with error handling
|
||||
if err := svc.updateNodeStatus(); err != nil {
|
||||
svc.Errorf("failed to report status: %v", err)
|
||||
}
|
||||
@@ -230,9 +318,9 @@ func (svc *Service) getRunner(taskId primitive.ObjectID) (r interfaces.TaskRunne
|
||||
svc.Errorf("get runner error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
switch v.(type) {
|
||||
switch v := v.(type) {
|
||||
case interfaces.TaskRunner:
|
||||
r = v.(interfaces.TaskRunner)
|
||||
r = v
|
||||
default:
|
||||
err = fmt.Errorf("invalid type: %T", v)
|
||||
svc.Errorf("get runner error: %v", err)
|
||||
@@ -312,16 +400,28 @@ func (svc *Service) runTask(taskId primitive.ObjectID) (err error) {
|
||||
// add runner to pool
|
||||
svc.addRunner(taskId, r)
|
||||
|
||||
// create a goroutine to run task
|
||||
// create a goroutine to run task with proper cleanup
|
||||
go func() {
|
||||
// get subscription stream
|
||||
stopCh := make(chan struct{})
|
||||
stream, err := svc.subscribeTask(r.GetTaskId())
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
svc.Errorf("task[%s] panic recovered: %v", taskId.Hex(), rec)
|
||||
}
|
||||
// Always cleanup runner from pool
|
||||
svc.deleteRunner(taskId)
|
||||
}()
|
||||
|
||||
// Create task-specific context for better cancellation control
|
||||
taskCtx, taskCancel := context.WithCancel(svc.ctx)
|
||||
defer taskCancel()
|
||||
|
||||
// get subscription stream with retry logic
|
||||
stopCh := make(chan struct{}, 1)
|
||||
stream, err := svc.subscribeTaskWithRetry(taskCtx, r.GetTaskId(), 3)
|
||||
if err == nil {
|
||||
// create a goroutine to handle stream messages
|
||||
go svc.handleStreamMessages(r.GetTaskId(), stream, stopCh)
|
||||
} else {
|
||||
svc.Errorf("failed to subscribe task[%s]: %v", r.GetTaskId().Hex(), err)
|
||||
svc.Errorf("failed to subscribe task[%s] after retries: %v", r.GetTaskId().Hex(), err)
|
||||
svc.Warnf("task[%s] will not be able to receive stream messages", r.GetTaskId().Hex())
|
||||
}
|
||||
|
||||
@@ -331,23 +431,26 @@ func (svc *Service) runTask(taskId primitive.ObjectID) (err error) {
|
||||
case errors.Is(err, constants.ErrTaskError):
|
||||
svc.Errorf("task[%s] finished with error: %v", r.GetTaskId().Hex(), err)
|
||||
case errors.Is(err, constants.ErrTaskCancelled):
|
||||
svc.Errorf("task[%s] cancelled", r.GetTaskId().Hex())
|
||||
svc.Infof("task[%s] cancelled", r.GetTaskId().Hex())
|
||||
default:
|
||||
svc.Errorf("task[%s] finished with unknown error: %v", r.GetTaskId().Hex(), err)
|
||||
}
|
||||
} else {
|
||||
svc.Infof("task[%s] finished successfully", r.GetTaskId().Hex())
|
||||
}
|
||||
svc.Infof("task[%s] finished", r.GetTaskId().Hex())
|
||||
|
||||
// send stopCh signal to stream message handler
|
||||
stopCh <- struct{}{}
|
||||
|
||||
// delete runner from pool
|
||||
svc.deleteRunner(r.GetTaskId())
|
||||
select {
|
||||
case stopCh <- struct{}{}:
|
||||
default:
|
||||
// Channel already closed or full
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// subscribeTask attempts to subscribe to task stream
|
||||
func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -362,35 +465,114 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// subscribeTaskWithRetry attempts to subscribe to task stream with retry logic
|
||||
func (svc *Service) subscribeTaskWithRetry(ctx context.Context, taskId primitive.ObjectID, maxRetries int) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
stream, err = svc.subscribeTask(taskId)
|
||||
if err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
svc.Warnf("failed to subscribe task[%s] (attempt %d/%d): %v", taskId.Hex(), i+1, maxRetries, err)
|
||||
|
||||
if i < maxRetries-1 {
|
||||
// Wait before retry with exponential backoff
|
||||
backoff := time.Duration(i+1) * time.Second
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to subscribe after %d retries: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
func (svc *Service) handleStreamMessages(taskId primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
svc.Errorf("handleStreamMessages[%s] panic recovered: %v", taskId.Hex(), r)
|
||||
}
|
||||
// Ensure stream is properly closed
|
||||
if stream != nil {
|
||||
if err := stream.CloseSend(); err != nil {
|
||||
svc.Debugf("task[%s] failed to close stream: %v", taskId.Hex(), err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create timeout for stream operations
|
||||
streamTimeout := 30 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
err := stream.CloseSend()
|
||||
if err != nil {
|
||||
svc.Errorf("task[%s] failed to close stream: %v", taskId.Hex(), err)
|
||||
return
|
||||
}
|
||||
svc.Debugf("task[%s] stream handler received stop signal", taskId.Hex())
|
||||
return
|
||||
case <-svc.ctx.Done():
|
||||
svc.Debugf("task[%s] stream handler stopped by service context", taskId.Hex())
|
||||
return
|
||||
default:
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
// Set deadline for receive operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), streamTimeout)
|
||||
|
||||
// Use a goroutine to handle the blocking Recv call
|
||||
msgCh := make(chan *grpc.TaskServiceSubscribeResponse, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
} else {
|
||||
msgCh <- msg
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
cancel()
|
||||
svc.processStreamMessage(taskId, msg)
|
||||
case err := <-errCh:
|
||||
cancel()
|
||||
if errors.Is(err, io.EOF) {
|
||||
svc.Infof("task[%s] received EOF, stream closed", taskId.Hex())
|
||||
return
|
||||
}
|
||||
svc.Errorf("task[%s] stream error: %v", taskId.Hex(), err)
|
||||
continue
|
||||
}
|
||||
switch msg.Code {
|
||||
case grpc.TaskServiceSubscribeCode_CANCEL:
|
||||
svc.Infof("task[%s] received cancel signal", taskId.Hex())
|
||||
go svc.handleCancel(msg, taskId)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
svc.Warnf("task[%s] stream receive timeout", taskId.Hex())
|
||||
// Continue loop to try again
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
return
|
||||
case <-svc.ctx.Done():
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) processStreamMessage(taskId primitive.ObjectID, msg *grpc.TaskServiceSubscribeResponse) {
|
||||
switch msg.Code {
|
||||
case grpc.TaskServiceSubscribeCode_CANCEL:
|
||||
svc.Infof("task[%s] received cancel signal", taskId.Hex())
|
||||
go svc.handleCancel(msg, taskId)
|
||||
default:
|
||||
svc.Debugf("task[%s] received unknown stream message code: %v", taskId.Hex(), msg.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) handleCancel(msg *grpc.TaskServiceSubscribeResponse, taskId primitive.ObjectID) {
|
||||
// validate task id
|
||||
if msg.TaskId != taskId.Hex() {
|
||||
@@ -430,6 +612,50 @@ func (svc *Service) cancelTask(taskId primitive.ObjectID, force bool) (err error
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopAllRunners gracefully stops all running tasks
|
||||
func (svc *Service) stopAllRunners() {
|
||||
svc.Infof("Stopping all running tasks...")
|
||||
|
||||
var runnerIds []primitive.ObjectID
|
||||
|
||||
// Collect all runner IDs
|
||||
svc.runners.Range(func(key, value interface{}) bool {
|
||||
if taskId, ok := key.(primitive.ObjectID); ok {
|
||||
runnerIds = append(runnerIds, taskId)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Cancel all runners with timeout
|
||||
var wg sync.WaitGroup
|
||||
for _, taskId := range runnerIds {
|
||||
wg.Add(1)
|
||||
go func(tid primitive.ObjectID) {
|
||||
defer wg.Done()
|
||||
if err := svc.cancelTask(tid, false); err != nil {
|
||||
svc.Errorf("failed to cancel task[%s]: %v", tid.Hex(), err)
|
||||
// Force cancel after timeout
|
||||
time.Sleep(5 * time.Second)
|
||||
_ = svc.cancelTask(tid, true)
|
||||
}
|
||||
}(taskId)
|
||||
}
|
||||
|
||||
// Wait for all cancellations with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
svc.Infof("All tasks stopped gracefully")
|
||||
case <-time.After(30 * time.Second):
|
||||
svc.Warnf("Some tasks did not stop within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func newTaskHandlerService() *Service {
|
||||
// service
|
||||
svc := &Service{
|
||||
@@ -437,7 +663,7 @@ func newTaskHandlerService() *Service {
|
||||
fetchTimeout: 15 * time.Second,
|
||||
reportInterval: 5 * time.Second,
|
||||
cancelTimeout: 60 * time.Second,
|
||||
mu: sync.Mutex{},
|
||||
mu: sync.RWMutex{},
|
||||
runners: sync.Map{},
|
||||
Logger: utils.NewLogger("TaskHandlerService"),
|
||||
}
|
||||
|
||||
217
core/task/handler/service_robustness_test.go
Normal file
217
core/task/handler/service_robustness_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// TestService_GracefulShutdown tests proper service shutdown
|
||||
func TestService_GracefulShutdown(t *testing.T) {
|
||||
svc := &Service{
|
||||
fetchInterval: 100 * time.Millisecond,
|
||||
reportInterval: 100 * time.Millisecond,
|
||||
mu: sync.RWMutex{},
|
||||
runners: sync.Map{},
|
||||
Logger: utils.NewLogger("TestService"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
svc.ctx, svc.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Initialize tickers
|
||||
svc.fetchTicker = time.NewTicker(svc.fetchInterval)
|
||||
svc.reportTicker = time.NewTicker(svc.reportInterval)
|
||||
|
||||
// Start background goroutines
|
||||
svc.wg.Add(2)
|
||||
go svc.testFetchAndRunTasks() // Mock version
|
||||
go svc.testReportStatus() // Mock version
|
||||
|
||||
// Let it run for a short time
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Test graceful shutdown
|
||||
svc.Stop()
|
||||
|
||||
t.Log("✅ Service shutdown completed gracefully")
|
||||
}
|
||||
|
||||
// Mock versions for testing without dependencies
|
||||
func (svc *Service) testFetchAndRunTasks() {
|
||||
defer svc.wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
svc.Errorf("testFetchAndRunTasks panic recovered: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-svc.ctx.Done():
|
||||
svc.Infof("testFetchAndRunTasks stopped by context")
|
||||
return
|
||||
case <-svc.fetchTicker.C:
|
||||
// Mock fetch operation
|
||||
svc.Debugf("Mock fetch operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) testReportStatus() {
|
||||
defer svc.wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
svc.Errorf("testReportStatus panic recovered: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-svc.ctx.Done():
|
||||
svc.Infof("testReportStatus stopped by context")
|
||||
return
|
||||
case <-svc.reportTicker.C:
|
||||
// Mock status update
|
||||
svc.Debugf("Mock status update")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestService_ConcurrentAccess tests thread safety
|
||||
func TestService_ConcurrentAccess(t *testing.T) {
|
||||
svc := &Service{
|
||||
mu: sync.RWMutex{},
|
||||
runners: sync.Map{},
|
||||
Logger: utils.NewLogger("TestService"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
svc.ctx, svc.cancel = context.WithCancel(context.Background())
|
||||
defer svc.cancel()
|
||||
|
||||
// Test concurrent runner management
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
// Mock runner for testing
|
||||
mockRunner := &mockTaskRunner{id: primitive.NewObjectID()}
|
||||
|
||||
// Concurrent adds
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
taskId := primitive.NewObjectID()
|
||||
svc.addRunner(taskId, mockRunner)
|
||||
|
||||
// Brief pause
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// Test get runner
|
||||
_, err := svc.getRunner(taskId)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get runner: %v", err)
|
||||
}
|
||||
|
||||
// Delete runner
|
||||
svc.deleteRunner(taskId)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
t.Log("✅ Concurrent access test completed successfully")
|
||||
}
|
||||
|
||||
// TestService_ErrorHandling tests error recovery
|
||||
func TestService_ErrorHandling(t *testing.T) {
|
||||
svc := &Service{
|
||||
mu: sync.RWMutex{},
|
||||
runners: sync.Map{},
|
||||
Logger: utils.NewLogger("TestService"),
|
||||
}
|
||||
|
||||
// Test getting non-existent runner
|
||||
_, err := svc.getRunner(primitive.NewObjectID())
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent runner")
|
||||
}
|
||||
|
||||
// Test adding invalid runner type
|
||||
taskId := primitive.NewObjectID()
|
||||
svc.runners.Store(taskId, "invalid-type")
|
||||
|
||||
_, err = svc.getRunner(taskId)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid runner type")
|
||||
}
|
||||
|
||||
t.Log("✅ Error handling test completed successfully")
|
||||
}
|
||||
|
||||
// TestService_ResourceCleanup tests proper resource cleanup
|
||||
func TestService_ResourceCleanup(t *testing.T) {
|
||||
svc := &Service{
|
||||
mu: sync.RWMutex{},
|
||||
runners: sync.Map{},
|
||||
Logger: utils.NewLogger("TestService"),
|
||||
}
|
||||
|
||||
// Initialize context and tickers
|
||||
svc.ctx, svc.cancel = context.WithCancel(context.Background())
|
||||
svc.fetchTicker = time.NewTicker(100 * time.Millisecond)
|
||||
svc.reportTicker = time.NewTicker(100 * time.Millisecond)
|
||||
|
||||
// Add some mock runners
|
||||
for i := 0; i < 5; i++ {
|
||||
taskId := primitive.NewObjectID()
|
||||
mockRunner := &mockTaskRunner{id: taskId}
|
||||
svc.addRunner(taskId, mockRunner)
|
||||
}
|
||||
|
||||
// Verify runners exist
|
||||
runnerCount := 0
|
||||
svc.runners.Range(func(key, value interface{}) bool {
|
||||
runnerCount++
|
||||
return true
|
||||
})
|
||||
if runnerCount != 5 {
|
||||
t.Errorf("Expected 5 runners, got %d", runnerCount)
|
||||
}
|
||||
|
||||
// Test cleanup
|
||||
svc.stopAllRunners()
|
||||
|
||||
// Verify cleanup (runners should still exist but be marked for cancellation)
|
||||
// In a real scenario, runners would remove themselves after cancellation
|
||||
t.Log("✅ Resource cleanup test completed successfully")
|
||||
}
|
||||
|
||||
// Mock task runner for testing
|
||||
type mockTaskRunner struct {
|
||||
id primitive.ObjectID
|
||||
}
|
||||
|
||||
func (r *mockTaskRunner) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mockTaskRunner) GetTaskId() primitive.ObjectID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
func (r *mockTaskRunner) Run() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mockTaskRunner) Cancel(force bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mockTaskRunner) SetSubscribeTimeout(timeout time.Duration) {
|
||||
// Mock implementation
|
||||
}
|
||||
148
core/task/handler/zombie_prevention_test.go
Normal file
148
core/task/handler/zombie_prevention_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// TestRunner_ZombieProcessPrevention tests the zombie process prevention mechanisms
|
||||
func TestRunner_ZombieProcessPrevention(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
pid: 12345, // Mock PID
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
// Test that process group configuration is set on Unix systems
|
||||
if runtime.GOOS != "windows" {
|
||||
// This would normally be tested in an integration test with actual process spawning
|
||||
t.Log("✅ Process group configuration available for Unix systems")
|
||||
}
|
||||
|
||||
// Test zombie cleanup methods exist and can be called
|
||||
r.cleanupOrphanedProcesses() // Should not panic
|
||||
t.Log("✅ Zombie cleanup methods callable without panic")
|
||||
|
||||
// Test process group killing method
|
||||
if runtime.GOOS != "windows" {
|
||||
r.killProcessGroup() // Should handle invalid PID gracefully
|
||||
t.Log("✅ Process group killing handles invalid PID gracefully")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunner_ProcessGroupManagement tests process group creation
|
||||
func TestRunner_ProcessGroupManagement(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Process groups not supported on Windows")
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
// Test that the process group setup logic doesn't panic
|
||||
// We can't actually test configureCmd without proper task/spider setup
|
||||
// but we can test that the syscall configuration is properly set
|
||||
|
||||
// Test process group killing with invalid PID (should not crash)
|
||||
r.pid = -1 // Invalid PID
|
||||
r.killProcessGroup() // Should handle gracefully
|
||||
|
||||
t.Log("✅ Process group management methods handle edge cases properly")
|
||||
}
|
||||
|
||||
// TestRunner_ZombieMonitor tests the zombie monitoring functionality
|
||||
func TestRunner_ZombieMonitor(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Start zombie monitor
|
||||
r.startZombieMonitor()
|
||||
|
||||
// Let it run briefly
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Cancel and cleanup
|
||||
r.cancel()
|
||||
|
||||
t.Log("✅ Zombie monitor starts and stops cleanly")
|
||||
}
|
||||
|
||||
// TestRunner_OrphanedProcessCleanup tests orphaned process detection
|
||||
func TestRunner_OrphanedProcessCleanup(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
// Test scanning for orphaned processes (should not find any in test environment)
|
||||
r.scanAndKillChildProcesses()
|
||||
|
||||
t.Log("✅ Orphaned process scanning completes without error")
|
||||
}
|
||||
|
||||
// TestRunner_SignalHandling tests signal handling for process groups
|
||||
func TestRunner_SignalHandling(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Signal handling test not applicable on Windows")
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
pid: os.Getpid(), // Use current process PID for testing
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Test that signal sending doesn't crash
|
||||
// Note: This sends signals to our own process group, which should be safe
|
||||
err := syscall.Kill(-r.pid, syscall.Signal(0)) // Signal 0 tests if process exists
|
||||
if err != nil {
|
||||
t.Logf("Signal test returned expected error: %v", err)
|
||||
}
|
||||
|
||||
t.Log("✅ Signal handling functionality works")
|
||||
}
|
||||
|
||||
// BenchmarkRunner_ZombieCheck benchmarks zombie process checking
|
||||
func BenchmarkRunner_ZombieCheck(b *testing.B) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
pid: os.Getpid(),
|
||||
Logger: utils.NewLogger("BenchmarkTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.checkForZombieProcesses()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user