feat: enhance gRPC client with state management and reconnection logic

- Introduced state management in GrpcClient to monitor and handle connection states effectively.
- Added a reconnect channel and a state monitoring goroutine to facilitate automatic reconnections on state changes.
- Updated the connect method to initiate a reconnection loop upon connection loss.
- Enhanced logging for connection state changes and errors during connection attempts.
- Refactored tests to ensure proper initialization of gRPC client and server, improving test reliability and coverage.
This commit is contained in:
Marvin Zhang
2024-12-21 21:41:00 +08:00
parent 75da4fff0f
commit 29af5a366b
3 changed files with 374 additions and 283 deletions

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
@@ -39,10 +38,21 @@ type GrpcClient struct {
ModelBaseServiceClient grpc2.ModelBaseServiceClient
DependencyClient grpc2.DependencyServiceClient
MetricClient grpc2.MetricServiceClient
// Add new fields for state management
state connectivity.State
stateMux sync.RWMutex
reconnect chan struct{}
}
func (c *GrpcClient) Start() (err error) {
c.once.Do(func() {
// initialize reconnect channel
c.reconnect = make(chan struct{})
// start state monitor
go c.monitorState()
// connect
err = c.connect()
if err != nil {
@@ -69,6 +79,7 @@ func (c *GrpcClient) Stop() (err error) {
// close connection
if err := c.conn.Close(); err != nil {
log.Errorf("grpc client failed to close connection: %v", err)
return err
}
log.Infof("grpc client disconnected from %s", c.address)
@@ -83,6 +94,7 @@ func (c *GrpcClient) WaitForReady() {
select {
case <-ticker.C:
if c.IsReady() {
log.Debugf("grpc client ready")
return
}
case <-c.stop:
@@ -104,7 +116,8 @@ func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc)
}
func (c *GrpcClient) IsReady() (res bool) {
return c.conn != nil && c.conn.GetState() == connectivity.Ready
state := c.conn.GetState()
return c.conn != nil && state == connectivity.Ready
}
func (c *GrpcClient) IsClosed() (res bool) {
@@ -114,24 +127,75 @@ func (c *GrpcClient) IsClosed() (res bool) {
return false
}
func (c *GrpcClient) getRequestData(d interface{}) (data []byte) {
if d == nil {
return data
}
switch d.(type) {
case []byte:
data = d.([]byte)
default:
var err error
data, err = json.Marshal(d)
if err != nil {
panic(err)
func (c *GrpcClient) monitorState() {
for {
select {
case <-c.stop:
return
default:
if c.conn == nil {
time.Sleep(time.Second)
continue
}
previous := c.getState()
current := c.conn.GetState()
if previous != current {
c.setState(current)
log.Infof("[GrpcClient] state changed from %s to %s", previous, current)
// Trigger reconnect if connection is lost or becomes idle from ready state
if current == connectivity.TransientFailure ||
current == connectivity.Shutdown ||
(previous == connectivity.Ready && current == connectivity.Idle) {
select {
case c.reconnect <- struct{}{}:
log.Infof("[GrpcClient] triggering reconnection due to state change to %s", current)
default:
}
}
}
time.Sleep(time.Second)
}
}
return data
}
func (c *GrpcClient) setState(state connectivity.State) {
c.stateMux.Lock()
defer c.stateMux.Unlock()
c.state = state
}
func (c *GrpcClient) getState() connectivity.State {
c.stateMux.RLock()
defer c.stateMux.RUnlock()
return c.state
}
func (c *GrpcClient) connect() (err error) {
// Start reconnection loop
go func() {
for {
select {
case <-c.stop:
return
case <-c.reconnect:
if !c.stopped {
log.Infof("[GrpcClient] attempting to reconnect to %s", c.address)
if err := c.doConnect(); err != nil {
log.Errorf("[GrpcClient] reconnection failed: %v", err)
}
}
}
}
}()
return c.doConnect()
}
func (c *GrpcClient) doConnect() (err error) {
op := func() error {
// connection options
opts := []grpc.DialOption{
@@ -164,10 +228,9 @@ func (c *GrpcClient) connect() (err error) {
return nil
}
b := backoff.NewExponentialBackOff(
backoff.WithInitialInterval(5*time.Second),
backoff.WithMaxElapsedTime(10*time.Minute),
)
b := backoff.NewExponentialBackOff()
b.InitialInterval = 5 * time.Second
b.MaxElapsedTime = 10 * time.Minute
n := func(err error, duration time.Duration) {
log.Errorf("[GrpcClient] grpc client failed to connect to %s: %v, retrying in %s", c.address, err, duration)
}
@@ -179,6 +242,7 @@ func newGrpcClient() (c *GrpcClient) {
address: utils.GetGrpcAddress(),
timeout: 10 * time.Second,
stop: make(chan struct{}),
state: connectivity.Idle,
}
}

View File

@@ -75,14 +75,14 @@ type Runner struct {
// Init initializes the task runner by updating the task status and establishing gRPC connections
func (r *Runner) Init() (err error) {
// wait for grpc client ready
client2.GetGrpcClient().WaitForReady()
// update task
if err := r.updateTask("", nil); err != nil {
return err
}
// wait for grpc client ready
client2.GetGrpcClient().WaitForReady()
// grpc task service stream client
if err := r.initConnection(); err != nil {
return err

View File

@@ -6,6 +6,9 @@ import (
"encoding/json"
"fmt"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/grpc/client"
"github.com/crawlab-team/crawlab/core/grpc/server"
"github.com/crawlab-team/crawlab/core/utils"
"io"
"runtime"
"testing"
@@ -15,7 +18,6 @@ import (
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/grpc"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
@@ -23,11 +25,30 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
)
func setupTest(t *testing.T) *Runner {
func setupGrpc(t *testing.T) {
// Mock IsMaster function by setting viper config
viper.Set("node.master", true)
defer viper.Set("node.master", nil) // cleanup after test
// Start a gRPC server
svr := server.GetGrpcServer()
err := svr.Start()
require.Nil(t, err)
// Start a gRPC client
err = client.GetGrpcClient().Start()
require.Nil(t, err)
// Cleanup
t.Cleanup(func() {
err = svr.Stop()
if err != nil {
log.Warnf("failed to stop gRPC server: %v", err)
}
})
}
func setupRunner(t *testing.T) *Runner {
// Create a test spider
spider := &models.Spider{
Name: "Test Spider",
@@ -72,226 +93,297 @@ func setupTest(t *testing.T) *Runner {
return runner
}
func TestRunner_HandleIPC(t *testing.T) {
func TestRunner(t *testing.T) {
// Setup test data
runner := setupTest(t)
setupGrpc(t)
// Create a pipe for testing
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
runner.stdoutPipe = pr
t.Run("HandleIPC", func(t *testing.T) {
// Create a runner
runner := setupRunner(t)
// Start IPC reader
go runner.startIPCReader()
// Create a pipe for testing
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
runner.stdoutPipe = pr
// Create test message
testMsg := entity.IPCMessage{
Type: "test_type",
Payload: map[string]interface{}{"key": "value"},
IPC: true,
}
// Start IPC reader
go runner.startIPCReader()
// Create test message
testMsg := entity.IPCMessage{
Type: "test_type",
Payload: map[string]interface{}{"key": "value"},
IPC: true,
}
// Create a channel to signal that the message was handled
handled := make(chan bool)
runner.SetIPCHandler(func(msg entity.IPCMessage) {
assert.Equal(t, testMsg.Type, msg.Type)
assert.Equal(t, testMsg.Payload, msg.Payload)
handled <- true
})
// Convert message to JSON and write to pipe
go func() {
jsonData, err := json.Marshal(testMsg)
if err != nil {
t.Errorf("failed to marshal test message: %v", err)
return
}
// Write message followed by newline
_, err = fmt.Fprintln(pw, string(jsonData))
if err != nil {
t.Errorf("failed to write to pipe: %v", err)
return
}
}()
select {
case <-handled:
// Message was handled successfully
log.Info("IPC message was handled successfully")
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for IPC message to be handled")
}
// Create a channel to signal that the message was handled
handled := make(chan bool)
runner.SetIPCHandler(func(msg entity.IPCMessage) {
assert.Equal(t, testMsg.Type, msg.Type)
assert.Equal(t, testMsg.Payload, msg.Payload)
handled <- true
})
// Convert message to JSON and write to pipe
go func() {
jsonData, err := json.Marshal(testMsg)
if err != nil {
t.Errorf("failed to marshal test message: %v", err)
return
}
t.Run("Cancel", func(t *testing.T) {
// Create a runner
runner := setupRunner(t)
// Write message followed by newline
_, err = fmt.Fprintln(pw, string(jsonData))
if err != nil {
t.Errorf("failed to write to pipe: %v", err)
return
}
}()
// Create pipes for stdout
pr, pw := io.Pipe()
runner.cmd.Stdout = pw
runner.cmd.Stderr = pw
select {
case <-handled:
// Message was handled successfully
log.Info("IPC message was handled successfully")
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for IPC message to be handled")
}
}
func TestRunner_Cancel(t *testing.T) {
// Setup
runner := setupTest(t)
// Create pipes for stdout
pr, pw := io.Pipe()
runner.cmd.Stdout = pw
runner.cmd.Stderr = pw
// Start the command
err := runner.cmd.Start()
assert.NoError(t, err)
log.Infof("started process with PID: %d", runner.cmd.Process.Pid)
runner.pid = runner.cmd.Process.Pid
// Read and print command output
go func() {
scanner := bufio.NewScanner(pr)
for scanner.Scan() {
log.Info(scanner.Text())
}
}()
// Wait for process to finish
go func() {
err = runner.cmd.Wait()
if err != nil {
log.Errorf("process[%d] exited with error: %v", runner.pid, err)
return
}
log.Infof("process[%d] exited successfully", runner.pid)
}()
// Wait for a certain period for the process to start properly
time.Sleep(1 * time.Second)
// Verify process exists before attempting to cancel
if !utils.ProcessIdExists(runner.pid) {
t.Fatalf("Process with PID %d was not started successfully", runner.pid)
}
// Test cancel
go func() {
err = runner.Cancel(true)
// Start the command
err := runner.cmd.Start()
assert.NoError(t, err)
log.Infof("process[%d] cancelled", runner.pid)
}()
log.Infof("started process with PID: %d", runner.cmd.Process.Pid)
runner.pid = runner.cmd.Process.Pid
// Wait for process to be killed, with shorter timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.Fatalf("Process with PID %d was not killed within timeout", runner.pid)
case <-ticker.C:
exists := utils.ProcessIdExists(runner.pid)
if !exists {
return // Exit the test when process is killed
// Read and print command output
go func() {
scanner := bufio.NewScanner(pr)
for scanner.Scan() {
log.Info(scanner.Text())
}
}()
// Wait for process to finish
go func() {
err = runner.cmd.Wait()
if err != nil {
log.Errorf("process[%d] exited with error: %v", runner.pid, err)
return
}
log.Infof("process[%d] exited successfully", runner.pid)
}()
// Wait for a certain period for the process to start properly
time.Sleep(1 * time.Second)
// Verify process exists before attempting to cancel
if !utils.ProcessIdExists(runner.pid) {
t.Fatalf("Process with PID %d was not started successfully", runner.pid)
}
// Test cancel
go func() {
err = runner.Cancel(true)
assert.NoError(t, err)
log.Infof("process[%d] cancelled", runner.pid)
}()
// Wait for process to be killed, with shorter timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.Fatalf("Process with PID %d was not killed within timeout", runner.pid)
case <-ticker.C:
exists := utils.ProcessIdExists(runner.pid)
if !exists {
return // Exit the test when process is killed
}
}
}
}
}
})
func TestRunner_HandleIPCData(t *testing.T) {
// Setup test data
runner := setupTest(t)
t.Run("HandleIPCData", func(t *testing.T) {
// Create a runner
runner := setupRunner(t)
// Create pipes for testing
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
runner.stdoutPipe = pr
// Create pipes for testing
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
runner.stdoutPipe = pr
// Start IPC reader
go runner.startIPCReader()
// Start IPC reader
go runner.startIPCReader()
// Test cases
testCases := []struct {
name string
payload interface{}
expected int // expected number of records
}{
{
name: "single object",
payload: map[string]interface{}{
"field1": "value1",
"field2": 123,
},
expected: 1,
},
{
name: "array of objects",
payload: []map[string]interface{}{
{
// Test cases
testCases := []struct {
name string
payload interface{}
expected int // expected number of records
}{
{
name: "single object",
payload: map[string]interface{}{
"field1": "value1",
"field2": 123,
},
{
"field1": "value2",
"field2": 456,
},
expected: 1,
},
expected: 2,
},
{
name: "empty payload",
payload: nil,
expected: 0,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a channel to track processed messages
processed := make(chan int)
// Mock the gRPC connection
runner.conn = &mockConnectClient{
sendFunc: func(req *grpc.TaskServiceConnectRequest) error {
// Verify the request
assert.Equal(t, grpc.TaskServiceConnectCode_INSERT_DATA, req.Code)
assert.Equal(t, runner.tid.Hex(), req.TaskId)
// If payload was nil, we expect no data
if tc.payload == nil {
processed <- 0
return nil
}
// Unmarshal the data to verify record count
var records []map[string]interface{}
err := json.Unmarshal(req.Data, &records)
assert.NoError(t, err)
processed <- len(records)
return nil
{
name: "array of objects",
payload: []map[string]interface{}{
{
"field1": "value1",
"field2": 123,
},
{
"field1": "value2",
"field2": 456,
},
},
}
expected: 2,
},
{
name: "empty payload",
payload: nil,
expected: 0,
},
}
// Create test message
testMsg := entity.IPCMessage{
Type: constants.IPCMessageData,
Payload: tc.payload,
IPC: true,
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a channel to track processed messages
processed := make(chan int)
// Convert message to JSON and write to pipe
go func() {
jsonData, err := json.Marshal(testMsg)
assert.NoError(t, err)
_, err = fmt.Fprintln(pw, string(jsonData))
assert.NoError(t, err)
}()
// Mock the gRPC connection
runner.conn = &mockConnectClient{
sendFunc: func(req *grpc.TaskServiceConnectRequest) error {
// Verify the request
assert.Equal(t, grpc.TaskServiceConnectCode_INSERT_DATA, req.Code)
assert.Equal(t, runner.tid.Hex(), req.TaskId)
// Wait for processing with timeout
select {
case recordCount := <-processed:
assert.Equal(t, tc.expected, recordCount)
case <-time.After(1 * time.Second):
if tc.expected > 0 {
t.Fatal("timeout waiting for IPC message to be processed")
// If payload was nil, we expect no data
if tc.payload == nil {
processed <- 0
return nil
}
// Unmarshal the data to verify record count
var records []map[string]interface{}
err := json.Unmarshal(req.Data, &records)
assert.NoError(t, err)
processed <- len(records)
return nil
},
}
}
})
}
// Create test message
testMsg := entity.IPCMessage{
Type: constants.IPCMessageData,
Payload: tc.payload,
IPC: true,
}
// Convert message to JSON and write to pipe
go func() {
jsonData, err := json.Marshal(testMsg)
assert.NoError(t, err)
_, err = fmt.Fprintln(pw, string(jsonData))
assert.NoError(t, err)
}()
// Wait for processing with timeout
select {
case recordCount := <-processed:
assert.Equal(t, tc.expected, recordCount)
case <-time.After(1 * time.Second):
if tc.expected > 0 {
t.Fatal("timeout waiting for IPC message to be processed")
}
}
})
}
})
t.Run("HandleIPCInvalidData", func(t *testing.T) {
// Create a runner
runner := setupRunner(t)
// Create pipes for testing
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
runner.stdoutPipe = pr
// Start IPC reader
go runner.startIPCReader()
// Test cases for invalid data
testCases := []struct {
name string
message string // Raw message to send
}{
{
name: "invalid json",
message: "{ invalid json",
},
{
name: "non-ipc json",
message: `{"type": "data", "payload": {"field": "value"}}`, // Missing IPC flag
},
{
name: "invalid payload type",
message: `{"type": "data", "payload": "invalid", "ipc": true}`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a channel to ensure no data is processed
processed := make(chan struct{})
// Mock the gRPC connection
runner.conn = &mockConnectClient{
sendFunc: func(req *grpc.TaskServiceConnectRequest) error {
if req.Code == grpc.TaskServiceConnectCode_INSERT_DATA {
// This should not be called for invalid data
processed <- struct{}{}
}
return nil
},
}
// Write test message to pipe
go func() {
_, err := fmt.Fprintln(pw, tc.message)
assert.NoError(t, err)
}()
// Wait briefly to ensure no processing occurs
select {
case <-processed:
t.Error("invalid message was processed")
case <-time.After(1 * time.Second):
// Success - no processing occurred
}
})
}
})
}
// mockConnectClient is a mock implementation of the gRPC Connect client
@@ -306,68 +398,3 @@ func (m *mockConnectClient) Send(req *grpc.TaskServiceConnectRequest) error {
}
return nil
}
func TestRunner_HandleIPCInvalidData(t *testing.T) {
// Setup test data
runner := setupTest(t)
// Create pipes for testing
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
runner.stdoutPipe = pr
// Start IPC reader
go runner.startIPCReader()
// Test cases for invalid data
testCases := []struct {
name string
message string // Raw message to send
}{
{
name: "invalid json",
message: "{ invalid json",
},
{
name: "non-ipc json",
message: `{"type": "data", "payload": {"field": "value"}}`, // Missing IPC flag
},
{
name: "invalid payload type",
message: `{"type": "data", "payload": "invalid", "ipc": true}`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a channel to ensure no data is processed
processed := make(chan struct{})
// Mock the gRPC connection
runner.conn = &mockConnectClient{
sendFunc: func(req *grpc.TaskServiceConnectRequest) error {
if req.Code == grpc.TaskServiceConnectCode_INSERT_DATA {
// This should not be called for invalid data
processed <- struct{}{}
}
return nil
},
}
// Write test message to pipe
go func() {
_, err := fmt.Fprintln(pw, tc.message)
assert.NoError(t, err)
}()
// Wait briefly to ensure no processing occurs
select {
case <-processed:
t.Error("invalid message was processed")
case <-time.After(1 * time.Second):
// Success - no processing occurred
}
})
}
}