fix: enhance gRPC client reconnection logic and add goroutine monitoring for potential leaks

This commit is contained in:
Marvin Zhang
2025-07-08 13:39:39 +08:00
parent f8e9c45a85
commit 00daa0ed96
2 changed files with 103 additions and 35 deletions

View File

@@ -41,15 +41,19 @@ type GrpcClient struct {
MetricClient grpc2.MetricServiceClient
// Add new fields for state management
state connectivity.State
stateMux sync.RWMutex
reconnect chan struct{}
state connectivity.State
stateMux sync.RWMutex
reconnect chan struct{}
// Circuit breaker fields
failureCount int
lastFailure time.Time
circuitBreaker bool
cbMux sync.RWMutex
// Reconnection control
reconnecting bool
reconnectMux sync.Mutex
}
func (c *GrpcClient) Start() {
@@ -133,7 +137,7 @@ func (c *GrpcClient) IsClosed() (res bool) {
func (c *GrpcClient) monitorState() {
idleStartTime := time.Time{}
idleGracePeriod := 30 * time.Second // Allow IDLE state for 30 seconds before considering it a problem
for {
select {
case <-c.stop:
@@ -179,8 +183,8 @@ func (c *GrpcClient) monitorState() {
}
// Check if IDLE state has exceeded grace period
if current == connectivity.Idle && !idleStartTime.IsZero() &&
time.Since(idleStartTime) > idleGracePeriod && !c.isCircuitBreakerOpen() {
if current == connectivity.Idle && !idleStartTime.IsZero() &&
time.Since(idleStartTime) > idleGracePeriod && !c.isCircuitBreakerOpen() {
c.Warnf("connection has been IDLE for %v, triggering reconnection", time.Since(idleStartTime))
select {
case c.reconnect <- struct{}{}:
@@ -215,13 +219,23 @@ func (c *GrpcClient) connect() (err error) {
c.Errorf("reconnection loop panic: %v", r)
}
}()
for {
select {
case <-c.stop:
c.Debugf("reconnection loop stopping")
return
case <-c.reconnect:
// Check if we're already reconnecting to avoid multiple attempts
c.reconnectMux.Lock()
if c.reconnecting {
c.Debugf("reconnection already in progress, skipping")
c.reconnectMux.Unlock()
continue
}
c.reconnecting = true
c.reconnectMux.Unlock()
if !c.stopped && !c.isCircuitBreakerOpen() {
c.Infof("attempting to reconnect to %s", c.address)
if err := c.doConnect(); err != nil {
@@ -235,6 +249,11 @@ func (c *GrpcClient) connect() (err error) {
} else if c.isCircuitBreakerOpen() {
c.Debugf("circuit breaker is open, skipping reconnection attempt")
}
// Reset reconnecting flag
c.reconnectMux.Lock()
c.reconnecting = false
c.reconnectMux.Unlock()
}
}
}()
@@ -246,7 +265,7 @@ func (c *GrpcClient) connect() (err error) {
func (c *GrpcClient) isCircuitBreakerOpen() bool {
c.cbMux.RLock()
defer c.cbMux.RUnlock()
// Circuit breaker opens after 5 consecutive failures
if c.failureCount >= 5 {
// Auto-recover after 1 minute
@@ -314,7 +333,7 @@ func (c *GrpcClient) doConnect() (err error) {
// wait for connection to be ready with shorter timeout for faster failure detection
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// Wait for state to change from connecting
for c.conn.GetState() == connectivity.Connecting {
if !c.conn.WaitForStateChange(ctx, connectivity.Connecting) {
@@ -330,30 +349,30 @@ func (c *GrpcClient) doConnect() (err error) {
// success
c.Infof("connected to %s", c.address)
// Register services after successful connection
c.register()
return nil
}
// Configure backoff with more reasonable settings
b := backoff.NewExponentialBackOff()
b.InitialInterval = 1 * time.Second // Start with shorter interval
b.MaxInterval = 30 * time.Second // Cap the max interval
b.MaxElapsedTime = 5 * time.Minute // Reduce max retry time
b.Multiplier = 1.5 // Gentler exponential growth
b.InitialInterval = 1 * time.Second // Start with shorter interval
b.MaxInterval = 30 * time.Second // Cap the max interval
b.MaxElapsedTime = 5 * time.Minute // Reduce max retry time
b.Multiplier = 1.5 // Gentler exponential growth
n := func(err error, duration time.Duration) {
c.Errorf("failed to connect to %s: %v, retrying in %s", c.address, err, duration)
}
err = backoff.RetryNotify(op, b, n)
if err != nil {
c.recordFailure()
return err
}
c.recordSuccess()
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"runtime"
"sync"
"time"
@@ -58,6 +59,9 @@ func (svc *Service) Start() {
svc.fetchTicker = time.NewTicker(svc.fetchInterval)
svc.reportTicker = time.NewTicker(svc.reportInterval)
// Start goroutine monitoring
svc.startGoroutineMonitoring()
// Start background goroutines with WaitGroup tracking
svc.wg.Add(2)
go svc.reportStatus()
@@ -114,6 +118,38 @@ func (svc *Service) Stop() {
svc.Infof("Task handler service stopped")
}
func (svc *Service) startGoroutineMonitoring() {
go func() {
defer func() {
if r := recover(); r != nil {
svc.Errorf("[TaskHandler] goroutine monitoring panic: %v", r)
}
}()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
initialCount := runtime.NumGoroutine()
svc.Infof("[TaskHandler] initial goroutine count: %d", initialCount)
for {
select {
case <-svc.ctx.Done():
svc.Infof("[TaskHandler] goroutine monitoring shutting down")
return
case <-ticker.C:
currentCount := runtime.NumGoroutine()
if currentCount > initialCount+50 { // Alert if 50+ more goroutines than initial
svc.Warnf("[TaskHandler] potential goroutine leak detected - current: %d, initial: %d, diff: %d",
currentCount, initialCount, currentCount-initialCount)
} else {
svc.Debugf("[TaskHandler] goroutine count: %d (initial: %d)", currentCount, initialCount)
}
}
}
}()
}
func (svc *Service) Run(taskId primitive.ObjectID) (err error) {
return svc.runTask(taskId)
}
@@ -526,31 +562,43 @@ func (svc *Service) handleStreamMessages(taskId primitive.ObjectID, stream grpc.
// Set deadline for receive operation
ctx, cancel := context.WithTimeout(context.Background(), streamTimeout)
// Use a goroutine to handle the blocking Recv call
msgCh := make(chan *grpc.TaskServiceSubscribeResponse, 1)
errCh := make(chan error, 1)
// Create a buffered channel to receive the result
result := make(chan struct {
msg *grpc.TaskServiceSubscribeResponse
err error
}, 1)
// Use a single goroutine to handle the blocking Recv call
go func() {
defer func() {
if r := recover(); r != nil {
svc.Errorf("task[%s] stream recv goroutine panic: %v", taskId.Hex(), r)
}
}()
msg, err := stream.Recv()
if err != nil {
errCh <- err
} else {
msgCh <- msg
select {
case result <- struct {
msg *grpc.TaskServiceSubscribeResponse
err error
}{msg, err}:
case <-ctx.Done():
// Context cancelled, don't send result
}
}()
select {
case msg := <-msgCh:
case res := <-result:
cancel()
svc.processStreamMessage(taskId, msg)
case err := <-errCh:
cancel()
if errors.Is(err, io.EOF) {
svc.Infof("task[%s] received EOF, stream closed", taskId.Hex())
if res.err != nil {
if errors.Is(res.err, io.EOF) {
svc.Infof("task[%s] received EOF, stream closed", taskId.Hex())
return
}
svc.Errorf("task[%s] stream error: %v", taskId.Hex(), res.err)
return
}
svc.Errorf("task[%s] stream error: %v", taskId.Hex(), err)
return
svc.processStreamMessage(taskId, res.msg)
case <-ctx.Done():
cancel()
svc.Warnf("task[%s] stream receive timeout", taskId.Hex())
@@ -570,7 +618,8 @@ func (svc *Service) processStreamMessage(taskId primitive.ObjectID, msg *grpc.Ta
switch msg.Code {
case grpc.TaskServiceSubscribeCode_CANCEL:
svc.Infof("task[%s] received cancel signal", taskId.Hex())
go svc.handleCancel(msg, taskId)
// Handle cancel synchronously to avoid goroutine accumulation
svc.handleCancel(msg, taskId)
default:
svc.Debugf("task[%s] received unknown stream message code: %v", taskId.Hex(), msg.Code)
}