refactor: updated grpc services

This commit is contained in:
Marvin Zhang
2024-10-30 18:42:23 +08:00
parent 789f71fd80
commit fa1433007f
64 changed files with 2704 additions and 5132 deletions

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/fs"
@@ -44,16 +43,16 @@ type RunnerV2 struct {
bufferSize int
// internals
cmd *exec.Cmd // process command instance
pid int // process id
tid primitive.ObjectID // task id
t *models.TaskV2 // task model.Task
s *models.SpiderV2 // spider model.Spider
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
sub grpc.TaskService_SubscribeClient // grpc task service stream client
cmd *exec.Cmd // process command instance
pid int // process id
tid primitive.ObjectID // task id
t *models.TaskV2 // task model.Task
s *models.SpiderV2 // spider model.Spider
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
err error // standard process error
cwd string // working directory
c *client2.GrpcClientV2 // grpc client
conn grpc.TaskService_ConnectClient // grpc task service stream client
// log internals
scannerStdout *bufio.Reader
@@ -76,7 +75,7 @@ func (r *RunnerV2) Init() (err error) {
}
// grpc task service stream client
if err := r.initSub(); err != nil {
if err := r.initConnection(); err != nil {
return err
}
@@ -177,26 +176,19 @@ func (r *RunnerV2) Cancel(force bool) (err error) {
}
// make sure the process does not exist
op := func() error {
if exists, _ := process.PidExists(int32(r.pid)); exists {
ticker := time.NewTicker(1 * time.Second)
timeout := time.After(r.svc.GetCancelTimeout())
for {
select {
case <-timeout:
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
case <-ticker.C:
if exists, _ := process.PidExists(int32(r.pid)); exists {
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
}
return nil
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetExitWatchDuration())
defer cancel()
b := backoff.WithContext(backoff.NewConstantBackOff(1*time.Second), ctx)
if err := backoff.Retry(op, b); err != nil {
log.Errorf("Error canceling task %s: %v", r.tid, err)
return trace.TraceError(err)
}
return nil
}
// CleanUp clean up task runner
func (r *RunnerV2) CleanUp() (err error) {
return nil
}
func (r *RunnerV2) SetSubscribeTimeout(timeout time.Duration) {
@@ -537,8 +529,8 @@ func (r *RunnerV2) updateTask(status string, e error) (err error) {
return nil
}
func (r *RunnerV2) initSub() (err error) {
r.sub, err = r.c.TaskClient.Subscribe(context.Background())
func (r *RunnerV2) initConnection() (err error) {
r.conn, err = r.c.TaskClient.Connect(context.Background())
if err != nil {
return trace.TraceError(err)
}
@@ -546,20 +538,18 @@ func (r *RunnerV2) initSub() (err error) {
}
func (r *RunnerV2) writeLogLines(lines []string) {
data, err := json.Marshal(&entity.StreamMessageTaskData{
TaskId: r.tid,
Logs: lines,
})
linesBytes, err := json.Marshal(lines)
if err != nil {
trace.PrintError(err)
log.Errorf("Error marshaling log lines: %v", err)
return
}
msg := &grpc.StreamMessage{
Code: grpc.StreamMessageCode_INSERT_LOGS,
Data: data,
msg := &grpc.TaskServiceConnectRequest{
Code: grpc.TaskServiceConnectCode_INSERT_LOGS,
TaskId: r.tid.Hex(),
Data: linesBytes,
}
if err := r.sub.Send(msg); err != nil {
trace.PrintError(err)
if err := r.conn.Send(msg); err != nil {
log.Errorf("Error sending log lines: %v", err)
return
}
}

View File

@@ -2,8 +2,8 @@ package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
errors2 "github.com/crawlab-team/crawlab/core/errors"
@@ -13,9 +13,11 @@ import (
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
"github.com/crawlab-team/crawlab/core/models/service"
nodeconfig "github.com/crawlab-team/crawlab/core/node/config"
"github.com/crawlab-team/crawlab/grpc"
"github.com/crawlab-team/crawlab/trace"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"io"
"sync"
"time"
)
@@ -50,7 +52,7 @@ func (svc *ServiceV2) Start() {
}
go svc.ReportStatus()
go svc.Fetch()
go svc.FetchAndRunTasks()
}
func (svc *ServiceV2) Stop() {
@@ -58,12 +60,7 @@ func (svc *ServiceV2) Stop() {
}
func (svc *ServiceV2) Run(taskId primitive.ObjectID) (err error) {
return svc.run(taskId)
}
func (svc *ServiceV2) Reset() {
svc.mu.Lock()
defer svc.mu.Unlock()
return svc.runTask(taskId)
}
func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error) {
@@ -77,7 +74,7 @@ func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error)
return nil
}
func (svc *ServiceV2) Fetch() {
func (svc *ServiceV2) FetchAndRunTasks() {
ticker := time.NewTicker(svc.fetchInterval)
for {
if svc.stopped {
@@ -102,10 +99,9 @@ func (svc *ServiceV2) Fetch() {
continue
}
// fetch task
tid, err := svc.fetch()
// fetch task id
tid, err := svc.fetchTask()
if err != nil {
trace.PrintError(err)
continue
}
@@ -115,8 +111,7 @@ func (svc *ServiceV2) Fetch() {
}
// run task
if err := svc.run(tid); err != nil {
trace.PrintError(err)
if err := svc.runTask(tid); err != nil {
t, err := svc.GetTaskById(tid)
if err != nil && t.Status != constants.TaskStatusCancelled {
t.Error = err.Error()
@@ -281,30 +276,36 @@ func (svc *ServiceV2) reportStatus() (err error) {
return nil
}
func (svc *ServiceV2) fetch() (tid primitive.ObjectID, err error) {
func (svc *ServiceV2) fetchTask() (tid primitive.ObjectID, err error) {
ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout)
defer cancel()
res, err := svc.c.TaskClient.Fetch(ctx, svc.c.NewRequest(nil))
res, err := svc.c.TaskClient.FetchTask(ctx, svc.c.NewRequest(nil))
if err != nil {
return tid, trace.TraceError(err)
return primitive.NilObjectID, fmt.Errorf("fetchTask task error: %v", err)
}
if err := json.Unmarshal(res.Data, &tid); err != nil {
return tid, trace.TraceError(err)
// validate task id
tid, err = primitive.ObjectIDFromHex(res.GetTaskId())
if err != nil {
return primitive.NilObjectID, fmt.Errorf("invalid task id: %s", res.GetTaskId())
}
return tid, nil
}
func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
func (svc *ServiceV2) runTask(taskId primitive.ObjectID) (err error) {
// attempt to get runner from pool
_, ok := svc.runners.Load(taskId)
if ok {
return trace.TraceError(errors2.ErrorTaskAlreadyExists)
err = fmt.Errorf("task[%s] already exists", taskId.Hex())
log.Errorf("run task error: %v", err)
return err
}
// create a new task runner
r, err := NewTaskRunnerV2(taskId, svc)
if err != nil {
return trace.TraceError(err)
err = fmt.Errorf("failed to create task runner: %v", err)
log.Errorf("run task error: %v", err)
return err
}
// add runner to pool
@@ -312,16 +313,18 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
// create a goroutine to run task
go func() {
// delete runner from pool
defer svc.deleteRunner(r.GetTaskId())
defer func(r interfaces.TaskRunner) {
err := r.CleanUp()
if err != nil {
log.Errorf("task[%s] clean up error: %v", r.GetTaskId().Hex(), err)
}
}(r)
// run task process (blocking)
// error or finish after task runner ends
// get subscription stream
stopCh := make(chan struct{})
stream, err := svc.subscribeTask(r.GetTaskId())
if err == nil {
// create a goroutine to handle stream messages
go svc.handleStreamMessages(r.GetTaskId(), stream, stopCh)
} else {
log.Errorf("failed to subscribe task[%s]: %v", r.GetTaskId().Hex(), err)
log.Warnf("task[%s] will not be able to receive stream messages", r.GetTaskId().Hex())
}
// run task process (blocking) error or finish after task runner ends
if err := r.Run(); err != nil {
switch {
case errors.Is(err, constants.ErrTaskError):
@@ -333,11 +336,64 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
}
}
log.Infof("task[%s] finished", r.GetTaskId().Hex())
// send stopCh signal to stream message handler
stopCh <- struct{}{}
// delete runner from pool
svc.deleteRunner(r.GetTaskId())
}()
return nil
}
func (svc *ServiceV2) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req := &grpc.TaskServiceSubscribeRequest{
TaskId: taskId.Hex(),
}
stream, err = svc.c.TaskClient.Subscribe(ctx, req)
if err != nil {
log.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
return nil, err
}
return stream, nil
}
func (svc *ServiceV2) handleStreamMessages(id primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
for {
select {
case <-stopCh:
err := stream.CloseSend()
if err != nil {
log.Errorf("task[%s] failed to close stream: %v", id.Hex(), err)
return
}
return
default:
msg, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return
}
log.Errorf("task[%s] stream error: %v", id.Hex(), err)
continue
}
switch msg.Code {
case grpc.TaskServiceSubscribeCode_CANCEL:
log.Infof("task[%s] received cancel signal", id.Hex())
go func() {
if err := svc.Cancel(id, true); err != nil {
log.Errorf("task[%s] failed to cancel: %v", id.Hex(), err)
}
log.Infof("task[%s] cancelled", id.Hex())
}()
}
}
}
}
func newTaskHandlerServiceV2() (svc2 *ServiceV2, err error) {
// service
svc := &ServiceV2{