diff --git a/core/controllers/sync.go b/core/controllers/sync.go index c03230a5..6ade48ab 100644 --- a/core/controllers/sync.go +++ b/core/controllers/sync.go @@ -1,12 +1,20 @@ package controllers import ( - "github.com/crawlab-team/crawlab/core/entity" - "github.com/juju/errors" + "context" "path/filepath" + "sync/atomic" + "github.com/crawlab-team/crawlab/core/entity" "github.com/crawlab-team/crawlab/core/utils" "github.com/gin-gonic/gin" + "github.com/juju/errors" + "golang.org/x/sync/semaphore" +) + +var ( + syncDownloadSemaphore = semaphore.NewWeighted(utils.GetSyncDownloadMaxConcurrency()) + syncDownloadInFlight int64 ) func GetSyncScan(c *gin.Context) (response *Response[entity.FsFileInfoMap], err error) { @@ -14,12 +22,30 @@ func GetSyncScan(c *gin.Context) (response *Response[entity.FsFileInfoMap], err dirPath := filepath.Join(workspacePath, c.Param("id"), c.Param("path")) files, err := utils.ScanDirectory(dirPath) if err != nil { + logger.Warnf("sync scan failed id=%s path=%s: %v", c.Param("id"), c.Param("path"), err) return GetErrorResponse[entity.FsFileInfoMap](err) } return GetDataResponse(files) } func GetSyncDownload(c *gin.Context) (err error) { + ctx := c.Request.Context() + if ctx == nil { + ctx = context.Background() + } + + if err := syncDownloadSemaphore.Acquire(ctx, 1); err != nil { + logger.Warnf("failed to acquire sync download slot for id=%s path=%s: %v", c.Param("id"), c.Query("path"), err) + return errors.Annotate(err, "acquire sync download slot") + } + current := atomic.AddInt64(&syncDownloadInFlight, 1) + logger.Debugf("sync download in-flight=%d id=%s path=%s", current, c.Param("id"), c.Query("path")) + defer func() { + newVal := atomic.AddInt64(&syncDownloadInFlight, -1) + logger.Debugf("sync download completed in-flight=%d id=%s path=%s", newVal, c.Param("id"), c.Query("path")) + syncDownloadSemaphore.Release(1) + }() + workspacePath := utils.GetWorkspace() filePath := filepath.Join(workspacePath, c.Param("id"), c.Query("path")) if !utils.Exists(filePath) { diff --git a/core/node/service/task_reconciliation_service.go b/core/node/service/task_reconciliation_service.go index e3ab0f76..61f82e4a 100644 --- a/core/node/service/task_reconciliation_service.go +++ b/core/node/service/task_reconciliation_service.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "strings" "sync" "time" @@ -26,6 +27,8 @@ type TaskReconciliationService struct { interfaces.Logger } +const staleReconciliationThreshold = 15 * time.Minute + // HandleTasksForOfflineNode updates all running tasks on an offline node to abnormal status func (svc *TaskReconciliationService) HandleTasksForOfflineNode(node *models.Node) { // Find all running tasks on the offline node @@ -51,6 +54,7 @@ func (svc *TaskReconciliationService) HandleTasksForOfflineNode(node *models.Nod for _, task := range runningTasks { task.Status = constants.TaskStatusNodeDisconnected task.Error = "Task temporarily disconnected due to worker node offline" + task.SetUpdated(primitive.NilObjectID) // Update the task in database err := backoff.Retry(func() error { @@ -144,6 +148,7 @@ func (svc *TaskReconciliationService) HandleNodeReconnection(node *models.Node) } // Update the task in database + task.SetUpdated(primitive.NilObjectID) err = backoff.Retry(func() error { return service.NewModelService[models.Task]().ReplaceById(task.Id, task) }, backoff.WithMaxRetries(backoff.NewConstantBackOff(500*time.Millisecond), 3)) @@ -385,12 +390,66 @@ func (svc *TaskReconciliationService) updateTaskStatusReliably(task *models.Task // The disconnect reason should already be in the error field } + task.SetUpdated(primitive.NilObjectID) + // Update with retry logic return backoff.Retry(func() error { return service.NewModelService[models.Task]().ReplaceById(task.Id, *task) }, backoff.WithMaxRetries(backoff.NewConstantBackOff(500*time.Millisecond), 3)) } +func (svc *TaskReconciliationService) shouldMarkTaskAbnormal(task *models.Task) bool { + if task == nil { + return false + } + + if svc.IsTaskStatusFinal(task.Status) { + return false + } + + if task.Status != constants.TaskStatusNodeDisconnected { + return false + } + + lastUpdated := task.UpdatedAt + if lastUpdated.IsZero() { + lastUpdated = task.CreatedAt + } + + if lastUpdated.IsZero() { + return false + } + + return time.Since(lastUpdated) >= staleReconciliationThreshold +} + +func (svc *TaskReconciliationService) markTaskAbnormal(task *models.Task, cause error) error { + if task == nil { + return fmt.Errorf("task is nil") + } + + reasonParts := make([]string, 0, 2) + if cause != nil { + reasonParts = append(reasonParts, fmt.Sprintf("last reconciliation error: %v", cause)) + } + reasonParts = append(reasonParts, fmt.Sprintf("task status not reconciled for %s", staleReconciliationThreshold)) + reason := strings.Join(reasonParts, "; ") + + if task.Error == "" { + task.Error = reason + } else if !strings.Contains(task.Error, reason) { + task.Error = fmt.Sprintf("%s; %s", task.Error, reason) + } + + if err := svc.updateTaskStatusReliably(task, constants.TaskStatusAbnormal); err != nil { + svc.Errorf("failed to mark task[%s] abnormal after reconciliation timeout: %v", task.Id.Hex(), err) + return err + } + + svc.Warnf("marked task[%s] as abnormal after %s of unresolved reconciliation", task.Id.Hex(), staleReconciliationThreshold) + return nil +} + // StartPeriodicReconciliation starts a background service to periodically reconcile task status func (svc *TaskReconciliationService) StartPeriodicReconciliation() { go svc.runPeriodicReconciliation() @@ -451,8 +510,12 @@ func (svc *TaskReconciliationService) reconcileTaskStatus(task *models.Task) err actualStatus, err := svc.GetActualTaskStatusFromWorker(node, task) if err != nil { svc.Warnf("failed to get actual status for task[%s]: %v", task.Id.Hex(), err) - // Don't change the status if we can't determine the actual state - // This is more honest than making assumptions + if svc.shouldMarkTaskAbnormal(task) { + if updateErr := svc.markTaskAbnormal(task, err); updateErr != nil { + return updateErr + } + return nil + } return err } diff --git a/core/task/handler/runner_sync.go b/core/task/handler/runner_sync.go index 2c211220..4850916b 100644 --- a/core/task/handler/runner_sync.go +++ b/core/task/handler/runner_sync.go @@ -5,17 +5,32 @@ import ( "errors" "fmt" "io" + "math/rand" "net/http" "net/url" "os" "path/filepath" "strings" "sync" + "time" "github.com/crawlab-team/crawlab/core/entity" "github.com/crawlab-team/crawlab/core/utils" ) +const ( + syncHTTPRequestMaxRetries = 5 + syncHTTPRequestInitialBackoff = 200 * time.Millisecond + syncHTTPRequestMaxBackoff = 3 * time.Second + syncHTTPRequestClientTimeout = 30 * time.Second +) + +var ( + syncHttpClient = &http.Client{Timeout: syncHTTPRequestClientTimeout} + jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) + jitterMutex sync.Mutex +) + // syncFiles synchronizes files between master and worker nodes: // 1. Gets file list from master // 2. Compares with local files @@ -136,10 +151,8 @@ func (r *Runner) syncFiles() (err error) { } func (r *Runner) performHttpRequest(method, path string, params url.Values) (*http.Response, error) { - // Normalize path path = strings.TrimPrefix(path, "/") - // Construct master URL var id string if r.s.GitId.IsZero() { id = r.s.Id.Hex() @@ -148,17 +161,74 @@ func (r *Runner) performHttpRequest(method, path string, params url.Values) (*ht } requestUrl := fmt.Sprintf("%s/sync/%s/%s?%s", utils.GetApiEndpoint(), id, path, params.Encode()) - // Create and execute request - req, err := http.NewRequest(method, requestUrl, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + backoff := syncHTTPRequestInitialBackoff + var lastErr error + + for attempt := range syncHTTPRequestMaxRetries { + req, err := http.NewRequest(method, requestUrl, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + for k, v := range r.getHttpRequestHeaders() { + req.Header.Set(k, v) + } + + resp, err := syncHttpClient.Do(req) + if err == nil && !shouldRetryStatus(resp.StatusCode) { + return resp, nil + } + + if err == nil { + lastErr = fmt.Errorf("received retryable status %d for %s", resp.StatusCode, requestUrl) + _ = resp.Body.Close() + } else { + lastErr = err + } + + wait := backoff + jitterDuration(backoff/2) + if wait > syncHTTPRequestMaxBackoff { + wait = syncHTTPRequestMaxBackoff + } + + r.Warnf("retrying %s %s in %s (attempt %d/%d): %v", method, requestUrl, wait, attempt+1, syncHTTPRequestMaxRetries, lastErr) + time.Sleep(wait) + + if backoff < syncHTTPRequestMaxBackoff { + backoff *= 2 + if backoff > syncHTTPRequestMaxBackoff { + backoff = syncHTTPRequestMaxBackoff + } + } } - for k, v := range r.getHttpRequestHeaders() { - req.Header.Set(k, v) + if lastErr == nil { + lastErr = fmt.Errorf("exceeded max retries for %s", requestUrl) } + return nil, lastErr +} - return http.DefaultClient.Do(req) +func shouldRetryStatus(status int) bool { + if status == http.StatusTooManyRequests || status == http.StatusRequestTimeout || status == http.StatusTooEarly { + return true + } + switch status { + case http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + return true + } + return status >= 500 +} + +func jitterDuration(max time.Duration) time.Duration { + if max <= 0 { + return 0 + } + jitterMutex.Lock() + defer jitterMutex.Unlock() + return time.Duration(jitterRand.Int63n(int64(max))) } // downloadFile downloads a file from the master node and saves it to the local file system diff --git a/core/utils/config.go b/core/utils/config.go index 5a549c21..7856182e 100644 --- a/core/utils/config.go +++ b/core/utils/config.go @@ -12,36 +12,38 @@ import ( ) const ( - DefaultWorkspace = "crawlab_workspace" - DefaultTaskLogPath = "/var/log/crawlab/tasks" - DefaultServerHost = "0.0.0.0" - DefaultServerPort = 8000 - DefaultGrpcHost = "localhost" - DefaultGrpcPort = 9666 - DefaultGrpcServerHost = "0.0.0.0" - DefaultGrpcServerPort = 9666 - DefaultAuthKey = "Crawlab2024!" - DefaultApiEndpoint = "http://localhost:8000" - DefaultApiAllowOrigin = "*" - DefaultApiAllowCredentials = "true" - DefaultApiAllowMethods = "DELETE, POST, OPTIONS, GET, PUT" - DefaultApiAllowHeaders = "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With" - DefaultApiPort = 8080 - DefaultApiPath = "/api" - DefaultNodeMaxRunners = 20 // Default max concurrent task runners per node - DefaultTaskQueueSize = 100 // Default task queue size per node - DefaultInstallRoot = "/app/install" - DefaultInstallEnvs = "" - MetadataConfigDirName = ".crawlab" - MetadataConfigName = "config.json" - DefaultPyenvPath = "/root/.pyenv" - DefaultNodeModulesPath = "/usr/lib/node_modules" - DefaultNodeBinPath = "/usr/lib/node_bin" - DefaultGoPath = "/root/go" - DefaultMCPServerHost = "0.0.0.0" - DefaultMCPServerPort = 9777 - DefaultMCPClientBaseUrl = "http://localhost:9777/sse" - DefaultOpenAPIUrlPath = "/openapi.json" + DefaultWorkspace = "crawlab_workspace" + DefaultTaskLogPath = "/var/log/crawlab/tasks" + DefaultServerHost = "0.0.0.0" + DefaultServerPort = 8000 + DefaultGrpcHost = "localhost" + DefaultGrpcPort = 9666 + DefaultGrpcServerHost = "0.0.0.0" + DefaultGrpcServerPort = 9666 + DefaultAuthKey = "Crawlab2024!" + DefaultApiEndpoint = "http://localhost:8000" + DefaultApiAllowOrigin = "*" + DefaultApiAllowCredentials = "true" + DefaultApiAllowMethods = "DELETE, POST, OPTIONS, GET, PUT" + DefaultApiAllowHeaders = "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With" + DefaultApiPort = 8080 + DefaultApiPath = "/api" + DefaultNodeMaxRunners = 20 // Default max concurrent task runners per node + DefaultTaskQueueSize = 100 // Default task queue size per node + DefaultInstallRoot = "/app/install" + DefaultInstallEnvs = "" + MetadataConfigDirName = ".crawlab" + MetadataConfigName = "config.json" + DefaultPyenvPath = "/root/.pyenv" + DefaultNodeModulesPath = "/usr/lib/node_modules" + DefaultNodeBinPath = "/usr/lib/node_bin" + DefaultGoPath = "/root/go" + DefaultMCPServerHost = "0.0.0.0" + DefaultMCPServerPort = 9777 + DefaultMCPClientBaseUrl = "http://localhost:9777/sse" + DefaultOpenAPIUrlPath = "/openapi.json" + DefaultSyncDownloadMaxConcurrency = 16 + DefaultMinFileDescriptorLimit = 8192 ) func IsDev() bool { @@ -332,3 +334,17 @@ func GetOpenAPIUrl() string { } return GetApiEndpoint() + DefaultOpenAPIUrlPath } + +func GetSyncDownloadMaxConcurrency() int64 { + if res := viper.GetInt("sync.download.max_concurrency"); res > 0 { + return int64(res) + } + return int64(DefaultSyncDownloadMaxConcurrency) +} + +func GetMinFileDescriptorLimit() uint64 { + if res := viper.GetUint64("system.fd_min"); res > 0 { + return res + } + return DefaultMinFileDescriptorLimit +} diff --git a/core/utils/file.go b/core/utils/file.go index 21328607..175eac87 100644 --- a/core/utils/file.go +++ b/core/utils/file.go @@ -7,12 +7,16 @@ import ( "fmt" "io" "io/fs" + "maps" "os" "path" "path/filepath" "regexp" + "sync" + "time" "github.com/crawlab-team/crawlab/core/entity" + "golang.org/x/sync/singleflight" ) func OpenFile(fileName string) *os.File { @@ -184,11 +188,54 @@ func GetFileHash(filePath string) (res string, err error) { } const IgnoreFileRegexPattern = `(^node_modules|__pycache__)/|\.(tmp|temp|log|swp|swo|bak|orig|lock|pid|pyc|pyo)$` +const scanDirectoryCacheTTL = 3 * time.Second -func ScanDirectory(dir string) (res entity.FsFileInfoMap, err error) { +var ( + scanDirectoryGroup singleflight.Group + scanDirectoryCache = struct { + sync.RWMutex + items map[string]scanDirectoryCacheEntry + }{items: make(map[string]scanDirectoryCacheEntry)} +) + +type scanDirectoryCacheEntry struct { + data entity.FsFileInfoMap + expiresAt time.Time +} + +func ScanDirectory(dir string) (entity.FsFileInfoMap, error) { + if res, ok := getScanDirectoryCache(dir); ok { + return cloneFsFileInfoMap(res), nil + } + + v, err, _ := scanDirectoryGroup.Do(dir, func() (any, error) { + if res, ok := getScanDirectoryCache(dir); ok { + return cloneFsFileInfoMap(res), nil + } + + files, err := scanDirectoryInternal(dir) + if err != nil { + return nil, err + } + + setScanDirectoryCache(dir, files) + return cloneFsFileInfoMap(files), nil + }) + if err != nil { + return nil, err + } + + res, ok := v.(entity.FsFileInfoMap) + if !ok { + return nil, fmt.Errorf("unexpected cache value type: %T", v) + } + + return cloneFsFileInfoMap(res), nil +} + +func scanDirectoryInternal(dir string) (entity.FsFileInfoMap, error) { files := make(entity.FsFileInfoMap) - // Compile the ignore pattern regex ignoreRegex, err := regexp.Compile(IgnoreFileRegexPattern) if err != nil { return nil, fmt.Errorf("failed to compile ignore pattern: %v", err) @@ -204,7 +251,6 @@ func ScanDirectory(dir string) (res entity.FsFileInfoMap, err error) { return err } - // Skip files that match the ignore pattern if ignoreRegex.MatchString(relPath) { if info.IsDir() { return filepath.SkipDir @@ -239,3 +285,33 @@ func ScanDirectory(dir string) (res entity.FsFileInfoMap, err error) { return files, nil } + +func getScanDirectoryCache(dir string) (entity.FsFileInfoMap, bool) { + scanDirectoryCache.RLock() + defer scanDirectoryCache.RUnlock() + + entry, ok := scanDirectoryCache.items[dir] + if !ok || time.Now().After(entry.expiresAt) { + return nil, false + } + return entry.data, true +} + +func setScanDirectoryCache(dir string, data entity.FsFileInfoMap) { + scanDirectoryCache.Lock() + defer scanDirectoryCache.Unlock() + + scanDirectoryCache.items[dir] = scanDirectoryCacheEntry{ + data: data, + expiresAt: time.Now().Add(scanDirectoryCacheTTL), + } +} + +func cloneFsFileInfoMap(src entity.FsFileInfoMap) entity.FsFileInfoMap { + if src == nil { + return nil + } + dst := make(entity.FsFileInfoMap, len(src)) + maps.Copy(dst, src) + return dst +} diff --git a/core/utils/rlimit_stub.go b/core/utils/rlimit_stub.go new file mode 100644 index 00000000..5e38adf9 --- /dev/null +++ b/core/utils/rlimit_stub.go @@ -0,0 +1,7 @@ +//go:build windows || plan9 + +package utils + +func EnsureFileDescriptorLimit(_ uint64) { + // no-op on unsupported platforms +} diff --git a/core/utils/rlimit_unix.go b/core/utils/rlimit_unix.go new file mode 100644 index 00000000..b9027dd5 --- /dev/null +++ b/core/utils/rlimit_unix.go @@ -0,0 +1,30 @@ +//go:build !windows && !plan9 + +package utils + +import "syscall" + +func EnsureFileDescriptorLimit(min uint64) { + var rLimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil { + logger.Warnf("failed to get rlimit: %v", err) + return + } + + if rLimit.Cur >= min { + return + } + + newLimit := min + if rLimit.Max < newLimit { + rLimit.Max = newLimit + } + rLimit.Cur = newLimit + + if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit); err != nil { + logger.Warnf("failed to raise rlimit to %d: %v", newLimit, err) + return + } + + logger.Infof("increased file descriptor limit to %d", newLimit) +}