diff --git a/core/grpc/client/client.go b/core/grpc/client/client.go index 3cbad922..acc70f6a 100644 --- a/core/grpc/client/client.go +++ b/core/grpc/client/client.go @@ -212,6 +212,7 @@ func (c *GrpcClient) WaitForReady() { } func (c *GrpcClient) register() { + c.Debugf("registering gRPC service clients") c.nodeClient = grpc2.NewNodeServiceClient(c.conn) c.modelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn) c.taskClient = grpc2.NewTaskServiceClient(c.conn) @@ -224,6 +225,7 @@ func (c *GrpcClient) register() { // Mark as registered c.setRegistered(true) + c.Infof("gRPC service clients successfully registered") } func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) { @@ -503,6 +505,11 @@ func (c *GrpcClient) GetNodeClientWithTimeout(timeout time.Duration) (grpc2.Node if c.stopped { return nil, fmt.Errorf("grpc client is stopped") } + // Check if connection is in bad state and needs reconnection + if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) { + c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState()) + c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState())) + } if !c.IsRegistered() { if err := c.waitForRegisteredWithTimeout(timeout); err != nil { return nil, fmt.Errorf("failed to get node client: %w", err) @@ -515,6 +522,11 @@ func (c *GrpcClient) GetTaskClientWithTimeout(timeout time.Duration) (grpc2.Task if c.stopped { return nil, fmt.Errorf("grpc client is stopped") } + // Check if connection is in bad state and needs reconnection + if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) { + c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState()) + c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState())) + } if !c.IsRegistered() { if err := c.waitForRegisteredWithTimeout(timeout); err != nil { return nil, fmt.Errorf("failed to get task client: %w", err) @@ -527,6 +539,11 @@ func (c *GrpcClient) GetModelBaseServiceClientWithTimeout(timeout time.Duration) if c.stopped { return nil, fmt.Errorf("grpc client is stopped") } + // Check if connection is in bad state and needs reconnection + if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) { + c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState()) + c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState())) + } if !c.IsRegistered() { if err := c.waitForRegisteredWithTimeout(timeout); err != nil { return nil, fmt.Errorf("failed to get model base service client: %w", err) @@ -539,6 +556,11 @@ func (c *GrpcClient) GetDependencyClientWithTimeout(timeout time.Duration) (grpc if c.stopped { return nil, fmt.Errorf("grpc client is stopped") } + // Check if connection is in bad state and needs reconnection + if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) { + c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState()) + c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState())) + } if !c.IsRegistered() { if err := c.waitForRegisteredWithTimeout(timeout); err != nil { return nil, fmt.Errorf("failed to get dependency client: %w", err) @@ -551,6 +573,11 @@ func (c *GrpcClient) GetMetricClientWithTimeout(timeout time.Duration) (grpc2.Me if c.stopped { return nil, fmt.Errorf("grpc client is stopped") } + // Check if connection is in bad state and needs reconnection + if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) { + c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState()) + c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState())) + } if !c.IsRegistered() { if err := c.waitForRegisteredWithTimeout(timeout); err != nil { return nil, fmt.Errorf("failed to get metric client: %w", err) @@ -706,7 +733,7 @@ func (c *GrpcClient) executeReconnection() { c.reconnectMux.Unlock() }() - c.Infof("executing reconnection to %s", c.address) + c.Infof("executing reconnection to %s (current state: %s)", c.address, c.getState()) if err := c.doConnect(); err != nil { c.Errorf("reconnection failed: %v", err) @@ -714,10 +741,11 @@ func (c *GrpcClient) executeReconnection() { // Exponential backoff before allowing next attempt backoffDuration := c.calculateBackoff() + c.Warnf("will retry reconnection after %v backoff", backoffDuration) time.Sleep(backoffDuration) } else { c.recordSuccess() - c.Infof("reconnection successful") + c.Infof("reconnection successful - connection state: %s, registered: %v", c.getState(), c.IsRegistered()) } } @@ -795,10 +823,12 @@ func (c *GrpcClient) calculateBackoff() time.Duration { } func (c *GrpcClient) doConnect() error { + c.Debugf("attempting connection to %s", c.address) c.setRegistered(false) // Close existing connection if c.conn != nil { + c.Debugf("closing existing connection (state: %s)", c.conn.GetState()) c.conn.Close() c.conn = nil } @@ -817,14 +847,16 @@ func (c *GrpcClient) doConnect() error { ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout) defer cancel() + c.Debugf("initiating connection to %s", c.address) c.conn.Connect() if err := c.waitForConnectionReady(ctx); err != nil { + c.Errorf("failed to reach ready state: %v", err) c.conn.Close() c.conn = nil return err } - c.Infof("connected to %s", c.address) + c.Infof("connected to %s (state: %s)", c.address, c.conn.GetState()) c.register() return nil