diff --git a/core/grpc/client/client.go b/core/grpc/client/client.go index a3616bef..4ff2ef04 100644 --- a/core/grpc/client/client.go +++ b/core/grpc/client/client.go @@ -2,7 +2,6 @@ package client import ( "context" - "encoding/json" "fmt" "sync" "time" @@ -39,10 +38,21 @@ type GrpcClient struct { ModelBaseServiceClient grpc2.ModelBaseServiceClient DependencyClient grpc2.DependencyServiceClient MetricClient grpc2.MetricServiceClient + + // Add new fields for state management + state connectivity.State + stateMux sync.RWMutex + reconnect chan struct{} } func (c *GrpcClient) Start() (err error) { c.once.Do(func() { + // initialize reconnect channel + c.reconnect = make(chan struct{}) + + // start state monitor + go c.monitorState() + // connect err = c.connect() if err != nil { @@ -69,6 +79,7 @@ func (c *GrpcClient) Stop() (err error) { // close connection if err := c.conn.Close(); err != nil { + log.Errorf("grpc client failed to close connection: %v", err) return err } log.Infof("grpc client disconnected from %s", c.address) @@ -83,6 +94,7 @@ func (c *GrpcClient) WaitForReady() { select { case <-ticker.C: if c.IsReady() { + log.Debugf("grpc client ready") return } case <-c.stop: @@ -104,7 +116,8 @@ func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) } func (c *GrpcClient) IsReady() (res bool) { - return c.conn != nil && c.conn.GetState() == connectivity.Ready + state := c.conn.GetState() + return c.conn != nil && state == connectivity.Ready } func (c *GrpcClient) IsClosed() (res bool) { @@ -114,24 +127,75 @@ func (c *GrpcClient) IsClosed() (res bool) { return false } -func (c *GrpcClient) getRequestData(d interface{}) (data []byte) { - if d == nil { - return data - } - switch d.(type) { - case []byte: - data = d.([]byte) - default: - var err error - data, err = json.Marshal(d) - if err != nil { - panic(err) +func (c *GrpcClient) monitorState() { + for { + select { + case <-c.stop: + return + default: + if c.conn == nil { + time.Sleep(time.Second) + continue + } + + previous := c.getState() + current := c.conn.GetState() + + if previous != current { + c.setState(current) + log.Infof("[GrpcClient] state changed from %s to %s", previous, current) + + // Trigger reconnect if connection is lost or becomes idle from ready state + if current == connectivity.TransientFailure || + current == connectivity.Shutdown || + (previous == connectivity.Ready && current == connectivity.Idle) { + select { + case c.reconnect <- struct{}{}: + log.Infof("[GrpcClient] triggering reconnection due to state change to %s", current) + default: + } + } + } + + time.Sleep(time.Second) } } - return data +} + +func (c *GrpcClient) setState(state connectivity.State) { + c.stateMux.Lock() + defer c.stateMux.Unlock() + c.state = state +} + +func (c *GrpcClient) getState() connectivity.State { + c.stateMux.RLock() + defer c.stateMux.RUnlock() + return c.state } func (c *GrpcClient) connect() (err error) { + // Start reconnection loop + go func() { + for { + select { + case <-c.stop: + return + case <-c.reconnect: + if !c.stopped { + log.Infof("[GrpcClient] attempting to reconnect to %s", c.address) + if err := c.doConnect(); err != nil { + log.Errorf("[GrpcClient] reconnection failed: %v", err) + } + } + } + } + }() + + return c.doConnect() +} + +func (c *GrpcClient) doConnect() (err error) { op := func() error { // connection options opts := []grpc.DialOption{ @@ -164,10 +228,9 @@ func (c *GrpcClient) connect() (err error) { return nil } - b := backoff.NewExponentialBackOff( - backoff.WithInitialInterval(5*time.Second), - backoff.WithMaxElapsedTime(10*time.Minute), - ) + b := backoff.NewExponentialBackOff() + b.InitialInterval = 5 * time.Second + b.MaxElapsedTime = 10 * time.Minute n := func(err error, duration time.Duration) { log.Errorf("[GrpcClient] grpc client failed to connect to %s: %v, retrying in %s", c.address, err, duration) } @@ -179,6 +242,7 @@ func newGrpcClient() (c *GrpcClient) { address: utils.GetGrpcAddress(), timeout: 10 * time.Second, stop: make(chan struct{}), + state: connectivity.Idle, } } diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 13d87bce..e33499df 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -75,14 +75,14 @@ type Runner struct { // Init initializes the task runner by updating the task status and establishing gRPC connections func (r *Runner) Init() (err error) { + // wait for grpc client ready + client2.GetGrpcClient().WaitForReady() + // update task if err := r.updateTask("", nil); err != nil { return err } - // wait for grpc client ready - client2.GetGrpcClient().WaitForReady() - // grpc task service stream client if err := r.initConnection(); err != nil { return err diff --git a/core/task/handler/runner_test.go b/core/task/handler/runner_test.go index 2dd431c6..ffa451d2 100644 --- a/core/task/handler/runner_test.go +++ b/core/task/handler/runner_test.go @@ -6,6 +6,9 @@ import ( "encoding/json" "fmt" "github.com/crawlab-team/crawlab/core/entity" + "github.com/crawlab-team/crawlab/core/grpc/client" + "github.com/crawlab-team/crawlab/core/grpc/server" + "github.com/crawlab-team/crawlab/core/utils" "io" "runtime" "testing" @@ -15,7 +18,6 @@ import ( "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/models/models" "github.com/crawlab-team/crawlab/core/models/service" - "github.com/crawlab-team/crawlab/core/utils" "github.com/crawlab-team/crawlab/grpc" "github.com/spf13/viper" "github.com/stretchr/testify/assert" @@ -23,11 +25,30 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) -func setupTest(t *testing.T) *Runner { +func setupGrpc(t *testing.T) { // Mock IsMaster function by setting viper config viper.Set("node.master", true) defer viper.Set("node.master", nil) // cleanup after test + // Start a gRPC server + svr := server.GetGrpcServer() + err := svr.Start() + require.Nil(t, err) + + // Start a gRPC client + err = client.GetGrpcClient().Start() + require.Nil(t, err) + + // Cleanup + t.Cleanup(func() { + err = svr.Stop() + if err != nil { + log.Warnf("failed to stop gRPC server: %v", err) + } + }) +} + +func setupRunner(t *testing.T) *Runner { // Create a test spider spider := &models.Spider{ Name: "Test Spider", @@ -72,226 +93,297 @@ func setupTest(t *testing.T) *Runner { return runner } -func TestRunner_HandleIPC(t *testing.T) { +func TestRunner(t *testing.T) { // Setup test data - runner := setupTest(t) + setupGrpc(t) - // Create a pipe for testing - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - runner.stdoutPipe = pr + t.Run("HandleIPC", func(t *testing.T) { + // Create a runner + runner := setupRunner(t) - // Start IPC reader - go runner.startIPCReader() + // Create a pipe for testing + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + runner.stdoutPipe = pr - // Create test message - testMsg := entity.IPCMessage{ - Type: "test_type", - Payload: map[string]interface{}{"key": "value"}, - IPC: true, - } + // Start IPC reader + go runner.startIPCReader() + + // Create test message + testMsg := entity.IPCMessage{ + Type: "test_type", + Payload: map[string]interface{}{"key": "value"}, + IPC: true, + } + + // Create a channel to signal that the message was handled + handled := make(chan bool) + runner.SetIPCHandler(func(msg entity.IPCMessage) { + assert.Equal(t, testMsg.Type, msg.Type) + assert.Equal(t, testMsg.Payload, msg.Payload) + handled <- true + }) + + // Convert message to JSON and write to pipe + go func() { + jsonData, err := json.Marshal(testMsg) + if err != nil { + t.Errorf("failed to marshal test message: %v", err) + return + } + + // Write message followed by newline + _, err = fmt.Fprintln(pw, string(jsonData)) + if err != nil { + t.Errorf("failed to write to pipe: %v", err) + return + } + }() + + select { + case <-handled: + // Message was handled successfully + log.Info("IPC message was handled successfully") + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for IPC message to be handled") + } - // Create a channel to signal that the message was handled - handled := make(chan bool) - runner.SetIPCHandler(func(msg entity.IPCMessage) { - assert.Equal(t, testMsg.Type, msg.Type) - assert.Equal(t, testMsg.Payload, msg.Payload) - handled <- true }) - // Convert message to JSON and write to pipe - go func() { - jsonData, err := json.Marshal(testMsg) - if err != nil { - t.Errorf("failed to marshal test message: %v", err) - return - } + t.Run("Cancel", func(t *testing.T) { + // Create a runner + runner := setupRunner(t) - // Write message followed by newline - _, err = fmt.Fprintln(pw, string(jsonData)) - if err != nil { - t.Errorf("failed to write to pipe: %v", err) - return - } - }() + // Create pipes for stdout + pr, pw := io.Pipe() + runner.cmd.Stdout = pw + runner.cmd.Stderr = pw - select { - case <-handled: - // Message was handled successfully - log.Info("IPC message was handled successfully") - case <-time.After(3 * time.Second): - t.Fatal("timeout waiting for IPC message to be handled") - } -} - -func TestRunner_Cancel(t *testing.T) { - // Setup - runner := setupTest(t) - - // Create pipes for stdout - pr, pw := io.Pipe() - runner.cmd.Stdout = pw - runner.cmd.Stderr = pw - - // Start the command - err := runner.cmd.Start() - assert.NoError(t, err) - log.Infof("started process with PID: %d", runner.cmd.Process.Pid) - runner.pid = runner.cmd.Process.Pid - - // Read and print command output - go func() { - scanner := bufio.NewScanner(pr) - for scanner.Scan() { - log.Info(scanner.Text()) - } - }() - - // Wait for process to finish - go func() { - err = runner.cmd.Wait() - if err != nil { - log.Errorf("process[%d] exited with error: %v", runner.pid, err) - return - } - log.Infof("process[%d] exited successfully", runner.pid) - }() - - // Wait for a certain period for the process to start properly - time.Sleep(1 * time.Second) - - // Verify process exists before attempting to cancel - if !utils.ProcessIdExists(runner.pid) { - t.Fatalf("Process with PID %d was not started successfully", runner.pid) - } - - // Test cancel - go func() { - err = runner.Cancel(true) + // Start the command + err := runner.cmd.Start() assert.NoError(t, err) - log.Infof("process[%d] cancelled", runner.pid) - }() + log.Infof("started process with PID: %d", runner.cmd.Process.Pid) + runner.pid = runner.cmd.Process.Pid - // Wait for process to be killed, with shorter timeout - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - t.Fatalf("Process with PID %d was not killed within timeout", runner.pid) - case <-ticker.C: - exists := utils.ProcessIdExists(runner.pid) - if !exists { - return // Exit the test when process is killed + // Read and print command output + go func() { + scanner := bufio.NewScanner(pr) + for scanner.Scan() { + log.Info(scanner.Text()) + } + }() + + // Wait for process to finish + go func() { + err = runner.cmd.Wait() + if err != nil { + log.Errorf("process[%d] exited with error: %v", runner.pid, err) + return + } + log.Infof("process[%d] exited successfully", runner.pid) + }() + + // Wait for a certain period for the process to start properly + time.Sleep(1 * time.Second) + + // Verify process exists before attempting to cancel + if !utils.ProcessIdExists(runner.pid) { + t.Fatalf("Process with PID %d was not started successfully", runner.pid) + } + + // Test cancel + go func() { + err = runner.Cancel(true) + assert.NoError(t, err) + log.Infof("process[%d] cancelled", runner.pid) + }() + + // Wait for process to be killed, with shorter timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + t.Fatalf("Process with PID %d was not killed within timeout", runner.pid) + case <-ticker.C: + exists := utils.ProcessIdExists(runner.pid) + if !exists { + return // Exit the test when process is killed + } } } - } -} + }) -func TestRunner_HandleIPCData(t *testing.T) { - // Setup test data - runner := setupTest(t) + t.Run("HandleIPCData", func(t *testing.T) { + // Create a runner + runner := setupRunner(t) - // Create pipes for testing - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - runner.stdoutPipe = pr + // Create pipes for testing + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + runner.stdoutPipe = pr - // Start IPC reader - go runner.startIPCReader() + // Start IPC reader + go runner.startIPCReader() - // Test cases - testCases := []struct { - name string - payload interface{} - expected int // expected number of records - }{ - { - name: "single object", - payload: map[string]interface{}{ - "field1": "value1", - "field2": 123, - }, - expected: 1, - }, - { - name: "array of objects", - payload: []map[string]interface{}{ - { + // Test cases + testCases := []struct { + name string + payload interface{} + expected int // expected number of records + }{ + { + name: "single object", + payload: map[string]interface{}{ "field1": "value1", "field2": 123, }, - { - "field1": "value2", - "field2": 456, - }, + expected: 1, }, - expected: 2, - }, - { - name: "empty payload", - payload: nil, - expected: 0, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create a channel to track processed messages - processed := make(chan int) - - // Mock the gRPC connection - runner.conn = &mockConnectClient{ - sendFunc: func(req *grpc.TaskServiceConnectRequest) error { - // Verify the request - assert.Equal(t, grpc.TaskServiceConnectCode_INSERT_DATA, req.Code) - assert.Equal(t, runner.tid.Hex(), req.TaskId) - - // If payload was nil, we expect no data - if tc.payload == nil { - processed <- 0 - return nil - } - - // Unmarshal the data to verify record count - var records []map[string]interface{} - err := json.Unmarshal(req.Data, &records) - assert.NoError(t, err) - processed <- len(records) - return nil + { + name: "array of objects", + payload: []map[string]interface{}{ + { + "field1": "value1", + "field2": 123, + }, + { + "field1": "value2", + "field2": 456, + }, }, - } + expected: 2, + }, + { + name: "empty payload", + payload: nil, + expected: 0, + }, + } - // Create test message - testMsg := entity.IPCMessage{ - Type: constants.IPCMessageData, - Payload: tc.payload, - IPC: true, - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a channel to track processed messages + processed := make(chan int) - // Convert message to JSON and write to pipe - go func() { - jsonData, err := json.Marshal(testMsg) - assert.NoError(t, err) - _, err = fmt.Fprintln(pw, string(jsonData)) - assert.NoError(t, err) - }() + // Mock the gRPC connection + runner.conn = &mockConnectClient{ + sendFunc: func(req *grpc.TaskServiceConnectRequest) error { + // Verify the request + assert.Equal(t, grpc.TaskServiceConnectCode_INSERT_DATA, req.Code) + assert.Equal(t, runner.tid.Hex(), req.TaskId) - // Wait for processing with timeout - select { - case recordCount := <-processed: - assert.Equal(t, tc.expected, recordCount) - case <-time.After(1 * time.Second): - if tc.expected > 0 { - t.Fatal("timeout waiting for IPC message to be processed") + // If payload was nil, we expect no data + if tc.payload == nil { + processed <- 0 + return nil + } + + // Unmarshal the data to verify record count + var records []map[string]interface{} + err := json.Unmarshal(req.Data, &records) + assert.NoError(t, err) + processed <- len(records) + return nil + }, } - } - }) - } + + // Create test message + testMsg := entity.IPCMessage{ + Type: constants.IPCMessageData, + Payload: tc.payload, + IPC: true, + } + + // Convert message to JSON and write to pipe + go func() { + jsonData, err := json.Marshal(testMsg) + assert.NoError(t, err) + _, err = fmt.Fprintln(pw, string(jsonData)) + assert.NoError(t, err) + }() + + // Wait for processing with timeout + select { + case recordCount := <-processed: + assert.Equal(t, tc.expected, recordCount) + case <-time.After(1 * time.Second): + if tc.expected > 0 { + t.Fatal("timeout waiting for IPC message to be processed") + } + } + }) + } + }) + + t.Run("HandleIPCInvalidData", func(t *testing.T) { + // Create a runner + runner := setupRunner(t) + + // Create pipes for testing + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + runner.stdoutPipe = pr + + // Start IPC reader + go runner.startIPCReader() + + // Test cases for invalid data + testCases := []struct { + name string + message string // Raw message to send + }{ + { + name: "invalid json", + message: "{ invalid json", + }, + { + name: "non-ipc json", + message: `{"type": "data", "payload": {"field": "value"}}`, // Missing IPC flag + }, + { + name: "invalid payload type", + message: `{"type": "data", "payload": "invalid", "ipc": true}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a channel to ensure no data is processed + processed := make(chan struct{}) + + // Mock the gRPC connection + runner.conn = &mockConnectClient{ + sendFunc: func(req *grpc.TaskServiceConnectRequest) error { + if req.Code == grpc.TaskServiceConnectCode_INSERT_DATA { + // This should not be called for invalid data + processed <- struct{}{} + } + return nil + }, + } + + // Write test message to pipe + go func() { + _, err := fmt.Fprintln(pw, tc.message) + assert.NoError(t, err) + }() + + // Wait briefly to ensure no processing occurs + select { + case <-processed: + t.Error("invalid message was processed") + case <-time.After(1 * time.Second): + // Success - no processing occurred + } + }) + } + }) } // mockConnectClient is a mock implementation of the gRPC Connect client @@ -306,68 +398,3 @@ func (m *mockConnectClient) Send(req *grpc.TaskServiceConnectRequest) error { } return nil } - -func TestRunner_HandleIPCInvalidData(t *testing.T) { - // Setup test data - runner := setupTest(t) - - // Create pipes for testing - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - runner.stdoutPipe = pr - - // Start IPC reader - go runner.startIPCReader() - - // Test cases for invalid data - testCases := []struct { - name string - message string // Raw message to send - }{ - { - name: "invalid json", - message: "{ invalid json", - }, - { - name: "non-ipc json", - message: `{"type": "data", "payload": {"field": "value"}}`, // Missing IPC flag - }, - { - name: "invalid payload type", - message: `{"type": "data", "payload": "invalid", "ipc": true}`, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create a channel to ensure no data is processed - processed := make(chan struct{}) - - // Mock the gRPC connection - runner.conn = &mockConnectClient{ - sendFunc: func(req *grpc.TaskServiceConnectRequest) error { - if req.Code == grpc.TaskServiceConnectCode_INSERT_DATA { - // This should not be called for invalid data - processed <- struct{}{} - } - return nil - }, - } - - // Write test message to pipe - go func() { - _, err := fmt.Fprintln(pw, tc.message) - assert.NoError(t, err) - }() - - // Wait briefly to ensure no processing occurs - select { - case <-processed: - t.Error("invalid message was processed") - case <-time.After(1 * time.Second): - // Success - no processing occurred - } - }) - } -}