mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
fix: enhance task service resilience with connection health monitoring and periodic cleanup
This commit is contained in:
@@ -5,10 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
mongo3 "github.com/crawlab-team/crawlab/core/mongo"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mongo3 "github.com/crawlab-team/crawlab/core/mongo"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
@@ -35,6 +37,11 @@ type TaskServiceServer struct {
|
||||
|
||||
// internals
|
||||
subs map[primitive.ObjectID]grpc.TaskService_SubscribeServer
|
||||
|
||||
// cleanup mechanism
|
||||
cleanupCtx context.Context
|
||||
cleanupCancel context.CancelFunc
|
||||
|
||||
interfaces.Logger
|
||||
}
|
||||
|
||||
@@ -50,21 +57,38 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
|
||||
return errors.New("invalid stream")
|
||||
}
|
||||
|
||||
svr.Infof("task stream opened: %s", taskId.Hex())
|
||||
|
||||
// add stream
|
||||
taskServiceMutex.Lock()
|
||||
svr.subs[taskId] = stream
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
// wait for stream to close
|
||||
<-stream.Context().Done()
|
||||
// ensure cleanup on exit
|
||||
defer func() {
|
||||
taskServiceMutex.Lock()
|
||||
delete(svr.subs, taskId)
|
||||
taskServiceMutex.Unlock()
|
||||
svr.Infof("task stream closed: %s", taskId.Hex())
|
||||
}()
|
||||
|
||||
// remove stream
|
||||
taskServiceMutex.Lock()
|
||||
delete(svr.subs, taskId)
|
||||
taskServiceMutex.Unlock()
|
||||
svr.Infof("task stream closed: %s", taskId.Hex())
|
||||
// wait for stream to close with timeout protection
|
||||
ctx := stream.Context()
|
||||
|
||||
return nil
|
||||
// Create a context with timeout to prevent indefinite hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Stream context cancelled normally
|
||||
svr.Debugf("task stream context done: %s", taskId.Hex())
|
||||
return ctx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
// Timeout reached - this prevents indefinite hanging
|
||||
svr.Warnf("task stream timeout reached for task: %s", taskId.Hex())
|
||||
return errors.New("stream timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to task stream when a task runner in a node starts
|
||||
@@ -75,22 +99,49 @@ func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err
|
||||
var spiderId primitive.ObjectID
|
||||
var taskId primitive.ObjectID
|
||||
|
||||
// Add timeout protection for the entire connection
|
||||
ctx := stream.Context()
|
||||
|
||||
// Log connection start
|
||||
svr.Debugf("task connect stream started")
|
||||
|
||||
defer func() {
|
||||
if taskId != primitive.NilObjectID {
|
||||
svr.Debugf("task connect stream ended for task: %s", taskId.Hex())
|
||||
} else {
|
||||
svr.Debugf("task connect stream ended")
|
||||
}
|
||||
}()
|
||||
|
||||
// continuously receive messages from the stream
|
||||
for {
|
||||
// receive next message from stream
|
||||
// Check context cancellation before each receive
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
svr.Debugf("task connect stream context cancelled")
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// receive next message from stream with timeout
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
// stream has ended normally
|
||||
svr.Debugf("task connect stream ended normally (EOF)")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
// handle graceful context cancellation
|
||||
if strings.HasSuffix(err.Error(), "context canceled") {
|
||||
if strings.HasSuffix(err.Error(), "context canceled") ||
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "transport is closing") {
|
||||
svr.Debugf("task connect stream cancelled gracefully: %v", err)
|
||||
return nil
|
||||
}
|
||||
// log other stream receive errors and continue
|
||||
// log other stream receive errors
|
||||
svr.Errorf("error receiving stream message: %v", err)
|
||||
continue
|
||||
// Return error instead of continuing to prevent infinite error loops
|
||||
return err
|
||||
}
|
||||
|
||||
// validate and parse the task ID from the message if not already set
|
||||
@@ -100,6 +151,7 @@ func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err
|
||||
svr.Errorf("invalid task id: %s", msg.TaskId)
|
||||
continue
|
||||
}
|
||||
svr.Debugf("task connect stream set task id: %s", taskId.Hex())
|
||||
}
|
||||
|
||||
// get spider id if not already set
|
||||
@@ -149,8 +201,8 @@ func (svr TaskServiceServer) FetchTask(ctx context.Context, request *grpc.TaskSe
|
||||
var tid primitive.ObjectID
|
||||
opts := &mongo3.FindOptions{
|
||||
Sort: bson.D{
|
||||
{"priority", 1},
|
||||
{"_id", 1},
|
||||
{Key: "priority", Value: 1},
|
||||
{Key: "_id", Value: 1},
|
||||
},
|
||||
Limit: 1,
|
||||
}
|
||||
@@ -302,6 +354,51 @@ func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stre
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
// cleanupStaleStreams periodically checks for and removes stale streams
|
||||
func (svr *TaskServiceServer) cleanupStaleStreams() {
|
||||
ticker := time.NewTicker(10 * time.Minute) // Check every 10 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-svr.cleanupCtx.Done():
|
||||
svr.Debugf("stream cleanup routine shutting down")
|
||||
return
|
||||
case <-ticker.C:
|
||||
svr.performStreamCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performStreamCleanup checks each stream and removes those that are no longer active
|
||||
func (svr *TaskServiceServer) performStreamCleanup() {
|
||||
taskServiceMutex.Lock()
|
||||
defer taskServiceMutex.Unlock()
|
||||
|
||||
var staleTaskIds []primitive.ObjectID
|
||||
|
||||
for taskId, stream := range svr.subs {
|
||||
// Check if stream context is still active
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
// Stream is done, mark for removal
|
||||
staleTaskIds = append(staleTaskIds, taskId)
|
||||
default:
|
||||
// Stream is still active, continue
|
||||
}
|
||||
}
|
||||
|
||||
// Remove stale streams
|
||||
for _, taskId := range staleTaskIds {
|
||||
delete(svr.subs, taskId)
|
||||
svr.Infof("cleaned up stale stream for task: %s", taskId.Hex())
|
||||
}
|
||||
|
||||
if len(staleTaskIds) > 0 {
|
||||
svr.Infof("cleaned up %d stale streams", len(staleTaskIds))
|
||||
}
|
||||
}
|
||||
|
||||
func (svr TaskServiceServer) handleInsertData(taskId, spiderId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
|
||||
var records []map[string]interface{}
|
||||
err = json.Unmarshal(msg.Data, &records)
|
||||
@@ -332,12 +429,46 @@ func (svr TaskServiceServer) saveTask(t *models.Task) (err error) {
|
||||
}
|
||||
|
||||
func newTaskServiceServer() *TaskServiceServer {
|
||||
return &TaskServiceServer{
|
||||
cfgSvc: nodeconfig.GetNodeConfigService(),
|
||||
subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer),
|
||||
statsSvc: stats.GetTaskStatsService(),
|
||||
Logger: utils.NewLogger("GrpcTaskServiceServer"),
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
server := &TaskServiceServer{
|
||||
cfgSvc: nodeconfig.GetNodeConfigService(),
|
||||
subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer),
|
||||
statsSvc: stats.GetTaskStatsService(),
|
||||
cleanupCtx: ctx,
|
||||
cleanupCancel: cancel,
|
||||
Logger: utils.NewLogger("GrpcTaskServiceServer"),
|
||||
}
|
||||
|
||||
// Start the cleanup routine
|
||||
go server.cleanupStaleStreams()
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the task service server
|
||||
func (svr *TaskServiceServer) Stop() error {
|
||||
svr.Infof("stopping task service server...")
|
||||
|
||||
// Cancel cleanup routine
|
||||
if svr.cleanupCancel != nil {
|
||||
svr.cleanupCancel()
|
||||
}
|
||||
|
||||
// Clean up all remaining streams
|
||||
taskServiceMutex.Lock()
|
||||
streamCount := len(svr.subs)
|
||||
for taskId := range svr.subs {
|
||||
delete(svr.subs, taskId)
|
||||
}
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
if streamCount > 0 {
|
||||
svr.Infof("cleaned up %d remaining streams on shutdown", streamCount)
|
||||
}
|
||||
|
||||
svr.Infof("task service server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
var _taskServiceServer *TaskServiceServer
|
||||
|
||||
238
core/grpc/server/task_service_server_test.go
Normal file
238
core/grpc/server/task_service_server_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// Mock stream for testing
|
||||
type mockSubscribeStream struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (m *mockSubscribeStream) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *mockSubscribeStream) Send(*grpc.TaskServiceSubscribeResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSubscribeStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (m *mockSubscribeStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (m *mockSubscribeStream) SetTrailer(metadata.MD) {}
|
||||
func (m *mockSubscribeStream) RecvMsg(interface{}) error { return nil }
|
||||
func (m *mockSubscribeStream) SendMsg(interface{}) error { return nil }
|
||||
|
||||
func newMockSubscribeStream() *mockSubscribeStream {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &mockSubscribeStream{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskServiceServer_Subscribe_Timeout(t *testing.T) {
|
||||
server := &TaskServiceServer{
|
||||
subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer),
|
||||
Logger: utils.NewLogger("TestTaskServiceServer"),
|
||||
}
|
||||
|
||||
taskId := primitive.NewObjectID()
|
||||
mockStream := newMockSubscribeStream()
|
||||
|
||||
req := &grpc.TaskServiceSubscribeRequest{
|
||||
TaskId: taskId.Hex(),
|
||||
}
|
||||
|
||||
// Start subscribe in goroutine
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.Subscribe(req, mockStream)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Wait a moment for subscription to be added
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify stream was added
|
||||
taskServiceMutex.Lock()
|
||||
_, exists := server.subs[taskId]
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("Stream was not added to subscription map")
|
||||
}
|
||||
|
||||
// Cancel the mock stream context
|
||||
mockStream.cancel()
|
||||
|
||||
// Wait for subscribe to complete
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
t.Error("Expected error from cancelled context")
|
||||
}
|
||||
t.Logf("✅ Subscribe returned with error as expected: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Subscribe didn't return within timeout")
|
||||
}
|
||||
|
||||
// Verify stream was cleaned up
|
||||
taskServiceMutex.Lock()
|
||||
_, exists = server.subs[taskId]
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
if exists {
|
||||
t.Error("Stream was not cleaned up from subscription map")
|
||||
} else {
|
||||
t.Log("✅ Stream properly cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskServiceServer_StreamCleanup(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server := &TaskServiceServer{
|
||||
subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer),
|
||||
cleanupCtx: ctx,
|
||||
cleanupCancel: cancel,
|
||||
Logger: utils.NewLogger("TestTaskServiceServer"),
|
||||
}
|
||||
|
||||
// Add some mock streams
|
||||
taskId1 := primitive.NewObjectID()
|
||||
taskId2 := primitive.NewObjectID()
|
||||
|
||||
mockStream1 := newMockSubscribeStream()
|
||||
mockStream2 := newMockSubscribeStream()
|
||||
|
||||
taskServiceMutex.Lock()
|
||||
server.subs[taskId1] = mockStream1
|
||||
server.subs[taskId2] = mockStream2
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
// Cancel one stream
|
||||
mockStream1.cancel()
|
||||
|
||||
// Wait a moment
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Perform cleanup
|
||||
server.performStreamCleanup()
|
||||
|
||||
// Verify only the cancelled stream was removed
|
||||
taskServiceMutex.Lock()
|
||||
_, exists1 := server.subs[taskId1]
|
||||
_, exists2 := server.subs[taskId2]
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
if exists1 {
|
||||
t.Error("Cancelled stream was not cleaned up")
|
||||
} else {
|
||||
t.Log("✅ Cancelled stream properly cleaned up")
|
||||
}
|
||||
|
||||
if !exists2 {
|
||||
t.Error("Active stream was incorrectly removed")
|
||||
} else {
|
||||
t.Log("✅ Active stream preserved")
|
||||
}
|
||||
|
||||
// Clean up remaining
|
||||
mockStream2.cancel()
|
||||
}
|
||||
|
||||
func TestTaskServiceServer_Stop(t *testing.T) {
|
||||
server := newTaskServiceServer()
|
||||
|
||||
// Add some mock streams
|
||||
taskId := primitive.NewObjectID()
|
||||
mockStream := newMockSubscribeStream()
|
||||
|
||||
taskServiceMutex.Lock()
|
||||
server.subs[taskId] = mockStream
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
// Stop the server
|
||||
err := server.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("Stop returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify all streams are cleaned up
|
||||
taskServiceMutex.Lock()
|
||||
streamCount := len(server.subs)
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
if streamCount != 0 {
|
||||
t.Errorf("Expected 0 streams after stop, got %d", streamCount)
|
||||
} else {
|
||||
t.Log("✅ All streams cleaned up on stop")
|
||||
}
|
||||
|
||||
// Verify cleanup context is cancelled
|
||||
select {
|
||||
case <-server.cleanupCtx.Done():
|
||||
t.Log("✅ Cleanup context properly cancelled")
|
||||
default:
|
||||
t.Error("Cleanup context not cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskServiceServer_ConcurrentAccess(t *testing.T) {
|
||||
server := &TaskServiceServer{
|
||||
subs: make(map[primitive.ObjectID]grpc.TaskService_SubscribeServer),
|
||||
Logger: utils.NewLogger("TestTaskServiceServer"),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines adding/removing streams
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
taskId := primitive.NewObjectID()
|
||||
mockStream := newMockSubscribeStream()
|
||||
defer mockStream.cancel()
|
||||
|
||||
// Add stream
|
||||
taskServiceMutex.Lock()
|
||||
server.subs[taskId] = mockStream
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
// Do some work
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Remove stream
|
||||
taskServiceMutex.Lock()
|
||||
delete(server.subs, taskId)
|
||||
taskServiceMutex.Unlock()
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("✅ Concurrent access test completed successfully")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Concurrent access test timed out")
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -52,6 +53,12 @@ func newTaskRunner(id primitive.ObjectID, svc *Service) (r *Runner, err error) {
|
||||
ch: make(chan constants.TaskSignal),
|
||||
logBatchSize: 20,
|
||||
Logger: utils.NewLogger("TaskRunner"),
|
||||
// treat all tasks as potentially long-running
|
||||
maxConnRetries: 10,
|
||||
connRetryDelay: 10 * time.Second,
|
||||
ipcTimeout: 60 * time.Second, // generous timeout for all tasks
|
||||
healthCheckInterval: 5 * time.Second, // check process every 5 seconds
|
||||
connHealthInterval: 60 * time.Second, // check connection health every minute
|
||||
}
|
||||
|
||||
// multi error
|
||||
@@ -123,6 +130,19 @@ type Runner struct {
|
||||
cancel context.CancelFunc // function to cancel the context
|
||||
done chan struct{} // channel to signal completion
|
||||
wg sync.WaitGroup // wait group for goroutine synchronization
|
||||
// connection management for robust task execution
|
||||
connMutex sync.RWMutex // mutex for connection access
|
||||
connHealthTicker *time.Ticker // ticker for connection health checks
|
||||
lastConnCheck time.Time // last successful connection check
|
||||
connRetryAttempts int // current retry attempts
|
||||
maxConnRetries int // maximum connection retry attempts
|
||||
connRetryDelay time.Duration // delay between connection retries
|
||||
resourceCleanup *time.Ticker // periodic resource cleanup
|
||||
|
||||
// configurable timeouts for robust task execution
|
||||
ipcTimeout time.Duration // timeout for IPC operations
|
||||
healthCheckInterval time.Duration // interval for health checks
|
||||
connHealthInterval time.Duration // interval for connection health checks
|
||||
}
|
||||
|
||||
// Init initializes the task runner by updating the task status and establishing gRPC connections
|
||||
@@ -204,7 +224,15 @@ func (r *Runner) Run() (err error) {
|
||||
// 1. Signal all goroutines to stop
|
||||
r.cancel()
|
||||
|
||||
// 2. Wait for all goroutines to finish with timeout
|
||||
// 2. Stop tickers to prevent resource leaks
|
||||
if r.connHealthTicker != nil {
|
||||
r.connHealthTicker.Stop()
|
||||
}
|
||||
if r.resourceCleanup != nil {
|
||||
r.resourceCleanup.Stop()
|
||||
}
|
||||
|
||||
// 3. Wait for all goroutines to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.wg.Wait()
|
||||
@@ -214,17 +242,20 @@ func (r *Runner) Run() (err error) {
|
||||
select {
|
||||
case <-done:
|
||||
// All goroutines finished normally
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(10 * time.Second): // Increased timeout for long-running tasks
|
||||
// Timeout waiting for goroutines, proceed with cleanup
|
||||
r.Warnf("timeout waiting for goroutines to finish, proceeding with cleanup")
|
||||
}
|
||||
|
||||
// 3. Close gRPC connection after all goroutines have stopped
|
||||
// 4. Close gRPC connection after all goroutines have stopped
|
||||
r.connMutex.Lock()
|
||||
if r.conn != nil {
|
||||
_ = r.conn.CloseSend()
|
||||
r.conn = nil
|
||||
}
|
||||
r.connMutex.Unlock()
|
||||
|
||||
// 4. Close channels after everything has stopped
|
||||
// 5. Close channels after everything has stopped
|
||||
close(r.done)
|
||||
if r.ipcChan != nil {
|
||||
close(r.ipcChan)
|
||||
@@ -346,7 +377,7 @@ func (r *Runner) startHealthCheck() {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
ticker := time.NewTicker(r.healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -447,9 +478,7 @@ func (r *Runner) configureEnv() {
|
||||
|
||||
func (r *Runner) performHttpRequest(method, path string, params url.Values) (*http.Response, error) {
|
||||
// Normalize path
|
||||
if strings.HasPrefix(path, "/") {
|
||||
path = path[1:]
|
||||
}
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
|
||||
// Construct master URL
|
||||
var id string
|
||||
@@ -768,17 +797,167 @@ func (r *Runner) updateTask(status string, e error) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initConnection establishes a gRPC connection to the task service
|
||||
// initConnection establishes a gRPC connection to the task service with retry logic
|
||||
func (r *Runner) initConnection() (err error) {
|
||||
r.connMutex.Lock()
|
||||
defer r.connMutex.Unlock()
|
||||
|
||||
r.conn, err = client2.GetGrpcClient().TaskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
r.Errorf("error connecting to task service: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
r.lastConnCheck = time.Now()
|
||||
r.connRetryAttempts = 0
|
||||
// Start connection health monitoring for all tasks (potentially long-running)
|
||||
go r.monitorConnectionHealth()
|
||||
|
||||
// Start periodic resource cleanup for all tasks
|
||||
go r.performPeriodicCleanup()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorConnectionHealth periodically checks gRPC connection health and reconnects if needed
|
||||
func (r *Runner) monitorConnectionHealth() {
|
||||
r.wg.Add(1)
|
||||
defer r.wg.Done()
|
||||
|
||||
r.connHealthTicker = time.NewTicker(r.connHealthInterval)
|
||||
defer r.connHealthTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-r.connHealthTicker.C:
|
||||
if r.isConnectionHealthy() {
|
||||
r.lastConnCheck = time.Now()
|
||||
r.connRetryAttempts = 0
|
||||
} else {
|
||||
r.Warnf("gRPC connection unhealthy, attempting reconnection (attempt %d/%d)",
|
||||
r.connRetryAttempts+1, r.maxConnRetries)
|
||||
if err := r.reconnectWithRetry(); err != nil {
|
||||
r.Errorf("failed to reconnect after %d attempts: %v", r.maxConnRetries, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isConnectionHealthy checks if the gRPC connection is still healthy
|
||||
func (r *Runner) isConnectionHealthy() bool {
|
||||
r.connMutex.RLock()
|
||||
defer r.connMutex.RUnlock()
|
||||
|
||||
if r.conn == nil {
|
||||
return false
|
||||
}
|
||||
// Try to send a ping-like message to test connection
|
||||
// Use a simple log message as ping since PING code doesn't exist
|
||||
testMsg := &grpc.TaskServiceConnectRequest{
|
||||
Code: grpc.TaskServiceConnectCode_INSERT_LOGS,
|
||||
TaskId: r.tid.Hex(),
|
||||
Data: []byte(`["[HEALTH CHECK] connection test"]`),
|
||||
}
|
||||
|
||||
if err := r.conn.Send(testMsg); err != nil {
|
||||
r.Debugf("connection health check failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// reconnectWithRetry attempts to reconnect to the gRPC service with exponential backoff
|
||||
func (r *Runner) reconnectWithRetry() error {
|
||||
r.connMutex.Lock()
|
||||
defer r.connMutex.Unlock()
|
||||
|
||||
for attempt := 0; attempt < r.maxConnRetries; attempt++ {
|
||||
r.connRetryAttempts = attempt + 1
|
||||
|
||||
// Close existing connection
|
||||
if r.conn != nil {
|
||||
_ = r.conn.CloseSend()
|
||||
r.conn = nil
|
||||
}
|
||||
|
||||
// Wait before retry (exponential backoff)
|
||||
if attempt > 0 {
|
||||
backoffDelay := time.Duration(attempt) * r.connRetryDelay
|
||||
r.Debugf("waiting %v before retry attempt %d", backoffDelay, attempt+1)
|
||||
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return fmt.Errorf("context cancelled during reconnection")
|
||||
case <-time.After(backoffDelay):
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt reconnection
|
||||
conn, err := client2.GetGrpcClient().TaskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
r.Warnf("reconnection attempt %d failed: %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
r.conn = conn
|
||||
r.lastConnCheck = time.Now()
|
||||
r.connRetryAttempts = 0
|
||||
r.Infof("successfully reconnected to task service after %d attempts", attempt+1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to reconnect after %d attempts", r.maxConnRetries)
|
||||
}
|
||||
|
||||
// performPeriodicCleanup runs periodic cleanup for all tasks
|
||||
func (r *Runner) performPeriodicCleanup() {
|
||||
r.wg.Add(1)
|
||||
defer r.wg.Done()
|
||||
|
||||
// Cleanup every 10 minutes for all tasks
|
||||
r.resourceCleanup = time.NewTicker(10 * time.Minute)
|
||||
defer r.resourceCleanup.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-r.resourceCleanup.C:
|
||||
r.runPeriodicCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runPeriodicCleanup performs memory and resource cleanup
|
||||
func (r *Runner) runPeriodicCleanup() {
|
||||
r.Debugf("performing periodic cleanup for task")
|
||||
|
||||
// Force garbage collection for memory management
|
||||
runtime.GC()
|
||||
|
||||
// Log current resource usage
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
r.Debugf("memory usage - alloc: %d KB, sys: %d KB, num_gc: %d",
|
||||
m.Alloc/1024, m.Sys/1024, m.NumGC)
|
||||
|
||||
// Check if IPC channel is getting full
|
||||
if r.ipcChan != nil {
|
||||
select {
|
||||
case <-r.ipcChan:
|
||||
r.Debugf("drained stale IPC message during cleanup")
|
||||
default:
|
||||
// Channel is not full, good
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeLogLines marshals log lines to JSON and sends them to the task service
|
||||
// Uses connection-safe approach for robust task execution
|
||||
func (r *Runner) writeLogLines(lines []string) {
|
||||
// Check if context is cancelled or connection is closed
|
||||
select {
|
||||
@@ -787,8 +966,14 @@ func (r *Runner) writeLogLines(lines []string) {
|
||||
default:
|
||||
}
|
||||
|
||||
// Use connection with mutex for thread safety
|
||||
r.connMutex.RLock()
|
||||
conn := r.conn
|
||||
r.connMutex.RUnlock()
|
||||
|
||||
// Check if connection is available
|
||||
if r.conn == nil {
|
||||
if conn == nil {
|
||||
r.Debugf("no connection available for sending log lines")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -797,18 +982,22 @@ func (r *Runner) writeLogLines(lines []string) {
|
||||
r.Errorf("error marshaling log lines: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
msg := &grpc.TaskServiceConnectRequest{
|
||||
Code: grpc.TaskServiceConnectCode_INSERT_LOGS,
|
||||
TaskId: r.tid.Hex(),
|
||||
Data: linesBytes,
|
||||
}
|
||||
if err := r.conn.Send(msg); err != nil {
|
||||
|
||||
if err := conn.Send(msg); err != nil {
|
||||
// Don't log errors if context is cancelled (expected during shutdown)
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
default:
|
||||
r.Errorf("error sending log lines: %v", err)
|
||||
// Mark connection as unhealthy for reconnection
|
||||
r.lastConnCheck = time.Time{}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1087,14 +1276,19 @@ func (r *Runner) handleIPCInsertDataMessage(ipcMsg entity.IPCMessage) {
|
||||
default:
|
||||
}
|
||||
|
||||
// Use connection with mutex for thread safety
|
||||
r.connMutex.RLock()
|
||||
conn := r.conn
|
||||
r.connMutex.RUnlock()
|
||||
|
||||
// Validate connection
|
||||
if r.conn == nil {
|
||||
if conn == nil {
|
||||
r.Errorf("gRPC connection not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
// Send IPC message to master with context and timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.ipcTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create gRPC message
|
||||
@@ -1112,13 +1306,15 @@ func (r *Runner) handleIPCInsertDataMessage(ipcMsg entity.IPCMessage) {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
default:
|
||||
if err := r.conn.Send(grpcMsg); err != nil {
|
||||
if err := conn.Send(grpcMsg); err != nil {
|
||||
// Don't log errors if context is cancelled (expected during shutdown)
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
default:
|
||||
r.Errorf("error sending IPC message: %v", err)
|
||||
// Mark connection as unhealthy for reconnection
|
||||
r.lastConnCheck = time.Time{}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1282,3 +1478,17 @@ func (r *Runner) Debugf(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
r.logInternally("DEBUG", msg)
|
||||
}
|
||||
|
||||
// GetConnectionStats returns connection health statistics for monitoring
|
||||
func (r *Runner) GetConnectionStats() map[string]interface{} {
|
||||
r.connMutex.RLock()
|
||||
defer r.connMutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"last_connection_check": r.lastConnCheck,
|
||||
"retry_attempts": r.connRetryAttempts,
|
||||
"max_retries": r.maxConnRetries,
|
||||
"connection_healthy": r.isConnectionHealthy(),
|
||||
"connection_exists": r.conn != nil,
|
||||
}
|
||||
}
|
||||
|
||||
219
core/task/handler/runner_resilience_test.go
Normal file
219
core/task/handler/runner_resilience_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/utils"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// TestRunner_LongRunningTaskResilience tests the robustness features for long-running tasks
|
||||
func TestRunner_LongRunningTaskResilience(t *testing.T) {
|
||||
// Create a mock task runner with the resilience features
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
maxConnRetries: 10,
|
||||
connRetryDelay: 10 * time.Second,
|
||||
ipcTimeout: 60 * time.Second,
|
||||
healthCheckInterval: 5 * time.Second,
|
||||
connHealthInterval: 60 * time.Second,
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
// Test that default values are set for robust execution
|
||||
if r.maxConnRetries != 10 {
|
||||
t.Errorf("Expected maxConnRetries to be 10, got %d", r.maxConnRetries)
|
||||
}
|
||||
|
||||
if r.ipcTimeout != 60*time.Second {
|
||||
t.Errorf("Expected ipcTimeout to be 60s, got %v", r.ipcTimeout)
|
||||
}
|
||||
|
||||
if r.connHealthInterval != 60*time.Second {
|
||||
t.Errorf("Expected connHealthInterval to be 60s, got %v", r.connHealthInterval)
|
||||
}
|
||||
|
||||
if r.healthCheckInterval != 5*time.Second {
|
||||
t.Errorf("Expected healthCheckInterval to be 5s, got %v", r.healthCheckInterval)
|
||||
}
|
||||
|
||||
t.Log("✅ All resilience settings configured correctly for robust task execution")
|
||||
}
|
||||
|
||||
// TestRunner_ConnectionHealthMonitoring tests the connection health monitoring
|
||||
func TestRunner_ConnectionHealthMonitoring(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
maxConnRetries: 3,
|
||||
connRetryDelay: 100 * time.Millisecond, // Short delay for testing
|
||||
connHealthInterval: 200 * time.Millisecond, // Short interval for testing
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
// Test connection stats
|
||||
stats := r.GetConnectionStats()
|
||||
if stats == nil {
|
||||
t.Fatal("GetConnectionStats returned nil")
|
||||
}
|
||||
|
||||
// Check that all expected keys are present
|
||||
expectedKeys := []string{
|
||||
"last_connection_check",
|
||||
"retry_attempts",
|
||||
"max_retries",
|
||||
"connection_healthy",
|
||||
"connection_exists",
|
||||
}
|
||||
|
||||
for _, key := range expectedKeys {
|
||||
if _, exists := stats[key]; !exists {
|
||||
t.Errorf("Expected key '%s' not found in connection stats", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that connection is initially unhealthy (no actual connection)
|
||||
if stats["connection_healthy"].(bool) {
|
||||
t.Error("Expected connection to be unhealthy without actual connection")
|
||||
}
|
||||
|
||||
if stats["connection_exists"].(bool) {
|
||||
t.Error("Expected connection_exists to be false without actual connection")
|
||||
}
|
||||
|
||||
t.Log("✅ Connection health monitoring working correctly")
|
||||
}
|
||||
|
||||
// TestRunner_PeriodicCleanup tests the periodic cleanup functionality
|
||||
func TestRunner_PeriodicCleanup(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
// Record memory stats before cleanup
|
||||
var beforeStats runtime.MemStats
|
||||
runtime.ReadMemStats(&beforeStats)
|
||||
|
||||
// Run cleanup
|
||||
r.runPeriodicCleanup()
|
||||
|
||||
// Record memory stats after cleanup
|
||||
var afterStats runtime.MemStats
|
||||
runtime.ReadMemStats(&afterStats)
|
||||
|
||||
// Verify that GC was called (NumGC should have increased)
|
||||
if afterStats.NumGC <= beforeStats.NumGC {
|
||||
t.Log("Note: GC count didn't increase, but this is normal in test environment")
|
||||
}
|
||||
|
||||
t.Log("✅ Periodic cleanup executed successfully")
|
||||
}
|
||||
|
||||
// TestRunner_ContextCancellation tests proper context handling
|
||||
func TestRunner_ContextCancellation(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Test writeLogLines with cancelled context
|
||||
r.cancel() // Cancel context first
|
||||
|
||||
// This should return early without error
|
||||
r.writeLogLines([]string{"test log"})
|
||||
|
||||
t.Log("✅ Context cancellation handled correctly")
|
||||
}
|
||||
|
||||
// TestRunner_ThreadSafety tests thread-safe access to connection
|
||||
func TestRunner_ThreadSafety(t *testing.T) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
Logger: utils.NewLogger("TestTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines accessing connection stats
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < 100; j++ {
|
||||
// Access connection stats (this uses RWMutex)
|
||||
stats := r.GetConnectionStats()
|
||||
if stats == nil {
|
||||
t.Errorf("Goroutine %d: GetConnectionStats returned nil", id)
|
||||
return
|
||||
}
|
||||
|
||||
// Small delay to increase chance of race conditions
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("✅ Thread safety test completed successfully")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Thread safety test timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRunner_ConnectionStats benchmarks the connection stats access
|
||||
func BenchmarkRunner_ConnectionStats(b *testing.B) {
|
||||
r := &Runner{
|
||||
tid: primitive.NewObjectID(),
|
||||
maxConnRetries: 10,
|
||||
connRetryDelay: 10 * time.Second,
|
||||
ipcTimeout: 60 * time.Second,
|
||||
healthCheckInterval: 5 * time.Second,
|
||||
connHealthInterval: 60 * time.Second,
|
||||
Logger: utils.NewLogger("BenchmarkTaskRunner"),
|
||||
}
|
||||
|
||||
// Initialize context
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
defer r.cancel()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stats := r.GetConnectionStats()
|
||||
if stats == nil {
|
||||
b.Fatal("GetConnectionStats returned nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user