mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
refactor: update gRPC client access patterns to use safe getter methods for improved error handling
This commit is contained in:
@@ -17,28 +17,47 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
type GrpcClient struct {
|
||||
// dependencies
|
||||
nodeCfgSvc interfaces.NodeConfigService
|
||||
// GrpcClient provides a robust gRPC client with connection management and client registration.
|
||||
//
|
||||
// The client handles connection lifecycle and ensures that gRPC service clients are properly
|
||||
// initialized before use. All client fields are private and can only be accessed through
|
||||
// safe getter methods that ensure registration before returning clients.
|
||||
//
|
||||
// Example usage:
|
||||
// client := GetGrpcClient()
|
||||
//
|
||||
// // Safe access pattern - always use getter methods
|
||||
// nodeClient, err := client.GetNodeClient()
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to get node client: %v", err)
|
||||
// }
|
||||
// resp, err := nodeClient.Register(ctx, req)
|
||||
//
|
||||
// // Alternative with timeout
|
||||
// taskClient, err := client.GetTaskClientWithTimeout(5 * time.Second)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to get task client: %v", err)
|
||||
// }
|
||||
// resp, err := taskClient.Connect(ctx)
|
||||
|
||||
type GrpcClient struct {
|
||||
// settings
|
||||
address string
|
||||
timeout time.Duration
|
||||
|
||||
// internals
|
||||
conn *grpc.ClientConn
|
||||
err error
|
||||
once sync.Once
|
||||
stopped bool
|
||||
stop chan struct{}
|
||||
interfaces.Logger
|
||||
|
||||
// clients
|
||||
NodeClient grpc2.NodeServiceClient
|
||||
TaskClient grpc2.TaskServiceClient
|
||||
ModelBaseServiceClient grpc2.ModelBaseServiceClient
|
||||
DependencyClient grpc2.DependencyServiceClient
|
||||
MetricClient grpc2.MetricServiceClient
|
||||
// clients (private to enforce safe access through getter methods)
|
||||
nodeClient grpc2.NodeServiceClient
|
||||
taskClient grpc2.TaskServiceClient
|
||||
modelBaseServiceClient grpc2.ModelBaseServiceClient
|
||||
dependencyClient grpc2.DependencyServiceClient
|
||||
metricClient grpc2.MetricServiceClient
|
||||
|
||||
// Add new fields for state management
|
||||
state connectivity.State
|
||||
@@ -46,14 +65,17 @@ type GrpcClient struct {
|
||||
reconnect chan struct{}
|
||||
|
||||
// Circuit breaker fields
|
||||
failureCount int
|
||||
lastFailure time.Time
|
||||
circuitBreaker bool
|
||||
cbMux sync.RWMutex
|
||||
failureCount int
|
||||
lastFailure time.Time
|
||||
cbMux sync.RWMutex
|
||||
|
||||
// Reconnection control
|
||||
reconnecting bool
|
||||
reconnectMux sync.Mutex
|
||||
|
||||
// Registration status
|
||||
registered bool
|
||||
registeredMux sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Start() {
|
||||
@@ -76,6 +98,7 @@ func (c *GrpcClient) Start() {
|
||||
func (c *GrpcClient) Stop() (err error) {
|
||||
// set stopped flag
|
||||
c.stopped = true
|
||||
c.setRegistered(false)
|
||||
c.stop <- struct{}{}
|
||||
c.Infof("stopped")
|
||||
|
||||
@@ -111,11 +134,14 @@ func (c *GrpcClient) WaitForReady() {
|
||||
}
|
||||
|
||||
func (c *GrpcClient) register() {
|
||||
c.NodeClient = grpc2.NewNodeServiceClient(c.conn)
|
||||
c.ModelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn)
|
||||
c.TaskClient = grpc2.NewTaskServiceClient(c.conn)
|
||||
c.DependencyClient = grpc2.NewDependencyServiceClient(c.conn)
|
||||
c.MetricClient = grpc2.NewMetricServiceClient(c.conn)
|
||||
c.nodeClient = grpc2.NewNodeServiceClient(c.conn)
|
||||
c.modelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn)
|
||||
c.taskClient = grpc2.NewTaskServiceClient(c.conn)
|
||||
c.dependencyClient = grpc2.NewDependencyServiceClient(c.conn)
|
||||
c.metricClient = grpc2.NewMetricServiceClient(c.conn)
|
||||
|
||||
// Mark as registered
|
||||
c.setRegistered(true)
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) {
|
||||
@@ -127,6 +153,10 @@ func (c *GrpcClient) IsReady() (res bool) {
|
||||
return c.conn != nil && state == connectivity.Ready
|
||||
}
|
||||
|
||||
func (c *GrpcClient) IsReadyAndRegistered() (res bool) {
|
||||
return c.IsReady() && c.IsRegistered()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) IsClosed() (res bool) {
|
||||
if c.conn != nil {
|
||||
return c.conn.GetState() == connectivity.Shutdown
|
||||
@@ -211,6 +241,192 @@ func (c *GrpcClient) getState() connectivity.State {
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *GrpcClient) setRegistered(registered bool) {
|
||||
c.registeredMux.Lock()
|
||||
defer c.registeredMux.Unlock()
|
||||
c.registered = registered
|
||||
}
|
||||
|
||||
func (c *GrpcClient) IsRegistered() bool {
|
||||
c.registeredMux.RLock()
|
||||
defer c.registeredMux.RUnlock()
|
||||
return c.registered
|
||||
}
|
||||
|
||||
func (c *GrpcClient) WaitForRegistered() {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if c.IsRegistered() {
|
||||
c.Debugf("client is now registered")
|
||||
return
|
||||
}
|
||||
case <-c.stop:
|
||||
c.Errorf("client has stopped while waiting for registration")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Safe client getters that ensure registration before returning clients
|
||||
// These methods will wait for registration to complete or return an error if the client is stopped
|
||||
|
||||
func (c *GrpcClient) GetNodeClient() (grpc2.NodeServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
c.Debugf("waiting for node client registration")
|
||||
c.WaitForRegistered()
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
||||
}
|
||||
}
|
||||
return c.nodeClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetTaskClient() (grpc2.TaskServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
c.Debugf("waiting for task client registration")
|
||||
c.WaitForRegistered()
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
||||
}
|
||||
}
|
||||
return c.taskClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetModelBaseServiceClient() (grpc2.ModelBaseServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
c.Debugf("waiting for model base service client registration")
|
||||
c.WaitForRegistered()
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
||||
}
|
||||
}
|
||||
return c.modelBaseServiceClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetDependencyClient() (grpc2.DependencyServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
c.Debugf("waiting for dependency client registration")
|
||||
c.WaitForRegistered()
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
||||
}
|
||||
}
|
||||
return c.dependencyClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetMetricClient() (grpc2.MetricServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
c.Debugf("waiting for metric client registration")
|
||||
c.WaitForRegistered()
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client stopped while waiting for registration")
|
||||
}
|
||||
}
|
||||
return c.metricClient, nil
|
||||
}
|
||||
|
||||
// Safe client getters with timeout - these methods will wait up to the specified timeout
|
||||
// for registration to complete before returning an error
|
||||
|
||||
func (c *GrpcClient) GetNodeClientWithTimeout(timeout time.Duration) (grpc2.NodeServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to get node client: %w", err)
|
||||
}
|
||||
}
|
||||
return c.nodeClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetTaskClientWithTimeout(timeout time.Duration) (grpc2.TaskServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to get task client: %w", err)
|
||||
}
|
||||
}
|
||||
return c.taskClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetModelBaseServiceClientWithTimeout(timeout time.Duration) (grpc2.ModelBaseServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to get model base service client: %w", err)
|
||||
}
|
||||
}
|
||||
return c.modelBaseServiceClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetDependencyClientWithTimeout(timeout time.Duration) (grpc2.DependencyServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to get dependency client: %w", err)
|
||||
}
|
||||
}
|
||||
return c.dependencyClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetMetricClientWithTimeout(timeout time.Duration) (grpc2.MetricServiceClient, error) {
|
||||
if c.stopped {
|
||||
return nil, fmt.Errorf("grpc client is stopped")
|
||||
}
|
||||
if !c.IsRegistered() {
|
||||
if err := c.waitForRegisteredWithTimeout(timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to get metric client: %w", err)
|
||||
}
|
||||
}
|
||||
return c.metricClient, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) waitForRegisteredWithTimeout(timeout time.Duration) error {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if c.IsRegistered() {
|
||||
c.Debugf("client is now registered")
|
||||
return nil
|
||||
}
|
||||
case <-timer.C:
|
||||
return fmt.Errorf("timeout waiting for client registration after %v", timeout)
|
||||
case <-c.stop:
|
||||
return fmt.Errorf("client has stopped while waiting for registration")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connect() (err error) {
|
||||
// Start reconnection loop with proper cleanup
|
||||
go func() {
|
||||
@@ -299,6 +515,9 @@ func (c *GrpcClient) recordSuccess() {
|
||||
|
||||
func (c *GrpcClient) doConnect() (err error) {
|
||||
op := func() error {
|
||||
// Mark as not registered during connection attempt
|
||||
c.setRegistered(false)
|
||||
|
||||
// Close existing connection if any
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
@@ -378,13 +597,15 @@ func (c *GrpcClient) doConnect() (err error) {
|
||||
}
|
||||
|
||||
func newGrpcClient() (c *GrpcClient) {
|
||||
return &GrpcClient{
|
||||
client := &GrpcClient{
|
||||
address: utils.GetGrpcAddress(),
|
||||
timeout: 10 * time.Second,
|
||||
stop: make(chan struct{}),
|
||||
Logger: utils.NewLogger("GrpcClient"),
|
||||
state: connectivity.Idle,
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
var _client *GrpcClient
|
||||
|
||||
@@ -2,6 +2,10 @@ package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/grpc/client"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/crawlab-team/crawlab/core/mongo"
|
||||
@@ -9,8 +13,6 @@ import (
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,7 +29,11 @@ type ModelService[T any] struct {
|
||||
func (svc *ModelService[T]) GetById(id primitive.ObjectID) (model *T, err error) {
|
||||
ctx, cancel := client.GetGrpcClient().Context()
|
||||
defer cancel()
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.GetById(ctx, &grpc.ModelServiceGetByIdRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.GetById(ctx, &grpc.ModelServiceGetByIdRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Id: id.Hex(),
|
||||
@@ -49,7 +55,11 @@ func (svc *ModelService[T]) GetOne(query bson.M, options *mongo.FindOptions) (mo
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.GetOne(ctx, &grpc.ModelServiceGetOneRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.GetOne(ctx, &grpc.ModelServiceGetOneRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -72,7 +82,11 @@ func (svc *ModelService[T]) GetMany(query bson.M, options *mongo.FindOptions) (m
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.GetMany(ctx, &grpc.ModelServiceGetManyRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.GetMany(ctx, &grpc.ModelServiceGetManyRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -87,7 +101,11 @@ func (svc *ModelService[T]) GetMany(query bson.M, options *mongo.FindOptions) (m
|
||||
func (svc *ModelService[T]) DeleteById(id primitive.ObjectID) (err error) {
|
||||
ctx, cancel := client.GetGrpcClient().Context()
|
||||
defer cancel()
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.DeleteById(ctx, &grpc.ModelServiceDeleteByIdRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.DeleteById(ctx, &grpc.ModelServiceDeleteByIdRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Id: id.Hex(),
|
||||
@@ -105,7 +123,11 @@ func (svc *ModelService[T]) DeleteOne(query bson.M) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.DeleteOne(ctx, &grpc.ModelServiceDeleteOneRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.DeleteOne(ctx, &grpc.ModelServiceDeleteOneRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -123,7 +145,11 @@ func (svc *ModelService[T]) DeleteMany(query bson.M) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.DeleteMany(ctx, &grpc.ModelServiceDeleteManyRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.DeleteMany(ctx, &grpc.ModelServiceDeleteManyRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -141,7 +167,11 @@ func (svc *ModelService[T]) UpdateById(id primitive.ObjectID, update bson.M) (er
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.UpdateById(ctx, &grpc.ModelServiceUpdateByIdRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.UpdateById(ctx, &grpc.ModelServiceUpdateByIdRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Id: id.Hex(),
|
||||
@@ -164,7 +194,11 @@ func (svc *ModelService[T]) UpdateOne(query bson.M, update bson.M) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.UpdateOne(ctx, &grpc.ModelServiceUpdateOneRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.UpdateOne(ctx, &grpc.ModelServiceUpdateOneRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -187,7 +221,11 @@ func (svc *ModelService[T]) UpdateMany(query bson.M, update bson.M) (err error)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.UpdateMany(ctx, &grpc.ModelServiceUpdateManyRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.UpdateMany(ctx, &grpc.ModelServiceUpdateManyRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -203,7 +241,11 @@ func (svc *ModelService[T]) ReplaceById(id primitive.ObjectID, model T) (err err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.ReplaceById(ctx, &grpc.ModelServiceReplaceByIdRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.ReplaceById(ctx, &grpc.ModelServiceReplaceByIdRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Id: id.Hex(),
|
||||
@@ -226,7 +268,11 @@ func (svc *ModelService[T]) ReplaceOne(query bson.M, model T) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.GetGrpcClient().ModelBaseServiceClient.ReplaceOne(ctx, &grpc.ModelServiceReplaceOneRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
_, err = modelClient.ReplaceOne(ctx, &grpc.ModelServiceReplaceOneRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -245,7 +291,11 @@ func (svc *ModelService[T]) InsertOne(model T) (id primitive.ObjectID, err error
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, err
|
||||
}
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.InsertOne(ctx, &grpc.ModelServiceInsertOneRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.InsertOne(ctx, &grpc.ModelServiceInsertOneRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Model: modelData,
|
||||
@@ -263,7 +313,11 @@ func (svc *ModelService[T]) InsertMany(models []T) (ids []primitive.ObjectID, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.InsertMany(ctx, &grpc.ModelServiceInsertManyRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.InsertMany(ctx, &grpc.ModelServiceInsertManyRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Models: modelsData,
|
||||
@@ -285,7 +339,11 @@ func (svc *ModelService[T]) UpsertOne(query bson.M, model T) (id primitive.Objec
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, err
|
||||
}
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.UpsertOne(ctx, &grpc.ModelServiceUpsertOneRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.UpsertOne(ctx, &grpc.ModelServiceUpsertOneRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
@@ -305,7 +363,11 @@ func (svc *ModelService[T]) Count(query bson.M) (total int, err error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res, err := client.GetGrpcClient().ModelBaseServiceClient.Count(ctx, &grpc.ModelServiceCountRequest{
|
||||
modelClient, err := client.GetGrpcClient().GetModelBaseServiceClient()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get model base service client: %v", err)
|
||||
}
|
||||
res, err := modelClient.Count(ctx, &grpc.ModelServiceCountRequest{
|
||||
NodeKey: svc.cfg.GetNodeKey(),
|
||||
ModelType: svc.modelType,
|
||||
Query: queryData,
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/crawlab-team/crawlab/core/controllers"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/controllers"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/crawlab-team/crawlab/core/models/models"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -85,7 +86,11 @@ func (svc *WorkerService) register() {
|
||||
op := func() (err error) {
|
||||
ctx, cancel := client.GetGrpcClient().Context()
|
||||
defer cancel()
|
||||
_, err = client.GetGrpcClient().NodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{
|
||||
nodeClient, err := client.GetGrpcClient().GetNodeClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get node client: %v", err)
|
||||
}
|
||||
_, err = nodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
NodeName: svc.cfgSvc.GetNodeName(),
|
||||
MaxRunners: int32(svc.cfgSvc.GetMaxRunners()),
|
||||
@@ -152,7 +157,12 @@ func (svc *WorkerService) subscribe() {
|
||||
// Use backoff for connection attempts
|
||||
operation := func() error {
|
||||
svc.Debugf("attempting to subscribe to master")
|
||||
stream, err := client.GetGrpcClient().NodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{
|
||||
nodeClient, err := client.GetGrpcClient().GetNodeClient()
|
||||
if err != nil {
|
||||
svc.Errorf("failed to get node client: %v", err)
|
||||
return err
|
||||
}
|
||||
stream, err := nodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -199,7 +209,12 @@ func (svc *WorkerService) subscribe() {
|
||||
func (svc *WorkerService) sendHeartbeat() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), svc.heartbeatInterval)
|
||||
defer cancel()
|
||||
_, err := client.GetGrpcClient().NodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{
|
||||
nodeClient, err := client.GetGrpcClient().GetNodeClient()
|
||||
if err != nil {
|
||||
svc.Errorf("failed to get node client: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = nodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -885,7 +885,12 @@ func (r *Runner) initConnection() (err error) {
|
||||
r.connMutex.Lock()
|
||||
defer r.connMutex.Unlock()
|
||||
|
||||
r.conn, err = client2.GetGrpcClient().TaskClient.Connect(context.Background())
|
||||
taskClient, err := client2.GetGrpcClient().GetTaskClient()
|
||||
if err != nil {
|
||||
r.Errorf("failed to get task client: %v", err)
|
||||
return err
|
||||
}
|
||||
r.conn, err = taskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
r.Errorf("error connecting to task service: %v", err)
|
||||
return err
|
||||
@@ -1004,7 +1009,12 @@ func (r *Runner) reconnectWithRetry() error {
|
||||
}
|
||||
|
||||
// Attempt reconnection
|
||||
conn, err := client2.GetGrpcClient().TaskClient.Connect(context.Background())
|
||||
taskClient, err := client2.GetGrpcClient().GetTaskClient()
|
||||
if err != nil {
|
||||
r.Warnf("reconnection attempt %d failed to get task client: %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
conn, err := taskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
r.Warnf("reconnection attempt %d failed: %v", attempt+1, err)
|
||||
continue
|
||||
@@ -1162,7 +1172,12 @@ func (r *Runner) sendNotification() {
|
||||
NodeKey: r.svc.GetNodeConfigService().GetNodeKey(),
|
||||
TaskId: r.tid.Hex(),
|
||||
}
|
||||
_, err := client2.GetGrpcClient().TaskClient.SendNotification(context.Background(), req)
|
||||
taskClient, err := client2.GetGrpcClient().GetTaskClient()
|
||||
if err != nil {
|
||||
r.Errorf("failed to get task client: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = taskClient.SendNotification(context.Background(), req)
|
||||
if err != nil {
|
||||
r.Errorf("error sending notification: %v", err)
|
||||
return
|
||||
|
||||
@@ -405,7 +405,11 @@ func (svc *Service) updateNodeStatus() (err error) {
|
||||
func (svc *Service) fetchTask() (tid primitive.ObjectID, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout)
|
||||
defer cancel()
|
||||
res, err := svc.c.TaskClient.FetchTask(ctx, &grpc.TaskServiceFetchTaskRequest{
|
||||
taskClient, err := svc.c.GetTaskClient()
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, fmt.Errorf("failed to get task client: %v", err)
|
||||
}
|
||||
res, err := taskClient.FetchTask(ctx, &grpc.TaskServiceFetchTaskRequest{
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -496,7 +500,11 @@ func (svc *Service) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskSe
|
||||
req := &grpc.TaskServiceSubscribeRequest{
|
||||
TaskId: taskId.Hex(),
|
||||
}
|
||||
stream, err = svc.c.TaskClient.Subscribe(ctx, req)
|
||||
taskClient, err := svc.c.GetTaskClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get task client: %v", err)
|
||||
}
|
||||
stream, err = taskClient.Subscribe(ctx, req)
|
||||
if err != nil {
|
||||
svc.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user