diff --git a/core/task/handler/runner_test.go b/core/task/handler/runner_test.go index b17113f6..10eca89d 100644 --- a/core/task/handler/runner_test.go +++ b/core/task/handler/runner_test.go @@ -11,6 +11,7 @@ import ( "github.com/apex/log" "github.com/crawlab-team/crawlab/core/utils" + "github.com/crawlab-team/crawlab/grpc" "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/models/models" @@ -174,3 +175,180 @@ func TestRunner_Cancel(t *testing.T) { } t.Errorf("Process with PID %d was not killed within timeout", runner.pid) } + +func TestRunner_HandleIPCData(t *testing.T) { + // Setup test data + runner := setupTest(t) + + // Create pipes for testing + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + runner.stdoutPipe = pr + + // Start IPC reader + go runner.startIPCReader() + + // Test cases + testCases := []struct { + name string + payload interface{} + expected int // expected number of records + }{ + { + name: "single object", + payload: map[string]interface{}{ + "field1": "value1", + "field2": 123, + }, + expected: 1, + }, + { + name: "array of objects", + payload: []map[string]interface{}{ + { + "field1": "value1", + "field2": 123, + }, + { + "field1": "value2", + "field2": 456, + }, + }, + expected: 2, + }, + { + name: "empty payload", + payload: nil, + expected: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a channel to track processed messages + processed := make(chan int) + + // Mock the gRPC connection + runner.conn = &mockConnectClient{ + sendFunc: func(req *grpc.TaskServiceConnectRequest) error { + // Verify the request + assert.Equal(t, grpc.TaskServiceConnectCode_INSERT_DATA, req.Code) + assert.Equal(t, runner.tid.Hex(), req.TaskId) + + // If payload was nil, we expect no data + if tc.payload == nil { + processed <- 0 + return nil + } + + // Unmarshal the data to verify record count + var records []map[string]interface{} + err := json.Unmarshal(req.Data, &records) + assert.NoError(t, err) + processed <- len(records) + return nil + }, + } + + // Create test message + testMsg := IPCMessage{ + Type: IPCMessageData, + Payload: tc.payload, + IPC: true, + } + + // Convert message to JSON and write to pipe + go func() { + jsonData, err := json.Marshal(testMsg) + assert.NoError(t, err) + _, err = fmt.Fprintln(pw, string(jsonData)) + assert.NoError(t, err) + }() + + // Wait for processing with timeout + select { + case recordCount := <-processed: + assert.Equal(t, tc.expected, recordCount) + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for IPC message to be processed") + } + }) + } +} + +// mockConnectClient is a mock implementation of the gRPC Connect client +type mockConnectClient struct { + grpc.TaskService_ConnectClient + sendFunc func(*grpc.TaskServiceConnectRequest) error +} + +func (m *mockConnectClient) Send(req *grpc.TaskServiceConnectRequest) error { + if m.sendFunc != nil { + return m.sendFunc(req) + } + return nil +} + +func TestRunner_HandleIPCInvalidData(t *testing.T) { + // Setup test data + runner := setupTest(t) + + // Create pipes for testing + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + runner.stdoutPipe = pr + + // Start IPC reader + go runner.startIPCReader() + + // Test cases for invalid data + testCases := []struct { + name string + message string // Raw message to send + }{ + { + name: "invalid json", + message: "{ invalid json", + }, + { + name: "non-ipc json", + message: `{"type": "data", "payload": {"field": "value"}}`, // Missing IPC flag + }, + { + name: "invalid payload type", + message: `{"type": "data", "payload": "invalid", "ipc": true}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a channel to ensure no data is processed + processed := make(chan struct{}) + + // Mock the gRPC connection + runner.conn = &mockConnectClient{ + sendFunc: func(req *grpc.TaskServiceConnectRequest) error { + // This should not be called for invalid data + processed <- struct{}{} + return nil + }, + } + + // Write test message to pipe + go func() { + _, err := fmt.Fprintln(pw, tc.message) + assert.NoError(t, err) + }() + + // Wait briefly to ensure no processing occurs + select { + case <-processed: + t.Error("invalid message was processed") + case <-time.After(1 * time.Second): + // Success - no processing occurred + } + }) + } +}