fix(grpc/client): trigger reconnection on bad conn state and improve connection logging

- Trigger reconnection proactively from Get*WithTimeout when underlying connection is in
  SHUTDOWN or TRANSIENT_FAILURE to avoid returning stale/unusable clients.
- Add debug/info logs around client registration, connection attempts, closing existing
  connections, connection initiation, reconnection start, backoff retry and successful
  reconnection (including current state and registration status).
- Surface more context in reconnection and connection logs to aid diagnostics.
This commit is contained in:
Marvin Zhang
2025-10-20 11:34:41 +08:00
parent 6020fef30b
commit 4baa5fad59

View File

@@ -212,6 +212,7 @@ func (c *GrpcClient) WaitForReady() {
}
func (c *GrpcClient) register() {
c.Debugf("registering gRPC service clients")
c.nodeClient = grpc2.NewNodeServiceClient(c.conn)
c.modelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn)
c.taskClient = grpc2.NewTaskServiceClient(c.conn)
@@ -224,6 +225,7 @@ func (c *GrpcClient) register() {
// Mark as registered
c.setRegistered(true)
c.Infof("gRPC service clients successfully registered")
}
func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) {
@@ -503,6 +505,11 @@ func (c *GrpcClient) GetNodeClientWithTimeout(timeout time.Duration) (grpc2.Node
if c.stopped {
return nil, fmt.Errorf("grpc client is stopped")
}
// Check if connection is in bad state and needs reconnection
if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) {
c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState())
c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState()))
}
if !c.IsRegistered() {
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
return nil, fmt.Errorf("failed to get node client: %w", err)
@@ -515,6 +522,11 @@ func (c *GrpcClient) GetTaskClientWithTimeout(timeout time.Duration) (grpc2.Task
if c.stopped {
return nil, fmt.Errorf("grpc client is stopped")
}
// Check if connection is in bad state and needs reconnection
if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) {
c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState())
c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState()))
}
if !c.IsRegistered() {
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
return nil, fmt.Errorf("failed to get task client: %w", err)
@@ -527,6 +539,11 @@ func (c *GrpcClient) GetModelBaseServiceClientWithTimeout(timeout time.Duration)
if c.stopped {
return nil, fmt.Errorf("grpc client is stopped")
}
// Check if connection is in bad state and needs reconnection
if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) {
c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState())
c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState()))
}
if !c.IsRegistered() {
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
return nil, fmt.Errorf("failed to get model base service client: %w", err)
@@ -539,6 +556,11 @@ func (c *GrpcClient) GetDependencyClientWithTimeout(timeout time.Duration) (grpc
if c.stopped {
return nil, fmt.Errorf("grpc client is stopped")
}
// Check if connection is in bad state and needs reconnection
if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) {
c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState())
c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState()))
}
if !c.IsRegistered() {
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
return nil, fmt.Errorf("failed to get dependency client: %w", err)
@@ -551,6 +573,11 @@ func (c *GrpcClient) GetMetricClientWithTimeout(timeout time.Duration) (grpc2.Me
if c.stopped {
return nil, fmt.Errorf("grpc client is stopped")
}
// Check if connection is in bad state and needs reconnection
if c.conn != nil && (c.conn.GetState() == connectivity.Shutdown || c.conn.GetState() == connectivity.TransientFailure) {
c.Debugf("connection in bad state (%s), triggering reconnection", c.conn.GetState())
c.triggerReconnection(fmt.Sprintf("bad connection state: %s", c.conn.GetState()))
}
if !c.IsRegistered() {
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
return nil, fmt.Errorf("failed to get metric client: %w", err)
@@ -706,7 +733,7 @@ func (c *GrpcClient) executeReconnection() {
c.reconnectMux.Unlock()
}()
c.Infof("executing reconnection to %s", c.address)
c.Infof("executing reconnection to %s (current state: %s)", c.address, c.getState())
if err := c.doConnect(); err != nil {
c.Errorf("reconnection failed: %v", err)
@@ -714,10 +741,11 @@ func (c *GrpcClient) executeReconnection() {
// Exponential backoff before allowing next attempt
backoffDuration := c.calculateBackoff()
c.Warnf("will retry reconnection after %v backoff", backoffDuration)
time.Sleep(backoffDuration)
} else {
c.recordSuccess()
c.Infof("reconnection successful")
c.Infof("reconnection successful - connection state: %s, registered: %v", c.getState(), c.IsRegistered())
}
}
@@ -795,10 +823,12 @@ func (c *GrpcClient) calculateBackoff() time.Duration {
}
func (c *GrpcClient) doConnect() error {
c.Debugf("attempting connection to %s", c.address)
c.setRegistered(false)
// Close existing connection
if c.conn != nil {
c.Debugf("closing existing connection (state: %s)", c.conn.GetState())
c.conn.Close()
c.conn = nil
}
@@ -817,14 +847,16 @@ func (c *GrpcClient) doConnect() error {
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout)
defer cancel()
c.Debugf("initiating connection to %s", c.address)
c.conn.Connect()
if err := c.waitForConnectionReady(ctx); err != nil {
c.Errorf("failed to reach ready state: %v", err)
c.conn.Close()
c.conn = nil
return err
}
c.Infof("connected to %s", c.address)
c.Infof("connected to %s (state: %s)", c.address, c.conn.GetState())
c.register()
return nil