mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
400 lines
10 KiB
Go
400 lines
10 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"context"
|
|
errors2 "errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/crawlab-team/crawlab/core/constants"
|
|
"github.com/crawlab-team/crawlab/core/grpc/server"
|
|
"github.com/crawlab-team/crawlab/core/interfaces"
|
|
"github.com/crawlab-team/crawlab/core/models/models"
|
|
"github.com/crawlab-team/crawlab/core/models/service"
|
|
"github.com/crawlab-team/crawlab/core/task/handler"
|
|
"github.com/crawlab-team/crawlab/core/utils"
|
|
"github.com/crawlab-team/crawlab/grpc"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
mongo2 "go.mongodb.org/mongo-driver/mongo"
|
|
)
|
|
|
|
type Service struct {
|
|
// dependencies
|
|
svr *server.GrpcServer
|
|
handlerSvc *handler.Service
|
|
|
|
// settings
|
|
interval time.Duration
|
|
|
|
// internals for lifecycle management
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
stopped bool
|
|
mu sync.RWMutex
|
|
wg sync.WaitGroup
|
|
|
|
interfaces.Logger
|
|
}
|
|
|
|
func (svc *Service) Start() {
|
|
// Initialize context for graceful shutdown
|
|
svc.ctx, svc.cancel = context.WithCancel(context.Background())
|
|
|
|
// Start background goroutines with proper tracking
|
|
svc.wg.Add(2)
|
|
go svc.initTaskStatus()
|
|
go svc.cleanupTasks()
|
|
|
|
svc.Infof("Task scheduler service started")
|
|
utils.DefaultWait()
|
|
}
|
|
|
|
func (svc *Service) Stop() {
|
|
svc.mu.Lock()
|
|
if svc.stopped {
|
|
svc.mu.Unlock()
|
|
return
|
|
}
|
|
svc.stopped = true
|
|
svc.mu.Unlock()
|
|
|
|
svc.Infof("Stopping task scheduler service...")
|
|
|
|
// Cancel context to signal all goroutines to stop
|
|
if svc.cancel != nil {
|
|
svc.cancel()
|
|
}
|
|
|
|
// Wait for all background goroutines to finish
|
|
done := make(chan struct{})
|
|
go func() {
|
|
svc.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
// Give goroutines time to finish gracefully, then force stop
|
|
select {
|
|
case <-done:
|
|
svc.Infof("All goroutines stopped gracefully")
|
|
case <-time.After(30 * time.Second):
|
|
svc.Warnf("Some goroutines did not stop gracefully within timeout")
|
|
}
|
|
|
|
svc.Infof("Task scheduler service stopped")
|
|
}
|
|
|
|
func (svc *Service) Enqueue(t *models.Task, by primitive.ObjectID) (t2 *models.Task, err error) {
|
|
// set task status
|
|
t.Status = constants.TaskStatusPending
|
|
t.SetCreated(by)
|
|
t.SetUpdated(by)
|
|
|
|
// add task
|
|
taskModelSvc := service.NewModelService[models.Task]()
|
|
t.Id, err = taskModelSvc.InsertOne(*t)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// task stat
|
|
ts := models.TaskStat{}
|
|
ts.SetId(t.Id)
|
|
ts.SetCreated(by)
|
|
ts.SetUpdated(by)
|
|
|
|
// add task stat
|
|
_, err = service.NewModelService[models.TaskStat]().InsertOne(ts)
|
|
if err != nil {
|
|
svc.Errorf("failed to add task stat: %s", t.Id.Hex())
|
|
return nil, err
|
|
}
|
|
|
|
// success
|
|
return t, nil
|
|
}
|
|
|
|
func (svc *Service) Cancel(id, by primitive.ObjectID, force bool) (err error) {
|
|
// task
|
|
t, err := service.NewModelService[models.Task]().GetById(id)
|
|
if err != nil {
|
|
svc.Errorf("task not found: %s", id.Hex())
|
|
return err
|
|
}
|
|
|
|
// initial status
|
|
initialStatus := t.Status
|
|
|
|
// set status of pending tasks as "cancelled"
|
|
if initialStatus == constants.TaskStatusPending {
|
|
t.Status = constants.TaskStatusCancelled
|
|
return svc.SaveTask(t, by)
|
|
}
|
|
|
|
// whether task is running on master node
|
|
isMasterTask, err := svc.isMasterNode(t)
|
|
if err != nil {
|
|
err := fmt.Errorf("failed to check if task is running on master node: %s", t.Id.Hex())
|
|
t.Status = constants.TaskStatusAbnormal
|
|
t.Error = err.Error()
|
|
return svc.SaveTask(t, by)
|
|
}
|
|
|
|
if isMasterTask {
|
|
// cancel task on master
|
|
return svc.cancelOnMaster(t, by, force)
|
|
} else {
|
|
// send to cancel task on worker nodes
|
|
return svc.cancelOnWorker(t, by, force)
|
|
}
|
|
}
|
|
|
|
func (svc *Service) cancelOnMaster(t *models.Task, by primitive.ObjectID, force bool) (err error) {
|
|
if err := svc.handlerSvc.Cancel(t.Id, force); err != nil {
|
|
svc.Errorf("failed to cancel task (%s) on master: %v", t.Id.Hex(), err)
|
|
return err
|
|
}
|
|
|
|
// set task status as "cancelled"
|
|
t.Status = constants.TaskStatusCancelled
|
|
return svc.SaveTask(t, by)
|
|
}
|
|
|
|
func (svc *Service) cancelOnWorker(t *models.Task, by primitive.ObjectID, force bool) (err error) {
|
|
// get subscribe stream
|
|
stream, ok := svc.svr.TaskSvr.GetSubscribeStream(t.Id)
|
|
if !ok {
|
|
svc.Warnf("stream not found for task (%s), task may already be finished or connection lost", t.Id.Hex())
|
|
// Task might have finished or connection lost, mark as cancelled anyway
|
|
t.Status = constants.TaskStatusCancelled
|
|
t.Error = "cancel signal could not be delivered - stream not found"
|
|
return svc.SaveTask(t, by)
|
|
}
|
|
|
|
// send cancel request with timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// Create a channel to handle the send operation
|
|
sendDone := make(chan error, 1)
|
|
go func() {
|
|
err := stream.Send(&grpc.TaskServiceSubscribeResponse{
|
|
Code: grpc.TaskServiceSubscribeCode_CANCEL,
|
|
TaskId: t.Id.Hex(),
|
|
Force: force,
|
|
})
|
|
sendDone <- err
|
|
}()
|
|
|
|
select {
|
|
case err = <-sendDone:
|
|
if err != nil {
|
|
svc.Errorf("failed to send cancel task (%s) request to worker: %v", t.Id.Hex(), err)
|
|
// If sending failed, still mark task as cancelled to avoid stuck state
|
|
t.Status = constants.TaskStatusCancelled
|
|
t.Error = fmt.Sprintf("cancel signal delivery failed: %v", err)
|
|
return svc.SaveTask(t, by)
|
|
}
|
|
svc.Infof("cancel signal sent for task (%s) with force=%v", t.Id.Hex(), force)
|
|
case <-ctx.Done():
|
|
svc.Errorf("timeout sending cancel request for task (%s)", t.Id.Hex())
|
|
// Mark as cancelled even if signal couldn't be delivered
|
|
t.Status = constants.TaskStatusCancelled
|
|
t.Error = "cancel signal delivery timeout"
|
|
return svc.SaveTask(t, by)
|
|
}
|
|
|
|
// For force cancellation, wait a bit and verify cancellation
|
|
if force {
|
|
time.Sleep(5 * time.Second)
|
|
// Re-fetch task to check current status
|
|
currentTask, fetchErr := service.NewModelService[models.Task]().GetById(t.Id)
|
|
if fetchErr == nil && currentTask.Status == constants.TaskStatusRunning {
|
|
svc.Warnf("task (%s) still running after force cancel, marking as cancelled", t.Id.Hex())
|
|
currentTask.Status = constants.TaskStatusCancelled
|
|
currentTask.Error = "forced cancellation - task was unresponsive"
|
|
return svc.SaveTask(currentTask, by)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (svc *Service) SetInterval(interval time.Duration) {
|
|
svc.interval = interval
|
|
}
|
|
|
|
func (svc *Service) SaveTask(t *models.Task, by primitive.ObjectID) (err error) {
|
|
if t.Id.IsZero() {
|
|
t.SetCreated(by)
|
|
t.SetUpdated(by)
|
|
_, err = service.NewModelService[models.Task]().InsertOne(*t)
|
|
return err
|
|
} else {
|
|
t.SetUpdated(by)
|
|
return service.NewModelService[models.Task]().ReplaceById(t.Id, *t)
|
|
}
|
|
}
|
|
|
|
// initTaskStatus initialize task status of existing tasks
|
|
func (svc *Service) initTaskStatus() {
|
|
defer svc.wg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
svc.Errorf("initTaskStatus panic recovered: %v", r)
|
|
}
|
|
}()
|
|
|
|
// set status of running tasks as TaskStatusAbnormal
|
|
runningTasks, err := service.NewModelService[models.Task]().GetMany(bson.M{
|
|
"status": bson.M{
|
|
"$in": []string{
|
|
constants.TaskStatusPending,
|
|
constants.TaskStatusAssigned,
|
|
constants.TaskStatusRunning,
|
|
},
|
|
},
|
|
}, nil)
|
|
if err != nil {
|
|
if errors2.Is(err, mongo2.ErrNoDocuments) {
|
|
return
|
|
}
|
|
svc.Errorf("failed to get running tasks: %v", err)
|
|
return
|
|
}
|
|
|
|
// Use bounded worker pool for task updates
|
|
var wg sync.WaitGroup
|
|
semaphore := make(chan struct{}, 10) // Limit concurrent updates
|
|
|
|
for _, t := range runningTasks {
|
|
select {
|
|
case <-svc.ctx.Done():
|
|
svc.Infof("initTaskStatus stopped by context")
|
|
return
|
|
case semaphore <- struct{}{}:
|
|
wg.Add(1)
|
|
go func(task models.Task) {
|
|
defer func() {
|
|
<-semaphore
|
|
wg.Done()
|
|
if r := recover(); r != nil {
|
|
svc.Errorf("task status update panic for task %s: %v", task.Id.Hex(), r)
|
|
}
|
|
}()
|
|
|
|
task.Status = constants.TaskStatusAbnormal
|
|
if err := svc.SaveTask(&task, primitive.NilObjectID); err != nil {
|
|
svc.Errorf("failed to set task status as TaskStatusAbnormal: %s", task.Id.Hex())
|
|
}
|
|
}(t)
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
svc.Infof("initTaskStatus completed")
|
|
}
|
|
|
|
func (svc *Service) isMasterNode(t *models.Task) (ok bool, err error) {
|
|
if t.NodeId.IsZero() {
|
|
err = fmt.Errorf("task %s has no node id", t.Id.Hex())
|
|
svc.Errorf("%v", err)
|
|
return false, err
|
|
}
|
|
n, err := service.NewModelService[models.Node]().GetById(t.NodeId)
|
|
if err != nil {
|
|
if errors2.Is(err, mongo2.ErrNoDocuments) {
|
|
svc.Errorf("node not found: %s", t.NodeId.Hex())
|
|
return false, err
|
|
}
|
|
svc.Errorf("failed to get node: %s", t.NodeId.Hex())
|
|
return false, err
|
|
}
|
|
return n.IsMaster, nil
|
|
}
|
|
|
|
func (svc *Service) cleanupTasks() {
|
|
defer svc.wg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
svc.Errorf("cleanupTasks panic recovered: %v", r)
|
|
}
|
|
}()
|
|
|
|
ticker := time.NewTicker(30 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-svc.ctx.Done():
|
|
svc.Infof("cleanupTasks stopped by context")
|
|
return
|
|
case <-ticker.C:
|
|
svc.performCleanup()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (svc *Service) performCleanup() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
svc.Errorf("performCleanup panic recovered: %v", r)
|
|
}
|
|
}()
|
|
|
|
// task stats over 30 days ago
|
|
taskStats, err := service.NewModelService[models.TaskStat]().GetMany(bson.M{
|
|
"created_at": bson.M{
|
|
"$lt": time.Now().Add(-30 * 24 * time.Hour),
|
|
},
|
|
}, nil)
|
|
if err != nil {
|
|
svc.Errorf("failed to get old task stats: %v", err)
|
|
return
|
|
}
|
|
|
|
// task ids
|
|
var ids []primitive.ObjectID
|
|
for _, ts := range taskStats {
|
|
ids = append(ids, ts.Id)
|
|
}
|
|
|
|
if len(ids) > 0 {
|
|
// remove tasks
|
|
if err := service.NewModelService[models.Task]().DeleteMany(bson.M{
|
|
"_id": bson.M{"$in": ids},
|
|
}); err != nil {
|
|
svc.Warnf("failed to remove tasks: %v", err)
|
|
}
|
|
|
|
// remove task stats
|
|
if err := service.NewModelService[models.TaskStat]().DeleteMany(bson.M{
|
|
"_id": bson.M{"$in": ids},
|
|
}); err != nil {
|
|
svc.Warnf("failed to remove task stats: %v", err)
|
|
}
|
|
|
|
svc.Infof("cleaned up %d old tasks", len(ids))
|
|
}
|
|
}
|
|
|
|
func newTaskSchedulerService() *Service {
|
|
return &Service{
|
|
interval: 5 * time.Second,
|
|
svr: server.GetGrpcServer(),
|
|
handlerSvc: handler.GetTaskHandlerService(),
|
|
Logger: utils.NewLogger("TaskScheduler"),
|
|
}
|
|
}
|
|
|
|
var _service *Service
|
|
var _serviceOnce sync.Once
|
|
|
|
func GetTaskSchedulerService() *Service {
|
|
_serviceOnce.Do(func() {
|
|
_service = newTaskSchedulerService()
|
|
})
|
|
return _service
|
|
}
|