diff --git a/core/controllers/data_source_v2.go b/core/controllers/data_source_v2.go new file mode 100644 index 00000000..6962dc95 --- /dev/null +++ b/core/controllers/data_source_v2.go @@ -0,0 +1,120 @@ +package controllers + +import ( + "github.com/crawlab-team/crawlab/core/ds" + "github.com/crawlab-team/crawlab/core/errors" + "github.com/crawlab-team/crawlab/core/models/models" + "github.com/crawlab-team/crawlab/core/models/service" + "github.com/gin-gonic/gin" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +func PostDataSource(c *gin.Context) { + // data source + var payload struct { + Name string `json:"name"` + Type string `json:"type"` + Description string `json:"description"` + Host string `json:"host"` + Port string `json:"port"` + Url string `json:"url"` + Hosts []string `json:"hosts"` + Database string `json:"database"` + Username string `json:"username"` + Password string `json:"-,omitempty"` + ConnectType string `json:"connect_type"` + Status string `json:"status"` + Error string `json:"error"` + Extra map[string]string `json:"extra,omitempty"` + } + if err := c.ShouldBindJSON(&payload); err != nil { + HandleErrorBadRequest(c, err) + return + } + + u := GetUserFromContextV2(c) + + // add data source to db + dataSource := models.DataSourceV2{ + Name: payload.Name, + Type: payload.Type, + Description: payload.Description, + Host: payload.Host, + Port: payload.Port, + Url: payload.Url, + Hosts: payload.Hosts, + Database: payload.Database, + Username: payload.Username, + Password: payload.Password, + ConnectType: payload.ConnectType, + Status: payload.Status, + Error: payload.Error, + Extra: payload.Extra, + } + dataSource.SetCreated(u.Id) + dataSource.SetUpdated(u.Id) + id, err := service.NewModelServiceV2[models.DataSourceV2]().InsertOne(dataSource) + if err != nil { + HandleErrorInternalServerError(c, err) + return + } + dataSource.Id = id + + // check data source status + go func() { + _ = ds.GetDataSourceServiceV2().CheckStatus(id) + }() + + HandleSuccessWithData(c, dataSource) +} + +func PutDataSourceById(c *gin.Context) { + id, err := primitive.ObjectIDFromHex(c.Param("id")) + if err != nil { + HandleErrorInternalServerError(c, err) + return + } + + // data source + var dataSource models.DataSourceV2 + if err := c.ShouldBindJSON(&dataSource); err != nil { + HandleErrorBadRequest(c, err) + return + } + + err = service.NewModelServiceV2[models.DataSourceV2]().ReplaceById(id, dataSource) + if err != nil { + HandleErrorInternalServerError(c, err) + return + } + + // check data source status + go func() { + _ = ds.GetDataSourceServiceV2().CheckStatus(id) + }() +} + +func PostDataSourceChangePassword(c *gin.Context) { + id, err := primitive.ObjectIDFromHex(c.Param("id")) + if err != nil { + HandleErrorBadRequest(c, err) + return + } + var payload struct { + Password string `json:"password"` + } + if err := c.ShouldBindJSON(&payload); err != nil { + HandleErrorBadRequest(c, err) + return + } + if payload.Password == "" { + HandleErrorBadRequest(c, errors.ErrorDataSourceMissingRequiredFields) + return + } + u := GetUserFromContextV2(c) + if err := ds.GetDataSourceServiceV2().ChangePassword(id, payload.Password, u.Id); err != nil { + HandleErrorInternalServerError(c, err) + return + } + HandleSuccess(c) +} diff --git a/core/controllers/router_v2.go b/core/controllers/router_v2.go index ec23515e..fa8a47a6 100644 --- a/core/controllers/router_v2.go +++ b/core/controllers/router_v2.go @@ -56,10 +56,26 @@ func InitRoutes(app *gin.Engine) (err error) { groups := NewRouterGroups(app) RegisterController(groups.AuthGroup, "/data/collections", NewControllerV2[models.DataCollectionV2]()) - RegisterController(groups.AuthGroup, "/data-sources", NewControllerV2[models.DataSourceV2]()) + RegisterController(groups.AuthGroup, "/data-sources", NewControllerV2[models.DataSourceV2]([]Action{ + { + Method: http.MethodPost, + Path: "", + HandlerFunc: PostDataSource, + }, + { + Method: http.MethodPut, + Path: "/:id", + HandlerFunc: PutDataSourceById, + }, + { + Method: http.MethodPost, + Path: "/:id/change-password", + HandlerFunc: PostDataSourceChangePassword, + }, + }...)) RegisterController(groups.AuthGroup, "/environments", NewControllerV2[models.EnvironmentV2]()) RegisterController(groups.AuthGroup, "/nodes", NewControllerV2[models.NodeV2]()) - RegisterController(groups.AuthGroup, "/notifications/settings", NewControllerV2[models.SettingV2]()) + RegisterController(groups.AuthGroup, "/notifications/settings", NewControllerV2[models.NotificationSettingV2]()) RegisterController(groups.AuthGroup, "/permissions", NewControllerV2[models.PermissionV2]()) RegisterController(groups.AuthGroup, "/projects", NewControllerV2[models.ProjectV2]([]Action{ { diff --git a/core/ds/service_v2.go b/core/ds/service_v2.go new file mode 100644 index 00000000..12093949 --- /dev/null +++ b/core/ds/service_v2.go @@ -0,0 +1,271 @@ +package ds + +import ( + "github.com/apex/log" + "github.com/crawlab-team/crawlab/core/constants" + constants2 "github.com/crawlab-team/crawlab/core/constants" + "github.com/crawlab-team/crawlab/core/models/models" + "github.com/crawlab-team/crawlab/core/models/service" + "github.com/crawlab-team/crawlab/core/result" + "github.com/crawlab-team/crawlab/core/utils" + utils2 "github.com/crawlab-team/crawlab/core/utils" + "github.com/crawlab-team/crawlab/trace" + "go.mongodb.org/mongo-driver/bson/primitive" + "sync" + "time" +) + +type ServiceV2 struct { + // internals + timeout time.Duration + monitorInterval time.Duration + stopped bool +} + +func (svc *ServiceV2) Init() { + // result service registry + reg := result.GetResultServiceRegistry() + + // register result services + reg.Register(constants.DataSourceTypeMongo, NewDataSourceMongoService) + reg.Register(constants.DataSourceTypeMysql, NewDataSourceMysqlService) + reg.Register(constants.DataSourceTypePostgresql, NewDataSourcePostgresqlService) + reg.Register(constants.DataSourceTypeMssql, NewDataSourceMssqlService) + reg.Register(constants.DataSourceTypeSqlite, NewDataSourceSqliteService) + reg.Register(constants.DataSourceTypeCockroachdb, NewDataSourceCockroachdbService) + reg.Register(constants.DataSourceTypeElasticSearch, NewDataSourceElasticsearchService) + reg.Register(constants.DataSourceTypeKafka, NewDataSourceKafkaService) +} + +func (svc *ServiceV2) Start() { + // start monitoring + go svc.Monitor() +} + +func (svc *ServiceV2) Wait() { + utils.DefaultWait() +} + +func (svc *ServiceV2) Stop() { + svc.stopped = true +} + +func (svc *ServiceV2) ChangePassword(id primitive.ObjectID, password string, by primitive.ObjectID) (err error) { + dataSource, err := service.NewModelServiceV2[models.DataSourceV2]().GetById(id) + if err != nil { + return err + } + dataSource.Password, err = utils.EncryptAES(password) + if err != nil { + return err + } + dataSource.SetUpdated(by) + err = service.NewModelServiceV2[models.DataSourceV2]().ReplaceById(id, *dataSource) + if err != nil { + return err + } + return nil +} + +func (svc *ServiceV2) Monitor() { + for { + // return if stopped + if svc.stopped { + return + } + + // monitor + if err := svc.monitor(); err != nil { + trace.PrintError(err) + } + + // wait + time.Sleep(svc.monitorInterval) + } +} + +func (svc *ServiceV2) CheckStatus(id primitive.ObjectID) (err error) { + ds, err := service.NewModelServiceV2[models.DataSourceV2]().GetById(id) + if err != nil { + return err + } + return svc.checkStatus(ds, true) +} + +func (svc *ServiceV2) SetTimeout(duration time.Duration) { + svc.timeout = duration +} + +func (svc *ServiceV2) SetMonitorInterval(duration time.Duration) { + svc.monitorInterval = duration +} + +func (svc *ServiceV2) monitor() (err error) { + // start + tic := time.Now() + log.Debugf("[DataSourceService] start monitoring") + + // data source list + dataSources, err := service.NewModelServiceV2[models.DataSourceV2]().GetMany(nil, nil) + if err != nil { + return err + } + + // waiting group + wg := sync.WaitGroup{} + wg.Add(len(dataSources)) + + // iterate data source list + for _, ds := range dataSources { + // async operation + go func(ds *models.DataSourceV2) { + // check status and save + _ = svc.checkStatus(ds, true) + + // release + wg.Done() + }(&ds) + } + + // wait + wg.Wait() + + // finish + toc := time.Now() + log.Debugf("[DataSourceService] finished monitoring. elapsed: %d ms", (toc.Sub(tic)).Milliseconds()) + + return nil +} + +func (svc *ServiceV2) checkStatus(ds *models.DataSourceV2, save bool) (err error) { + // check status + if err := svc._checkStatus(ds); err != nil { + ds.Status = constants2.DataSourceStatusOffline + ds.Error = err.Error() + } else { + ds.Status = constants2.DataSourceStatusOnline + ds.Error = "" + } + + // save + if save { + return svc._save(ds) + } + + return nil +} + +func (svc *ServiceV2) _save(ds *models.DataSourceV2) (err error) { + log.Debugf("[DataSourceService] saving data source: name=%s, type=%s, status=%s, error=%s", ds.Name, ds.Type, ds.Status, ds.Error) + return service.NewModelServiceV2[models.DataSourceV2]().ReplaceById(ds.Id, *ds) +} + +func (svc *ServiceV2) _checkStatus(ds *models.DataSourceV2) (err error) { + switch ds.Type { + case constants.DataSourceTypeMongo: + _, err := utils2.GetMongoClientWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + case constants.DataSourceTypeMysql: + s, err := utils2.GetMysqlSessionWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + if s != nil { + err := s.Close() + if err != nil { + return err + } + } + case constants.DataSourceTypePostgresql: + s, err := utils2.GetPostgresqlSessionWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + if s != nil { + err := s.Close() + if err != nil { + return err + } + } + case constants.DataSourceTypeMssql: + s, err := utils2.GetMssqlSessionWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + if s != nil { + err := s.Close() + if err != nil { + return err + } + } + case constants.DataSourceTypeSqlite: + s, err := utils2.GetSqliteSessionWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + if s != nil { + err := s.Close() + if err != nil { + return err + } + } + case constants.DataSourceTypeCockroachdb: + s, err := utils2.GetCockroachdbSessionWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + if s != nil { + err := s.Close() + if err != nil { + return err + } + } + case constants.DataSourceTypeElasticSearch: + _, err := utils2.GetElasticsearchClientWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + case constants.DataSourceTypeKafka: + c, err := utils2.GetKafkaConnectionWithTimeoutV2(ds, svc.timeout) + if err != nil { + return err + } + if c != nil { + err := c.Close() + if err != nil { + return err + } + } + default: + log.Warnf("[DataSourceService] invalid data source type: %s", ds.Type) + } + return nil +} + +func NewDataSourceServiceV2() *ServiceV2 { + // service + svc := &ServiceV2{ + monitorInterval: 15 * time.Second, + timeout: 10 * time.Second, + } + + // initialize + svc.Init() + + // start + svc.Start() + + return svc +} + +var _dsSvcV2 *ServiceV2 + +func GetDataSourceServiceV2() *ServiceV2 { + if _dsSvcV2 != nil { + return _dsSvcV2 + } + _dsSvcV2 = NewDataSourceServiceV2() + return _dsSvcV2 +} diff --git a/core/grpc/server/model_base_service_v2_server.go b/core/grpc/server/model_base_service_v2_server.go index 17771925..3050e118 100644 --- a/core/grpc/server/model_base_service_v2_server.go +++ b/core/grpc/server/model_base_service_v2_server.go @@ -26,6 +26,7 @@ var ( *new(models.EnvironmentV2), *new(models.GitV2), *new(models.NodeV2), + *new(models.NotificationSettingV2), *new(models.PermissionV2), *new(models.ProjectV2), *new(models.RolePermissionV2), diff --git a/core/models/models/data_source_v2.go b/core/models/models/data_source_v2.go index eb81e199..16415510 100644 --- a/core/models/models/data_source_v2.go +++ b/core/models/models/data_source_v2.go @@ -12,7 +12,7 @@ type DataSourceV2 struct { Hosts []string `json:"hosts" bson:"hosts"` Database string `json:"database" bson:"database"` Username string `json:"username" bson:"username"` - Password string `json:"password,omitempty" bson:"-"` + Password string `json:"-,omitempty" bson:"password"` ConnectType string `json:"connect_type" bson:"connect_type"` Status string `json:"status" bson:"status"` Error string `json:"error" bson:"error"` diff --git a/core/models/models/notification_setting_v2.go b/core/models/models/notification_setting_v2.go index 6b565410..f4e18362 100644 --- a/core/models/models/notification_setting_v2.go +++ b/core/models/models/notification_setting_v2.go @@ -3,17 +3,19 @@ package models import "go.mongodb.org/mongo-driver/bson/primitive" type NotificationSettingV2 struct { - Id primitive.ObjectID `json:"_id" bson:"_id"` - Type string `json:"type" bson:"type"` - Name string `json:"name" bson:"name"` - Description string `json:"description" bson:"description"` - Enabled bool `json:"enabled" bson:"enabled"` - Global bool `json:"global" bson:"global"` - Title string `json:"title,omitempty" bson:"title,omitempty"` - Template string `json:"template,omitempty" bson:"template,omitempty"` - TaskTrigger string `json:"task_trigger" bson:"task_trigger"` - Mail NotificationSettingMail `json:"mail,omitempty" bson:"mail,omitempty"` - Mobile NotificationSettingMobile `json:"mobile,omitempty" bson:"mobile,omitempty"` + any `collection:"notification_settings"` + BaseModelV2[NotificationSettingV2] `bson:",inline"` + Id primitive.ObjectID `json:"_id" bson:"_id"` + Type string `json:"type" bson:"type"` + Name string `json:"name" bson:"name"` + Description string `json:"description" bson:"description"` + Enabled bool `json:"enabled" bson:"enabled"` + Global bool `json:"global" bson:"global"` + Title string `json:"title,omitempty" bson:"title,omitempty"` + Template string `json:"template,omitempty" bson:"template,omitempty"` + TaskTrigger string `json:"task_trigger" bson:"task_trigger"` + Mail NotificationSettingMail `json:"mail,omitempty" bson:"mail,omitempty"` + Mobile NotificationSettingMobile `json:"mobile,omitempty" bson:"mobile,omitempty"` } type NotificationSettingMail struct { diff --git a/core/models/models/spider_v2.go b/core/models/models/spider_v2.go index 939bad51..ac3bf300 100644 --- a/core/models/models/spider_v2.go +++ b/core/models/models/spider_v2.go @@ -5,7 +5,7 @@ import ( ) type SpiderV2 struct { - any `collection:"spiders"` // spider id + any `collection:"spiders"` BaseModelV2[SpiderV2] `bson:",inline"` Name string `json:"name" bson:"name"` // spider name Type string `json:"type" bson:"type"` // spider type diff --git a/core/node/service/master_service_v2.go b/core/node/service/master_service_v2.go index 13aa69c1..6c168c10 100644 --- a/core/node/service/master_service_v2.go +++ b/core/node/service/master_service_v2.go @@ -36,9 +36,9 @@ type MasterServiceV2 struct { schedulerSvc *scheduler.ServiceV2 handlerSvc *handler.ServiceV2 scheduleSvc *schedule.ServiceV2 - notificationSvc *notification.Service + notificationSvc *notification.ServiceV2 spiderAdminSvc *admin.ServiceV2 - systemSvc *system.Service + systemSvc *system.ServiceV2 // settings cfgPath string @@ -368,7 +368,7 @@ func NewMasterServiceV2() (res interfaces.NodeMasterService, err error) { } // notification service - svc.notificationSvc = notification.GetService() + svc.notificationSvc = notification.GetServiceV2() // spider admin service svc.spiderAdminSvc, err = admin.GetSpiderAdminServiceV2() @@ -377,7 +377,7 @@ func NewMasterServiceV2() (res interfaces.NodeMasterService, err error) { } // system service - svc.systemSvc = system.GetService() + svc.systemSvc = system.GetServiceV2() // init if err := svc.Init(); err != nil { diff --git a/core/notification/service_v2.go b/core/notification/service_v2.go index 6db6a01a..10ddca3e 100644 --- a/core/notification/service_v2.go +++ b/core/notification/service_v2.go @@ -41,7 +41,6 @@ func (svc *ServiceV2) initData() (err error) { // data to initialize settings := []models.NotificationSettingV2{ { - Id: primitive.NewObjectID(), Type: TypeMail, Enabled: true, Name: "任务通知(邮件)", @@ -77,7 +76,6 @@ func (svc *ServiceV2) initData() (err error) { }, }, { - Id: primitive.NewObjectID(), Type: TypeMail, Enabled: true, Name: "Task Change (Mail)", @@ -113,7 +111,6 @@ Please find the task data as below. }, }, { - Id: primitive.NewObjectID(), Type: TypeMobile, Enabled: true, Name: "任务通知(移动端)", @@ -142,7 +139,6 @@ Please find the task data as below. Mobile: models.NotificationSettingMobile{}, }, { - Id: primitive.NewObjectID(), Type: TypeMobile, Enabled: true, Name: "Task Change (Mobile)", diff --git a/core/system/service_v2.go b/core/system/service_v2.go new file mode 100644 index 00000000..e61513ed --- /dev/null +++ b/core/system/service_v2.go @@ -0,0 +1,67 @@ +package system + +import ( + "github.com/crawlab-team/crawlab/core/models/models" + "github.com/crawlab-team/crawlab/core/models/service" + "go.mongodb.org/mongo-driver/bson" +) + +type ServiceV2 struct { +} + +func (svc *ServiceV2) Init() (err error) { + // initialize data + if err := svc.initData(); err != nil { + return err + } + + return nil +} + +func (svc *ServiceV2) initData() (err error) { + total, err := service.NewModelServiceV2[models.SettingV2]().Count(bson.M{ + "key": "site_title", + }) + if err != nil { + return err + } + if total > 0 { + return nil + } + + // data to initialize + settings := []models.SettingV2{ + { + Key: "site_title", + Value: bson.M{ + "customize_site_title": false, + "site_title": "", + }, + }, + } + _, err = service.NewModelServiceV2[models.SettingV2]().InsertMany(settings) + if err != nil { + return err + } + return nil +} + +func NewServiceV2() *ServiceV2 { + // service + svc := &ServiceV2{} + + if err := svc.Init(); err != nil { + panic(err) + } + + return svc +} + +var _serviceV2 *ServiceV2 + +func GetServiceV2() *ServiceV2 { + if _serviceV2 == nil { + _serviceV2 = NewServiceV2() + } + return _serviceV2 +} diff --git a/core/utils/cockroachdb.go b/core/utils/cockroachdb.go index 650c32f1..feccc5ec 100644 --- a/core/utils/cockroachdb.go +++ b/core/utils/cockroachdb.go @@ -58,3 +58,48 @@ func getCockroachdbSession(ctx context.Context, ds *models.DataSource) (s db.Ses return s, err } + +func GetCockroachdbSessionWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (s db.Session, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getCockroachdbSessionV2(ctx, ds) +} + +func getCockroachdbSessionV2(ctx context.Context, ds *models.DataSourceV2) (s db.Session, err error) { + // normalize settings + host := ds.Host + port := ds.Port + if ds.Host == "" { + host = constants.DefaultHost + } + if ds.Port == "" { + port = constants.DefaultCockroachdbPort + } + + // connect settings + settings := mssql.ConnectionURL{ + User: ds.Username, + Password: ds.Password, + Database: ds.Database, + Host: fmt.Sprintf("%s:%s", host, port), + Options: nil, + } + + // session + done := make(chan struct{}) + go func() { + s, err = mssql.Open(settings) + close(done) + }() + + // wait for done + select { + case <-ctx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } + case <-done: + } + + return s, err +} diff --git a/core/utils/es.go b/core/utils/es.go index 6636d0da..7723205b 100644 --- a/core/utils/es.go +++ b/core/utils/es.go @@ -50,34 +50,75 @@ func getElasticsearchClient(ctx context.Context, ds *models.DataSource) (c *elas Addresses: addresses, Username: ds.Username, Password: ds.Password, - //CloudID: "", - //APIKey: "", - //ServiceToken: "", - //CertificateFingerprint: "", - //Header: nil, - //CACert: nil, - //RetryOnStatus: nil, - //DisableRetry: false, - //EnableRetryOnTimeout: false, - //MaxRetries: 0, - //CompressRequestBody: false, - //DiscoverNodesOnStart: false, - //DiscoverNodesInterval: 0, - //EnableMetrics: false, - //EnableDebugLogger: false, - //EnableCompatibilityMode: false, - //DisableMetaHeader: false, - //UseResponseCheckOnly: false, RetryBackoff: func(i int) time.Duration { if i == 1 { rb.Reset() } return rb.NextBackOff() }, - //Transport: nil, - //Logger: nil, - //Selector: nil, - //ConnectionPoolFunc: nil, + } + + // es client + done := make(chan struct{}) + go func() { + c, err = elasticsearch.NewClient(cfg) + if err != nil { + return + } + var res *esapi.Response + res, err = c.Info() + fmt.Println(res) + close(done) + }() + + // wait for done + select { + case <-ctx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } + case <-done: + } + + return c, err +} + +func GetElasticsearchClientWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (c *elasticsearch.Client, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getElasticsearchClientV2(ctx, ds) +} + +func getElasticsearchClientV2(ctx context.Context, ds *models.DataSourceV2) (c *elasticsearch.Client, err error) { + // normalize settings + host := ds.Host + port := ds.Port + if ds.Host == "" { + host = constants.DefaultHost + } + if ds.Port == "" { + port = constants.DefaultElasticsearchPort + } + + // es hosts + addresses := []string{ + fmt.Sprintf("http://%s:%s", host, port), + } + + // retry backoff + rb := backoff.NewExponentialBackOff() + + // es client options + cfg := elasticsearch.Config{ + Addresses: addresses, + Username: ds.Username, + Password: ds.Password, + RetryBackoff: func(i int) time.Duration { + if i == 1 { + rb.Reset() + } + return rb.NextBackOff() + }, } // es client diff --git a/core/utils/kafka.go b/core/utils/kafka.go index 1bf005f2..e3360392 100644 --- a/core/utils/kafka.go +++ b/core/utils/kafka.go @@ -39,3 +39,30 @@ func getKafkaConnection(ctx context.Context, ds *models.DataSource) (c *kafka.Co // kafka connection return kafka.DialLeader(ctx, network, address, topic, partition) } + +func GetKafkaConnectionWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (c *kafka.Conn, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getKafkaConnectionV2(ctx, ds) +} + +func getKafkaConnectionV2(ctx context.Context, ds *models.DataSourceV2) (c *kafka.Conn, err error) { + // normalize settings + host := ds.Host + port := ds.Port + if ds.Host == "" { + host = constants.DefaultHost + } + if ds.Port == "" { + port = constants.DefaultKafkaPort + } + + // kafka connection address + network := "tcp" + address := fmt.Sprintf("%s:%s", host, port) + topic := ds.Database + partition := 0 // TODO: parameterize + + // kafka connection + return kafka.DialLeader(ctx, network, address, topic, partition) +} diff --git a/core/utils/mongo.go b/core/utils/mongo.go index 5807f24a..23e92541 100644 --- a/core/utils/mongo.go +++ b/core/utils/mongo.go @@ -54,6 +54,12 @@ func GetMongoClientWithTimeout(ds *models.DataSource, timeout time.Duration) (c return getMongoClient(ctx, ds) } +func GetMongoClientWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (c *mongo2.Client, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getMongoClientV2(ctx, ds) +} + func getMongoClient(ctx context.Context, ds *models.DataSource) (c *mongo2.Client, err error) { // normalize settings if ds.Host == "" { @@ -92,3 +98,42 @@ func getMongoClient(ctx context.Context, ds *models.DataSource) (c *mongo2.Clien // client return mongo.GetMongoClient(opts...) } + +func getMongoClientV2(ctx context.Context, ds *models.DataSourceV2) (c *mongo2.Client, err error) { + // normalize settings + if ds.Host == "" { + ds.Host = constants.DefaultHost + } + if ds.Port == "" { + ds.Port = constants.DefaultMongoPort + } + + // options + var opts []mongo.ClientOption + opts = append(opts, mongo.WithContext(ctx)) + opts = append(opts, mongo.WithUri(ds.Url)) + opts = append(opts, mongo.WithHost(ds.Host)) + opts = append(opts, mongo.WithPort(ds.Port)) + opts = append(opts, mongo.WithDb(ds.Database)) + opts = append(opts, mongo.WithUsername(ds.Username)) + opts = append(opts, mongo.WithPassword(ds.Password)) + opts = append(opts, mongo.WithHosts(ds.Hosts)) + + // extra + if ds.Extra != nil { + // auth source + authSource, ok := ds.Extra["auth_source"] + if ok { + opts = append(opts, mongo.WithAuthSource(authSource)) + } + + // auth mechanism + authMechanism, ok := ds.Extra["auth_mechanism"] + if ok { + opts = append(opts, mongo.WithAuthMechanism(authMechanism)) + } + } + + // client + return mongo.GetMongoClient(opts...) +} diff --git a/core/utils/mssql.go b/core/utils/mssql.go index 0fb21353..03329eae 100644 --- a/core/utils/mssql.go +++ b/core/utils/mssql.go @@ -58,3 +58,48 @@ func getMssqlSession(ctx context.Context, ds *models.DataSource) (s db.Session, return s, err } + +func GetMssqlSessionWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (s db.Session, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getMssqlSessionV2(ctx, ds) +} + +func getMssqlSessionV2(ctx context.Context, ds *models.DataSourceV2) (s db.Session, err error) { + // normalize settings + host := ds.Host + port := ds.Port + if ds.Host == "" { + host = constants.DefaultHost + } + if ds.Port == "" { + port = constants.DefaultMssqlPort + } + + // connect settings + settings := mssql.ConnectionURL{ + User: ds.Username, + Password: ds.Password, + Database: ds.Database, + Host: fmt.Sprintf("%s:%s", host, port), + Options: nil, + } + + // session + done := make(chan struct{}) + go func() { + s, err = mssql.Open(settings) + close(done) + }() + + // wait for done + select { + case <-ctx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } + case <-done: + } + + return s, err +} diff --git a/core/utils/mysql.go b/core/utils/mysql.go index 8fc9ae18..c1e78a20 100644 --- a/core/utils/mysql.go +++ b/core/utils/mysql.go @@ -58,3 +58,48 @@ func getMysqlSession(ctx context.Context, ds *models.DataSource) (s db.Session, return s, err } + +func GetMysqlSessionWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (s db.Session, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getMysqlSessionV2(ctx, ds) +} + +func getMysqlSessionV2(ctx context.Context, ds *models.DataSourceV2) (s db.Session, err error) { + // normalize settings + host := ds.Host + port := ds.Port + if ds.Host == "" { + host = constants.DefaultHost + } + if ds.Port == "" { + port = constants.DefaultMysqlPort + } + + // connect settings + settings := mysql.ConnectionURL{ + User: ds.Username, + Password: ds.Password, + Database: ds.Database, + Host: fmt.Sprintf("%s:%s", host, port), + Options: nil, + } + + // session + done := make(chan struct{}) + go func() { + s, err = mysql.Open(settings) + close(done) + }() + + // wait for done + select { + case <-ctx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } + case <-done: + } + + return s, err +} diff --git a/core/utils/postgresql.go b/core/utils/postgresql.go index 40c8a208..cf3ba7dc 100644 --- a/core/utils/postgresql.go +++ b/core/utils/postgresql.go @@ -58,3 +58,48 @@ func getPostgresqlSession(ctx context.Context, ds *models.DataSource) (s db.Sess return s, err } + +func GetPostgresqlSessionWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (s db.Session, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getPostgresqlSessionV2(ctx, ds) +} + +func getPostgresqlSessionV2(ctx context.Context, ds *models.DataSourceV2) (s db.Session, err error) { + // normalize settings + host := ds.Host + port := ds.Port + if ds.Host == "" { + host = constants.DefaultHost + } + if ds.Port == "" { + port = constants.DefaultPostgresqlPort + } + + // connect settings + settings := postgresql.ConnectionURL{ + User: ds.Username, + Password: ds.Password, + Database: ds.Database, + Host: fmt.Sprintf("%s:%s", host, port), + Options: nil, + } + + // session + done := make(chan struct{}) + go func() { + s, err = postgresql.Open(settings) + close(done) + }() + + // wait for done + select { + case <-ctx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } + case <-done: + } + + return s, err +} diff --git a/core/utils/sqlite.go b/core/utils/sqlite.go index 1d6ff682..83a06cbc 100644 --- a/core/utils/sqlite.go +++ b/core/utils/sqlite.go @@ -43,3 +43,35 @@ func getSqliteSession(ctx context.Context, ds *models.DataSource) (s db.Session, return s, err } + +func GetSqliteSessionWithTimeoutV2(ds *models.DataSourceV2, timeout time.Duration) (s db.Session, err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return getSqliteSessionV2(ctx, ds) +} + +func getSqliteSessionV2(ctx context.Context, ds *models.DataSourceV2) (s db.Session, err error) { + // connect settings + settings := sqlite.ConnectionURL{ + Database: ds.Database, + Options: nil, + } + + // session + done := make(chan struct{}) + go func() { + s, err = sqlite.Open(settings) + close(done) + }() + + // wait for done + select { + case <-ctx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } + case <-done: + } + + return s, err +}