mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-22 17:31:03 +01:00
refactor: updated grpc services
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/apex/log"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/entity"
|
||||
"github.com/crawlab-team/crawlab/core/fs"
|
||||
@@ -44,16 +43,16 @@ type RunnerV2 struct {
|
||||
bufferSize int
|
||||
|
||||
// internals
|
||||
cmd *exec.Cmd // process command instance
|
||||
pid int // process id
|
||||
tid primitive.ObjectID // task id
|
||||
t *models.TaskV2 // task model.Task
|
||||
s *models.SpiderV2 // spider model.Spider
|
||||
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
|
||||
err error // standard process error
|
||||
cwd string // working directory
|
||||
c *client2.GrpcClientV2 // grpc client
|
||||
sub grpc.TaskService_SubscribeClient // grpc task service stream client
|
||||
cmd *exec.Cmd // process command instance
|
||||
pid int // process id
|
||||
tid primitive.ObjectID // task id
|
||||
t *models.TaskV2 // task model.Task
|
||||
s *models.SpiderV2 // spider model.Spider
|
||||
ch chan constants.TaskSignal // channel to communicate between Service and RunnerV2
|
||||
err error // standard process error
|
||||
cwd string // working directory
|
||||
c *client2.GrpcClientV2 // grpc client
|
||||
conn grpc.TaskService_ConnectClient // grpc task service stream client
|
||||
|
||||
// log internals
|
||||
scannerStdout *bufio.Reader
|
||||
@@ -76,7 +75,7 @@ func (r *RunnerV2) Init() (err error) {
|
||||
}
|
||||
|
||||
// grpc task service stream client
|
||||
if err := r.initSub(); err != nil {
|
||||
if err := r.initConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -177,26 +176,19 @@ func (r *RunnerV2) Cancel(force bool) (err error) {
|
||||
}
|
||||
|
||||
// make sure the process does not exist
|
||||
op := func() error {
|
||||
if exists, _ := process.PidExists(int32(r.pid)); exists {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
timeout := time.After(r.svc.GetCancelTimeout())
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
|
||||
case <-ticker.C:
|
||||
if exists, _ := process.PidExists(int32(r.pid)); exists {
|
||||
return errors.New(fmt.Sprintf("task process %d still exists", r.pid))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.svc.GetExitWatchDuration())
|
||||
defer cancel()
|
||||
b := backoff.WithContext(backoff.NewConstantBackOff(1*time.Second), ctx)
|
||||
if err := backoff.Retry(op, b); err != nil {
|
||||
log.Errorf("Error canceling task %s: %v", r.tid, err)
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp clean up task runner
|
||||
func (r *RunnerV2) CleanUp() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RunnerV2) SetSubscribeTimeout(timeout time.Duration) {
|
||||
@@ -537,8 +529,8 @@ func (r *RunnerV2) updateTask(status string, e error) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RunnerV2) initSub() (err error) {
|
||||
r.sub, err = r.c.TaskClient.Subscribe(context.Background())
|
||||
func (r *RunnerV2) initConnection() (err error) {
|
||||
r.conn, err = r.c.TaskClient.Connect(context.Background())
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
@@ -546,20 +538,18 @@ func (r *RunnerV2) initSub() (err error) {
|
||||
}
|
||||
|
||||
func (r *RunnerV2) writeLogLines(lines []string) {
|
||||
data, err := json.Marshal(&entity.StreamMessageTaskData{
|
||||
TaskId: r.tid,
|
||||
Logs: lines,
|
||||
})
|
||||
linesBytes, err := json.Marshal(lines)
|
||||
if err != nil {
|
||||
trace.PrintError(err)
|
||||
log.Errorf("Error marshaling log lines: %v", err)
|
||||
return
|
||||
}
|
||||
msg := &grpc.StreamMessage{
|
||||
Code: grpc.StreamMessageCode_INSERT_LOGS,
|
||||
Data: data,
|
||||
msg := &grpc.TaskServiceConnectRequest{
|
||||
Code: grpc.TaskServiceConnectCode_INSERT_LOGS,
|
||||
TaskId: r.tid.Hex(),
|
||||
Data: linesBytes,
|
||||
}
|
||||
if err := r.sub.Send(msg); err != nil {
|
||||
trace.PrintError(err)
|
||||
if err := r.conn.Send(msg); err != nil {
|
||||
log.Errorf("Error sending log lines: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
errors2 "github.com/crawlab-team/crawlab/core/errors"
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
models2 "github.com/crawlab-team/crawlab/core/models/models/v2"
|
||||
"github.com/crawlab-team/crawlab/core/models/service"
|
||||
nodeconfig "github.com/crawlab-team/crawlab/core/node/config"
|
||||
"github.com/crawlab-team/crawlab/grpc"
|
||||
"github.com/crawlab-team/crawlab/trace"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -50,7 +52,7 @@ func (svc *ServiceV2) Start() {
|
||||
}
|
||||
|
||||
go svc.ReportStatus()
|
||||
go svc.Fetch()
|
||||
go svc.FetchAndRunTasks()
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Stop() {
|
||||
@@ -58,12 +60,7 @@ func (svc *ServiceV2) Stop() {
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Run(taskId primitive.ObjectID) (err error) {
|
||||
return svc.run(taskId)
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Reset() {
|
||||
svc.mu.Lock()
|
||||
defer svc.mu.Unlock()
|
||||
return svc.runTask(taskId)
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error) {
|
||||
@@ -77,7 +74,7 @@ func (svc *ServiceV2) Cancel(taskId primitive.ObjectID, force bool) (err error)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) Fetch() {
|
||||
func (svc *ServiceV2) FetchAndRunTasks() {
|
||||
ticker := time.NewTicker(svc.fetchInterval)
|
||||
for {
|
||||
if svc.stopped {
|
||||
@@ -102,10 +99,9 @@ func (svc *ServiceV2) Fetch() {
|
||||
continue
|
||||
}
|
||||
|
||||
// fetch task
|
||||
tid, err := svc.fetch()
|
||||
// fetch task id
|
||||
tid, err := svc.fetchTask()
|
||||
if err != nil {
|
||||
trace.PrintError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -115,8 +111,7 @@ func (svc *ServiceV2) Fetch() {
|
||||
}
|
||||
|
||||
// run task
|
||||
if err := svc.run(tid); err != nil {
|
||||
trace.PrintError(err)
|
||||
if err := svc.runTask(tid); err != nil {
|
||||
t, err := svc.GetTaskById(tid)
|
||||
if err != nil && t.Status != constants.TaskStatusCancelled {
|
||||
t.Error = err.Error()
|
||||
@@ -281,30 +276,36 @@ func (svc *ServiceV2) reportStatus() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) fetch() (tid primitive.ObjectID, err error) {
|
||||
func (svc *ServiceV2) fetchTask() (tid primitive.ObjectID, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), svc.fetchTimeout)
|
||||
defer cancel()
|
||||
res, err := svc.c.TaskClient.Fetch(ctx, svc.c.NewRequest(nil))
|
||||
res, err := svc.c.TaskClient.FetchTask(ctx, svc.c.NewRequest(nil))
|
||||
if err != nil {
|
||||
return tid, trace.TraceError(err)
|
||||
return primitive.NilObjectID, fmt.Errorf("fetchTask task error: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(res.Data, &tid); err != nil {
|
||||
return tid, trace.TraceError(err)
|
||||
// validate task id
|
||||
tid, err = primitive.ObjectIDFromHex(res.GetTaskId())
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, fmt.Errorf("invalid task id: %s", res.GetTaskId())
|
||||
}
|
||||
return tid, nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
|
||||
func (svc *ServiceV2) runTask(taskId primitive.ObjectID) (err error) {
|
||||
// attempt to get runner from pool
|
||||
_, ok := svc.runners.Load(taskId)
|
||||
if ok {
|
||||
return trace.TraceError(errors2.ErrorTaskAlreadyExists)
|
||||
err = fmt.Errorf("task[%s] already exists", taskId.Hex())
|
||||
log.Errorf("run task error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// create a new task runner
|
||||
r, err := NewTaskRunnerV2(taskId, svc)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
err = fmt.Errorf("failed to create task runner: %v", err)
|
||||
log.Errorf("run task error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// add runner to pool
|
||||
@@ -312,16 +313,18 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
|
||||
|
||||
// create a goroutine to run task
|
||||
go func() {
|
||||
// delete runner from pool
|
||||
defer svc.deleteRunner(r.GetTaskId())
|
||||
defer func(r interfaces.TaskRunner) {
|
||||
err := r.CleanUp()
|
||||
if err != nil {
|
||||
log.Errorf("task[%s] clean up error: %v", r.GetTaskId().Hex(), err)
|
||||
}
|
||||
}(r)
|
||||
// run task process (blocking)
|
||||
// error or finish after task runner ends
|
||||
// get subscription stream
|
||||
stopCh := make(chan struct{})
|
||||
stream, err := svc.subscribeTask(r.GetTaskId())
|
||||
if err == nil {
|
||||
// create a goroutine to handle stream messages
|
||||
go svc.handleStreamMessages(r.GetTaskId(), stream, stopCh)
|
||||
} else {
|
||||
log.Errorf("failed to subscribe task[%s]: %v", r.GetTaskId().Hex(), err)
|
||||
log.Warnf("task[%s] will not be able to receive stream messages", r.GetTaskId().Hex())
|
||||
}
|
||||
|
||||
// run task process (blocking) error or finish after task runner ends
|
||||
if err := r.Run(); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, constants.ErrTaskError):
|
||||
@@ -333,11 +336,64 @@ func (svc *ServiceV2) run(taskId primitive.ObjectID) (err error) {
|
||||
}
|
||||
}
|
||||
log.Infof("task[%s] finished", r.GetTaskId().Hex())
|
||||
|
||||
// send stopCh signal to stream message handler
|
||||
stopCh <- struct{}{}
|
||||
|
||||
// delete runner from pool
|
||||
svc.deleteRunner(r.GetTaskId())
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) subscribeTask(taskId primitive.ObjectID) (stream grpc.TaskService_SubscribeClient, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
req := &grpc.TaskServiceSubscribeRequest{
|
||||
TaskId: taskId.Hex(),
|
||||
}
|
||||
stream, err = svc.c.TaskClient.Subscribe(ctx, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to subscribe task[%s]: %v", taskId.Hex(), err)
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (svc *ServiceV2) handleStreamMessages(id primitive.ObjectID, stream grpc.TaskService_SubscribeClient, stopCh chan struct{}) {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
err := stream.CloseSend()
|
||||
if err != nil {
|
||||
log.Errorf("task[%s] failed to close stream: %v", id.Hex(), err)
|
||||
return
|
||||
}
|
||||
return
|
||||
default:
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
log.Errorf("task[%s] stream error: %v", id.Hex(), err)
|
||||
continue
|
||||
}
|
||||
switch msg.Code {
|
||||
case grpc.TaskServiceSubscribeCode_CANCEL:
|
||||
log.Infof("task[%s] received cancel signal", id.Hex())
|
||||
go func() {
|
||||
if err := svc.Cancel(id, true); err != nil {
|
||||
log.Errorf("task[%s] failed to cancel: %v", id.Hex(), err)
|
||||
}
|
||||
log.Infof("task[%s] cancelled", id.Hex())
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTaskHandlerServiceV2() (svc2 *ServiceV2, err error) {
|
||||
// service
|
||||
svc := &ServiceV2{
|
||||
|
||||
Reference in New Issue
Block a user