mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
621 lines
16 KiB
Go
621 lines
16 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/crawlab-team/crawlab/core/grpc/middlewares"
|
|
"github.com/crawlab-team/crawlab/core/interfaces"
|
|
"github.com/crawlab-team/crawlab/core/utils"
|
|
grpc2 "github.com/crawlab-team/crawlab/grpc"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/connectivity"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/keepalive"
|
|
)
|
|
|
|
// GrpcClient provides a robust gRPC client with connection management and client registration.
|
|
//
|
|
// The client handles connection lifecycle and ensures that gRPC service clients are properly
|
|
// initialized before use. All client fields are private and can only be accessed through
|
|
// safe getter methods that ensure registration before returning clients.
|
|
//
|
|
// Example usage:
|
|
// client := GetGrpcClient()
|
|
//
|
|
// // Safe access pattern - always use getter methods
|
|
// nodeClient, err := client.GetNodeClient()
|
|
// if err != nil {
|
|
// return fmt.Errorf("failed to get node client: %v", err)
|
|
// }
|
|
// resp, err := nodeClient.Register(ctx, req)
|
|
//
|
|
// // Alternative with timeout
|
|
// taskClient, err := client.GetTaskClientWithTimeout(5 * time.Second)
|
|
// if err != nil {
|
|
// return fmt.Errorf("failed to get task client: %v", err)
|
|
// }
|
|
// resp, err := taskClient.Connect(ctx)
|
|
|
|
type GrpcClient struct {
|
|
// settings
|
|
address string
|
|
timeout time.Duration
|
|
|
|
// internals
|
|
conn *grpc.ClientConn
|
|
once sync.Once
|
|
stopped bool
|
|
stop chan struct{}
|
|
interfaces.Logger
|
|
|
|
// clients (private to enforce safe access through getter methods)
|
|
nodeClient grpc2.NodeServiceClient
|
|
taskClient grpc2.TaskServiceClient
|
|
modelBaseServiceClient grpc2.ModelBaseServiceClient
|
|
dependencyClient grpc2.DependencyServiceClient
|
|
metricClient grpc2.MetricServiceClient
|
|
|
|
// Add new fields for state management
|
|
state connectivity.State
|
|
stateMux sync.RWMutex
|
|
reconnect chan struct{}
|
|
|
|
// Circuit breaker fields
|
|
failureCount int
|
|
lastFailure time.Time
|
|
cbMux sync.RWMutex
|
|
|
|
// Reconnection control
|
|
reconnecting bool
|
|
reconnectMux sync.Mutex
|
|
|
|
// Registration status
|
|
registered bool
|
|
registeredMux sync.RWMutex
|
|
}
|
|
|
|
func (c *GrpcClient) Start() {
|
|
c.once.Do(func() {
|
|
// initialize reconnect channel
|
|
c.reconnect = make(chan struct{})
|
|
|
|
// start state monitor
|
|
go c.monitorState()
|
|
|
|
// connect (this will also register services)
|
|
err := c.connect()
|
|
if err != nil {
|
|
c.Fatalf("failed to connect: %v", err)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
|
|
func (c *GrpcClient) Stop() (err error) {
|
|
// set stopped flag
|
|
c.stopped = true
|
|
c.setRegistered(false)
|
|
c.stop <- struct{}{}
|
|
c.Infof("stopped")
|
|
|
|
// skip if connection is nil
|
|
if c.conn == nil {
|
|
return nil
|
|
}
|
|
|
|
// close connection
|
|
if err := c.conn.Close(); err != nil {
|
|
c.Errorf("failed to close connection: %v", err)
|
|
return err
|
|
}
|
|
c.Infof("disconnected from %s", c.address)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *GrpcClient) WaitForReady() {
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if c.IsReady() {
|
|
c.Debugf("client is now ready")
|
|
return
|
|
}
|
|
case <-c.stop:
|
|
c.Errorf("client has stopped")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *GrpcClient) register() {
|
|
c.nodeClient = grpc2.NewNodeServiceClient(c.conn)
|
|
c.modelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn)
|
|
c.taskClient = grpc2.NewTaskServiceClient(c.conn)
|
|
c.dependencyClient = grpc2.NewDependencyServiceClient(c.conn)
|
|
c.metricClient = grpc2.NewMetricServiceClient(c.conn)
|
|
|
|
// Mark as registered
|
|
c.setRegistered(true)
|
|
}
|
|
|
|
func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) {
|
|
return context.WithTimeout(context.Background(), c.timeout)
|
|
}
|
|
|
|
func (c *GrpcClient) IsReady() (res bool) {
|
|
state := c.conn.GetState()
|
|
return c.conn != nil && state == connectivity.Ready
|
|
}
|
|
|
|
func (c *GrpcClient) IsReadyAndRegistered() (res bool) {
|
|
return c.IsReady() && c.IsRegistered()
|
|
}
|
|
|
|
func (c *GrpcClient) IsClosed() (res bool) {
|
|
if c.conn != nil {
|
|
return c.conn.GetState() == connectivity.Shutdown
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *GrpcClient) monitorState() {
|
|
idleStartTime := time.Time{}
|
|
idleGracePeriod := 30 * time.Second // Allow IDLE state for 30 seconds before considering it a problem
|
|
|
|
for {
|
|
select {
|
|
case <-c.stop:
|
|
return
|
|
default:
|
|
if c.conn == nil {
|
|
time.Sleep(time.Second)
|
|
continue
|
|
}
|
|
|
|
previous := c.getState()
|
|
current := c.conn.GetState()
|
|
|
|
if previous != current {
|
|
c.setState(current)
|
|
c.Infof("state changed from %s to %s", previous, current)
|
|
|
|
// Handle state transitions more intelligently
|
|
switch current {
|
|
case connectivity.TransientFailure, connectivity.Shutdown:
|
|
// Always reconnect on actual failures, but respect circuit breaker
|
|
if !c.isCircuitBreakerOpen() {
|
|
select {
|
|
case c.reconnect <- struct{}{}:
|
|
c.Infof("triggering reconnection due to state change to %s", current)
|
|
default:
|
|
}
|
|
} else {
|
|
c.Debugf("circuit breaker open, not triggering reconnection for state %s", current)
|
|
}
|
|
case connectivity.Idle:
|
|
if previous == connectivity.Ready {
|
|
// Start grace period timer for IDLE state
|
|
idleStartTime = time.Now()
|
|
c.Debugf("connection went IDLE, starting grace period")
|
|
}
|
|
case connectivity.Ready:
|
|
// Reset idle timer when connection becomes ready
|
|
idleStartTime = time.Time{}
|
|
// Record successful connection
|
|
c.recordSuccess()
|
|
}
|
|
}
|
|
|
|
// Check if IDLE state has exceeded grace period
|
|
if current == connectivity.Idle && !idleStartTime.IsZero() &&
|
|
time.Since(idleStartTime) > idleGracePeriod && !c.isCircuitBreakerOpen() {
|
|
c.Warnf("connection has been IDLE for %v, triggering reconnection", time.Since(idleStartTime))
|
|
select {
|
|
case c.reconnect <- struct{}{}:
|
|
c.Infof("triggering reconnection due to prolonged IDLE state")
|
|
default:
|
|
}
|
|
idleStartTime = time.Time{} // Reset timer to avoid repeated reconnections
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *GrpcClient) setState(state connectivity.State) {
|
|
c.stateMux.Lock()
|
|
defer c.stateMux.Unlock()
|
|
c.state = state
|
|
}
|
|
|
|
func (c *GrpcClient) getState() connectivity.State {
|
|
c.stateMux.RLock()
|
|
defer c.stateMux.RUnlock()
|
|
return c.state
|
|
}
|
|
|
|
func (c *GrpcClient) setRegistered(registered bool) {
|
|
c.registeredMux.Lock()
|
|
defer c.registeredMux.Unlock()
|
|
c.registered = registered
|
|
}
|
|
|
|
func (c *GrpcClient) IsRegistered() bool {
|
|
c.registeredMux.RLock()
|
|
defer c.registeredMux.RUnlock()
|
|
return c.registered
|
|
}
|
|
|
|
func (c *GrpcClient) WaitForRegistered() {
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if c.IsRegistered() {
|
|
c.Debugf("client is now registered")
|
|
return
|
|
}
|
|
case <-c.stop:
|
|
c.Errorf("client has stopped while waiting for registration")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Safe client getters that ensure registration before returning clients
|
|
// These methods will wait for registration to complete or return an error if the client is stopped
|
|
|
|
func (c *GrpcClient) GetNodeClient() (grpc2.NodeServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
c.Debugf("waiting for node client registration")
|
|
c.WaitForRegistered()
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
|
}
|
|
}
|
|
return c.nodeClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetTaskClient() (grpc2.TaskServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
c.Debugf("waiting for task client registration")
|
|
c.WaitForRegistered()
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
|
}
|
|
}
|
|
return c.taskClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetModelBaseServiceClient() (grpc2.ModelBaseServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
c.Debugf("waiting for model base service client registration")
|
|
c.WaitForRegistered()
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
|
}
|
|
}
|
|
return c.modelBaseServiceClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetDependencyClient() (grpc2.DependencyServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
c.Debugf("waiting for dependency client registration")
|
|
c.WaitForRegistered()
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
|
}
|
|
}
|
|
return c.dependencyClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetMetricClient() (grpc2.MetricServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
c.Debugf("waiting for metric client registration")
|
|
c.WaitForRegistered()
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
|
}
|
|
}
|
|
return c.metricClient, nil
|
|
}
|
|
|
|
// Safe client getters with timeout - these methods will wait up to the specified timeout
|
|
// for registration to complete before returning an error
|
|
|
|
func (c *GrpcClient) GetNodeClientWithTimeout(timeout time.Duration) (grpc2.NodeServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
|
return nil, fmt.Errorf("failed to get node client: %w", err)
|
|
}
|
|
}
|
|
return c.nodeClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetTaskClientWithTimeout(timeout time.Duration) (grpc2.TaskServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
|
return nil, fmt.Errorf("failed to get task client: %w", err)
|
|
}
|
|
}
|
|
return c.taskClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetModelBaseServiceClientWithTimeout(timeout time.Duration) (grpc2.ModelBaseServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
|
return nil, fmt.Errorf("failed to get model base service client: %w", err)
|
|
}
|
|
}
|
|
return c.modelBaseServiceClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetDependencyClientWithTimeout(timeout time.Duration) (grpc2.DependencyServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
|
return nil, fmt.Errorf("failed to get dependency client: %w", err)
|
|
}
|
|
}
|
|
return c.dependencyClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) GetMetricClientWithTimeout(timeout time.Duration) (grpc2.MetricServiceClient, error) {
|
|
if c.stopped {
|
|
return nil, fmt.Errorf("grpc client is stopped")
|
|
}
|
|
if !c.IsRegistered() {
|
|
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
|
return nil, fmt.Errorf("failed to get metric client: %w", err)
|
|
}
|
|
}
|
|
return c.metricClient, nil
|
|
}
|
|
|
|
func (c *GrpcClient) waitForRegisteredWithTimeout(timeout time.Duration) error {
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if c.IsRegistered() {
|
|
c.Debugf("client is now registered")
|
|
return nil
|
|
}
|
|
case <-timer.C:
|
|
return fmt.Errorf("timeout waiting for client registration after %v", timeout)
|
|
case <-c.stop:
|
|
return fmt.Errorf("client has stopped while waiting for registration")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *GrpcClient) connect() (err error) {
|
|
// Start reconnection loop with proper cleanup
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
c.Errorf("reconnection loop panic: %v", r)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-c.stop:
|
|
c.Debugf("reconnection loop stopping")
|
|
return
|
|
case <-c.reconnect:
|
|
// Check if we're already reconnecting to avoid multiple attempts
|
|
c.reconnectMux.Lock()
|
|
if c.reconnecting {
|
|
c.Debugf("reconnection already in progress, skipping")
|
|
c.reconnectMux.Unlock()
|
|
continue
|
|
}
|
|
c.reconnecting = true
|
|
c.reconnectMux.Unlock()
|
|
|
|
if !c.stopped && !c.isCircuitBreakerOpen() {
|
|
c.Infof("attempting to reconnect to %s", c.address)
|
|
if err := c.doConnect(); err != nil {
|
|
c.Errorf("reconnection failed: %v", err)
|
|
c.recordFailure()
|
|
// Add a brief delay before allowing next reconnection attempt
|
|
time.Sleep(2 * time.Second)
|
|
} else {
|
|
c.recordSuccess()
|
|
}
|
|
} else if c.isCircuitBreakerOpen() {
|
|
c.Debugf("circuit breaker is open, skipping reconnection attempt")
|
|
}
|
|
|
|
// Reset reconnecting flag
|
|
c.reconnectMux.Lock()
|
|
c.reconnecting = false
|
|
c.reconnectMux.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return c.doConnect()
|
|
}
|
|
|
|
// Circuit breaker methods
|
|
func (c *GrpcClient) isCircuitBreakerOpen() bool {
|
|
c.cbMux.RLock()
|
|
defer c.cbMux.RUnlock()
|
|
|
|
// Circuit breaker opens after 5 consecutive failures
|
|
if c.failureCount >= 5 {
|
|
// Auto-recover after 1 minute
|
|
if time.Since(c.lastFailure) > 1*time.Minute {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *GrpcClient) recordFailure() {
|
|
c.cbMux.Lock()
|
|
defer c.cbMux.Unlock()
|
|
c.failureCount++
|
|
c.lastFailure = time.Now()
|
|
if c.failureCount >= 5 {
|
|
c.Warnf("circuit breaker opened after %d consecutive failures", c.failureCount)
|
|
}
|
|
}
|
|
|
|
func (c *GrpcClient) recordSuccess() {
|
|
c.cbMux.Lock()
|
|
defer c.cbMux.Unlock()
|
|
if c.failureCount > 0 {
|
|
c.Infof("connection restored, resetting circuit breaker (was %d failures)", c.failureCount)
|
|
}
|
|
c.failureCount = 0
|
|
c.lastFailure = time.Time{}
|
|
}
|
|
|
|
func (c *GrpcClient) doConnect() (err error) {
|
|
op := func() error {
|
|
// Mark as not registered during connection attempt
|
|
c.setRegistered(false)
|
|
|
|
// Close existing connection if any
|
|
if c.conn != nil {
|
|
if err := c.conn.Close(); err != nil {
|
|
c.Debugf("failed to close existing connection: %v", err)
|
|
}
|
|
}
|
|
|
|
// connection options with better settings for stability
|
|
opts := []grpc.DialOption{
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithChainUnaryInterceptor(middlewares.GetGrpcClientAuthTokenUnaryChainInterceptor()),
|
|
grpc.WithChainStreamInterceptor(middlewares.GetGrpcClientAuthTokenStreamChainInterceptor()),
|
|
// Add keep-alive settings to maintain connection health
|
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: 30 * time.Second, // Send ping every 30 seconds
|
|
Timeout: 5 * time.Second, // Wait 5 seconds for ping response
|
|
PermitWithoutStream: true, // Send pings even without active streams
|
|
}),
|
|
}
|
|
|
|
// create new client connection
|
|
c.conn, err = grpc.NewClient(c.address, opts...)
|
|
if err != nil {
|
|
c.Errorf("failed to create connection to %s: %v", c.address, err)
|
|
return err
|
|
}
|
|
|
|
// connect
|
|
c.Infof("connecting to %s", c.address)
|
|
c.conn.Connect()
|
|
|
|
// wait for connection to be ready with shorter timeout for faster failure detection
|
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
defer cancel()
|
|
|
|
// Wait for state to change from connecting
|
|
for c.conn.GetState() == connectivity.Connecting {
|
|
if !c.conn.WaitForStateChange(ctx, connectivity.Connecting) {
|
|
return fmt.Errorf("failed to connect to %s: connection timeout", c.address)
|
|
}
|
|
}
|
|
|
|
// Check final state
|
|
state := c.conn.GetState()
|
|
if state != connectivity.Ready {
|
|
return fmt.Errorf("failed to connect to %s: final state is %s", c.address, state)
|
|
}
|
|
|
|
// success
|
|
c.Infof("connected to %s", c.address)
|
|
|
|
// Register services after successful connection
|
|
c.register()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Configure backoff with more reasonable settings
|
|
b := backoff.NewExponentialBackOff()
|
|
b.InitialInterval = 1 * time.Second // Start with shorter interval
|
|
b.MaxInterval = 30 * time.Second // Cap the max interval
|
|
b.MaxElapsedTime = 5 * time.Minute // Reduce max retry time
|
|
b.Multiplier = 1.5 // Gentler exponential growth
|
|
|
|
n := func(err error, duration time.Duration) {
|
|
c.Errorf("failed to connect to %s: %v, retrying in %s", c.address, err, duration)
|
|
}
|
|
|
|
err = backoff.RetryNotify(op, b, n)
|
|
if err != nil {
|
|
c.recordFailure()
|
|
return err
|
|
}
|
|
|
|
c.recordSuccess()
|
|
return nil
|
|
}
|
|
|
|
func newGrpcClient() (c *GrpcClient) {
|
|
client := &GrpcClient{
|
|
address: utils.GetGrpcAddress(),
|
|
timeout: 10 * time.Second,
|
|
stop: make(chan struct{}),
|
|
Logger: utils.NewLogger("GrpcClient"),
|
|
state: connectivity.Idle,
|
|
}
|
|
|
|
return client
|
|
}
|
|
|
|
var _client *GrpcClient
|
|
var _clientOnce sync.Once
|
|
|
|
func GetGrpcClient() *GrpcClient {
|
|
_clientOnce.Do(func() {
|
|
_client = newGrpcClient()
|
|
go _client.Start()
|
|
})
|
|
return _client
|
|
}
|