mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-29 18:00:51 +01:00
refactor: updated grpc services
This commit is contained in:
@@ -104,9 +104,9 @@ func NewServerV2() (app NodeApp) {
|
||||
// node service
|
||||
var err error
|
||||
if utils.IsMaster() {
|
||||
svr.nodeSvc, err = service.GetMasterServiceV2()
|
||||
svr.nodeSvc, err = service.GetMasterService()
|
||||
} else {
|
||||
svr.nodeSvc, err = service.GetWorkerServiceV2()
|
||||
svr.nodeSvc, err = service.GetWorkerService()
|
||||
}
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -17,29 +17,29 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type DependenciesServerV2 struct {
|
||||
grpc.UnimplementedDependenciesServiceV2Server
|
||||
type DependencyServiceServer struct {
|
||||
grpc.UnimplementedDependencyServiceV2Server
|
||||
mu *sync.Mutex
|
||||
streams map[string]*grpc.DependenciesServiceV2_ConnectServer
|
||||
streams map[string]*grpc.DependencyServiceV2_ConnectServer
|
||||
}
|
||||
|
||||
func (svr DependenciesServerV2) Connect(req *grpc.DependenciesServiceV2ConnectRequest, stream grpc.DependenciesServiceV2_ConnectServer) (err error) {
|
||||
func (svr DependencyServiceServer) Connect(req *grpc.DependencyServiceV2ConnectRequest, stream grpc.DependencyServiceV2_ConnectServer) (err error) {
|
||||
svr.mu.Lock()
|
||||
svr.streams[req.NodeKey] = &stream
|
||||
svr.mu.Unlock()
|
||||
log.Info("[DependenciesServerV2] connected: " + req.NodeKey)
|
||||
log.Info("[DependencyServiceServer] connected: " + req.NodeKey)
|
||||
|
||||
// Keep this scope alive because once this scope exits - the stream is closed
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
log.Info("[DependenciesServerV2] disconnected: " + req.NodeKey)
|
||||
log.Info("[DependencyServiceServer] disconnected: " + req.NodeKey)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.DependenciesServiceV2SyncRequest) (response *grpc.Response, err error) {
|
||||
func (svr DependencyServiceServer) Sync(ctx context.Context, request *grpc.DependencyServiceV2SyncRequest) (response *grpc.Response, err error) {
|
||||
n, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": request.NodeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -51,7 +51,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mongo.ErrNoDocuments) {
|
||||
log.Errorf("[DependenciesServiceV2] get dependencies from db error: %v", err)
|
||||
log.Errorf("[DependencyServiceV2] get dependencies from db error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
|
||||
"_id": bson.M{"$in": depIdsToDelete},
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("[DependenciesServerV2] delete dependencies in db error: %v", err)
|
||||
log.Errorf("[DependencyServiceServer] delete dependencies in db error: %v", err)
|
||||
trace.PrintError(err)
|
||||
return err
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
|
||||
if len(depsToInsert) > 0 {
|
||||
_, err = service.NewModelServiceV2[models2.DependencyV2]().InsertMany(depsToInsert)
|
||||
if err != nil {
|
||||
log.Errorf("[DependenciesServerV2] insert dependencies in db error: %v", err)
|
||||
log.Errorf("[DependencyServiceServer] insert dependencies in db error: %v", err)
|
||||
trace.PrintError(err)
|
||||
return err
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (svr DependenciesServerV2) UpdateTaskLog(stream grpc.DependenciesServiceV2_UpdateTaskLogServer) (err error) {
|
||||
func (svr DependencyServiceServer) UpdateTaskLog(stream grpc.DependencyServiceV2_UpdateTaskLogServer) (err error) {
|
||||
var t *models2.DependencyTaskV2
|
||||
for {
|
||||
req, err := stream.Recv()
|
||||
@@ -152,7 +152,7 @@ func (svr DependenciesServerV2) UpdateTaskLog(stream grpc.DependenciesServiceV2_
|
||||
}
|
||||
}
|
||||
|
||||
func (svr DependenciesServerV2) GetStream(key string) (stream *grpc.DependenciesServiceV2_ConnectServer, err error) {
|
||||
func (svr DependencyServiceServer) GetStream(key string) (stream *grpc.DependencyServiceV2_ConnectServer, err error) {
|
||||
svr.mu.Lock()
|
||||
defer svr.mu.Unlock()
|
||||
stream, ok := svr.streams[key]
|
||||
@@ -162,19 +162,19 @@ func (svr DependenciesServerV2) GetStream(key string) (stream *grpc.Dependencies
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func NewDependenciesServerV2() *DependenciesServerV2 {
|
||||
return &DependenciesServerV2{
|
||||
func NewDependencyServerV2() *DependencyServiceServer {
|
||||
return &DependencyServiceServer{
|
||||
mu: new(sync.Mutex),
|
||||
streams: make(map[string]*grpc.DependenciesServiceV2_ConnectServer),
|
||||
streams: make(map[string]*grpc.DependencyServiceV2_ConnectServer),
|
||||
}
|
||||
}
|
||||
|
||||
var depSvc *DependenciesServerV2
|
||||
var depSvc *DependencyServiceServer
|
||||
|
||||
func GetDependenciesServerV2() *DependenciesServerV2 {
|
||||
func GetDependencyServerV2() *DependencyServiceServer {
|
||||
if depSvc != nil {
|
||||
return depSvc
|
||||
}
|
||||
depSvc = NewDependenciesServerV2()
|
||||
depSvc = NewDependencyServerV2()
|
||||
return depSvc
|
||||
}
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type MetricsServerV2 struct {
|
||||
grpc.UnimplementedMetricsServiceV2Server
|
||||
type MetricServiceServer struct {
|
||||
grpc.UnimplementedMetricServiceV2Server
|
||||
}
|
||||
|
||||
func (svr MetricsServerV2) Send(_ context.Context, req *grpc.MetricsServiceV2SendRequest) (res *grpc.Response, err error) {
|
||||
log.Info("[MetricsServerV2] received metric from node: " + req.NodeKey)
|
||||
func (svr MetricServiceServer) Send(_ context.Context, req *grpc.MetricServiceV2SendRequest) (res *grpc.Response, err error) {
|
||||
log.Info("[MetricServiceServer] received metric from node: " + req.NodeKey)
|
||||
n, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil)
|
||||
if err != nil {
|
||||
log.Errorf("[MetricsServerV2] error getting node: %v", err)
|
||||
log.Errorf("[MetricServiceServer] error getting node: %v", err)
|
||||
return HandleError(err)
|
||||
}
|
||||
metric := models2.MetricV2{
|
||||
@@ -42,20 +42,20 @@ func (svr MetricsServerV2) Send(_ context.Context, req *grpc.MetricsServiceV2Sen
|
||||
metric.CreatedAt = time.Unix(req.Timestamp, 0)
|
||||
_, err = service.NewModelServiceV2[models2.MetricV2]().InsertOne(metric)
|
||||
if err != nil {
|
||||
log.Errorf("[MetricsServerV2] error inserting metric: %v", err)
|
||||
log.Errorf("[MetricServiceServer] error inserting metric: %v", err)
|
||||
return HandleError(err)
|
||||
}
|
||||
return HandleSuccess()
|
||||
}
|
||||
|
||||
func newMetricsServerV2() *MetricsServerV2 {
|
||||
return &MetricsServerV2{}
|
||||
func newMetricsServerV2() *MetricServiceServer {
|
||||
return &MetricServiceServer{}
|
||||
}
|
||||
|
||||
var metricsServerV2 *MetricsServerV2
|
||||
var metricsServerV2 *MetricServiceServer
|
||||
var metricsServerV2Once = &sync.Once{}
|
||||
|
||||
func GetMetricsServerV2() *MetricsServerV2 {
|
||||
func GetMetricsServerV2() *MetricServiceServer {
|
||||
if metricsServerV2 != nil {
|
||||
return metricsServerV2
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/errors"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
"github.com/crawlab-team/crawlab/core/models/models/v2"
|
||||
@@ -17,36 +16,34 @@ import (
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NodeServerV2 struct {
|
||||
var nodeServiceMutex = sync.Mutex{}
|
||||
|
||||
type NodeServiceServer struct {
|
||||
grpc.UnimplementedNodeServiceServer
|
||||
|
||||
// dependencies
|
||||
cfgSvc interfaces.NodeConfigService
|
||||
|
||||
// internals
|
||||
server *GrpcServerV2
|
||||
server *GrpcServer
|
||||
subs map[primitive.ObjectID]grpc.NodeService_SubscribeServer
|
||||
}
|
||||
|
||||
// Register from handler/worker to master
|
||||
func (svr NodeServerV2) Register(_ context.Context, req *grpc.NodeServiceRegisterRequest) (res *grpc.Response, err error) {
|
||||
// unmarshall data
|
||||
if req.IsMaster {
|
||||
// error: cannot register master node
|
||||
return HandleError(errors.ErrorGrpcNotAllowed)
|
||||
}
|
||||
|
||||
func (svr NodeServiceServer) Register(_ context.Context, req *grpc.NodeServiceRegisterRequest) (res *grpc.Response, err error) {
|
||||
// node key
|
||||
if req.Key == "" {
|
||||
if req.NodeKey == "" {
|
||||
return HandleError(errors.ErrorModelMissingRequiredData)
|
||||
}
|
||||
|
||||
// find in db
|
||||
var node *models.NodeV2
|
||||
node, err = service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.Key}, nil)
|
||||
node, err = service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil)
|
||||
if err == nil {
|
||||
// register existing
|
||||
node.Status = constants.NodeStatusRegistered
|
||||
@@ -56,12 +53,12 @@ func (svr NodeServerV2) Register(_ context.Context, req *grpc.NodeServiceRegiste
|
||||
if err != nil {
|
||||
return HandleError(err)
|
||||
}
|
||||
log.Infof("[NodeServerV2] updated worker[%s] in db. id: %s", req.Key, node.Id.Hex())
|
||||
log.Infof("[NodeServiceServer] updated worker[%s] in db. id: %s", req.NodeKey, node.Id.Hex())
|
||||
} else if errors2.Is(err, mongo.ErrNoDocuments) {
|
||||
// register new
|
||||
node = &models.NodeV2{
|
||||
Key: req.Key,
|
||||
Name: req.Name,
|
||||
Key: req.NodeKey,
|
||||
Name: req.NodeName,
|
||||
Status: constants.NodeStatusRegistered,
|
||||
Active: true,
|
||||
ActiveAt: time.Now(),
|
||||
@@ -74,21 +71,21 @@ func (svr NodeServerV2) Register(_ context.Context, req *grpc.NodeServiceRegiste
|
||||
if err != nil {
|
||||
return HandleError(err)
|
||||
}
|
||||
log.Infof("[NodeServerV2] added worker[%s] in db. id: %s", req.Key, node.Id.Hex())
|
||||
log.Infof("[NodeServiceServer] added worker[%s] in db. id: %s", req.NodeKey, node.Id.Hex())
|
||||
} else {
|
||||
// error
|
||||
return HandleError(err)
|
||||
}
|
||||
|
||||
log.Infof("[NodeServerV2] master registered worker[%s]", req.Key)
|
||||
log.Infof("[NodeServiceServer] master registered worker[%s]", req.NodeKey)
|
||||
|
||||
return HandleSuccessWithData(node)
|
||||
}
|
||||
|
||||
// SendHeartbeat from worker to master
|
||||
func (svr NodeServerV2) SendHeartbeat(_ context.Context, req *grpc.NodeServiceSendHeartbeatRequest) (res *grpc.Response, err error) {
|
||||
func (svr NodeServiceServer) SendHeartbeat(_ context.Context, req *grpc.NodeServiceSendHeartbeatRequest) (res *grpc.Response, err error) {
|
||||
// find in db
|
||||
node, err := service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.Key}, nil)
|
||||
node, err := service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil)
|
||||
if err != nil {
|
||||
if errors2.Is(err, mongo.ErrNoDocuments) {
|
||||
return HandleError(errors.ErrorNodeNotExists)
|
||||
@@ -122,64 +119,66 @@ func (svr NodeServerV2) SendHeartbeat(_ context.Context, req *grpc.NodeServiceSe
|
||||
return HandleSuccessWithData(node)
|
||||
}
|
||||
|
||||
func (svr NodeServerV2) Subscribe(request *grpc.Request, stream grpc.NodeService_SubscribeServer) (err error) {
|
||||
log.Infof("[NodeServerV2] master received subscribe request from node[%s]", request.NodeKey)
|
||||
func (svr NodeServiceServer) Subscribe(request *grpc.NodeServiceSubscribeRequest, stream grpc.NodeService_SubscribeServer) (err error) {
|
||||
log.Infof("[NodeServiceServer] master received subscribe request from node[%s]", request.NodeKey)
|
||||
|
||||
// finished channel
|
||||
finished := make(chan bool)
|
||||
|
||||
// set subscribe
|
||||
svr.server.SetSubscribe("node:"+request.NodeKey, &entity.GrpcSubscribe{
|
||||
Stream: stream,
|
||||
Finished: finished,
|
||||
})
|
||||
ctx := stream.Context()
|
||||
|
||||
log.Infof("[NodeServerV2] master subscribed node[%s]", request.NodeKey)
|
||||
|
||||
// Keep this scope alive because once this scope exits - the stream is closed
|
||||
for {
|
||||
select {
|
||||
case <-finished:
|
||||
log.Infof("[NodeServerV2] closing stream for node[%s]", request.NodeKey)
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
log.Infof("[NodeServerV2] node[%s] has disconnected", request.NodeKey)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svr NodeServerV2) Unsubscribe(_ context.Context, req *grpc.Request) (res *grpc.Response, err error) {
|
||||
sub, err := svr.server.GetSubscribe("node:" + req.NodeKey)
|
||||
// find in db
|
||||
node, err := service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": request.NodeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, errors.ErrorGrpcSubscribeNotExists
|
||||
log.Errorf("[NodeServiceServer] error getting node: %v", err)
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case sub.GetFinished() <- true:
|
||||
log.Infof("unsubscribed node[%s]", req.NodeKey)
|
||||
default:
|
||||
// Default case is to avoid blocking in case client has already unsubscribed
|
||||
}
|
||||
svr.server.DeleteSubscribe(req.NodeKey)
|
||||
return &grpc.Response{
|
||||
Code: grpc.ResponseCode_OK,
|
||||
Message: "unsubscribed successfully",
|
||||
}, nil
|
||||
|
||||
// subscribe
|
||||
nodeServiceMutex.Lock()
|
||||
svr.subs[node.Id] = stream
|
||||
nodeServiceMutex.Unlock()
|
||||
|
||||
// TODO: send notification
|
||||
|
||||
// create a new goroutine to receive messages from the stream to listen for EOF (end of stream)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
nodeServiceMutex.Lock()
|
||||
delete(svr.subs, node.Id)
|
||||
nodeServiceMutex.Unlock()
|
||||
return
|
||||
default:
|
||||
err := stream.RecvMsg(nil)
|
||||
if err == io.EOF {
|
||||
nodeServiceMutex.Lock()
|
||||
delete(svr.subs, node.Id)
|
||||
nodeServiceMutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var nodeSvrV2 *NodeServerV2
|
||||
func (svr NodeServiceServer) GetSubscribeStream(nodeId primitive.ObjectID) (stream grpc.NodeService_SubscribeServer, ok bool) {
|
||||
nodeServiceMutex.Lock()
|
||||
defer nodeServiceMutex.Unlock()
|
||||
stream, ok = svr.subs[nodeId]
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
var nodeSvrV2 *NodeServiceServer
|
||||
var nodeSvrV2Once = new(sync.Once)
|
||||
|
||||
func NewNodeServerV2() (res *NodeServerV2, err error) {
|
||||
func NewNodeServiceServer() (res *NodeServiceServer, err error) {
|
||||
if nodeSvrV2 != nil {
|
||||
return nodeSvrV2, nil
|
||||
}
|
||||
nodeSvrV2Once.Do(func() {
|
||||
nodeSvrV2 = &NodeServerV2{}
|
||||
nodeSvrV2 = &NodeServiceServer{}
|
||||
nodeSvrV2.cfgSvc = nodeconfig.GetNodeConfigService()
|
||||
if err != nil {
|
||||
log.Errorf("[NodeServerV2] error: %s", err.Error())
|
||||
log.Errorf("[NodeServiceServer] error: %s", err.Error())
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
|
||||
errors2 "github.com/pkg/errors"
|
||||
"github.com/spf13/viper"
|
||||
"go/types"
|
||||
"google.golang.org/grpc"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -28,7 +26,7 @@ var (
|
||||
mutexSubsV2 = &sync.Mutex{}
|
||||
)
|
||||
|
||||
type GrpcServerV2 struct {
|
||||
type GrpcServer struct {
|
||||
// settings
|
||||
cfgPath string
|
||||
address interfaces.Address
|
||||
@@ -42,22 +40,22 @@ type GrpcServerV2 struct {
|
||||
nodeCfgSvc interfaces.NodeConfigService
|
||||
|
||||
// servers
|
||||
NodeSvr *NodeServerV2
|
||||
TaskSvr *TaskServerV2
|
||||
NodeSvr *NodeServiceServer
|
||||
TaskSvr *TaskServiceServer
|
||||
ModelBaseServiceSvr *ModelBaseServiceServerV2
|
||||
DependenciesSvr *DependenciesServerV2
|
||||
MetricsSvr *MetricsServerV2
|
||||
DependencySvr *DependencyServiceServer
|
||||
MetricSvr *MetricServiceServer
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) GetConfigPath() (path string) {
|
||||
func (svr *GrpcServer) GetConfigPath() (path string) {
|
||||
return svr.cfgPath
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) SetConfigPath(path string) {
|
||||
func (svr *GrpcServer) SetConfigPath(path string) {
|
||||
svr.cfgPath = path
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) Init() (err error) {
|
||||
func (svr *GrpcServer) Init() (err error) {
|
||||
// register
|
||||
if err := svr.Register(); err != nil {
|
||||
return err
|
||||
@@ -66,7 +64,7 @@ func (svr *GrpcServerV2) Init() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) Start() (err error) {
|
||||
func (svr *GrpcServer) Start() (err error) {
|
||||
// grpc server binding address
|
||||
address := svr.address.String()
|
||||
|
||||
@@ -92,7 +90,7 @@ func (svr *GrpcServerV2) Start() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) Stop() (err error) {
|
||||
func (svr *GrpcServer) Stop() (err error) {
|
||||
// skip if listener is nil
|
||||
if svr.l == nil {
|
||||
return nil
|
||||
@@ -115,27 +113,27 @@ func (svr *GrpcServerV2) Stop() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) Register() (err error) {
|
||||
func (svr *GrpcServer) Register() (err error) {
|
||||
grpc2.RegisterNodeServiceServer(svr.svr, *svr.NodeSvr)
|
||||
grpc2.RegisterModelBaseServiceV2Server(svr.svr, *svr.ModelBaseServiceSvr)
|
||||
grpc2.RegisterTaskServiceServer(svr.svr, *svr.TaskSvr)
|
||||
grpc2.RegisterDependenciesServiceV2Server(svr.svr, *svr.DependenciesSvr)
|
||||
grpc2.RegisterMetricsServiceV2Server(svr.svr, *svr.MetricsSvr)
|
||||
grpc2.RegisterDependencyServiceV2Server(svr.svr, *svr.DependencySvr)
|
||||
grpc2.RegisterMetricServiceV2Server(svr.svr, *svr.MetricSvr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) recoveryHandlerFunc(p interface{}) (err error) {
|
||||
func (svr *GrpcServer) recoveryHandlerFunc(p interface{}) (err error) {
|
||||
err = errors.NewError(errors.ErrorPrefixGrpc, fmt.Sprintf("%v", p))
|
||||
trace.PrintError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) SetAddress(address interfaces.Address) {
|
||||
func (svr *GrpcServer) SetAddress(address interfaces.Address) {
|
||||
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) GetSubscribe(key string) (sub interfaces.GrpcSubscribe, err error) {
|
||||
func (svr *GrpcServer) GetSubscribe(key string) (sub interfaces.GrpcSubscribe, err error) {
|
||||
mutexSubsV2.Lock()
|
||||
defer mutexSubsV2.Unlock()
|
||||
sub, ok := subsV2[key]
|
||||
@@ -145,55 +143,25 @@ func (svr *GrpcServerV2) GetSubscribe(key string) (sub interfaces.GrpcSubscribe,
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) SetSubscribe(key string, sub interfaces.GrpcSubscribe) {
|
||||
func (svr *GrpcServer) SetSubscribe(key string, sub interfaces.GrpcSubscribe) {
|
||||
mutexSubsV2.Lock()
|
||||
defer mutexSubsV2.Unlock()
|
||||
subsV2[key] = sub
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) DeleteSubscribe(key string) {
|
||||
func (svr *GrpcServer) DeleteSubscribe(key string) {
|
||||
mutexSubsV2.Lock()
|
||||
defer mutexSubsV2.Unlock()
|
||||
delete(subsV2, key)
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) SendStreamMessage(key string, code grpc2.StreamMessageCode) (err error) {
|
||||
return svr.SendStreamMessageWithData(key, code, nil)
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) SendStreamMessageWithData(key string, code grpc2.StreamMessageCode, d interface{}) (err error) {
|
||||
var data []byte
|
||||
switch d.(type) {
|
||||
case types.Nil:
|
||||
// do nothing
|
||||
case []byte:
|
||||
data = d.([]byte)
|
||||
default:
|
||||
var err error
|
||||
data, err = json.Marshal(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
sub, err := svr.GetSubscribe(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := &grpc2.StreamMessage{
|
||||
Code: code,
|
||||
Key: svr.nodeCfgSvc.GetNodeKey(),
|
||||
Data: data,
|
||||
}
|
||||
return sub.GetStream().Send(msg)
|
||||
}
|
||||
|
||||
func (svr *GrpcServerV2) IsStopped() (res bool) {
|
||||
func (svr *GrpcServer) IsStopped() (res bool) {
|
||||
return svr.stopped
|
||||
}
|
||||
|
||||
func NewGrpcServerV2() (svr *GrpcServerV2, err error) {
|
||||
func NewGrpcServerV2() (svr *GrpcServer, err error) {
|
||||
// server
|
||||
svr = &GrpcServerV2{
|
||||
svr = &GrpcServer{
|
||||
address: entity.NewAddress(&entity.AddressOptions{
|
||||
Host: constants.DefaultGrpcServerHost,
|
||||
Port: constants.DefaultGrpcServerPort,
|
||||
@@ -209,17 +177,17 @@ func NewGrpcServerV2() (svr *GrpcServerV2, err error) {
|
||||
|
||||
svr.nodeCfgSvc = nodeconfig.GetNodeConfigService()
|
||||
|
||||
svr.NodeSvr, err = NewNodeServerV2()
|
||||
svr.NodeSvr, err = NewNodeServiceServer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
svr.ModelBaseServiceSvr = NewModelBaseServiceV2Server()
|
||||
svr.TaskSvr, err = NewTaskServerV2()
|
||||
svr.TaskSvr, err = NewTaskServiceServer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
svr.DependenciesSvr = GetDependenciesServerV2()
|
||||
svr.MetricsSvr = GetMetricsServerV2()
|
||||
svr.DependencySvr = GetDependenciesServerV2()
|
||||
svr.MetricSvr = GetMetricsServerV2()
|
||||
|
||||
// recovery options
|
||||
recoveryOpts := []grpc_recovery.Option{
|
||||
@@ -246,9 +214,9 @@ func NewGrpcServerV2() (svr *GrpcServerV2, err error) {
|
||||
return svr, nil
|
||||
}
|
||||
|
||||
var _serverV2 *GrpcServerV2
|
||||
var _serverV2 *GrpcServer
|
||||
|
||||
func GetGrpcServerV2() (svr *GrpcServerV2, err error) {
|
||||
func GetGrpcServerV2() (svr *GrpcServer, err error) {
|
||||
if _serverV2 != nil {
|
||||
return _serverV2, nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
@@ -22,9 +21,12 @@ import (
|
||||
mongo2 "go.mongodb.org/mongo-driver/mongo"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TaskServerV2 struct {
|
||||
var taskServiceMutex = sync.Mutex{}
|
||||
|
||||
type TaskServiceServer struct {
|
||||
grpc.UnimplementedTaskServiceServer
|
||||
|
||||
// dependencies
|
||||
@@ -33,10 +35,52 @@ type TaskServerV2 struct {
|
||||
|
||||
// internals
|
||||
server interfaces.GrpcServer
|
||||
subs map[primitive.ObjectID]grpc.TaskService_SubscribeServer
|
||||
}
|
||||
|
||||
// Subscribe to task stream when a task runner in a node starts
|
||||
func (svr TaskServerV2) Subscribe(stream grpc.TaskService_SubscribeServer) (err error) {
|
||||
func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, stream grpc.TaskService_SubscribeServer) (err error) {
|
||||
// task id
|
||||
taskId, err := primitive.ObjectIDFromHex(req.TaskId)
|
||||
if err != nil {
|
||||
return errors.New("invalid task id")
|
||||
}
|
||||
|
||||
// validate stream
|
||||
if stream == nil {
|
||||
return errors.New("invalid stream")
|
||||
}
|
||||
|
||||
// add stream
|
||||
taskServiceMutex.Lock()
|
||||
svr.subs[taskId] = stream
|
||||
taskServiceMutex.Unlock()
|
||||
|
||||
// create a new goroutine to receive messages from the stream to listen for EOF (end of stream)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
taskServiceMutex.Lock()
|
||||
delete(svr.subs, taskId)
|
||||
taskServiceMutex.Unlock()
|
||||
return
|
||||
default:
|
||||
err := stream.RecvMsg(nil)
|
||||
if err == io.EOF {
|
||||
taskServiceMutex.Lock()
|
||||
delete(svr.subs, taskId)
|
||||
taskServiceMutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect to task stream when a task runner in a node starts
|
||||
func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err error) {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
@@ -49,11 +93,19 @@ func (svr TaskServerV2) Subscribe(stream grpc.TaskService_SubscribeServer) (err
|
||||
trace.PrintError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// validate task id
|
||||
taskId, err := primitive.ObjectIDFromHex(msg.TaskId)
|
||||
if err != nil {
|
||||
log.Errorf("invalid task id: %s", msg.TaskId)
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg.Code {
|
||||
case grpc.StreamMessageCode_INSERT_DATA:
|
||||
err = svr.handleInsertData(msg)
|
||||
case grpc.StreamMessageCode_INSERT_LOGS:
|
||||
err = svr.handleInsertLogs(msg)
|
||||
case grpc.TaskServiceConnectCode_INSERT_DATA:
|
||||
err = svr.handleInsertData(taskId, msg)
|
||||
case grpc.TaskServiceConnectCode_INSERT_LOGS:
|
||||
err = svr.handleInsertLogs(taskId, msg)
|
||||
default:
|
||||
err = errors.New("invalid stream message code")
|
||||
log.Errorf("invalid stream message code: %d", msg.Code)
|
||||
@@ -65,8 +117,8 @@ func (svr TaskServerV2) Subscribe(stream grpc.TaskService_SubscribeServer) (err
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch tasks to be executed by a task handler
|
||||
func (svr TaskServerV2) Fetch(ctx context.Context, request *grpc.Request) (response *grpc.Response, err error) {
|
||||
// FetchTask tasks to be executed by a task handler
|
||||
func (svr TaskServiceServer) FetchTask(ctx context.Context, request *grpc.TaskServiceFetchTaskRequest) (response *grpc.TaskServiceFetchTaskResponse, err error) {
|
||||
nodeKey := request.GetNodeKey()
|
||||
if nodeKey == "" {
|
||||
return nil, errors.New("invalid node key")
|
||||
@@ -105,10 +157,10 @@ func (svr TaskServerV2) Fetch(ctx context.Context, request *grpc.Request) (respo
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return HandleSuccessWithData(tid)
|
||||
return &grpc.TaskServiceFetchTaskResponse{TaskId: tid.Hex()}, nil
|
||||
}
|
||||
|
||||
func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.TaskServiceSendNotificationRequest) (response *grpc.Response, err error) {
|
||||
func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.TaskServiceSendNotificationRequest) (response *grpc.Response, err error) {
|
||||
if !utils.IsPro() {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -208,27 +260,32 @@ func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.TaskSe
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (svr TaskServerV2) handleInsertData(msg *grpc.StreamMessage) (err error) {
|
||||
data, err := svr.deserialize(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeServer, ok bool) {
|
||||
taskServiceMutex.Lock()
|
||||
defer taskServiceMutex.Unlock()
|
||||
stream, ok = svr.subs[taskId]
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
func (svr TaskServiceServer) handleInsertData(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
|
||||
var records []map[string]interface{}
|
||||
for _, d := range data.Records {
|
||||
records = append(records, d)
|
||||
}
|
||||
return svr.statsSvc.InsertData(data.TaskId, records...)
|
||||
}
|
||||
|
||||
func (svr TaskServerV2) handleInsertLogs(msg *grpc.StreamMessage) (err error) {
|
||||
data, err := svr.deserialize(msg)
|
||||
err = json.Unmarshal(msg.Data, &records)
|
||||
if err != nil {
|
||||
return err
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return svr.statsSvc.InsertLogs(data.TaskId, data.Logs...)
|
||||
return svr.statsSvc.InsertData(taskId, records...)
|
||||
}
|
||||
|
||||
func (svr TaskServerV2) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.FindOptions, nid primitive.ObjectID) (tid primitive.ObjectID, err error) {
|
||||
func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
|
||||
var logs []string
|
||||
err = json.Unmarshal(msg.Data, &logs)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return svr.statsSvc.InsertLogs(taskId, logs...)
|
||||
}
|
||||
|
||||
func (svr TaskServiceServer) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.FindOptions, nid primitive.ObjectID) (tid primitive.ObjectID, err error) {
|
||||
tq, err := service.NewModelServiceV2[models2.TaskQueueItemV2]().GetOne(query, opts)
|
||||
if err != nil {
|
||||
if errors.Is(err, mongo2.ErrNoDocuments) {
|
||||
@@ -251,19 +308,9 @@ func (svr TaskServerV2) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.F
|
||||
return tq.Id, nil
|
||||
}
|
||||
|
||||
func (svr TaskServerV2) deserialize(msg *grpc.StreamMessage) (data entity.StreamMessageTaskData, err error) {
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
return data, trace.TraceError(err)
|
||||
}
|
||||
if data.TaskId.IsZero() {
|
||||
return data, errors.New("invalid task id")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func NewTaskServerV2() (res *TaskServerV2, err error) {
|
||||
func NewTaskServiceServer() (res *TaskServiceServer, err error) {
|
||||
// task server
|
||||
svr := &TaskServerV2{}
|
||||
svr := &TaskServiceServer{}
|
||||
|
||||
svr.cfgSvc = nodeconfig.GetNodeConfigService()
|
||||
|
||||
@@ -9,6 +9,4 @@ type NodeMasterService interface {
|
||||
Monitor()
|
||||
SetMonitorInterval(duration time.Duration)
|
||||
Register() error
|
||||
StopOnError()
|
||||
GetServer() GrpcServer
|
||||
}
|
||||
|
||||
@@ -3,6 +3,5 @@ package interfaces
|
||||
type NodeService interface {
|
||||
Module
|
||||
WithConfigPath
|
||||
WithAddress
|
||||
GetConfigService() NodeConfigService
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import "time"
|
||||
type NodeWorkerService interface {
|
||||
NodeService
|
||||
Register()
|
||||
Recv()
|
||||
ReportStatus()
|
||||
SetHeartbeatInterval(duration time.Duration)
|
||||
}
|
||||
|
||||
@@ -11,5 +11,4 @@ type TaskRunner interface {
|
||||
Cancel(force bool) (err error)
|
||||
SetSubscribeTimeout(timeout time.Duration)
|
||||
GetTaskId() (id primitive.ObjectID)
|
||||
CleanUp() (err error)
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ func teardownTestDB() {
|
||||
db.Drop(context.Background())
|
||||
}
|
||||
|
||||
func startSvr(svr *server.GrpcServerV2) {
|
||||
func startSvr(svr *server.GrpcServer) {
|
||||
err := svr.Start()
|
||||
if err != nil {
|
||||
log.Errorf("failed to start grpc server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func stopSvr(svr *server.GrpcServerV2) {
|
||||
func stopSvr(svr *server.GrpcServer) {
|
||||
err := svr.Stop()
|
||||
if err != nil {
|
||||
log.Errorf("failed to stop grpc server: %v", err)
|
||||
|
||||
@@ -27,10 +27,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type MasterServiceV2 struct {
|
||||
type MasterService struct {
|
||||
// dependencies
|
||||
cfgSvc interfaces.NodeConfigService
|
||||
server *server.GrpcServerV2
|
||||
server *server.GrpcServer
|
||||
schedulerSvc *scheduler.ServiceV2
|
||||
handlerSvc *handler.ServiceV2
|
||||
scheduleSvc *schedule.ServiceV2
|
||||
@@ -43,12 +43,12 @@ type MasterServiceV2 struct {
|
||||
stopOnError bool
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) Init() (err error) {
|
||||
func (svc *MasterService) Init() (err error) {
|
||||
// do nothing
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) Start() {
|
||||
func (svc *MasterService) Start() {
|
||||
// create indexes
|
||||
common.InitIndexes()
|
||||
|
||||
@@ -81,17 +81,17 @@ func (svc *MasterServiceV2) Start() {
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) Wait() {
|
||||
func (svc *MasterService) Wait() {
|
||||
utils.DefaultWait()
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) Stop() {
|
||||
func (svc *MasterService) Stop() {
|
||||
_ = svc.server.Stop()
|
||||
svc.handlerSvc.Stop()
|
||||
log.Infof("master[%s] service has stopped", svc.GetConfigService().GetNodeKey())
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) Monitor() {
|
||||
func (svc *MasterService) Monitor() {
|
||||
log.Infof("master[%s] monitoring started", svc.GetConfigService().GetNodeKey())
|
||||
|
||||
// ticker
|
||||
@@ -114,31 +114,23 @@ func (svc *MasterServiceV2) Monitor() {
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
|
||||
func (svc *MasterService) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
|
||||
return svc.cfgSvc
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) GetConfigPath() (path string) {
|
||||
func (svc *MasterService) GetConfigPath() (path string) {
|
||||
return svc.cfgPath
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) SetConfigPath(path string) {
|
||||
func (svc *MasterService) SetConfigPath(path string) {
|
||||
svc.cfgPath = path
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) GetAddress() (address interfaces.Address) {
|
||||
return svc.address
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) SetAddress(address interfaces.Address) {
|
||||
svc.address = address
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) SetMonitorInterval(duration time.Duration) {
|
||||
func (svc *MasterService) SetMonitorInterval(duration time.Duration) {
|
||||
svc.monitorInterval = duration
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) Register() (err error) {
|
||||
func (svc *MasterService) Register() (err error) {
|
||||
nodeKey := svc.GetConfigService().GetNodeKey()
|
||||
nodeName := svc.GetConfigService().GetNodeName()
|
||||
node, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": nodeKey}, nil)
|
||||
@@ -181,15 +173,7 @@ func (svc *MasterServiceV2) Register() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) StopOnError() {
|
||||
svc.stopOnError = true
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) GetServer() (svr interfaces.GrpcServer) {
|
||||
return svc.server
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) monitor() (err error) {
|
||||
func (svc *MasterService) monitor() (err error) {
|
||||
// update master node status in db
|
||||
if err := svc.updateMasterNodeStatus(); err != nil {
|
||||
if err.Error() == mongo2.ErrNoDocuments.Error() {
|
||||
@@ -238,7 +222,7 @@ func (svc *MasterServiceV2) monitor() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) getAllWorkerNodes() (nodes []models2.NodeV2, err error) {
|
||||
func (svc *MasterService) getAllWorkerNodes() (nodes []models2.NodeV2, err error) {
|
||||
query := bson.M{
|
||||
"key": bson.M{"$ne": svc.cfgSvc.GetNodeKey()}, // not self
|
||||
"active": true, // active
|
||||
@@ -253,7 +237,7 @@ func (svc *MasterServiceV2) getAllWorkerNodes() (nodes []models2.NodeV2, err err
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) updateMasterNodeStatus() (err error) {
|
||||
func (svc *MasterService) updateMasterNodeStatus() (err error) {
|
||||
nodeKey := svc.GetConfigService().GetNodeKey()
|
||||
node, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": nodeKey}, nil)
|
||||
if err != nil {
|
||||
@@ -280,7 +264,7 @@ func (svc *MasterServiceV2) updateMasterNodeStatus() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) setWorkerNodeOffline(node *models2.NodeV2) {
|
||||
func (svc *MasterService) setWorkerNodeOffline(node *models2.NodeV2) {
|
||||
node.Status = constants.NodeStatusOffline
|
||||
node.Active = false
|
||||
err := backoff.Retry(func() error {
|
||||
@@ -292,24 +276,28 @@ func (svc *MasterServiceV2) setWorkerNodeOffline(node *models2.NodeV2) {
|
||||
svc.sendNotification(node)
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) subscribeNode(n *models2.NodeV2) (ok bool) {
|
||||
_, err := svc.server.GetSubscribe("node:" + n.Key)
|
||||
func (svc *MasterService) subscribeNode(n *models2.NodeV2) (ok bool) {
|
||||
_, ok = svc.server.NodeSvr.GetSubscribeStream(n.Id)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (svc *MasterService) pingNodeClient(n *models2.NodeV2) (ok bool) {
|
||||
stream, ok := svc.server.NodeSvr.GetSubscribeStream(n.Id)
|
||||
if !ok {
|
||||
log.Errorf("cannot get worker node client[%s]", n.Key)
|
||||
return false
|
||||
}
|
||||
err := stream.Send(&grpc.NodeServiceSubscribeResponse{
|
||||
Code: grpc.NodeServiceSubscribeCode_PING,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("cannot subscribe worker node[%s]: %v", n.Key, err)
|
||||
log.Errorf("failed to ping worker node client[%s]: %v", n.Key, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) pingNodeClient(n *models2.NodeV2) (ok bool) {
|
||||
if err := svc.server.SendStreamMessage("node:"+n.Key, grpc.StreamMessageCode_PING); err != nil {
|
||||
log.Errorf("cannot ping worker node client[%s]: %v", n.Key, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) updateNodeAvailableRunners(node *models2.NodeV2) (err error) {
|
||||
func (svc *MasterService) updateNodeAvailableRunners(node *models2.NodeV2) (err error) {
|
||||
query := bson.M{
|
||||
"node_id": node.Id,
|
||||
"status": constants.TaskStatusRunning,
|
||||
@@ -326,16 +314,16 @@ func (svc *MasterServiceV2) updateNodeAvailableRunners(node *models2.NodeV2) (er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *MasterServiceV2) sendNotification(node *models2.NodeV2) {
|
||||
func (svc *MasterService) sendNotification(node *models2.NodeV2) {
|
||||
if !utils.IsPro() {
|
||||
return
|
||||
}
|
||||
go notification.GetNotificationServiceV2().SendNodeNotification(node)
|
||||
}
|
||||
|
||||
func newMasterServiceV2() (res *MasterServiceV2, err error) {
|
||||
func newMasterServiceV2() (res *MasterService, err error) {
|
||||
// master service
|
||||
svc := &MasterServiceV2{
|
||||
svc := &MasterService{
|
||||
cfgPath: config2.GetConfigPath(),
|
||||
monitorInterval: 15 * time.Second,
|
||||
stopOnError: false,
|
||||
@@ -379,15 +367,15 @@ func newMasterServiceV2() (res *MasterServiceV2, err error) {
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
var masterServiceV2 *MasterServiceV2
|
||||
var masterServiceV2Once = new(sync.Once)
|
||||
var masterService *MasterService
|
||||
var masterServiceOnce = new(sync.Once)
|
||||
|
||||
func GetMasterServiceV2() (res *MasterServiceV2, err error) {
|
||||
masterServiceV2Once.Do(func() {
|
||||
masterServiceV2, err = newMasterServiceV2()
|
||||
func GetMasterService() (res *MasterService, err error) {
|
||||
masterServiceOnce.Do(func() {
|
||||
masterService, err = newMasterServiceV2()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get master service: %v", err)
|
||||
}
|
||||
})
|
||||
return masterServiceV2, err
|
||||
return masterService, err
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/config"
|
||||
"github.com/crawlab-team/crawlab/core/grpc/client"
|
||||
@@ -15,11 +15,12 @@ import (
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"github.com/crawlab-team/crawlab/trace"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WorkerServiceV2 struct {
|
||||
type WorkerService struct {
|
||||
// dependencies
|
||||
cfgSvc interfaces.NodeConfigService
|
||||
client *client.GrpcClientV2
|
||||
@@ -31,16 +32,17 @@ type WorkerServiceV2 struct {
|
||||
heartbeatInterval time.Duration
|
||||
|
||||
// internals
|
||||
n *models.NodeV2
|
||||
s grpc.NodeService_SubscribeClient
|
||||
stopped bool
|
||||
n *models.NodeV2
|
||||
s grpc.NodeService_SubscribeClient
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) Init() (err error) {
|
||||
func (svc *WorkerService) Init() (err error) {
|
||||
// do nothing
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) Start() {
|
||||
func (svc *WorkerService) Start() {
|
||||
// start grpc client
|
||||
if err := svc.client.Start(); err != nil {
|
||||
panic(err)
|
||||
@@ -49,8 +51,8 @@ func (svc *WorkerServiceV2) Start() {
|
||||
// register to master
|
||||
svc.Register()
|
||||
|
||||
// start receiving stream messages
|
||||
go svc.Recv()
|
||||
// subscribe
|
||||
svc.Subscribe()
|
||||
|
||||
// start sending heartbeat to master
|
||||
go svc.ReportStatus()
|
||||
@@ -65,24 +67,23 @@ func (svc *WorkerServiceV2) Start() {
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) Wait() {
|
||||
func (svc *WorkerService) Wait() {
|
||||
utils.DefaultWait()
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) Stop() {
|
||||
func (svc *WorkerService) Stop() {
|
||||
svc.stopped = true
|
||||
_ = svc.client.Stop()
|
||||
svc.handlerSvc.Stop()
|
||||
log.Infof("worker[%s] service has stopped", svc.cfgSvc.GetNodeKey())
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) Register() {
|
||||
func (svc *WorkerService) Register() {
|
||||
ctx, cancel := svc.client.Context()
|
||||
defer cancel()
|
||||
_, err := svc.client.NodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{
|
||||
Key: svc.cfgSvc.GetNodeKey(),
|
||||
Name: svc.cfgSvc.GetNodeName(),
|
||||
IsMaster: svc.cfgSvc.IsMaster(),
|
||||
AuthKey: svc.cfgSvc.GetAuthKey(),
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
NodeName: svc.cfgSvc.GetNodeName(),
|
||||
MaxRunners: int32(svc.cfgSvc.GetMaxRunners()),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -96,56 +97,7 @@ func (svc *WorkerServiceV2) Register() {
|
||||
return
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) Recv() {
|
||||
msgCh := svc.client.GetMessageChannel()
|
||||
for {
|
||||
// return if client is closed
|
||||
if svc.client.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
// receive message from channel
|
||||
msg := <-msgCh
|
||||
|
||||
// handle message
|
||||
if err := svc.handleStreamMessage(msg); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) handleStreamMessage(msg *grpc.StreamMessage) (err error) {
|
||||
log.Debugf("[WorkerServiceV2] handle msg: %v", msg)
|
||||
switch msg.Code {
|
||||
case grpc.StreamMessageCode_PING:
|
||||
_, err := svc.client.NodeClient.SendHeartbeat(context.Background(), &grpc.NodeServiceSendHeartbeatRequest{
|
||||
Key: svc.cfgSvc.GetNodeKey(),
|
||||
})
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
case grpc.StreamMessageCode_RUN_TASK:
|
||||
var t models.TaskV2
|
||||
if err := json.Unmarshal(msg.Data, &t); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
if err := svc.handlerSvc.Run(t.Id); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
case grpc.StreamMessageCode_CANCEL_TASK:
|
||||
var t models.TaskV2
|
||||
if err := json.Unmarshal(msg.Data, &t); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
if err := svc.handlerSvc.Cancel(t.Id); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) ReportStatus() {
|
||||
func (svc *WorkerService) ReportStatus() {
|
||||
ticker := time.NewTicker(svc.heartbeatInterval)
|
||||
for {
|
||||
// return if client is closed
|
||||
@@ -162,46 +114,76 @@ func (svc *WorkerServiceV2) ReportStatus() {
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
|
||||
func (svc *WorkerService) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
|
||||
return svc.cfgSvc
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) GetConfigPath() (path string) {
|
||||
func (svc *WorkerService) GetConfigPath() (path string) {
|
||||
return svc.cfgPath
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) SetConfigPath(path string) {
|
||||
func (svc *WorkerService) SetConfigPath(path string) {
|
||||
svc.cfgPath = path
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) GetAddress() (address interfaces.Address) {
|
||||
func (svc *WorkerService) GetAddress() (address interfaces.Address) {
|
||||
return svc.address
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) SetAddress(address interfaces.Address) {
|
||||
func (svc *WorkerService) SetAddress(address interfaces.Address) {
|
||||
svc.address = address
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) SetHeartbeatInterval(duration time.Duration) {
|
||||
func (svc *WorkerService) SetHeartbeatInterval(duration time.Duration) {
|
||||
svc.heartbeatInterval = duration
|
||||
}
|
||||
|
||||
func (svc *WorkerServiceV2) reportStatus() {
|
||||
func (svc *WorkerService) Subscribe() {
|
||||
stream, err := svc.client.NodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to subscribe to master: %v", err)
|
||||
return
|
||||
}
|
||||
for {
|
||||
if svc.stopped {
|
||||
return
|
||||
}
|
||||
select {
|
||||
default:
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to receive message from master: %v", err)
|
||||
continue
|
||||
}
|
||||
switch msg.Code {
|
||||
case grpc.NodeServiceSubscribeCode_PING:
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *WorkerService) reportStatus() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), svc.heartbeatInterval)
|
||||
defer cancel()
|
||||
_, err := svc.client.NodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{
|
||||
Key: svc.cfgSvc.GetNodeKey(),
|
||||
NodeKey: svc.cfgSvc.GetNodeKey(),
|
||||
})
|
||||
if err != nil {
|
||||
trace.PrintError(err)
|
||||
}
|
||||
}
|
||||
|
||||
var workerServiceV2 *WorkerServiceV2
|
||||
var workerServiceV2 *WorkerService
|
||||
var workerServiceV2Once = new(sync.Once)
|
||||
|
||||
func newWorkerServiceV2() (res *WorkerServiceV2, err error) {
|
||||
svc := &WorkerServiceV2{
|
||||
func newWorkerService() (res *WorkerService, err error) {
|
||||
svc := &WorkerService{
|
||||
cfgPath: config.GetConfigPath(),
|
||||
heartbeatInterval: 15 * time.Second,
|
||||
}
|
||||
@@ -233,9 +215,9 @@ func newWorkerServiceV2() (res *WorkerServiceV2, err error) {
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func GetWorkerServiceV2() (res *WorkerServiceV2, err error) {
|
||||
func GetWorkerService() (res *WorkerService, err error) {
|
||||
workerServiceV2Once.Do(func() {
|
||||
workerServiceV2, err = newWorkerServiceV2()
|
||||
workerServiceV2, err = newWorkerService()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get worker service: %v", err)
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/apex/log"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/fs"
|
||||
@@ -44,16 +43,16 @@ type RunnerV2 struct {
|
||||
bufferSize int
|
||||
|
||||
// internals
|
||||
cmd *exec.Cmd // process command instance
|
||||
pid int // process id
|
||||
tid primitive.ObjectID // task id
|
||||
t *models.TaskV2 // task model.Task
|
||||
s *models.SpiderV2 // spider model.Spider
|
||||
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
|
||||
err error // standard process error
|
||||
cwd string // working directory
|
||||
c *client2.GrpcClientV2 // grpc client
|
||||
sub grpc.TaskService_SubscribeClient // grpc task service stream client
|
||||
cmd *exec.Cmd // process command instance
|
||||
pid int // process id
|
||||
tid primitive.ObjectID // task id
|
||||
t *models.TaskV2 // task model.Task
|
||||
s *models.SpiderV2 // spider model.Spider
|
||||
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
|
||||
err error // standard process error
|
||||
cwd string // working directory
|
||||
c *client2.GrpcClientV2 // grpc client
|
||||
conn grpc.TaskService_ConnectClient // grpc task service stream client
|
||||
|
||||
// log internals
|
||||
scannerStdout *bufio.Reader
|
||||
@@ -76,7 +75,7 @@ func (r *RunnerV2) Init() (err error) {
|
||||
}
|
||||
|
||||
// grpc task service stream client
|
||||
if err := r.initSub(); err != nil {
|
||||
if err := r.initConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -177,26 +176,19 @@ func (r *RunnerV2) Cancel(force bool) (err error) {
|
||||
}
|
||||
|
||||
// make sure the process does not exist
|
||||
op := func() error {
|
||||
if exists, _ := process.PidExists(int32(r.pid)); exists {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
timeout := time.After(r.svc.GetCancelTimeout())
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
|
||||
case <-ticker.C:
|
||||
if exists, _ := process.PidExists(int32(r.pid)); exists {
|
||||
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetExitWatchDuration())
|
||||
defer cancel()
|
||||
b := backoff.WithContext(backoff.NewConstantBackOff(1*time.Second), ctx)
|
||||
if err := backoff.Retry(op, b); err != nil {
|
||||
log.Errorf("Error canceling task %s: %v", r.tid, err)
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp clean up task runner
|
||||
func (r *RunnerV2) CleanUp() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RunnerV2) SetSubscribeTimeout(timeout time.Duration) {
|
||||
@@ -537,8 +529,8 @@ func (r *RunnerV2) updateTask(status string, e error) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RunnerV2) initSub() (err error) {
|
||||
r.sub, err = r.c.TaskClient.Subscribe(context.Background())
|
||||
func (r *RunnerV2) initConnection() (err error) {
|
||||
r.conn, err = r.c.TaskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
@@ -546,20 +538,18 @@ func (r *RunnerV2) initSub() (err error) {
|
||||
}
|
||||
|
||||
func (r *RunnerV2) writeLogLines(lines []string) {
|
||||
data, err := json.Marshal(&entity.StreamMessageTaskData{
|
||||
TaskId: r.tid,
|
||||
Logs: lines,
|
||||
})
|
||||
linesBytes, err := json.Marshal(lines)
|
||||
if err != nil {
|
||||
trace.PrintError(err)
|
||||
log.Errorf("Error marshaling log lines: %v", err)
|
||||
return
|
||||
}
|
||||
msg := &grpc.StreamMessage{
|
||||
Code: grpc.StreamMessageCode_INSERT_LOGS,
|
||||
Data: data,
|
||||
msg := &grpc.TaskServiceConnectRequest{
|
||||
Code: grpc.TaskServiceConnectCode_INSERT_LOGS,
|
||||
TaskId: r.tid.Hex(),
|
||||
Data: linesBytes,
|
||||
}
|
||||
if err := r.sub.Send(msg); err != nil {
|
||||
trace.PrintError(err)
|
||||
if err := r.conn.Send(msg); err != nil {
|
||||
log.Errorf("Error sending log lines: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
errors2 "github.com/crawlab-team/crawlab/core/errors"
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
nodeconfig "github.com/crawlab-team/crawlab/core/node/config"
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"github.com/crawlab-team/crawlab/trace"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -50,7 +52,7 @@ func (svc *ServiceV2) Start() {
|
||||
}
|
||||
|
||||
go svc.ReportStatus()
|
||||
go svc.Fetch()
|
||||
go svc.FetchAndRunTasks()
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Stop() {
|
||||
@@ -58,12 +60,7 @@ func (svc *ServiceV2) Stop() {
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Run(taskId primitive.ObjectID) (err error) {
|
||||
return svc.run(taskId)
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Reset() {
|
||||
svc.mu.Lock()
|
||||
defer svc.mu.Unlock()
|
||||
return svc.runTask(taskId)
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error) {
|
||||
@@ -77,7 +74,7 @@ func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Fetch() {
|
||||
func (svc *ServiceV2) FetchAndRunTasks() {
|
||||
ticker := time.NewTicker(svc.fetchInterval)
|
||||
for {
|
||||
if svc.stopped {
|
||||
@@ -102,10 +99,9 @@ func (svc *ServiceV2) Fetch() {
|
||||
continue
|
||||
}
|
||||
|
||||
// fetch task
|
||||
tid, err := svc.fetch()
|
||||
// fetch task id
|
||||
tid, err := svc.fetchTask()
|
||||
if err != nil {
|
||||
trace.PrintError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -115,8 +111,7 @@ func (svc *ServiceV2) Fetch() {
|
||||
}
|
||||
|
||||
// run task
|
||||
if err := svc.run(tid); err != nil {
|
||||
trace.PrintError(err)
|
||||
if err := svc.runTask(tid); err != nil {
|
||||
t, err := svc.GetTaskById(tid)
|
||||
if err != nil && t.Status != constants.TaskStatusCancelled {
|
||||
t.Error = err.Error()
|
||||
@@ -281,30 +276,36 @@ func (svc *ServiceV2) reportStatus() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) fetch() (tid primitive.ObjectID, err error) {
|
||||
func (svc *ServiceV2) fetchTask() (tid primitive.ObjectID, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout)
|
||||
defer cancel()
|
||||
res, err := svc.c.TaskClient.Fetch(ctx, svc.c.NewRequest(nil))
|
||||
res, err := svc.c.TaskClient.FetchTask(ctx, svc.c.NewRequest(nil))
|
||||
if err != nil {
|
||||
return tid, trace.TraceError(err)
|
||||
return primitive.NilObjectID, fmt.Errorf("fetchTask task error: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(res.Data, &tid); err != nil {
|
||||
return tid, trace.TraceError(err)
|
||||
// validate task id
|
||||
tid, err = primitive.ObjectIDFromHex(res.GetTaskId())
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, fmt.Errorf("invalid task id: %s", res.GetTaskId())
|
||||
}
|
||||
return tid, nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
|
||||
func (svc *ServiceV2) runTask(taskId primitive.ObjectID) (err error) {
|
||||
// attempt to get runner from pool
|
||||
_, ok := svc.runners.Load(taskId)
|
||||
if ok {
|
||||
return trace.TraceError(errors2.ErrorTaskAlreadyExists)
|
||||
err = fmt.Errorf("task[%s] already exists", taskId.Hex())
|
||||
log.Errorf("run task error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// create a new task runner
|
||||
r, err := NewTaskRunnerV2(taskId, svc)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
err = fmt.Errorf("failed to create task runner: %v", err)
|
||||
log.Errorf("run task error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// add runner to pool
|
||||
@@ -312,16 +313,18 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
|
||||
|
||||
// create a goroutine to run task
|
||||
go func() {
|
||||
// delete runner from pool
|
||||
defer svc.deleteRunner(r.GetTaskId())
|
||||
defer func(r interfaces.TaskRunner) {
|
||||
err := r.CleanUp()
|
||||
if err != nil {
|
||||
log.Errorf("task[%s] clean up error: %v", r.GetTaskId().Hex(), err)
|
||||
}
|
||||
}(r)
|
||||
// run task process (blocking)
|
||||
// error or finish after task runner ends
|
||||
// get subscription stream
|
||||
stopCh := make(chan struct{})
|
||||
stream, err := svc.subscribeTask(r.GetTaskId())
|
||||
if err == nil {
|
||||
// create a goroutine to handle stream messages
|
||||
go svc.handleStreamMessages(r.GetTaskId(), stream, stopCh)
|
||||
} else {
|
||||
log.Errorf("failed to subscribe task[%s]: %v", r.GetTaskId().Hex(), err)
|
||||
log.Warnf("task[%s] will not be able to receive stream messages", r.GetTaskId().Hex())
|
||||
}
|
||||
|
||||
// run task process (blocking) error or finish after task runner ends
|
||||
if err := r.Run(); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, constants.ErrTaskError):
|
||||
@@ -333,11 +336,64 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
|
||||
}
|
||||
}
|
||||
log.Infof("task[%s] finished", r.GetTaskId().Hex())
|
||||
|
||||
// send stopCh signal to stream message handler
|
||||
stopCh <- struct{}{}
|
||||
|
||||
// delete runner from pool
|
||||
svc.deleteRunner(r.GetTaskId())
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
req := &grpc.TaskServiceSubscribeRequest{
|
||||
TaskId: taskId.Hex(),
|
||||
}
|
||||
stream, err = svc.c.TaskClient.Subscribe(ctx, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) handleStreamMessages(id primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
err := stream.CloseSend()
|
||||
if err != nil {
|
||||
log.Errorf("task[%s] failed to close stream: %v", id.Hex(), err)
|
||||
return
|
||||
}
|
||||
return
|
||||
default:
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
log.Errorf("task[%s] stream error: %v", id.Hex(), err)
|
||||
continue
|
||||
}
|
||||
switch msg.Code {
|
||||
case grpc.TaskServiceSubscribeCode_CANCEL:
|
||||
log.Infof("task[%s] received cancel signal", id.Hex())
|
||||
go func() {
|
||||
if err := svc.Cancel(id, true); err != nil {
|
||||
log.Errorf("task[%s] failed to cancel: %v", id.Hex(), err)
|
||||
}
|
||||
log.Infof("task[%s] cancelled", id.Hex())
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTaskHandlerServiceV2() (svc2 *ServiceV2, err error) {
|
||||
// service
|
||||
svc := &ServiceV2{
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
type ServiceV2 struct {
|
||||
// dependencies
|
||||
nodeCfgSvc interfaces.NodeConfigService
|
||||
svr *server.GrpcServerV2
|
||||
svr *server.GrpcServer
|
||||
handlerSvc *handler.ServiceV2
|
||||
|
||||
// settings
|
||||
|
||||
Reference in New Issue
Block a user