From 00daa0ed96993ad8275ee5066d1010dcecf0bdc5 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Tue, 8 Jul 2025 13:39:39 +0800 Subject: [PATCH] fix: enhance gRPC client reconnection logic and add goroutine monitoring for potential leaks --- core/grpc/client/client.go | 57 ++++++++++++++++--------- core/task/handler/service.go | 81 +++++++++++++++++++++++++++++------- 2 files changed, 103 insertions(+), 35 deletions(-) diff --git a/core/grpc/client/client.go b/core/grpc/client/client.go index 73458cef..98442811 100644 --- a/core/grpc/client/client.go +++ b/core/grpc/client/client.go @@ -41,15 +41,19 @@ type GrpcClient struct { MetricClient grpc2.MetricServiceClient // Add new fields for state management - state connectivity.State - stateMux sync.RWMutex - reconnect chan struct{} - + state connectivity.State + stateMux sync.RWMutex + reconnect chan struct{} + // Circuit breaker fields failureCount int lastFailure time.Time circuitBreaker bool cbMux sync.RWMutex + + // Reconnection control + reconnecting bool + reconnectMux sync.Mutex } func (c *GrpcClient) Start() { @@ -133,7 +137,7 @@ func (c *GrpcClient) IsClosed() (res bool) { func (c *GrpcClient) monitorState() { idleStartTime := time.Time{} idleGracePeriod := 30 * time.Second // Allow IDLE state for 30 seconds before considering it a problem - + for { select { case <-c.stop: @@ -179,8 +183,8 @@ func (c *GrpcClient) monitorState() { } // Check if IDLE state has exceeded grace period - if current == connectivity.Idle && !idleStartTime.IsZero() && - time.Since(idleStartTime) > idleGracePeriod && !c.isCircuitBreakerOpen() { + if current == connectivity.Idle && !idleStartTime.IsZero() && + time.Since(idleStartTime) > idleGracePeriod && !c.isCircuitBreakerOpen() { c.Warnf("connection has been IDLE for %v, triggering reconnection", time.Since(idleStartTime)) select { case c.reconnect <- struct{}{}: @@ -215,13 +219,23 @@ func (c *GrpcClient) connect() (err error) { c.Errorf("reconnection loop panic: %v", r) } }() - + for { select { case <-c.stop: c.Debugf("reconnection loop stopping") return case <-c.reconnect: + // Check if we're already reconnecting to avoid multiple attempts + c.reconnectMux.Lock() + if c.reconnecting { + c.Debugf("reconnection already in progress, skipping") + c.reconnectMux.Unlock() + continue + } + c.reconnecting = true + c.reconnectMux.Unlock() + if !c.stopped && !c.isCircuitBreakerOpen() { c.Infof("attempting to reconnect to %s", c.address) if err := c.doConnect(); err != nil { @@ -235,6 +249,11 @@ func (c *GrpcClient) connect() (err error) { } else if c.isCircuitBreakerOpen() { c.Debugf("circuit breaker is open, skipping reconnection attempt") } + + // Reset reconnecting flag + c.reconnectMux.Lock() + c.reconnecting = false + c.reconnectMux.Unlock() } } }() @@ -246,7 +265,7 @@ func (c *GrpcClient) connect() (err error) { func (c *GrpcClient) isCircuitBreakerOpen() bool { c.cbMux.RLock() defer c.cbMux.RUnlock() - + // Circuit breaker opens after 5 consecutive failures if c.failureCount >= 5 { // Auto-recover after 1 minute @@ -314,7 +333,7 @@ func (c *GrpcClient) doConnect() (err error) { // wait for connection to be ready with shorter timeout for faster failure detection ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - + // Wait for state to change from connecting for c.conn.GetState() == connectivity.Connecting { if !c.conn.WaitForStateChange(ctx, connectivity.Connecting) { @@ -330,30 +349,30 @@ func (c *GrpcClient) doConnect() (err error) { // success c.Infof("connected to %s", c.address) - + // Register services after successful connection c.register() return nil } - + // Configure backoff with more reasonable settings b := backoff.NewExponentialBackOff() - b.InitialInterval = 1 * time.Second // Start with shorter interval - b.MaxInterval = 30 * time.Second // Cap the max interval - b.MaxElapsedTime = 5 * time.Minute // Reduce max retry time - b.Multiplier = 1.5 // Gentler exponential growth - + b.InitialInterval = 1 * time.Second // Start with shorter interval + b.MaxInterval = 30 * time.Second // Cap the max interval + b.MaxElapsedTime = 5 * time.Minute // Reduce max retry time + b.Multiplier = 1.5 // Gentler exponential growth + n := func(err error, duration time.Duration) { c.Errorf("failed to connect to %s: %v, retrying in %s", c.address, err, duration) } - + err = backoff.RetryNotify(op, b, n) if err != nil { c.recordFailure() return err } - + c.recordSuccess() return nil } diff --git a/core/task/handler/service.go b/core/task/handler/service.go index 459d7710..03f085ab 100644 --- a/core/task/handler/service.go +++ b/core/task/handler/service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "runtime" "sync" "time" @@ -58,6 +59,9 @@ func (svc *Service) Start() { svc.fetchTicker = time.NewTicker(svc.fetchInterval) svc.reportTicker = time.NewTicker(svc.reportInterval) + // Start goroutine monitoring + svc.startGoroutineMonitoring() + // Start background goroutines with WaitGroup tracking svc.wg.Add(2) go svc.reportStatus() @@ -114,6 +118,38 @@ func (svc *Service) Stop() { svc.Infof("Task handler service stopped") } +func (svc *Service) startGoroutineMonitoring() { + go func() { + defer func() { + if r := recover(); r != nil { + svc.Errorf("[TaskHandler] goroutine monitoring panic: %v", r) + } + }() + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + initialCount := runtime.NumGoroutine() + svc.Infof("[TaskHandler] initial goroutine count: %d", initialCount) + + for { + select { + case <-svc.ctx.Done(): + svc.Infof("[TaskHandler] goroutine monitoring shutting down") + return + case <-ticker.C: + currentCount := runtime.NumGoroutine() + if currentCount > initialCount+50 { // Alert if 50+ more goroutines than initial + svc.Warnf("[TaskHandler] potential goroutine leak detected - current: %d, initial: %d, diff: %d", + currentCount, initialCount, currentCount-initialCount) + } else { + svc.Debugf("[TaskHandler] goroutine count: %d (initial: %d)", currentCount, initialCount) + } + } + } + }() +} + func (svc *Service) Run(taskId primitive.ObjectID) (err error) { return svc.runTask(taskId) } @@ -526,31 +562,43 @@ func (svc *Service) handleStreamMessages(taskId primitive.ObjectID, stream grpc. // Set deadline for receive operation ctx, cancel := context.WithTimeout(context.Background(), streamTimeout) - // Use a goroutine to handle the blocking Recv call - msgCh := make(chan *grpc.TaskServiceSubscribeResponse, 1) - errCh := make(chan error, 1) + // Create a buffered channel to receive the result + result := make(chan struct { + msg *grpc.TaskServiceSubscribeResponse + err error + }, 1) + // Use a single goroutine to handle the blocking Recv call go func() { + defer func() { + if r := recover(); r != nil { + svc.Errorf("task[%s] stream recv goroutine panic: %v", taskId.Hex(), r) + } + }() + msg, err := stream.Recv() - if err != nil { - errCh <- err - } else { - msgCh <- msg + select { + case result <- struct { + msg *grpc.TaskServiceSubscribeResponse + err error + }{msg, err}: + case <-ctx.Done(): + // Context cancelled, don't send result } }() select { - case msg := <-msgCh: + case res := <-result: cancel() - svc.processStreamMessage(taskId, msg) - case err := <-errCh: - cancel() - if errors.Is(err, io.EOF) { - svc.Infof("task[%s] received EOF, stream closed", taskId.Hex()) + if res.err != nil { + if errors.Is(res.err, io.EOF) { + svc.Infof("task[%s] received EOF, stream closed", taskId.Hex()) + return + } + svc.Errorf("task[%s] stream error: %v", taskId.Hex(), res.err) return } - svc.Errorf("task[%s] stream error: %v", taskId.Hex(), err) - return + svc.processStreamMessage(taskId, res.msg) case <-ctx.Done(): cancel() svc.Warnf("task[%s] stream receive timeout", taskId.Hex()) @@ -570,7 +618,8 @@ func (svc *Service) processStreamMessage(taskId primitive.ObjectID, msg *grpc.Ta switch msg.Code { case grpc.TaskServiceSubscribeCode_CANCEL: svc.Infof("task[%s] received cancel signal", taskId.Hex()) - go svc.handleCancel(msg, taskId) + // Handle cancel synchronously to avoid goroutine accumulation + svc.handleCancel(msg, taskId) default: svc.Debugf("task[%s] received unknown stream message code: %v", taskId.Hex(), msg.Code) }