package client import ( "context" "fmt" "sync" "time" "github.com/apex/log" "github.com/cenkalti/backoff/v4" "github.com/crawlab-team/crawlab/core/grpc/middlewares" "github.com/crawlab-team/crawlab/core/interfaces" "github.com/crawlab-team/crawlab/core/utils" grpc2 "github.com/crawlab-team/crawlab/grpc" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" ) type GrpcClient struct { // dependencies nodeCfgSvc interfaces.NodeConfigService // settings address string timeout time.Duration // internals conn *grpc.ClientConn err error once sync.Once stopped bool stop chan struct{} // clients NodeClient grpc2.NodeServiceClient TaskClient grpc2.TaskServiceClient ModelBaseServiceClient grpc2.ModelBaseServiceClient DependencyClient grpc2.DependencyServiceClient MetricClient grpc2.MetricServiceClient // Add new fields for state management state connectivity.State stateMux sync.RWMutex reconnect chan struct{} } func (c *GrpcClient) Start() (err error) { c.once.Do(func() { // initialize reconnect channel c.reconnect = make(chan struct{}) // start state monitor go c.monitorState() // connect err = c.connect() if err != nil { return } // register rpc services c.register() }) return err } func (c *GrpcClient) Stop() (err error) { // set stopped flag c.stopped = true c.stop <- struct{}{} log.Infof("[GrpcClient] stopped") // skip if connection is nil if c.conn == nil { return nil } // close connection if err := c.conn.Close(); err != nil { log.Errorf("grpc client failed to close connection: %v", err) return err } log.Infof("grpc client disconnected from %s", c.address) return nil } func (c *GrpcClient) WaitForReady() { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: if c.IsReady() { log.Debugf("grpc client ready") return } case <-c.stop: log.Errorf("grpc client stopped") } } } func (c *GrpcClient) register() { c.NodeClient = grpc2.NewNodeServiceClient(c.conn) c.ModelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn) c.TaskClient = grpc2.NewTaskServiceClient(c.conn) c.DependencyClient = grpc2.NewDependencyServiceClient(c.conn) c.MetricClient = grpc2.NewMetricServiceClient(c.conn) } func (c *GrpcClient) Context() (ctx context.Context, cancel context.CancelFunc) { return context.WithTimeout(context.Background(), c.timeout) } func (c *GrpcClient) IsReady() (res bool) { state := c.conn.GetState() return c.conn != nil && state == connectivity.Ready } func (c *GrpcClient) IsClosed() (res bool) { if c.conn != nil { return c.conn.GetState() == connectivity.Shutdown } return false } func (c *GrpcClient) monitorState() { for { select { case <-c.stop: return default: if c.conn == nil { time.Sleep(time.Second) continue } previous := c.getState() current := c.conn.GetState() if previous != current { c.setState(current) log.Infof("[GrpcClient] state changed from %s to %s", previous, current) // Trigger reconnect if connection is lost or becomes idle from ready state if current == connectivity.TransientFailure || current == connectivity.Shutdown || (previous == connectivity.Ready && current == connectivity.Idle) { select { case c.reconnect <- struct{}{}: log.Infof("[GrpcClient] triggering reconnection due to state change to %s", current) default: } } } time.Sleep(time.Second) } } } func (c *GrpcClient) setState(state connectivity.State) { c.stateMux.Lock() defer c.stateMux.Unlock() c.state = state } func (c *GrpcClient) getState() connectivity.State { c.stateMux.RLock() defer c.stateMux.RUnlock() return c.state } func (c *GrpcClient) connect() (err error) { // Start reconnection loop go func() { for { select { case <-c.stop: return case <-c.reconnect: if !c.stopped { log.Infof("[GrpcClient] attempting to reconnect to %s", c.address) if err := c.doConnect(); err != nil { log.Errorf("[GrpcClient] reconnection failed: %v", err) } } } } }() return c.doConnect() } func (c *GrpcClient) doConnect() (err error) { op := func() error { // connection options opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithChainUnaryInterceptor(middlewares.GetGrpcClientAuthTokenUnaryChainInterceptor()), grpc.WithChainStreamInterceptor(middlewares.GetGrpcClientAuthTokenStreamChainInterceptor()), } // create new client connection c.conn, err = grpc.NewClient(c.address, opts...) if err != nil { log.Errorf("[GrpcClient] grpc client failed to connect to %s: %v", c.address, err) return err } // connect log.Infof("[GrpcClient] grpc client connecting to %s", c.address) c.conn.Connect() // wait for connection to be ready ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ok := c.conn.WaitForStateChange(ctx, connectivity.Ready) if !ok { return fmt.Errorf("[GrpcClient] grpc client failed to connect to %s: timed out", c.address) } // success log.Infof("[GrpcClient] grpc client connected to %s", c.address) return nil } b := backoff.NewExponentialBackOff() b.InitialInterval = 5 * time.Second b.MaxElapsedTime = 10 * time.Minute n := func(err error, duration time.Duration) { log.Errorf("[GrpcClient] grpc client failed to connect to %s: %v, retrying in %s", c.address, err, duration) } return backoff.RetryNotify(op, b, n) } func newGrpcClient() (c *GrpcClient) { return &GrpcClient{ address: utils.GetGrpcAddress(), timeout: 10 * time.Second, stop: make(chan struct{}), state: connectivity.Idle, } } var _client *GrpcClient var _clientOnce sync.Once func GetGrpcClient() *GrpcClient { _clientOnce.Do(func() { _client = newGrpcClient() go func() { err := _client.Start() if err != nil { log.Fatalf("[GrpcClient] failed to start: %v", err) } }() }) return _client }