package mongo import ( "context" "encoding/json" "fmt" "github.com/cenkalti/backoff/v4" "github.com/spf13/viper" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "sync" "time" ) var AppName = "crawlab-db" var _clientMap = map[string]*mongo.Client{} var _mu sync.Mutex func GetMongoClient(opts ...ClientOption) (c *mongo.Client, err error) { // client options _opts := &ClientOptions{} for _, op := range opts { op(_opts) } if _opts.Uri == "" { _opts.Uri = viper.GetString("mongo.uri") } if _opts.Host == "" { _opts.Host = viper.GetString("mongo.host") if _opts.Host == "" { _opts.Host = "localhost" } } if _opts.Port == 0 { _opts.Port = viper.GetInt("mongo.port") if _opts.Port == 0 { _opts.Port = 27017 } } if _opts.Db == "" { _opts.Db = viper.GetString("mongo.db") if _opts.Db == "" { _opts.Db = "admin" } } if len(_opts.Hosts) == 0 { _opts.Hosts = viper.GetStringSlice("mongo.hosts") } if _opts.Username == "" { _opts.Username = viper.GetString("mongo.username") } if _opts.Password == "" { _opts.Password = viper.GetString("mongo.password") } if _opts.AuthSource == "" { _opts.AuthSource = viper.GetString("mongo.authSource") if _opts.AuthSource == "" { _opts.AuthSource = "admin" } } if _opts.AuthMechanism == "" { _opts.AuthMechanism = viper.GetString("mongo.authMechanism") } if _opts.AuthMechanismProperties == nil { _opts.AuthMechanismProperties = viper.GetStringMapString("mongo.authMechanismProperties") } // client options key json string _optsKeyBytes, err := json.Marshal(_opts) if err != nil { return nil, err } _optsKey := string(_optsKeyBytes) // attempt to get client by client options c, ok := _clientMap[_optsKey] if ok { return c, nil } // create new mongo client c, err = newMongoClient(_opts) if err != nil { logger.Errorf("create mongo client error: %v", err) return nil, err } // add to map _mu.Lock() _clientMap[_optsKey] = c _mu.Unlock() return c, nil } func newMongoClient(_opts *ClientOptions) (c *mongo.Client, err error) { // mongo client options mongoOpts := &options.ClientOptions{ AppName: &AppName, } if _opts.Uri != "" { // uri is set mongoOpts.ApplyURI(_opts.Uri) } else { // uri is unset if _opts.Username != "" && _opts.Password != "" { // username and password are set mongoOpts.SetAuth(options.Credential{ AuthMechanism: _opts.AuthMechanism, AuthMechanismProperties: _opts.AuthMechanismProperties, AuthSource: _opts.AuthSource, Username: _opts.Username, Password: _opts.Password, PasswordSet: true, }) } if len(_opts.Hosts) > 0 { // hosts are set mongoOpts.SetHosts(_opts.Hosts) } else { // hosts are unset mongoOpts.ApplyURI(fmt.Sprintf("mongodb://%s:%d/%s", _opts.Host, _opts.Port, _opts.Db)) } } // attempt to connect with retry op := func() error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() logger.Infof("connecting to mongo") c, err = mongo.Connect(ctx, mongoOpts) if err != nil { return err } err = c.Ping(ctx, nil) if err != nil { logger.Errorf("ping mongo error: %v", err) return err } logger.Infof("connected to mongo") return nil } b := backoff.NewExponentialBackOff() n := func(err error, duration time.Duration) { logger.Errorf("connect to mongo error: %v. retrying in %.1fs", err, duration.Seconds()) } err = backoff.RetryNotify(op, b, n) if err != nil { logger.Errorf("connect to mongo error: %v", err) return nil, err } return c, nil }