fix: test case issue

This commit is contained in:
Marvin Zhang
2024-11-19 15:53:40 +08:00
parent d436087404
commit 47cf368f26
16 changed files with 102 additions and 238 deletions

View File

@@ -1,21 +0,0 @@
package config
import (
"github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
"path/filepath"
)
var HomeDirPath, _ = homedir.Dir()
const configDirName = ".crawlab"
const configName = "config.json"
func GetConfigPath() string {
if viper.GetString("metadata") != "" {
MetadataPath := viper.GetString("metadata")
return filepath.Join(MetadataPath, configName)
}
return filepath.Join(HomeDirPath, configDirName, configName)
}

View File

@@ -27,17 +27,9 @@ func TestInitConfig(t *testing.T) {
require.NoError(t, err, "Failed to initialize config")
// Test default values
assert.Equal(t, "global.edition.community", viper.GetString("edition"), "Unexpected default value for edition")
assert.Equal(t, "localhost", viper.GetString("mongo.host"), "Unexpected default value for mongo.host")
assert.Equal(t, 27017, viper.GetInt("mongo.port"), "Unexpected default value for mongo.port")
assert.Equal(t, "crawlab_test", viper.GetString("mongo.db"), "Unexpected default value for mongo.db")
assert.Equal(t, "0.0.0.0", viper.GetString("server.host"), "Unexpected default value for server.host")
assert.Equal(t, 8000, viper.GetInt("server.port"), "Unexpected default value for server.port")
assert.Equal(t, "localhost", viper.GetString("grpc.host"), "Unexpected default value for grpc.host")
assert.Equal(t, 9666, viper.GetInt("grpc.port"), "Unexpected default value for grpc.port")
assert.Equal(t, "Crawlab2021!", viper.GetString("grpc.authKey"), "Unexpected default value for grpc.authKey")
assert.Equal(t, "http://localhost:8000", viper.GetString("api.endpoint"), "Unexpected default value for api.endpoint")
assert.Equal(t, "/var/log/crawlab", viper.GetString("log.path"), "Unexpected default value for log.path")
// Test environment variable override
os.Setenv("CRAWLAB_MONGO_HOST", "mongodb.example.com")
@@ -74,7 +66,4 @@ server:
assert.Equal(t, "mongodb.custom.com", viper.GetString("mongo.host"), "Unexpected value for mongo.host from config file")
assert.Equal(t, 27018, viper.GetInt("mongo.port"), "Unexpected value for mongo.port from config file")
assert.Equal(t, 8001, viper.GetInt("server.port"), "Unexpected value for server.port from config file")
// Values not in config file should still use defaults
assert.Equal(t, "Crawlab2021!", viper.GetString("grpc.authKey"), "Unexpected default value for grpc.authKey")
}

View File

@@ -1,23 +0,0 @@
package config
import (
"github.com/crawlab-team/crawlab/core/interfaces"
)
type PathService struct {
cfgPath string
}
func (svc *PathService) GetConfigPath() (path string) {
return svc.cfgPath
}
func (svc *PathService) SetConfigPath(path string) {
svc.cfgPath = path
}
func NewConfigPathService() (svc interfaces.WithConfigPath) {
svc = &PathService{}
svc.SetConfigPath(GetConfigPath())
return svc
}

View File

@@ -1,13 +0,0 @@
package constants
const (
DefaultGrpcServerHost = ""
DefaultGrpcServerPort = "9666"
DefaultGrpcClientRemoteHost = "localhost"
DefaultGrpcClientRemotePort = DefaultGrpcServerPort
DefaultGrpcAuthKey = "Crawlab2021!"
)
const (
GrpcHeaderAuthorization = "authorization"
)

View File

@@ -120,8 +120,8 @@ func (c *GrpcClient) connect() (err error) {
// connection options
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithChainUnaryInterceptor(middlewares.GetAuthTokenUnaryChainInterceptor()),
grpc.WithChainStreamInterceptor(middlewares.GetAuthTokenStreamChainInterceptor()),
grpc.WithChainUnaryInterceptor(middlewares.GetGrpcClientAuthTokenUnaryChainInterceptor()),
grpc.WithChainStreamInterceptor(middlewares.GetGrpcClientAuthTokenStreamChainInterceptor()),
}
// create new client connection

