diff --git a/core/grpc/client/client.go b/core/grpc/client/client.go index 98442811..494473e3 100644 --- a/core/grpc/client/client.go +++ b/core/grpc/client/client.go @@ -17,28 +17,47 @@ import ( "google.golang.org/grpc/keepalive" ) -type GrpcClient struct { - // dependencies - nodeCfgSvc interfaces.NodeConfigService +// GrpcClient provides a robust gRPC client with connection management and client registration. +// +// The client handles connection lifecycle and ensures that gRPC service clients are properly +// initialized before use. All client fields are private and can only be accessed through +// safe getter methods that ensure registration before returning clients. +// +// Example usage: +// client := GetGrpcClient() +// +// // Safe access pattern - always use getter methods +// nodeClient, err := client.GetNodeClient() +// if err != nil { +// return fmt.Errorf("failed to get node client: %v", err) +// } +// resp, err := nodeClient.Register(ctx, req) +// +// // Alternative with timeout +// taskClient, err := client.GetTaskClientWithTimeout(5 * time.Second) +// if err != nil { +// return fmt.Errorf("failed to get task client: %v", err) +// } +// resp, err := taskClient.Connect(ctx) +type GrpcClient struct { // settings address string timeout time.Duration // internals conn *grpc.ClientConn - err error once sync.Once stopped bool stop chan struct{} interfaces.Logger - // clients - NodeClient grpc2.NodeServiceClient - TaskClient grpc2.TaskServiceClient - ModelBaseServiceClient grpc2.ModelBaseServiceClient - DependencyClient grpc2.DependencyServiceClient - MetricClient grpc2.MetricServiceClient + // clients (private to enforce safe access through getter methods) + nodeClient grpc2.NodeServiceClient + taskClient grpc2.TaskServiceClient + modelBaseServiceClient grpc2.ModelBaseServiceClient + dependencyClient grpc2.DependencyServiceClient + metricClient grpc2.MetricServiceClient // Add new fields for state management state connectivity.State @@ -46,14 +65,17 @@ type GrpcClient struct { reconnect chan struct{} // Circuit breaker fields - failureCount int - lastFailure time.Time - circuitBreaker bool - cbMux sync.RWMutex + failureCount int + lastFailure time.Time + cbMux sync.RWMutex // Reconnection control reconnecting bool reconnectMux sync.Mutex + + // Registration status + registered bool + registeredMux sync.RWMutex } func (c *GrpcClient) Start() { @@ -76,6 +98,7 @@ func (c *GrpcClient) Start() { func (c *GrpcClient) Stop() (err error) { // set stopped flag c.stopped = true + c.setRegistered(false) c.stop <- struct{}{} c.Infof("stopped") @@ -111,11 +134,14 @@ func (c *GrpcClient) WaitForReady() { } func (c *GrpcClient) register() { - c.NodeClient = grpc2.NewNodeServiceClient(c.conn) - c.ModelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn) - c.TaskClient = grpc2.NewTaskServiceClient(c.conn) - c.DependencyClient = grpc2.NewDependencyServiceClient(c.conn) - c.MetricClient = grpc2.NewMetricServiceClient(c.conn) + c.nodeClient = grpc2.NewNodeServiceClient(c.conn) + c.modelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn) + c.taskClient = grpc2.NewTaskServiceClient(c.conn) + c.dependencyClient = grpc2.NewDependencyServiceClient(c.conn) + c.metricClient = grpc2.NewMetricServiceClient(c.conn) + + // Mark as registered + c.setRegistered(true) } func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) { @@ -127,6 +153,10 @@ func (c *GrpcClient) IsReady() (res bool) { return c.conn != nil && state == connectivity.Ready } +func (c *GrpcClient) IsReadyAndRegistered() (res bool) { + return c.IsReady() && c.IsRegistered() +} + func (c *GrpcClient) IsClosed() (res bool) { if c.conn != nil { return c.conn.GetState() == connectivity.Shutdown @@ -211,6 +241,192 @@ func (c *GrpcClient) getState() connectivity.State { return c.state } +func (c *GrpcClient) setRegistered(registered bool) { + c.registeredMux.Lock() + defer c.registeredMux.Unlock() + c.registered = registered +} + +func (c *GrpcClient) IsRegistered() bool { + c.registeredMux.RLock() + defer c.registeredMux.RUnlock() + return c.registered +} + +func (c *GrpcClient) WaitForRegistered() { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if c.IsRegistered() { + c.Debugf("client is now registered") + return + } + case <-c.stop: + c.Errorf("client has stopped while waiting for registration") + return + } + } +} + +// Safe client getters that ensure registration before returning clients +// These methods will wait for registration to complete or return an error if the client is stopped + +func (c *GrpcClient) GetNodeClient() (grpc2.NodeServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + c.Debugf("waiting for node client registration") + c.WaitForRegistered() + if c.stopped { + return nil, fmt.Errorf("grpc client stopped while waiting for registration") + } + } + return c.nodeClient, nil +} + +func (c *GrpcClient) GetTaskClient() (grpc2.TaskServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + c.Debugf("waiting for task client registration") + c.WaitForRegistered() + if c.stopped { + return nil, fmt.Errorf("grpc client stopped while waiting for registration") + } + } + return c.taskClient, nil +} + +func (c *GrpcClient) GetModelBaseServiceClient() (grpc2.ModelBaseServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + c.Debugf("waiting for model base service client registration") + c.WaitForRegistered() + if c.stopped { + return nil, fmt.Errorf("grpc client stopped while waiting for registration") + } + } + return c.modelBaseServiceClient, nil +} + +func (c *GrpcClient) GetDependencyClient() (grpc2.DependencyServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + c.Debugf("waiting for dependency client registration") + c.WaitForRegistered() + if c.stopped { + return nil, fmt.Errorf("grpc client stopped while waiting for registration") + } + } + return c.dependencyClient, nil +} + +func (c *GrpcClient) GetMetricClient() (grpc2.MetricServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + c.Debugf("waiting for metric client registration") + c.WaitForRegistered() + if c.stopped { + return nil, fmt.Errorf("grpc client stopped while waiting for registration") + } + } + return c.metricClient, nil +} + +// Safe client getters with timeout - these methods will wait up to the specified timeout +// for registration to complete before returning an error + +func (c *GrpcClient) GetNodeClientWithTimeout(timeout time.Duration) (grpc2.NodeServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + if err := c.waitForRegisteredWithTimeout(timeout); err != nil { + return nil, fmt.Errorf("failed to get node client: %w", err) + } + } + return c.nodeClient, nil +} + +func (c *GrpcClient) GetTaskClientWithTimeout(timeout time.Duration) (grpc2.TaskServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + if err := c.waitForRegisteredWithTimeout(timeout); err != nil { + return nil, fmt.Errorf("failed to get task client: %w", err) + } + } + return c.taskClient, nil +} + +func (c *GrpcClient) GetModelBaseServiceClientWithTimeout(timeout time.Duration) (grpc2.ModelBaseServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + if err := c.waitForRegisteredWithTimeout(timeout); err != nil { + return nil, fmt.Errorf("failed to get model base service client: %w", err) + } + } + return c.modelBaseServiceClient, nil +} + +func (c *GrpcClient) GetDependencyClientWithTimeout(timeout time.Duration) (grpc2.DependencyServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + if err := c.waitForRegisteredWithTimeout(timeout); err != nil { + return nil, fmt.Errorf("failed to get dependency client: %w", err) + } + } + return c.dependencyClient, nil +} + +func (c *GrpcClient) GetMetricClientWithTimeout(timeout time.Duration) (grpc2.MetricServiceClient, error) { + if c.stopped { + return nil, fmt.Errorf("grpc client is stopped") + } + if !c.IsRegistered() { + if err := c.waitForRegisteredWithTimeout(timeout); err != nil { + return nil, fmt.Errorf("failed to get metric client: %w", err) + } + } + return c.metricClient, nil +} + +func (c *GrpcClient) waitForRegisteredWithTimeout(timeout time.Duration) error { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + select { + case <-ticker.C: + if c.IsRegistered() { + c.Debugf("client is now registered") + return nil + } + case <-timer.C: + return fmt.Errorf("timeout waiting for client registration after %v", timeout) + case <-c.stop: + return fmt.Errorf("client has stopped while waiting for registration") + } + } +} + func (c *GrpcClient) connect() (err error) { // Start reconnection loop with proper cleanup go func() { @@ -299,6 +515,9 @@ func (c *GrpcClient) recordSuccess() { func (c *GrpcClient) doConnect() (err error) { op := func() error { + // Mark as not registered during connection attempt + c.setRegistered(false) + // Close existing connection if any if c.conn != nil { if err := c.conn.Close(); err != nil { @@ -378,13 +597,15 @@ func (c *GrpcClient) doConnect() (err error) { } func newGrpcClient() (c *GrpcClient) { - return &GrpcClient{ + client := &GrpcClient{ address: utils.GetGrpcAddress(), timeout: 10 * time.Second, stop: make(chan struct{}), Logger: utils.NewLogger("GrpcClient"), state: connectivity.Idle, } + + return client } var _client *GrpcClient diff --git a/core/models/client/model_service.go b/core/models/client/model_service.go index df82de19..f9154676 100644 --- a/core/models/client/model_service.go +++ b/core/models/client/model_service.go @@ -2,6 +2,10 @@ package client import ( "encoding/json" + "fmt" + "reflect" + "sync" + "github.com/crawlab-team/crawlab/core/grpc/client" "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/core/mongo" @@ -9,8 +13,6 @@ import ( "github.com/crawlab-team/crawlab/grpc" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" - "reflect" - "sync" ) var ( @@ -27,7 +29,11 @@ type ModelService[T any] struct { func (svc *ModelService[T]) GetById(id primitive.ObjectID) (model *T, err error) { ctx, cancel := client.GetGrpcClient().Context() defer cancel() - res, err := client.GetGrpcClient().ModelBaseServiceClient.GetById(ctx, &grpc.ModelServiceGetByIdRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return nil, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.GetById(ctx, &grpc.ModelServiceGetByIdRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Id: id.Hex(), @@ -49,7 +55,11 @@ func (svc *ModelService[T]) GetOne(query bson.M, options *mongo.FindOptions) (mo if err != nil { return nil, err } - res, err := client.GetGrpcClient().ModelBaseServiceClient.GetOne(ctx, &grpc.ModelServiceGetOneRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return nil, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.GetOne(ctx, &grpc.ModelServiceGetOneRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -72,7 +82,11 @@ func (svc *ModelService[T]) GetMany(query bson.M, options *mongo.FindOptions) (m if err != nil { return nil, err } - res, err := client.GetGrpcClient().ModelBaseServiceClient.GetMany(ctx, &grpc.ModelServiceGetManyRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return nil, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.GetMany(ctx, &grpc.ModelServiceGetManyRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -87,7 +101,11 @@ func (svc *ModelService[T]) GetMany(query bson.M, options *mongo.FindOptions) (m func (svc *ModelService[T]) DeleteById(id primitive.ObjectID) (err error) { ctx, cancel := client.GetGrpcClient().Context() defer cancel() - _, err = client.GetGrpcClient().ModelBaseServiceClient.DeleteById(ctx, &grpc.ModelServiceDeleteByIdRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.DeleteById(ctx, &grpc.ModelServiceDeleteByIdRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Id: id.Hex(), @@ -105,7 +123,11 @@ func (svc *ModelService[T]) DeleteOne(query bson.M) (err error) { if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.DeleteOne(ctx, &grpc.ModelServiceDeleteOneRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.DeleteOne(ctx, &grpc.ModelServiceDeleteOneRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -123,7 +145,11 @@ func (svc *ModelService[T]) DeleteMany(query bson.M) (err error) { if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.DeleteMany(ctx, &grpc.ModelServiceDeleteManyRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.DeleteMany(ctx, &grpc.ModelServiceDeleteManyRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -141,7 +167,11 @@ func (svc *ModelService[T]) UpdateById(id primitive.ObjectID, update bson.M) (er if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.UpdateById(ctx, &grpc.ModelServiceUpdateByIdRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.UpdateById(ctx, &grpc.ModelServiceUpdateByIdRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Id: id.Hex(), @@ -164,7 +194,11 @@ func (svc *ModelService[T]) UpdateOne(query bson.M, update bson.M) (err error) { if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.UpdateOne(ctx, &grpc.ModelServiceUpdateOneRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.UpdateOne(ctx, &grpc.ModelServiceUpdateOneRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -187,7 +221,11 @@ func (svc *ModelService[T]) UpdateMany(query bson.M, update bson.M) (err error) if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.UpdateMany(ctx, &grpc.ModelServiceUpdateManyRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.UpdateMany(ctx, &grpc.ModelServiceUpdateManyRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -203,7 +241,11 @@ func (svc *ModelService[T]) ReplaceById(id primitive.ObjectID, model T) (err err if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.ReplaceById(ctx, &grpc.ModelServiceReplaceByIdRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.ReplaceById(ctx, &grpc.ModelServiceReplaceByIdRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Id: id.Hex(), @@ -226,7 +268,11 @@ func (svc *ModelService[T]) ReplaceOne(query bson.M, model T) (err error) { if err != nil { return err } - _, err = client.GetGrpcClient().ModelBaseServiceClient.ReplaceOne(ctx, &grpc.ModelServiceReplaceOneRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return fmt.Errorf("failed to get model base service client: %v", err) + } + _, err = modelClient.ReplaceOne(ctx, &grpc.ModelServiceReplaceOneRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -245,7 +291,11 @@ func (svc *ModelService[T]) InsertOne(model T) (id primitive.ObjectID, err error if err != nil { return primitive.NilObjectID, err } - res, err := client.GetGrpcClient().ModelBaseServiceClient.InsertOne(ctx, &grpc.ModelServiceInsertOneRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return primitive.NilObjectID, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.InsertOne(ctx, &grpc.ModelServiceInsertOneRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Model: modelData, @@ -263,7 +313,11 @@ func (svc *ModelService[T]) InsertMany(models []T) (ids []primitive.ObjectID, er if err != nil { return nil, err } - res, err := client.GetGrpcClient().ModelBaseServiceClient.InsertMany(ctx, &grpc.ModelServiceInsertManyRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return nil, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.InsertMany(ctx, &grpc.ModelServiceInsertManyRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Models: modelsData, @@ -285,7 +339,11 @@ func (svc *ModelService[T]) UpsertOne(query bson.M, model T) (id primitive.Objec if err != nil { return primitive.NilObjectID, err } - res, err := client.GetGrpcClient().ModelBaseServiceClient.UpsertOne(ctx, &grpc.ModelServiceUpsertOneRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return primitive.NilObjectID, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.UpsertOne(ctx, &grpc.ModelServiceUpsertOneRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, @@ -305,7 +363,11 @@ func (svc *ModelService[T]) Count(query bson.M) (total int, err error) { if err != nil { return 0, err } - res, err := client.GetGrpcClient().ModelBaseServiceClient.Count(ctx, &grpc.ModelServiceCountRequest{ + modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient() + if err != nil { + return 0, fmt.Errorf("failed to get model base service client: %v", err) + } + res, err := modelClient.Count(ctx, &grpc.ModelServiceCountRequest{ NodeKey: svc.cfg.GetNodeKey(), ModelType: svc.modelType, Query: queryData, diff --git a/core/node/service/worker_service.go b/core/node/service/worker_service.go index e4078118..89242763 100644 --- a/core/node/service/worker_service.go +++ b/core/node/service/worker_service.go @@ -4,13 +4,14 @@ import ( "context" "errors" "fmt" - "github.com/crawlab-team/crawlab/core/controllers" - "github.com/gin-gonic/gin" "net" "net/http" "sync" "time" + "github.com/crawlab-team/crawlab/core/controllers" + "github.com/gin-gonic/gin" + "github.com/crawlab-team/crawlab/core/models/models" "github.com/cenkalti/backoff/v4" @@ -85,7 +86,11 @@ func (svc *WorkerService) register() { op := func() (err error) { ctx, cancel := client.GetGrpcClient().Context() defer cancel() - _, err = client.GetGrpcClient().NodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{ + nodeClient, err := client.GetGrpcClient().GetNodeClient() + if err != nil { + return fmt.Errorf("failed to get node client: %v", err) + } + _, err = nodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{ NodeKey: svc.cfgSvc.GetNodeKey(), NodeName: svc.cfgSvc.GetNodeName(), MaxRunners: int32(svc.cfgSvc.GetMaxRunners()), @@ -152,7 +157,12 @@ func (svc *WorkerService) subscribe() { // Use backoff for connection attempts operation := func() error { svc.Debugf("attempting to subscribe to master") - stream, err := client.GetGrpcClient().NodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{ + nodeClient, err := client.GetGrpcClient().GetNodeClient() + if err != nil { + svc.Errorf("failed to get node client: %v", err) + return err + } + stream, err := nodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{ NodeKey: svc.cfgSvc.GetNodeKey(), }) if err != nil { @@ -199,7 +209,12 @@ func (svc *WorkerService) subscribe() { func (svc *WorkerService) sendHeartbeat() { ctx, cancel := context.WithTimeout(context.Background(), svc.heartbeatInterval) defer cancel() - _, err := client.GetGrpcClient().NodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{ + nodeClient, err := client.GetGrpcClient().GetNodeClient() + if err != nil { + svc.Errorf("failed to get node client: %v", err) + return + } + _, err = nodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{ NodeKey: svc.cfgSvc.GetNodeKey(), }) if err != nil { diff --git a/core/task/handler/runner.go b/core/task/handler/runner.go index 7b3e2a16..7832cc93 100644 --- a/core/task/handler/runner.go +++ b/core/task/handler/runner.go @@ -885,7 +885,12 @@ func (r *Runner) initConnection() (err error) { r.connMutex.Lock() defer r.connMutex.Unlock() - r.conn, err = client2.GetGrpcClient().TaskClient.Connect(context.Background()) + taskClient, err := client2.GetGrpcClient().GetTaskClient() + if err != nil { + r.Errorf("failed to get task client: %v", err) + return err + } + r.conn, err = taskClient.Connect(context.Background()) if err != nil { r.Errorf("error connecting to task service: %v", err) return err @@ -1004,7 +1009,12 @@ func (r *Runner) reconnectWithRetry() error { } // Attempt reconnection - conn, err := client2.GetGrpcClient().TaskClient.Connect(context.Background()) + taskClient, err := client2.GetGrpcClient().GetTaskClient() + if err != nil { + r.Warnf("reconnection attempt %d failed to get task client: %v", attempt+1, err) + continue + } + conn, err := taskClient.Connect(context.Background()) if err != nil { r.Warnf("reconnection attempt %d failed: %v", attempt+1, err) continue @@ -1162,7 +1172,12 @@ func (r *Runner) sendNotification() { NodeKey: r.svc.GetNodeConfigService().GetNodeKey(), TaskId: r.tid.Hex(), } - _, err := client2.GetGrpcClient().TaskClient.SendNotification(context.Background(), req) + taskClient, err := client2.GetGrpcClient().GetTaskClient() + if err != nil { + r.Errorf("failed to get task client: %v", err) + return + } + _, err = taskClient.SendNotification(context.Background(), req) if err != nil { r.Errorf("error sending notification: %v", err) return diff --git a/core/task/handler/service.go b/core/task/handler/service.go index 03f085ab..bcd1b37d 100644 --- a/core/task/handler/service.go +++ b/core/task/handler/service.go @@ -405,7 +405,11 @@ func (svc *Service) updateNodeStatus() (err error) { func (svc *Service) fetchTask() (tid primitive.ObjectID, err error) { ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout) defer cancel() - res, err := svc.c.TaskClient.FetchTask(ctx, &grpc.TaskServiceFetchTaskRequest{ + taskClient, err := svc.c.GetTaskClient() + if err != nil { + return primitive.NilObjectID, fmt.Errorf("failed to get task client: %v", err) + } + res, err := taskClient.FetchTask(ctx, &grpc.TaskServiceFetchTaskRequest{ NodeKey: svc.cfgSvc.GetNodeKey(), }) if err != nil { @@ -496,7 +500,11 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe req := &grpc.TaskServiceSubscribeRequest{ TaskId: taskId.Hex(), } - stream, err = svc.c.TaskClient.Subscribe(ctx, req) + taskClient, err := svc.c.GetTaskClient() + if err != nil { + return nil, fmt.Errorf("failed to get task client: %v", err) + } + stream, err = taskClient.Subscribe(ctx, req) if err != nil { svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err) return nil, err