Files
crawlab/core/grpc/client/client.go
2024-06-14 16:37:48 +08:00

421 lines
9.6 KiB
Go

package client
import (
"context"
"encoding/json"
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/container"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/errors"
"github.com/crawlab-team/crawlab/core/grpc/middlewares"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/utils"
grpc2 "github.com/crawlab-team/crawlab/grpc"
"github.com/crawlab-team/crawlab/trace"
"github.com/spf13/viper"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"io"
"time"
)
type Client struct {
// dependencies
nodeCfgSvc interfaces.NodeConfigService
// settings
cfgPath string
address interfaces.Address
timeout time.Duration
subscribeType string
handleMessage bool
// internals
conn *grpc.ClientConn
stream grpc2.NodeService_SubscribeClient
msgCh chan *grpc2.StreamMessage
err error
// grpc clients
ModelDelegateClient grpc2.ModelDelegateClient
ModelBaseServiceClient grpc2.ModelBaseServiceClient
NodeClient grpc2.NodeServiceClient
TaskClient grpc2.TaskServiceClient
MessageClient grpc2.MessageServiceClient
}
func (c *Client) Init() (err error) {
// do nothing
return nil
}
func (c *Client) Start() (err error) {
// connect
if err := c.connect(); err != nil {
return err
}
// register rpc services
if err := c.Register(); err != nil {
return err
}
// subscribe
if err := c.subscribe(); err != nil {
return err
}
// handle stream message
if c.handleMessage {
go c.handleStreamMessage()
}
return nil
}
func (c *Client) Stop() (err error) {
// skip if connection is nil
if c.conn == nil {
return nil
}
// grpc server address
address := c.address.String()
// unsubscribe
if err := c.unsubscribe(); err != nil {
return err
}
log.Infof("grpc client unsubscribed from %s", address)
// close connection
if err := c.conn.Close(); err != nil {
return err
}
log.Infof("grpc client disconnected from %s", address)
return nil
}
func (c *Client) Register() (err error) {
// model delegate
c.ModelDelegateClient = grpc2.NewModelDelegateClient(c.conn)
// model base service
c.ModelBaseServiceClient = grpc2.NewModelBaseServiceClient(c.conn)
// node
c.NodeClient = grpc2.NewNodeServiceClient(c.conn)
// task
c.TaskClient = grpc2.NewTaskServiceClient(c.conn)
// message
c.MessageClient = grpc2.NewMessageServiceClient(c.conn)
// log
log.Infof("[GrpcClient] grpc client registered client services")
log.Debugf("[GrpcClient] ModelDelegateClient: %v", c.ModelDelegateClient)
log.Debugf("[GrpcClient] ModelBaseServiceClient: %v", c.ModelBaseServiceClient)
log.Debugf("[GrpcClient] NodeClient: %v", c.NodeClient)
log.Debugf("[GrpcClient] TaskClient: %v", c.TaskClient)
log.Debugf("[GrpcClient] MessageClient: %v", c.MessageClient)
return nil
}
func (c *Client) GetModelDelegateClient() (res grpc2.ModelDelegateClient) {
return c.ModelDelegateClient
}
func (c *Client) GetModelBaseServiceClient() (res grpc2.ModelBaseServiceClient) {
return c.ModelBaseServiceClient
}
func (c *Client) GetNodeClient() grpc2.NodeServiceClient {
return c.NodeClient
}
func (c *Client) GetTaskClient() grpc2.TaskServiceClient {
return c.TaskClient
}
func (c *Client) GetMessageClient() grpc2.MessageServiceClient {
return c.MessageClient
}
func (c *Client) SetAddress(address interfaces.Address) {
c.address = address
}
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
func (c *Client) SetSubscribeType(value string) {
c.subscribeType = value
}
func (c *Client) SetHandleMessage(handleMessage bool) {
c.handleMessage = handleMessage
}
func (c *Client) Context() (ctx context.Context, cancel context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout)
}
func (c *Client) NewRequest(d interface{}) (req *grpc2.Request) {
return &grpc2.Request{
NodeKey: c.nodeCfgSvc.GetNodeKey(),
Data: c.getRequestData(d),
}
}
func (c *Client) GetConfigPath() (path string) {
return c.cfgPath
}
func (c *Client) SetConfigPath(path string) {
c.cfgPath = path
}
func (c *Client) NewModelBaseServiceRequest(id interfaces.ModelId, params interfaces.GrpcBaseServiceParams) (req *grpc2.Request, err error) {
data, err := json.Marshal(params)
if err != nil {
return nil, trace.TraceError(err)
}
msg := &entity.GrpcBaseServiceMessage{
ModelId: id,
Data: data,
}
return c.NewRequest(msg), nil
}
func (c *Client) GetMessageChannel() (msgCh chan *grpc2.StreamMessage) {
return c.msgCh
}
func (c *Client) Restart() (err error) {
if c.needRestart() {
return c.Start()
}
return nil
}
func (c *Client) IsStarted() (res bool) {
return c.conn != nil
}
func (c *Client) IsClosed() (res bool) {
if c.conn != nil {
return c.conn.GetState() == connectivity.Shutdown
}
return false
}
func (c *Client) Err() (err error) {
return c.err
}
func (c *Client) GetStream() (stream grpc2.NodeService_SubscribeClient) {
return c.stream
}
func (c *Client) connect() (err error) {
return backoff.RetryNotify(c._connect, backoff.NewExponentialBackOff(), utils.BackoffErrorNotify("grpc client connect"))
}
func (c *Client) _connect() (err error) {
// grpc server address
address := c.address.String()
// timeout context
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
// connection
// TODO: configure dial options
var opts []grpc.DialOption
opts = append(opts, grpc.WithInsecure())
opts = append(opts, grpc.WithBlock())
opts = append(opts, grpc.WithChainUnaryInterceptor(middlewares.GetAuthTokenUnaryChainInterceptor(c.nodeCfgSvc)))
opts = append(opts, grpc.WithChainStreamInterceptor(middlewares.GetAuthTokenStreamChainInterceptor(c.nodeCfgSvc)))
c.conn, err = grpc.DialContext(ctx, address, opts...)
if err != nil {
_ = trace.TraceError(err)
return errors.ErrorGrpcClientFailedToStart
}
log.Infof("[GrpcClient] grpc client connected to %s", address)
return nil
}
func (c *Client) subscribe() (err error) {
var op func() error
switch c.subscribeType {
case constants.GrpcSubscribeTypeNode:
op = c._subscribeNode
default:
return errors.ErrorGrpcInvalidType
}
return backoff.RetryNotify(op, backoff.NewExponentialBackOff(), utils.BackoffErrorNotify("grpc client subscribe"))
}
func (c *Client) _subscribeNode() (err error) {
req := c.NewRequest(&entity.NodeInfo{
Key: c.nodeCfgSvc.GetNodeKey(),
IsMaster: false,
})
c.stream, err = c.GetNodeClient().Subscribe(context.Background(), req)
if err != nil {
return trace.TraceError(err)
}
// log
log.Infof("[GrpcClient] grpc client subscribed to remote server")
return nil
}
func (c *Client) unsubscribe() (err error) {
req := c.NewRequest(&entity.NodeInfo{
Key: c.nodeCfgSvc.GetNodeKey(),
IsMaster: false,
})
if _, err = c.GetNodeClient().Unsubscribe(context.Background(), req); err != nil {
return trace.TraceError(err)
}
return nil
}
func (c *Client) handleStreamMessage() {
log.Infof("[GrpcClient] start handling stream message...")
for {
// resubscribe if stream is set to nil
if c.stream == nil {
if err := backoff.RetryNotify(c.subscribe, backoff.NewExponentialBackOff(), utils.BackoffErrorNotify("grpc client subscribe")); err != nil {
log.Errorf("subscribe")
return
}
}
// receive stream message
msg, err := c.stream.Recv()
log.Debugf("[GrpcClient] received message: %v", msg)
if err != nil {
// set error
c.err = err
// end
if err == io.EOF {
log.Infof("[GrpcClient] received EOF signal, disconnecting")
return
}
// connection closed
if c.IsClosed() {
return
}
// error
trace.PrintError(err)
c.stream = nil
time.Sleep(1 * time.Second)
continue
}
// send stream message to channel
c.msgCh <- msg
// reset error
c.err = nil
}
}
func (c *Client) needRestart() bool {
switch c.conn.GetState() {
case connectivity.Shutdown, connectivity.TransientFailure:
return true
case connectivity.Idle, connectivity.Connecting, connectivity.Ready:
return false
default:
return false
}
}
func (c *Client) getRequestData(d interface{}) (data []byte) {
if d == nil {
return data
}
switch d.(type) {
case []byte:
data = d.([]byte)
default:
var err error
data, err = json.Marshal(d)
if err != nil {
panic(err)
}
}
return data
}
func NewClient() (res interfaces.GrpcClient, err error) {
// client
client := &Client{
address: entity.NewAddress(&entity.AddressOptions{
Host: constants.DefaultGrpcClientRemoteHost,
Port: constants.DefaultGrpcClientRemotePort,
}),
timeout: 10 * time.Second,
msgCh: make(chan *grpc2.StreamMessage),
subscribeType: constants.GrpcSubscribeTypeNode,
handleMessage: true,
}
if viper.GetString("grpc.address") != "" {
client.address, err = entity.NewAddressFromString(viper.GetString("grpc.address"))
if err != nil {
return nil, trace.TraceError(err)
}
}
// dependency injection
if err := container.GetContainer().Invoke(func(nodeCfgSvc interfaces.NodeConfigService) {
client.nodeCfgSvc = nodeCfgSvc
}); err != nil {
return nil, err
}
// init
if err := client.Init(); err != nil {
return nil, err
}
return client, nil
}
var _client interfaces.GrpcClient
func GetClient() (c interfaces.GrpcClient, err error) {
if _client != nil {
return _client, nil
}
_client, err = createClient()
if err != nil {
return nil, err
}
return _client, nil
}
func createClient() (client2 interfaces.GrpcClient, err error) {
if err := container.GetContainer().Invoke(func(client interfaces.GrpcClient) {
client2 = client
}); err != nil {
return nil, trace.TraceError(err)
}
return client2, nil
}