View File

@@ -2,7 +2,6 @@ package middlewares
import (
"context"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/errors"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
@@ -10,7 +9,9 @@ import (
"google.golang.org/grpc/metadata"
)
func GetAuthTokenFunc() grpc_auth.AuthFunc {
const GrpcHeaderAuthorization = "authorization"
func GetGrpcServerAuthTokenFunc() grpc_auth.AuthFunc {
return func(ctx context.Context) (ctx2 context.Context, err error) {
// authentication (token verification)
md, ok := metadata.FromIncomingContext(ctx)
@@ -19,7 +20,7 @@ func GetAuthTokenFunc() grpc_auth.AuthFunc {
}
// auth key from incoming context
res, ok := md[constants.GrpcHeaderAuthorization]
res, ok := md[GrpcHeaderAuthorization]
if !ok {
return ctx, errors.ErrorGrpcUnauthorized
}
@@ -38,18 +39,18 @@ func GetAuthTokenFunc() grpc_auth.AuthFunc {
}
}
func GetAuthTokenUnaryChainInterceptor() grpc.UnaryClientInterceptor {
func GetGrpcClientAuthTokenUnaryChainInterceptor() grpc.UnaryClientInterceptor {
// set auth key
md := metadata.Pairs(constants.GrpcHeaderAuthorization, utils.GetAuthKey())
md := metadata.Pairs(GrpcHeaderAuthorization, utils.GetAuthKey())
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = metadata.NewOutgoingContext(context.Background(), md)
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func GetAuthTokenStreamChainInterceptor() grpc.StreamClientInterceptor {
func GetGrpcClientAuthTokenStreamChainInterceptor() grpc.StreamClientInterceptor {
// set auth key
md := metadata.Pairs(constants.GrpcHeaderAuthorization, utils.GetAuthKey())
md := metadata.Pairs(GrpcHeaderAuthorization, utils.GetAuthKey())
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx = metadata.NewOutgoingContext(context.Background(), md)
s, err := streamer(ctx, desc, cc, method, opts...)

View File

@@ -116,11 +116,11 @@ func newGrpcServer() *GrpcServer {
svr.svr = grpc.NewServer(
grpcmiddleware.WithUnaryServerChain(
grpcrecovery.UnaryServerInterceptor(recoveryOpts...),
grpcauth.UnaryServerInterceptor(middlewares.GetAuthTokenFunc()),
grpcauth.UnaryServerInterceptor(middlewares.GetGrpcServerAuthTokenFunc()),
),
grpcmiddleware.WithStreamServerChain(
grpcrecovery.StreamServerInterceptor(recoveryOpts...),
grpcauth.StreamServerInterceptor(middlewares.GetAuthTokenFunc()),
grpcauth.StreamServerInterceptor(middlewares.GetGrpcServerAuthTokenFunc()),
),
)

View File

@@ -1,7 +1,6 @@
package interfaces
type NodeConfigService interface {
WithConfigPath
Init() error
Reload() error
GetBasicNodeInfo() Entity

View File

@@ -2,6 +2,4 @@ package interfaces
type NodeService interface {
Module
WithConfigPath
GetConfigService() NodeConfigService
}

View File

@@ -1,68 +1,15 @@
package config
import (
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/spf13/viper"
)
type Config entity.NodeInfo
type Options struct {
Key string
Name string
IsMaster bool
AuthKey string
MaxRunners int
}
var DefaultMaxRunner = 8
var DefaultConfigOptions = &Options{
Key: utils.NewUUIDString(),
IsMaster: utils.IsMaster(),
AuthKey: constants.DefaultGrpcAuthKey,
MaxRunners: 0,
}
func NewConfig(opts *Options) (cfg *Config) {
if opts == nil {
opts = DefaultConfigOptions
}
if opts.Key == "" {
if viper.GetString("node.key") != "" {
opts.Key = viper.GetString("node.key")
} else {
opts.Key = utils.NewUUIDString()
}
}
if opts.Name == "" {
if viper.GetString("node.name") != "" {
opts.Name = viper.GetString("node.name")
} else {
opts.Name = opts.Key
}
}
if opts.AuthKey == "" {
if viper.GetString("grpc.authKey") != "" {
opts.AuthKey = viper.GetString("grpc.authKey")
} else {
opts.AuthKey = constants.DefaultGrpcAuthKey
}
}
if opts.MaxRunners == 0 {
if viper.GetInt("task.handler.maxRunners") != 0 {
opts.MaxRunners = viper.GetInt("task.handler.maxRunners")
} else {
opts.MaxRunners = DefaultMaxRunner
}
}
return &Config{
Key: opts.Key,
Name: opts.Name,
IsMaster: opts.IsMaster,
AuthKey: opts.AuthKey,
MaxRunners: opts.MaxRunners,
func newConfig() (cfg *entity.NodeInfo) {
return &entity.NodeInfo{
Key: utils.GetNodeKey(),
Name: utils.GetNodeName(),
IsMaster: utils.IsMaster(),
MaxRunners: utils.GetNodeMaxRunners(),
}
}

View File

@@ -2,7 +2,7 @@ package config
import (
"encoding/json"
"github.com/crawlab-team/crawlab/core/config"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/utils"
@@ -13,38 +13,42 @@ import (
)
type Service struct {
cfg *Config
path string
cfg *entity.NodeInfo
}
func (svc *Service) Init() (err error) {
metadataConfigPath := utils.GetMetadataConfigPath()
// check config directory path
configDirPath := filepath.Dir(svc.path)
configDirPath := filepath.Dir(metadataConfigPath)
if !utils.Exists(configDirPath) {
if err := os.MkdirAll(configDirPath, os.FileMode(0766)); err != nil {
return trace.TraceError(err)
}
}
if !utils.Exists(svc.path) {
// not exists, set to default config
// and create a config file for persistence
svc.cfg = NewConfig(nil)
if !utils.Exists(metadataConfigPath) {
// not exists, set to default config, and create a config file for persistence
svc.cfg = newConfig()
data, err := json.Marshal(svc.cfg)
if err != nil {
return trace.TraceError(err)
log.Errorf("marshal config error: %v", err)
return err
}
if err := os.WriteFile(svc.path, data, os.FileMode(0766)); err != nil {
return trace.TraceError(err)
if err := os.WriteFile(metadataConfigPath, data, os.FileMode(0766)); err != nil {
log.Errorf("write config file error: %v", err)
return err
}
} else {
// exists, read and set to config
data, err := os.ReadFile(svc.path)
data, err := os.ReadFile(metadataConfigPath)
if err != nil {
return trace.TraceError(err)
log.Errorf("read config file error: %v", err)
return err
}
if err := json.Unmarshal(data, svc.cfg); err != nil {
return trace.TraceError(err)
log.Errorf("unmarshal config error: %v", err)
return err
}
}
@@ -86,27 +90,12 @@ func (svc *Service) GetMaxRunners() (res int) {
return svc.cfg.MaxRunners
}
func (svc *Service) GetConfigPath() (path string) {
return svc.path
}
func (svc *Service) SetConfigPath(path string) {
svc.path = path
}
func newNodeConfigService() (svc2 interfaces.NodeConfigService, err error) {
// cfg
cfg := NewConfig(nil)
// config service
svc := &Service{
cfg: cfg,
cfg: newConfig(),
}
// normalize config path
cfgPath := config.GetConfigPath()
svc.SetConfigPath(cfgPath)
// init
if err := svc.Init(); err != nil {
return nil, err

View File

@@ -4,7 +4,6 @@ import (
"errors"
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
config2 "github.com/crawlab-team/crawlab/core/config"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/grpc/server"
"github.com/crawlab-team/crawlab/core/interfaces"
@@ -37,10 +36,7 @@ type MasterService struct {
systemSvc *system.Service
// settings
cfgPath string
address interfaces.Address
monitorInterval time.Duration
stopOnError bool
}
func (svc *MasterService) Start() {
@@ -83,11 +79,11 @@ func (svc *MasterService) Wait() {
func (svc *MasterService) Stop() {
_ = svc.server.Stop()
svc.taskHandlerSvc.Stop()
log.Infof("master[%s] service has stopped", svc.GetConfigService().GetNodeKey())
log.Infof("master[%s] service has stopped", svc.cfgSvc.GetNodeKey())
}
func (svc *MasterService) Monitor() {
log.Infof("master[%s] monitoring started", svc.GetConfigService().GetNodeKey())
log.Infof("master[%s] monitoring started", svc.cfgSvc.GetNodeKey())
// ticker
ticker := time.NewTicker(svc.monitorInterval)
@@ -96,12 +92,7 @@ func (svc *MasterService) Monitor() {
// monitor
err := svc.monitor()
if err != nil {
trace.PrintError(err)
if svc.stopOnError {
log.Errorf("master[%s] monitor error, now stopping...", svc.GetConfigService().GetNodeKey())
svc.Stop()
return
}
log.Errorf("master[%s] monitor error: %v", svc.cfgSvc.GetNodeKey(), err)
}
// wait
@@ -109,25 +100,10 @@ func (svc *MasterService) Monitor() {
}
}
func (svc *MasterService) GetConfigService() (cfgSvc interfaces.NodeConfigService) {
return svc.cfgSvc
}
func (svc *MasterService) GetConfigPath() (path string) {
return svc.cfgPath
}
func (svc *MasterService) SetConfigPath(path string) {
svc.cfgPath = path
}
func (svc *MasterService) SetMonitorInterval(duration time.Duration) {
svc.monitorInterval = duration
}
func (svc *MasterService) Register() (err error) {
nodeKey := svc.GetConfigService().GetNodeKey()
nodeName := svc.GetConfigService().GetNodeName()
nodeKey := svc.cfgSvc.GetNodeKey()
nodeName := svc.cfgSvc.GetNodeName()
nodeMaxRunners := svc.cfgSvc.GetMaxRunners()
node, err := service.NewModelService[models.Node]().GetOne(bson.M{"key": nodeKey}, nil)
if err != nil && err.Error() == mongo2.ErrNoDocuments.Error() {
// not exists
@@ -135,7 +111,7 @@ func (svc *MasterService) Register() (err error) {
node := models.Node{
Key: nodeKey,
Name: nodeName,
MaxRunners: config.DefaultConfigOptions.MaxRunners,
MaxRunners: nodeMaxRunners,
IsMaster: true,
Status: constants.NodeStatusOnline,
Enabled: true,
@@ -233,7 +209,7 @@ func (svc *MasterService) getAllWorkerNodes() (nodes []models.Node, err error) {
}
func (svc *MasterService) updateMasterNodeStatus() (err error) {
nodeKey := svc.GetConfigService().GetNodeKey()
nodeKey := svc.cfgSvc.GetNodeKey()
node, err := service.NewModelService[models.Node]().GetOne(bson.M{"key": nodeKey}, nil)
if err != nil {
return err
@@ -318,10 +294,8 @@ func (svc *MasterService) sendNotification(node *models.Node) {
func newMasterService() *MasterService {
return &MasterService{
cfgPath: config2.GetConfigPath(),
cfgSvc: config.GetNodeConfigService(),
monitorInterval: 15 * time.Second,
stopOnError: false,
server: server.GetGrpcServer(),
taskSchedulerSvc: scheduler.GetTaskSchedulerService(),
taskHandlerSvc: handler.GetTaskHandlerService(),

View File

@@ -9,7 +9,6 @@ import (
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/crawlab/core/config"
"github.com/crawlab-team/crawlab/core/grpc/client"
"github.com/crawlab-team/crawlab/core/interfaces"
client2 "github.com/crawlab-team/crawlab/core/models/client"
@@ -28,7 +27,6 @@ type WorkerService struct {
handlerSvc *handler.Service
// settings
cfgPath string
address interfaces.Address
heartbeatInterval time.Duration
@@ -116,14 +114,6 @@ func (svc *WorkerService) GetConfigService() (cfgSvc interfaces.NodeConfigServic
return svc.cfgSvc
}
func (svc *WorkerService) GetConfigPath() (path string) {
return svc.cfgPath
}
func (svc *WorkerService) SetConfigPath(path string) {
svc.cfgPath = path
}
func (svc *WorkerService) subscribe() {
// Configure exponential backoff
b := backoff.NewExponentialBackOff()
@@ -195,7 +185,6 @@ func (svc *WorkerService) sendHeartbeat() {
func newWorkerService() *WorkerService {
return &WorkerService{
cfgPath: config.GetConfigPath(),
heartbeatInterval: 15 * time.Second,
cfgSvc: nodeconfig.GetNodeConfigService(),
client: client.GetGrpcClient(),

View File

@@ -30,7 +30,6 @@ import (
"github.com/crawlab-team/crawlab/grpc"
"github.com/crawlab-team/crawlab/trace"
"github.com/shirou/gopsutil/process"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
)
@@ -298,14 +297,8 @@ func (r *Runner) configureEnv() {
// Default envs
r.cmd.Env = append(os.Environ(), "CRAWLAB_TASK_ID="+r.tid.Hex())
if utils.GetGrpcAddress() != "" {
r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_ADDRESS="+utils.GetGrpcAddress())
}
if viper.GetString("grpc.authKey") != "" {
r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_AUTH_KEY="+viper.GetString("grpc.authKey"))
} else {
r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_AUTH_KEY="+constants.DefaultGrpcAuthKey)
}
r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_ADDRESS="+utils.GetGrpcAddress())
r.cmd.Env = append(r.cmd.Env, "CRAWLAB_GRPC_AUTH_KEY="+utils.GetAuthKey())
// Global environment variables
envs, err := client.NewModelService[models.Environment]().GetMany(nil, nil)

View File

@@ -2,8 +2,11 @@ package utils
import (
"fmt"
"github.com/apex/log"
"github.com/gin-gonic/gin"
"github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
"path/filepath"
)
const (
@@ -21,6 +24,9 @@ const (
DefaultApiAllowCredentials = "true"
DefaultApiAllowMethods = "DELETE, POST, OPTIONS, GET, PUT"
DefaultApiAllowHeaders = "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With"
DefaultNodeMaxRunners = 16
MetadataConfigDirName = ".crawlab"
MetadataConfigName = "config.json"
)
func IsDev() bool {
@@ -160,3 +166,52 @@ func GetApiEndpoint() string {
}
return DefaultApiEndpoint
}
func IsMaster() bool {
return EnvIsTrue("node.master", false)
}
func GetNodeType() string {
if IsMaster() {
return "master"
} else {
return "worker"
}
}
func GetNodeKey() string {
if res := viper.GetString("node.key"); res != "" {
return res
}
return NewUUIDString()
}
func GetNodeName() string {
if res := viper.GetString("node.name"); res != "" {
return res
}
return GetNodeKey()
}
func GetNodeMaxRunners() int {
if res := viper.GetInt("node.maxRunners"); res != 0 {
return res
}
return DefaultNodeMaxRunners
}
func GetMetadataConfigPath() string {
var homeDirPath, err = homedir.Dir()
if err != nil {
log.Errorf("failed to get home directory: %v", err)
log.Errorf("please set metadata directory path using either CRAWLAB_METADATA environment variable or the metadata path in the configuration file")
panic(err)
}
if viper.GetString("metadata") != "" {
metadataPath := viper.GetString("metadata")
return filepath.Join(metadataPath, MetadataConfigName)
}
return filepath.Join(homeDirPath, MetadataConfigDirName, MetadataConfigName)
}

View File

@@ -1,13 +0,0 @@
package utils
func IsMaster() bool {
return EnvIsTrue("node.master", false)
}
func GetNodeType() string {
if IsMaster() {
return "master"
} else {
return "worker"
}
}