refactor: updated grpc services

This commit is contained in:
Marvin Zhang
2024-10-30 18:42:23 +08:00
parent 789f71fd80
commit fa1433007f
64 changed files with 2704 additions and 5132 deletions

View File

@@ -104,9 +104,9 @@ func NewServerV2() (app NodeApp) {
// node service
var err error
if utils.IsMaster() {
svr.nodeSvc, err = service.GetMasterServiceV2()
svr.nodeSvc, err = service.GetMasterService()
} else {
svr.nodeSvc, err = service.GetWorkerServiceV2()
svr.nodeSvc, err = service.GetWorkerService()
}
if err != nil {
panic(err)

View File

@@ -17,29 +17,29 @@ import (
"time"
)
type DependenciesServerV2 struct {
grpc.UnimplementedDependenciesServiceV2Server
type DependencyServiceServer struct {
grpc.UnimplementedDependencyServiceV2Server
mu *sync.Mutex
streams map[string]*grpc.DependenciesServiceV2_ConnectServer
streams map[string]*grpc.DependencyServiceV2_ConnectServer
}
func (svr DependenciesServerV2) Connect(req *grpc.DependenciesServiceV2ConnectRequest, stream grpc.DependenciesServiceV2_ConnectServer) (err error) {
func (svr DependencyServiceServer) Connect(req *grpc.DependencyServiceV2ConnectRequest, stream grpc.DependencyServiceV2_ConnectServer) (err error) {
svr.mu.Lock()
svr.streams[req.NodeKey] = &stream
svr.mu.Unlock()
log.Info("[DependenciesServerV2] connected: " + req.NodeKey)
log.Info("[DependencyServiceServer] connected: " + req.NodeKey)
// Keep this scope alive because once this scope exits - the stream is closed
for {
select {
case <-stream.Context().Done():
log.Info("[DependenciesServerV2] disconnected: " + req.NodeKey)
log.Info("[DependencyServiceServer] disconnected: " + req.NodeKey)
return nil
}
}
}
func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.DependenciesServiceV2SyncRequest) (response *grpc.Response, err error) {
func (svr DependencyServiceServer) Sync(ctx context.Context, request *grpc.DependencyServiceV2SyncRequest) (response *grpc.Response, err error) {
n, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": request.NodeKey}, nil)
if err != nil {
return nil, err
@@ -51,7 +51,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
}, nil)
if err != nil {
if !errors.Is(err, mongo.ErrNoDocuments) {
log.Errorf("[DependenciesServiceV2] get dependencies from db error: %v", err)
log.Errorf("[DependencyServiceV2] get dependencies from db error: %v", err)
return nil, err
}
}
@@ -94,7 +94,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
"_id": bson.M{"$in": depIdsToDelete},
})
if err != nil {
log.Errorf("[DependenciesServerV2] delete dependencies in db error: %v", err)
log.Errorf("[DependencyServiceServer] delete dependencies in db error: %v", err)
trace.PrintError(err)
return err
}
@@ -103,7 +103,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
if len(depsToInsert) > 0 {
_, err = service.NewModelServiceV2[models2.DependencyV2]().InsertMany(depsToInsert)
if err != nil {
log.Errorf("[DependenciesServerV2] insert dependencies in db error: %v", err)
log.Errorf("[DependencyServiceServer] insert dependencies in db error: %v", err)
trace.PrintError(err)
return err
}
@@ -115,7 +115,7 @@ func (svr DependenciesServerV2) Sync(ctx context.Context, request *grpc.Dependen
return nil, err
}
func (svr DependenciesServerV2) UpdateTaskLog(stream grpc.DependenciesServiceV2_UpdateTaskLogServer) (err error) {
func (svr DependencyServiceServer) UpdateTaskLog(stream grpc.DependencyServiceV2_UpdateTaskLogServer) (err error) {
var t *models2.DependencyTaskV2
for {
req, err := stream.Recv()
@@ -152,7 +152,7 @@ func (svr DependenciesServerV2) UpdateTaskLog(stream grpc.DependenciesServiceV2_
}
}
func (svr DependenciesServerV2) GetStream(key string) (stream *grpc.DependenciesServiceV2_ConnectServer, err error) {
func (svr DependencyServiceServer) GetStream(key string) (stream *grpc.DependencyServiceV2_ConnectServer, err error) {
svr.mu.Lock()
defer svr.mu.Unlock()
stream, ok := svr.streams[key]
@@ -162,19 +162,19 @@ func (svr DependenciesServerV2) GetStream(key string) (stream *grpc.Dependencies
return stream, nil
}
func NewDependenciesServerV2() *DependenciesServerV2 {
return &DependenciesServerV2{
func NewDependencyServerV2() *DependencyServiceServer {
return &DependencyServiceServer{
mu: new(sync.Mutex),
streams: make(map[string]*grpc.DependenciesServiceV2_ConnectServer),
streams: make(map[string]*grpc.DependencyServiceV2_ConnectServer),
}
}
var depSvc *DependenciesServerV2
var depSvc *DependencyServiceServer
func GetDependenciesServerV2() *DependenciesServerV2 {
func GetDependencyServerV2() *DependencyServiceServer {
if depSvc != nil {
return depSvc
}
depSvc = NewDependenciesServerV2()
depSvc = NewDependencyServerV2()
return depSvc
}

View File

@@ -11,15 +11,15 @@ import (
"time"
)
type MetricsServerV2 struct {
grpc.UnimplementedMetricsServiceV2Server
type MetricServiceServer struct {
grpc.UnimplementedMetricServiceV2Server
}
func (svr MetricsServerV2) Send(_ context.Context, req *grpc.MetricsServiceV2SendRequest) (res *grpc.Response, err error) {
log.Info("[MetricsServerV2] received metric from node: " + req.NodeKey)
func (svr MetricServiceServer) Send(_ context.Context, req *grpc.MetricServiceV2SendRequest) (res *grpc.Response, err error) {
log.Info("[MetricServiceServer] received metric from node: " + req.NodeKey)
n, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil)
if err != nil {
log.Errorf("[MetricsServerV2] error getting node: %v", err)
log.Errorf("[MetricServiceServer] error getting node: %v", err)
return HandleError(err)
}
metric := models2.MetricV2{
@@ -42,20 +42,20 @@ func (svr MetricsServerV2) Send(_ context.Context, req *grpc.MetricsServiceV2Sen
metric.CreatedAt = time.Unix(req.Timestamp, 0)
_, err = service.NewModelServiceV2[models2.MetricV2]().InsertOne(metric)
if err != nil {
log.Errorf("[MetricsServerV2] error inserting metric: %v", err)
log.Errorf("[MetricServiceServer] error inserting metric: %v", err)
return HandleError(err)
}
return HandleSuccess()
}
func newMetricsServerV2() *MetricsServerV2 {
return &MetricsServerV2{}
func newMetricsServerV2() *MetricServiceServer {
return &MetricServiceServer{}
}
var metricsServerV2 *MetricsServerV2
var metricsServerV2 *MetricServiceServer
var metricsServerV2Once = &sync.Once{}
func GetMetricsServerV2() *MetricsServerV2 {
func GetMetricsServerV2() *MetricServiceServer {
if metricsServerV2 != nil {
return metricsServerV2
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/errors"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/models/v2"
@@ -17,36 +16,34 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"io"
"sync"
"time"
)
type NodeServerV2 struct {
var nodeServiceMutex = sync.Mutex{}
type NodeServiceServer struct {
grpc.UnimplementedNodeServiceServer
// dependencies
cfgSvc interfaces.NodeConfigService
// internals
server *GrpcServerV2
server *GrpcServer
subs map[primitive.ObjectID]grpc.NodeService_SubscribeServer
}
// Register from handler/worker to master
func (svr NodeServerV2) Register(_ context.Context, req *grpc.NodeServiceRegisterRequest) (res *grpc.Response, err error) {
// unmarshall data
if req.IsMaster {
// error: cannot register master node
return HandleError(errors.ErrorGrpcNotAllowed)
}
func (svr NodeServiceServer) Register(_ context.Context, req *grpc.NodeServiceRegisterRequest) (res *grpc.Response, err error) {
// node key
if req.Key == "" {
if req.NodeKey == "" {
return HandleError(errors.ErrorModelMissingRequiredData)
}
// find in db
var node *models.NodeV2
node, err = service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.Key}, nil)
node, err = service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil)
if err == nil {
// register existing
node.Status = constants.NodeStatusRegistered
@@ -56,12 +53,12 @@ func (svr NodeServerV2) Register(_ context.Context, req *grpc.NodeServiceRegiste
if err != nil {
return HandleError(err)
}
log.Infof("[NodeServerV2] updated worker[%s] in db. id: %s", req.Key, node.Id.Hex())
log.Infof("[NodeServiceServer] updated worker[%s] in db. id: %s", req.NodeKey, node.Id.Hex())
} else if errors2.Is(err, mongo.ErrNoDocuments) {
// register new
node = &models.NodeV2{
Key: req.Key,
Name: req.Name,
Key: req.NodeKey,
Name: req.NodeName,
Status: constants.NodeStatusRegistered,
Active: true,
ActiveAt: time.Now(),
@@ -74,21 +71,21 @@ func (svr NodeServerV2) Register(_ context.Context, req *grpc.NodeServiceRegiste
if err != nil {
return HandleError(err)
}
log.Infof("[NodeServerV2] added worker[%s] in db. id: %s", req.Key, node.Id.Hex())
log.Infof("[NodeServiceServer] added worker[%s] in db. id: %s", req.NodeKey, node.Id.Hex())
} else {
// error
return HandleError(err)
}
log.Infof("[NodeServerV2] master registered worker[%s]", req.Key)
log.Infof("[NodeServiceServer] master registered worker[%s]", req.NodeKey)
return HandleSuccessWithData(node)
}
// SendHeartbeat from worker to master
func (svr NodeServerV2) SendHeartbeat(_ context.Context, req *grpc.NodeServiceSendHeartbeatRequest) (res *grpc.Response, err error) {
func (svr NodeServiceServer) SendHeartbeat(_ context.Context, req *grpc.NodeServiceSendHeartbeatRequest) (res *grpc.Response, err error) {
// find in db
node, err := service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.Key}, nil)
node, err := service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": req.NodeKey}, nil)
if err != nil {
if errors2.Is(err, mongo.ErrNoDocuments) {
return HandleError(errors.ErrorNodeNotExists)
@@ -122,64 +119,66 @@ func (svr NodeServerV2) SendHeartbeat(_ context.Context, req *grpc.NodeServiceSe
return HandleSuccessWithData(node)
}
func (svr NodeServerV2) Subscribe(request *grpc.Request, stream grpc.NodeService_SubscribeServer) (err error) {
log.Infof("[NodeServerV2] master received subscribe request from node[%s]", request.NodeKey)
func (svr NodeServiceServer) Subscribe(request *grpc.NodeServiceSubscribeRequest, stream grpc.NodeService_SubscribeServer) (err error) {
log.Infof("[NodeServiceServer] master received subscribe request from node[%s]", request.NodeKey)
// finished channel
finished := make(chan bool)
// set subscribe
svr.server.SetSubscribe("node:"+request.NodeKey, &entity.GrpcSubscribe{
Stream: stream,
Finished: finished,
})
ctx := stream.Context()
log.Infof("[NodeServerV2] master subscribed node[%s]", request.NodeKey)
// Keep this scope alive because once this scope exits - the stream is closed
for {
select {
case <-finished:
log.Infof("[NodeServerV2] closing stream for node[%s]", request.NodeKey)
return nil
case <-ctx.Done():
log.Infof("[NodeServerV2] node[%s] has disconnected", request.NodeKey)
return nil
}
}
}
func (svr NodeServerV2) Unsubscribe(_ context.Context, req *grpc.Request) (res *grpc.Response, err error) {
sub, err := svr.server.GetSubscribe("node:" + req.NodeKey)
// find in db
node, err := service.NewModelServiceV2[models.NodeV2]().GetOne(bson.M{"key": request.NodeKey}, nil)
if err != nil {
return nil, errors.ErrorGrpcSubscribeNotExists
log.Errorf("[NodeServiceServer] error getting node: %v", err)
return err
}
select {
case sub.GetFinished() <- true:
log.Infof("unsubscribed node[%s]", req.NodeKey)
default:
// Default case is to avoid blocking in case client has already unsubscribed
}
svr.server.DeleteSubscribe(req.NodeKey)
return &grpc.Response{
Code: grpc.ResponseCode_OK,
Message: "unsubscribed successfully",
}, nil
// subscribe
nodeServiceMutex.Lock()
svr.subs[node.Id] = stream
nodeServiceMutex.Unlock()
// TODO: send notification
// create a new goroutine to receive messages from the stream to listen for EOF (end of stream)
go func() {
for {
select {
case <-stream.Context().Done():
nodeServiceMutex.Lock()
delete(svr.subs, node.Id)
nodeServiceMutex.Unlock()
return
default:
err := stream.RecvMsg(nil)
if err == io.EOF {
nodeServiceMutex.Lock()
delete(svr.subs, node.Id)
nodeServiceMutex.Unlock()
return
}
}
}
}()
return nil
}
var nodeSvrV2 *NodeServerV2
func (svr NodeServiceServer) GetSubscribeStream(nodeId primitive.ObjectID) (stream grpc.NodeService_SubscribeServer, ok bool) {
nodeServiceMutex.Lock()
defer nodeServiceMutex.Unlock()
stream, ok = svr.subs[nodeId]
return stream, ok
}
var nodeSvrV2 *NodeServiceServer
var nodeSvrV2Once = new(sync.Once)
func NewNodeServerV2() (res *NodeServerV2, err error) {
func NewNodeServiceServer() (res *NodeServiceServer, err error) {
if nodeSvrV2 != nil {
return nodeSvrV2, nil
}
nodeSvrV2Once.Do(func() {
nodeSvrV2 = &NodeServerV2{}
nodeSvrV2 = &NodeServiceServer{}
nodeSvrV2.cfgSvc = nodeconfig.GetNodeConfigService()
if err != nil {
log.Errorf("[NodeServerV2] error: %s", err.Error())
log.Errorf("[NodeServiceServer] error: %s", err.Error())
}
})
if err != nil {

View File

@@ -1,7 +1,6 @@
package server
import (
"encoding/json"
"fmt"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
@@ -17,7 +16,6 @@ import (
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
errors2 "github.com/pkg/errors"
"github.com/spf13/viper"
"go/types"
"google.golang.org/grpc"
"net"
"sync"
@@ -28,7 +26,7 @@ var (
mutexSubsV2 = &sync.Mutex{}
)
type GrpcServerV2 struct {
type GrpcServer struct {
// settings
cfgPath string
address interfaces.Address
@@ -42,22 +40,22 @@ type GrpcServerV2 struct {
nodeCfgSvc interfaces.NodeConfigService
// servers
NodeSvr *NodeServerV2
TaskSvr *TaskServerV2
NodeSvr *NodeServiceServer
TaskSvr *TaskServiceServer
ModelBaseServiceSvr *ModelBaseServiceServerV2
DependenciesSvr *DependenciesServerV2
MetricsSvr *MetricsServerV2
DependencySvr *DependencyServiceServer
MetricSvr *MetricServiceServer
}
func (svr *GrpcServerV2) GetConfigPath() (path string) {
func (svr *GrpcServer) GetConfigPath() (path string) {
return svr.cfgPath
}
func (svr *GrpcServerV2) SetConfigPath(path string) {
func (svr *GrpcServer) SetConfigPath(path string) {
svr.cfgPath = path
}
func (svr *GrpcServerV2) Init() (err error) {
func (svr *GrpcServer) Init() (err error) {
// register
if err := svr.Register(); err != nil {
return err
@@ -66,7 +64,7 @@ func (svr *GrpcServerV2) Init() (err error) {
return nil
}
func (svr *GrpcServerV2) Start() (err error) {
func (svr *GrpcServer) Start() (err error) {
// grpc server binding address
address := svr.address.String()
@@ -92,7 +90,7 @@ func (svr *GrpcServerV2) Start() (err error) {
return nil
}
func (svr *GrpcServerV2) Stop() (err error) {
func (svr *GrpcServer) Stop() (err error) {
// skip if listener is nil
if svr.l == nil {
return nil
@@ -115,27 +113,27 @@ func (svr *GrpcServerV2) Stop() (err error) {
return nil
}
func (svr *GrpcServerV2) Register() (err error) {
func (svr *GrpcServer) Register() (err error) {
grpc2.RegisterNodeServiceServer(svr.svr, *svr.NodeSvr)
grpc2.RegisterModelBaseServiceV2Server(svr.svr, *svr.ModelBaseServiceSvr)
grpc2.RegisterTaskServiceServer(svr.svr, *svr.TaskSvr)
grpc2.RegisterDependenciesServiceV2Server(svr.svr, *svr.DependenciesSvr)
grpc2.RegisterMetricsServiceV2Server(svr.svr, *svr.MetricsSvr)
grpc2.RegisterDependencyServiceV2Server(svr.svr, *svr.DependencySvr)
grpc2.RegisterMetricServiceV2Server(svr.svr, *svr.MetricSvr)
return nil
}
func (svr *GrpcServerV2) recoveryHandlerFunc(p interface{}) (err error) {
func (svr *GrpcServer) recoveryHandlerFunc(p interface{}) (err error) {
err = errors.NewError(errors.ErrorPrefixGrpc, fmt.Sprintf("%v", p))
trace.PrintError(err)
return err
}
func (svr *GrpcServerV2) SetAddress(address interfaces.Address) {
func (svr *GrpcServer) SetAddress(address interfaces.Address) {
}
func (svr *GrpcServerV2) GetSubscribe(key string) (sub interfaces.GrpcSubscribe, err error) {
func (svr *GrpcServer) GetSubscribe(key string) (sub interfaces.GrpcSubscribe, err error) {
mutexSubsV2.Lock()
defer mutexSubsV2.Unlock()
sub, ok := subsV2[key]
@@ -145,55 +143,25 @@ func (svr *GrpcServerV2) GetSubscribe(key string) (sub interfaces.GrpcSubscribe,
return sub, nil
}
func (svr *GrpcServerV2) SetSubscribe(key string, sub interfaces.GrpcSubscribe) {
func (svr *GrpcServer) SetSubscribe(key string, sub interfaces.GrpcSubscribe) {
mutexSubsV2.Lock()
defer mutexSubsV2.Unlock()
subsV2[key] = sub
}
func (svr *GrpcServerV2) DeleteSubscribe(key string) {
func (svr *GrpcServer) DeleteSubscribe(key string) {
mutexSubsV2.Lock()
defer mutexSubsV2.Unlock()
delete(subsV2, key)
}
func (svr *GrpcServerV2) SendStreamMessage(key string, code grpc2.StreamMessageCode) (err error) {
return svr.SendStreamMessageWithData(key, code, nil)
}
func (svr *GrpcServerV2) SendStreamMessageWithData(key string, code grpc2.StreamMessageCode, d interface{}) (err error) {
var data []byte
switch d.(type) {
case types.Nil:
// do nothing
case []byte:
data = d.([]byte)
default:
var err error
data, err = json.Marshal(d)
if err != nil {
return err
}
}
sub, err := svr.GetSubscribe(key)
if err != nil {
return err
}
msg := &grpc2.StreamMessage{
Code: code,
Key: svr.nodeCfgSvc.GetNodeKey(),
Data: data,
}
return sub.GetStream().Send(msg)
}
func (svr *GrpcServerV2) IsStopped() (res bool) {
func (svr *GrpcServer) IsStopped() (res bool) {
return svr.stopped
}
func NewGrpcServerV2() (svr *GrpcServerV2, err error) {
func NewGrpcServerV2() (svr *GrpcServer, err error) {
// server
svr = &GrpcServerV2{
svr = &GrpcServer{
address: entity.NewAddress(&entity.AddressOptions{
Host: constants.DefaultGrpcServerHost,
Port: constants.DefaultGrpcServerPort,
@@ -209,17 +177,17 @@ func NewGrpcServerV2() (svr *GrpcServerV2, err error) {
svr.nodeCfgSvc = nodeconfig.GetNodeConfigService()
svr.NodeSvr, err = NewNodeServerV2()
svr.NodeSvr, err = NewNodeServiceServer()
if err != nil {
return nil, err
}
svr.ModelBaseServiceSvr = NewModelBaseServiceV2Server()
svr.TaskSvr, err = NewTaskServerV2()
svr.TaskSvr, err = NewTaskServiceServer()
if err != nil {
return nil, err
}
svr.DependenciesSvr = GetDependenciesServerV2()
svr.MetricsSvr = GetMetricsServerV2()
svr.DependencySvr = GetDependenciesServerV2()
svr.MetricSvr = GetMetricsServerV2()
// recovery options
recoveryOpts := []grpc_recovery.Option{
@@ -246,9 +214,9 @@ func NewGrpcServerV2() (svr *GrpcServerV2, err error) {
return svr, nil
}
var _serverV2 *GrpcServerV2
var _serverV2 *GrpcServer
func GetGrpcServerV2() (svr *GrpcServerV2, err error) {
func GetGrpcServerV2() (svr *GrpcServer, err error) {
if _serverV2 != nil {
return _serverV2, nil
}

View File

@@ -6,7 +6,6 @@ import (
"errors"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/interfaces"
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
"github.com/crawlab-team/crawlab/core/models/service"
@@ -22,9 +21,12 @@ import (
mongo2 "go.mongodb.org/mongo-driver/mongo"
"io"
"strings"
"sync"
)
type TaskServerV2 struct {
var taskServiceMutex = sync.Mutex{}
type TaskServiceServer struct {
grpc.UnimplementedTaskServiceServer
// dependencies
@@ -33,10 +35,52 @@ type TaskServerV2 struct {
// internals
server interfaces.GrpcServer
subs map[primitive.ObjectID]grpc.TaskService_SubscribeServer
}
// Subscribe to task stream when a task runner in a node starts
func (svr TaskServerV2) Subscribe(stream grpc.TaskService_SubscribeServer) (err error) {
func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, stream grpc.TaskService_SubscribeServer) (err error) {
// task id
taskId, err := primitive.ObjectIDFromHex(req.TaskId)
if err != nil {
return errors.New("invalid task id")
}
// validate stream
if stream == nil {
return errors.New("invalid stream")
}
// add stream
taskServiceMutex.Lock()
svr.subs[taskId] = stream
taskServiceMutex.Unlock()
// create a new goroutine to receive messages from the stream to listen for EOF (end of stream)
go func() {
for {
select {
case <-stream.Context().Done():
taskServiceMutex.Lock()
delete(svr.subs, taskId)
taskServiceMutex.Unlock()
return
default:
err := stream.RecvMsg(nil)
if err == io.EOF {
taskServiceMutex.Lock()
delete(svr.subs, taskId)
taskServiceMutex.Unlock()
return
}
}
}
}()
return nil
}
// Connect to task stream when a task runner in a node starts
func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err error) {
for {
msg, err := stream.Recv()
if err == io.EOF {
@@ -49,11 +93,19 @@ func (svr TaskServerV2) Subscribe(stream grpc.TaskService_SubscribeServer) (err
trace.PrintError(err)
continue
}
// validate task id
taskId, err := primitive.ObjectIDFromHex(msg.TaskId)
if err != nil {
log.Errorf("invalid task id: %s", msg.TaskId)
continue
}
switch msg.Code {
case grpc.StreamMessageCode_INSERT_DATA:
err = svr.handleInsertData(msg)
case grpc.StreamMessageCode_INSERT_LOGS:
err = svr.handleInsertLogs(msg)
case grpc.TaskServiceConnectCode_INSERT_DATA:
err = svr.handleInsertData(taskId, msg)
case grpc.TaskServiceConnectCode_INSERT_LOGS:
err = svr.handleInsertLogs(taskId, msg)
default:
err = errors.New("invalid stream message code")
log.Errorf("invalid stream message code: %d", msg.Code)
@@ -65,8 +117,8 @@ func (svr TaskServerV2) Subscribe(stream grpc.TaskService_SubscribeServer) (err
}
}
// Fetch tasks to be executed by a task handler
func (svr TaskServerV2) Fetch(ctx context.Context, request *grpc.Request) (response *grpc.Response, err error) {
// FetchTask tasks to be executed by a task handler
func (svr TaskServiceServer) FetchTask(ctx context.Context, request *grpc.TaskServiceFetchTaskRequest) (response *grpc.TaskServiceFetchTaskResponse, err error) {
nodeKey := request.GetNodeKey()
if nodeKey == "" {
return nil, errors.New("invalid node key")
@@ -105,10 +157,10 @@ func (svr TaskServerV2) Fetch(ctx context.Context, request *grpc.Request) (respo
}); err != nil {
return nil, err
}
return HandleSuccessWithData(tid)
return &grpc.TaskServiceFetchTaskResponse{TaskId: tid.Hex()}, nil
}
func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.TaskServiceSendNotificationRequest) (response *grpc.Response, err error) {
func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.TaskServiceSendNotificationRequest) (response *grpc.Response, err error) {
if !utils.IsPro() {
return nil, nil
}
@@ -208,27 +260,32 @@ func (svr TaskServerV2) SendNotification(_ context.Context, request *grpc.TaskSe
return nil, nil
}
func (svr TaskServerV2) handleInsertData(msg *grpc.StreamMessage) (err error) {
data, err := svr.deserialize(msg)
if err != nil {
return err
}
func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeServer, ok bool) {
taskServiceMutex.Lock()
defer taskServiceMutex.Unlock()
stream, ok = svr.subs[taskId]
return stream, ok
}
func (svr TaskServiceServer) handleInsertData(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
var records []map[string]interface{}
for _, d := range data.Records {
records = append(records, d)
}
return svr.statsSvc.InsertData(data.TaskId, records...)
}
func (svr TaskServerV2) handleInsertLogs(msg *grpc.StreamMessage) (err error) {
data, err := svr.deserialize(msg)
err = json.Unmarshal(msg.Data, &records)
if err != nil {
return err
return trace.TraceError(err)
}
return svr.statsSvc.InsertLogs(data.TaskId, data.Logs...)
return svr.statsSvc.InsertData(taskId, records...)
}
func (svr TaskServerV2) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.FindOptions, nid primitive.ObjectID) (tid primitive.ObjectID, err error) {
func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
var logs []string
err = json.Unmarshal(msg.Data, &logs)
if err != nil {
return trace.TraceError(err)
}
return svr.statsSvc.InsertLogs(taskId, logs...)
}
func (svr TaskServiceServer) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.FindOptions, nid primitive.ObjectID) (tid primitive.ObjectID, err error) {
tq, err := service.NewModelServiceV2[models2.TaskQueueItemV2]().GetOne(query, opts)
if err != nil {
if errors.Is(err, mongo2.ErrNoDocuments) {
@@ -251,19 +308,9 @@ func (svr TaskServerV2) getTaskQueueItemIdAndDequeue(query bson.M, opts *mongo.F
return tq.Id, nil
}
func (svr TaskServerV2) deserialize(msg *grpc.StreamMessage) (data entity.StreamMessageTaskData, err error) {
if err := json.Unmarshal(msg.Data, &data); err != nil {
return data, trace.TraceError(err)
}
if data.TaskId.IsZero() {
return data, errors.New("invalid task id")
}
return data, nil
}
func NewTaskServerV2() (res *TaskServerV2, err error) {
func NewTaskServiceServer() (res *TaskServiceServer, err error) {
// task server
svr := &TaskServerV2{}
svr := &TaskServiceServer{}
svr.cfgSvc = nodeconfig.GetNodeConfigService()

View File

@@ -9,6 +9,4 @@ type NodeMasterService interface {
Monitor()
SetMonitorInterval(duration time.Duration)
Register() error
StopOnError()
GetServer() GrpcServer
}

View File

@@ -3,6 +3,5 @@ package interfaces
type NodeService interface {
Module
WithConfigPath
WithAddress
GetConfigService() NodeConfigService
}

View File

@@ -5,7 +5,6 @@ import "time"
type NodeWorkerService interface {
NodeService
Register()
Recv()
ReportStatus()
SetHeartbeatInterval(duration time.Duration)
}

View File

@@ -11,5 +11,4 @@ type TaskRunner interface {
Cancel(force bool) (err error)
SetSubscribeTimeout(timeout time.Duration)
GetTaskId() (id primitive.ObjectID)
CleanUp() (err error)
}

View File

@@ -29,14 +29,14 @@ func teardownTestDB() {
db.Drop(context.Background())
}
func startSvr(svr *server.GrpcServerV2) {
func startSvr(svr *server.GrpcServer) {
err := svr.Start()
if err != nil {
log.Errorf("failed to start grpc server: %v", err)
}
}
func stopSvr(svr *server.GrpcServerV2) {
func stopSvr(svr *server.GrpcServer) {
err := svr.Stop()
if err != nil {
log.Errorf("failed to stop grpc server: %v", err)

View File

@@ -27,10 +27,10 @@ import (
"time"
)
type MasterServiceV2 struct {
type MasterService struct {
// dependencies
cfgSvc interfaces.NodeConfigService
server *server.GrpcServerV2
server *server.GrpcServer
schedulerSvc *scheduler.ServiceV2
handlerSvc *handler.ServiceV2
scheduleSvc *schedule.ServiceV2
@@ -43,12 +43,12 @@ type MasterServiceV2 struct {
stopOnError bool
}
func (svc *MasterServiceV2) Init() (err error) {
func (svc *MasterService) Init() (err error) {
// do nothing
return nil
}
func (svc *MasterServiceV2) Start() {
func (svc *MasterService) Start() {
// create indexes
common.InitIndexes()
@@ -81,17 +81,17 @@ func (svc *MasterServiceV2) Start() {
svc.Stop()
}
func (svc *MasterServiceV2) Wait() {
func (svc *MasterService) Wait() {
utils.DefaultWait()
}
func (svc *MasterServiceV2) Stop() {
func (svc *MasterService) Stop() {
_ = svc.server.Stop()
svc.handlerSvc.Stop()
log.Infof("master[%s] service has stopped", svc.GetConfigService().GetNodeKey())
}
func (svc *MasterServiceV2) Monitor() {
func (svc *MasterService) Monitor() {
log.Infof("master[%s] monitoring started", svc.GetConfigService().GetNodeKey())
// ticker
@@ -114,31 +114,23 @@ func (svc *MasterServiceV2) Monitor() {
}
}
func (svc *MasterServiceV2) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
func (svc *MasterService) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
return svc.cfgSvc
}
func (svc *MasterServiceV2) GetConfigPath() (path string) {
func (svc *MasterService) GetConfigPath() (path string) {
return svc.cfgPath
}
func (svc *MasterServiceV2) SetConfigPath(path string) {
func (svc *MasterService) SetConfigPath(path string) {
svc.cfgPath = path
}
func (svc *MasterServiceV2) GetAddress() (address interfaces.Address) {
return svc.address
}
func (svc *MasterServiceV2) SetAddress(address interfaces.Address) {
svc.address = address
}
func (svc *MasterServiceV2) SetMonitorInterval(duration time.Duration) {
func (svc *MasterService) SetMonitorInterval(duration time.Duration) {
svc.monitorInterval = duration
}
func (svc *MasterServiceV2) Register() (err error) {
func (svc *MasterService) Register() (err error) {
nodeKey := svc.GetConfigService().GetNodeKey()
nodeName := svc.GetConfigService().GetNodeName()
node, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": nodeKey}, nil)
@@ -181,15 +173,7 @@ func (svc *MasterServiceV2) Register() (err error) {
}
}
func (svc *MasterServiceV2) StopOnError() {
svc.stopOnError = true
}
func (svc *MasterServiceV2) GetServer() (svr interfaces.GrpcServer) {
return svc.server
}
func (svc *MasterServiceV2) monitor() (err error) {
func (svc *MasterService) monitor() (err error) {
// update master node status in db
if err := svc.updateMasterNodeStatus(); err != nil {
if err.Error() == mongo2.ErrNoDocuments.Error() {
@@ -238,7 +222,7 @@ func (svc *MasterServiceV2) monitor() (err error) {
return nil
}
func (svc *MasterServiceV2) getAllWorkerNodes() (nodes []models2.NodeV2, err error) {
func (svc *MasterService) getAllWorkerNodes() (nodes []models2.NodeV2, err error) {
query := bson.M{
"key": bson.M{"$ne": svc.cfgSvc.GetNodeKey()}, // not self
"active": true, // active
@@ -253,7 +237,7 @@ func (svc *MasterServiceV2) getAllWorkerNodes() (nodes []models2.NodeV2, err err
return nodes, nil
}
func (svc *MasterServiceV2) updateMasterNodeStatus() (err error) {
func (svc *MasterService) updateMasterNodeStatus() (err error) {
nodeKey := svc.GetConfigService().GetNodeKey()
node, err := service.NewModelServiceV2[models2.NodeV2]().GetOne(bson.M{"key": nodeKey}, nil)
if err != nil {
@@ -280,7 +264,7 @@ func (svc *MasterServiceV2) updateMasterNodeStatus() (err error) {
return nil
}
func (svc *MasterServiceV2) setWorkerNodeOffline(node *models2.NodeV2) {
func (svc *MasterService) setWorkerNodeOffline(node *models2.NodeV2) {
node.Status = constants.NodeStatusOffline
node.Active = false
err := backoff.Retry(func() error {
@@ -292,24 +276,28 @@ func (svc *MasterServiceV2) setWorkerNodeOffline(node *models2.NodeV2) {
svc.sendNotification(node)
}
func (svc *MasterServiceV2) subscribeNode(n *models2.NodeV2) (ok bool) {
_, err := svc.server.GetSubscribe("node:" + n.Key)
func (svc *MasterService) subscribeNode(n *models2.NodeV2) (ok bool) {
_, ok = svc.server.NodeSvr.GetSubscribeStream(n.Id)
return ok
}
func (svc *MasterService) pingNodeClient(n *models2.NodeV2) (ok bool) {
stream, ok := svc.server.NodeSvr.GetSubscribeStream(n.Id)
if !ok {
log.Errorf("cannot get worker node client[%s]", n.Key)
return false
}
err := stream.Send(&grpc.NodeServiceSubscribeResponse{
Code: grpc.NodeServiceSubscribeCode_PING,
})
if err != nil {
log.Errorf("cannot subscribe worker node[%s]: %v", n.Key, err)
log.Errorf("failed to ping worker node client[%s]: %v", n.Key, err)
return false
}
return true
}
func (svc *MasterServiceV2) pingNodeClient(n *models2.NodeV2) (ok bool) {
if err := svc.server.SendStreamMessage("node:"+n.Key, grpc.StreamMessageCode_PING); err != nil {
log.Errorf("cannot ping worker node client[%s]: %v", n.Key, err)
return false
}
return true
}
func (svc *MasterServiceV2) updateNodeAvailableRunners(node *models2.NodeV2) (err error) {
func (svc *MasterService) updateNodeAvailableRunners(node *models2.NodeV2) (err error) {
query := bson.M{
"node_id": node.Id,
"status": constants.TaskStatusRunning,
@@ -326,16 +314,16 @@ func (svc *MasterServiceV2) updateNodeAvailableRunners(node *models2.NodeV2) (er
return nil
}
func (svc *MasterServiceV2) sendNotification(node *models2.NodeV2) {
func (svc *MasterService) sendNotification(node *models2.NodeV2) {
if !utils.IsPro() {
return
}
go notification.GetNotificationServiceV2().SendNodeNotification(node)
}
func newMasterServiceV2() (res *MasterServiceV2, err error) {
func newMasterServiceV2() (res *MasterService, err error) {
// master service
svc := &MasterServiceV2{
svc := &MasterService{
cfgPath: config2.GetConfigPath(),
monitorInterval: 15 * time.Second,
stopOnError: false,
@@ -379,15 +367,15 @@ func newMasterServiceV2() (res *MasterServiceV2, err error) {
return svc, nil
}
var masterServiceV2 *MasterServiceV2
var masterServiceV2Once = new(sync.Once)
var masterService *MasterService
var masterServiceOnce = new(sync.Once)
func GetMasterServiceV2() (res *MasterServiceV2, err error) {
masterServiceV2Once.Do(func() {
masterServiceV2, err = newMasterServiceV2()
func GetMasterService() (res *MasterService, err error) {
masterServiceOnce.Do(func() {
masterService, err = newMasterServiceV2()
if err != nil {
log.Errorf("failed to get master service: %v", err)
}
})
return masterServiceV2, err
return masterService, err
}

View File

@@ -2,7 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/config"
"github.com/crawlab-team/crawlab/core/grpc/client"
@@ -15,11 +15,12 @@ import (
"github.com/crawlab-team/crawlab/grpc"
"github.com/crawlab-team/crawlab/trace"
"go.mongodb.org/mongo-driver/bson"
"io"
"sync"
"time"
)
type WorkerServiceV2 struct {
type WorkerService struct {
// dependencies
cfgSvc interfaces.NodeConfigService
client *client.GrpcClientV2
@@ -31,16 +32,17 @@ type WorkerServiceV2 struct {
heartbeatInterval time.Duration
// internals
n *models.NodeV2
s grpc.NodeService_SubscribeClient
stopped bool
n *models.NodeV2
s grpc.NodeService_SubscribeClient
}
func (svc *WorkerServiceV2) Init() (err error) {
func (svc *WorkerService) Init() (err error) {
// do nothing
return nil
}
func (svc *WorkerServiceV2) Start() {
func (svc *WorkerService) Start() {
// start grpc client
if err := svc.client.Start(); err != nil {
panic(err)
@@ -49,8 +51,8 @@ func (svc *WorkerServiceV2) Start() {
// register to master
svc.Register()
// start receiving stream messages
go svc.Recv()
// subscribe
svc.Subscribe()
// start sending heartbeat to master
go svc.ReportStatus()
@@ -65,24 +67,23 @@ func (svc *WorkerServiceV2) Start() {
svc.Stop()
}
func (svc *WorkerServiceV2) Wait() {
func (svc *WorkerService) Wait() {
utils.DefaultWait()
}
func (svc *WorkerServiceV2) Stop() {
func (svc *WorkerService) Stop() {
svc.stopped = true
_ = svc.client.Stop()
svc.handlerSvc.Stop()
log.Infof("worker[%s] service has stopped", svc.cfgSvc.GetNodeKey())
}
func (svc *WorkerServiceV2) Register() {
func (svc *WorkerService) Register() {
ctx, cancel := svc.client.Context()
defer cancel()
_, err := svc.client.NodeClient.Register(ctx, &grpc.NodeServiceRegisterRequest{
Key: svc.cfgSvc.GetNodeKey(),
Name: svc.cfgSvc.GetNodeName(),
IsMaster: svc.cfgSvc.IsMaster(),
AuthKey: svc.cfgSvc.GetAuthKey(),
NodeKey: svc.cfgSvc.GetNodeKey(),
NodeName: svc.cfgSvc.GetNodeName(),
MaxRunners: int32(svc.cfgSvc.GetMaxRunners()),
})
if err != nil {
@@ -96,56 +97,7 @@ func (svc *WorkerServiceV2) Register() {
return
}
func (svc *WorkerServiceV2) Recv() {
msgCh := svc.client.GetMessageChannel()
for {
// return if client is closed
if svc.client.IsClosed() {
return
}
// receive message from channel
msg := <-msgCh
// handle message
if err := svc.handleStreamMessage(msg); err != nil {
continue
}
}
}
func (svc *WorkerServiceV2) handleStreamMessage(msg *grpc.StreamMessage) (err error) {
log.Debugf("[WorkerServiceV2] handle msg: %v", msg)
switch msg.Code {
case grpc.StreamMessageCode_PING:
_, err := svc.client.NodeClient.SendHeartbeat(context.Background(), &grpc.NodeServiceSendHeartbeatRequest{
Key: svc.cfgSvc.GetNodeKey(),
})
if err != nil {
return trace.TraceError(err)
}
case grpc.StreamMessageCode_RUN_TASK:
var t models.TaskV2
if err := json.Unmarshal(msg.Data, &t); err != nil {
return trace.TraceError(err)
}
if err := svc.handlerSvc.Run(t.Id); err != nil {
return trace.TraceError(err)
}
case grpc.StreamMessageCode_CANCEL_TASK:
var t models.TaskV2
if err := json.Unmarshal(msg.Data, &t); err != nil {
return trace.TraceError(err)
}
if err := svc.handlerSvc.Cancel(t.Id); err != nil {
return trace.TraceError(err)
}
}
return nil
}
func (svc *WorkerServiceV2) ReportStatus() {
func (svc *WorkerService) ReportStatus() {
ticker := time.NewTicker(svc.heartbeatInterval)
for {
// return if client is closed
@@ -162,46 +114,76 @@ func (svc *WorkerServiceV2) ReportStatus() {
}
}
func (svc *WorkerServiceV2) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
func (svc *WorkerService) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
return svc.cfgSvc
}
func (svc *WorkerServiceV2) GetConfigPath() (path string) {
func (svc *WorkerService) GetConfigPath() (path string) {
return svc.cfgPath
}
func (svc *WorkerServiceV2) SetConfigPath(path string) {
func (svc *WorkerService) SetConfigPath(path string) {
svc.cfgPath = path
}
func (svc *WorkerServiceV2) GetAddress() (address interfaces.Address) {
func (svc *WorkerService) GetAddress() (address interfaces.Address) {
return svc.address
}
func (svc *WorkerServiceV2) SetAddress(address interfaces.Address) {
func (svc *WorkerService) SetAddress(address interfaces.Address) {
svc.address = address
}
func (svc *WorkerServiceV2) SetHeartbeatInterval(duration time.Duration) {
func (svc *WorkerService) SetHeartbeatInterval(duration time.Duration) {
svc.heartbeatInterval = duration
}
func (svc *WorkerServiceV2) reportStatus() {
func (svc *WorkerService) Subscribe() {
stream, err := svc.client.NodeClient.Subscribe(context.Background(), &grpc.NodeServiceSubscribeRequest{
NodeKey: svc.cfgSvc.GetNodeKey(),
})
if err != nil {
log.Errorf("failed to subscribe to master: %v", err)
return
}
for {
if svc.stopped {
return
}
select {
default:
msg, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return
}
log.Errorf("failed to receive message from master: %v", err)
continue
}
switch msg.Code {
case grpc.NodeServiceSubscribeCode_PING:
// do nothing
}
}
}
}
func (svc *WorkerService) reportStatus() {
ctx, cancel := context.WithTimeout(context.Background(), svc.heartbeatInterval)
defer cancel()
_, err := svc.client.NodeClient.SendHeartbeat(ctx, &grpc.NodeServiceSendHeartbeatRequest{
Key: svc.cfgSvc.GetNodeKey(),
NodeKey: svc.cfgSvc.GetNodeKey(),
})
if err != nil {
trace.PrintError(err)
}
}
var workerServiceV2 *WorkerServiceV2
var workerServiceV2 *WorkerService
var workerServiceV2Once = new(sync.Once)
func newWorkerServiceV2() (res *WorkerServiceV2, err error) {
svc := &WorkerServiceV2{
func newWorkerService() (res *WorkerService, err error) {
svc := &WorkerService{
cfgPath: config.GetConfigPath(),
heartbeatInterval: 15 * time.Second,
}
@@ -233,9 +215,9 @@ func newWorkerServiceV2() (res *WorkerServiceV2, err error) {
return svc, nil
}
func GetWorkerServiceV2() (res *WorkerServiceV2, err error) {
func GetWorkerService() (res *WorkerService, err error) {
workerServiceV2Once.Do(func() {
workerServiceV2, err = newWorkerServiceV2()
workerServiceV2, err = newWorkerService()
if err != nil {
log.Errorf("failed to get worker service: %v", err)
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/fs"
@@ -44,16 +43,16 @@ type RunnerV2 struct {
bufferSize int
// internals
cmd *exec.Cmd // process command instance
pid int // process id
tid primitive.ObjectID // task id
t *models.TaskV2 // task model.Task
s *models.SpiderV2 // spider model.Spider
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
sub grpc.TaskService_SubscribeClient // grpc task service stream client
cmd *exec.Cmd // process command instance
pid int // process id
tid primitive.ObjectID // task id
t *models.TaskV2 // task model.Task
s *models.SpiderV2 // spider model.Spider
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
conn grpc.TaskService_ConnectClient // grpc task service stream client
// log internals
scannerStdout *bufio.Reader
@@ -76,7 +75,7 @@ func (r *RunnerV2) Init() (err error) {
}
// grpc task service stream client
if err := r.initSub(); err != nil {
if err := r.initConnection(); err != nil {
return err
}
@@ -177,26 +176,19 @@ func (r *RunnerV2) Cancel(force bool) (err error) {
}
// make sure the process does not exist
op := func() error {
if exists, _ := process.PidExists(int32(r.pid)); exists {
ticker := time.NewTicker(1 * time.Second)
timeout := time.After(r.svc.GetCancelTimeout())
for {
select {
case <-timeout:
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
case <-ticker.C:
if exists, _ := process.PidExists(int32(r.pid)); exists {
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
}
return nil
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetExitWatchDuration())
defer cancel()
b := backoff.WithContext(backoff.NewConstantBackOff(1*time.Second), ctx)
if err := backoff.Retry(op, b); err != nil {
log.Errorf("Error canceling task %s: %v", r.tid, err)
return trace.TraceError(err)
}
return nil
}
// CleanUp clean up task runner
func (r *RunnerV2) CleanUp() (err error) {
return nil
}
func (r *RunnerV2) SetSubscribeTimeout(timeout time.Duration) {
@@ -537,8 +529,8 @@ func (r *RunnerV2) updateTask(status string, e error) (err error) {
return nil
}
func (r *RunnerV2) initSub() (err error) {
r.sub, err = r.c.TaskClient.Subscribe(context.Background())
func (r *RunnerV2) initConnection() (err error) {
r.conn, err = r.c.TaskClient.Connect(context.Background())
if err != nil {
return trace.TraceError(err)
}
@@ -546,20 +538,18 @@ func (r *RunnerV2) initSub() (err error) {
}
func (r *RunnerV2) writeLogLines(lines []string) {
data, err := json.Marshal(&entity.StreamMessageTaskData{
TaskId: r.tid,
Logs: lines,
})
linesBytes, err := json.Marshal(lines)
if err != nil {
trace.PrintError(err)
log.Errorf("Error marshaling log lines: %v", err)
return
}
msg := &grpc.StreamMessage{
Code: grpc.StreamMessageCode_INSERT_LOGS,
Data: data,
msg := &grpc.TaskServiceConnectRequest{
Code: grpc.TaskServiceConnectCode_INSERT_LOGS,
TaskId: r.tid.Hex(),
Data: linesBytes,
}
if err := r.sub.Send(msg); err != nil {
trace.PrintError(err)
if err := r.conn.Send(msg); err != nil {
log.Errorf("Error sending log lines: %v", err)
return
}
}

View File

@@ -2,8 +2,8 @@ package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
errors2 "github.com/crawlab-team/crawlab/core/errors"
@@ -13,9 +13,11 @@ import (
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
"github.com/crawlab-team/crawlab/core/models/service"
nodeconfig "github.com/crawlab-team/crawlab/core/node/config"
"github.com/crawlab-team/crawlab/grpc"
"github.com/crawlab-team/crawlab/trace"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"io"
"sync"
"time"
)
@@ -50,7 +52,7 @@ func (svc *ServiceV2) Start() {
}
go svc.ReportStatus()
go svc.Fetch()
go svc.FetchAndRunTasks()
}
func (svc *ServiceV2) Stop() {
@@ -58,12 +60,7 @@ func (svc *ServiceV2) Stop() {
}
func (svc *ServiceV2) Run(taskId primitive.ObjectID) (err error) {
return svc.run(taskId)
}
func (svc *ServiceV2) Reset() {
svc.mu.Lock()
defer svc.mu.Unlock()
return svc.runTask(taskId)
}
func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error) {
@@ -77,7 +74,7 @@ func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error)
return nil
}
func (svc *ServiceV2) Fetch() {
func (svc *ServiceV2) FetchAndRunTasks() {
ticker := time.NewTicker(svc.fetchInterval)
for {
if svc.stopped {
@@ -102,10 +99,9 @@ func (svc *ServiceV2) Fetch() {
continue
}
// fetch task
tid, err := svc.fetch()
// fetch task id
tid, err := svc.fetchTask()
if err != nil {
trace.PrintError(err)
continue
}
@@ -115,8 +111,7 @@ func (svc *ServiceV2) Fetch() {
}
// run task
if err := svc.run(tid); err != nil {
trace.PrintError(err)
if err := svc.runTask(tid); err != nil {
t, err := svc.GetTaskById(tid)
if err != nil && t.Status != constants.TaskStatusCancelled {
t.Error = err.Error()
@@ -281,30 +276,36 @@ func (svc *ServiceV2) reportStatus() (err error) {
return nil
}
func (svc *ServiceV2) fetch() (tid primitive.ObjectID, err error) {
func (svc *ServiceV2) fetchTask() (tid primitive.ObjectID, err error) {
ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout)
defer cancel()
res, err := svc.c.TaskClient.Fetch(ctx, svc.c.NewRequest(nil))
res, err := svc.c.TaskClient.FetchTask(ctx, svc.c.NewRequest(nil))
if err != nil {
return tid, trace.TraceError(err)
return primitive.NilObjectID, fmt.Errorf("fetchTask task error: %v", err)
}
if err := json.Unmarshal(res.Data, &tid); err != nil {
return tid, trace.TraceError(err)
// validate task id
tid, err = primitive.ObjectIDFromHex(res.GetTaskId())
if err != nil {
return primitive.NilObjectID, fmt.Errorf("invalid task id: %s", res.GetTaskId())
}
return tid, nil
}
func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
func (svc *ServiceV2) runTask(taskId primitive.ObjectID) (err error) {
// attempt to get runner from pool
_, ok := svc.runners.Load(taskId)
if ok {
return trace.TraceError(errors2.ErrorTaskAlreadyExists)
err = fmt.Errorf("task[%s] already exists", taskId.Hex())
log.Errorf("run task error: %v", err)
return err
}
// create a new task runner
r, err := NewTaskRunnerV2(taskId, svc)
if err != nil {
return trace.TraceError(err)
err = fmt.Errorf("failed to create task runner: %v", err)
log.Errorf("run task error: %v", err)
return err
}
// add runner to pool
@@ -312,16 +313,18 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
// create a goroutine to run task
go func() {
// delete runner from pool
defer svc.deleteRunner(r.GetTaskId())
defer func(r interfaces.TaskRunner) {
err := r.CleanUp()
if err != nil {
log.Errorf("task[%s] clean up error: %v", r.GetTaskId().Hex(), err)
}
}(r)
// run task process (blocking)
// error or finish after task runner ends
// get subscription stream
stopCh := make(chan struct{})
stream, err := svc.subscribeTask(r.GetTaskId())
if err == nil {
// create a goroutine to handle stream messages
go svc.handleStreamMessages(r.GetTaskId(), stream, stopCh)
} else {
log.Errorf("failed to subscribe task[%s]: %v", r.GetTaskId().Hex(), err)
log.Warnf("task[%s] will not be able to receive stream messages", r.GetTaskId().Hex())
}
// run task process (blocking) error or finish after task runner ends
if err := r.Run(); err != nil {
switch {
case errors.Is(err, constants.ErrTaskError):
@@ -333,11 +336,64 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
}
}
log.Infof("task[%s] finished", r.GetTaskId().Hex())
// send stopCh signal to stream message handler
stopCh <- struct{}{}
// delete runner from pool
svc.deleteRunner(r.GetTaskId())
}()
return nil
}
func (svc *ServiceV2) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req := &grpc.TaskServiceSubscribeRequest{
TaskId: taskId.Hex(),
}
stream, err = svc.c.TaskClient.Subscribe(ctx, req)
if err != nil {
log.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
return nil, err
}
return stream, nil
}
func (svc *ServiceV2) handleStreamMessages(id primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
for {
select {
case <-stopCh:
err := stream.CloseSend()
if err != nil {
log.Errorf("task[%s] failed to close stream: %v", id.Hex(), err)
return
}
return
default:
msg, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return
}
log.Errorf("task[%s] stream error: %v", id.Hex(), err)
continue
}
switch msg.Code {
case grpc.TaskServiceSubscribeCode_CANCEL:
log.Infof("task[%s] received cancel signal", id.Hex())
go func() {
if err := svc.Cancel(id, true); err != nil {
log.Errorf("task[%s] failed to cancel: %v", id.Hex(), err)
}
log.Infof("task[%s] cancelled", id.Hex())
}()
}
}
}
}
func newTaskHandlerServiceV2() (svc2 *ServiceV2, err error) {
// service
svc := &ServiceV2{

View File

@@ -23,7 +23,7 @@ import (
type ServiceV2 struct {
// dependencies
nodeCfgSvc interfaces.NodeConfigService
svr *server.GrpcServerV2
svr *server.GrpcServer
handlerSvc *handler.ServiceV2
// settings