Files
crawlab/core/mongo/client.go

159 lines
3.6 KiB
Go

package mongo
import (
"context"
"encoding/json"
"fmt"
"github.com/cenkalti/backoff/v4"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"sync"
"time"
)
var AppName = "crawlab-db"
var _clientMap = map[string]*mongo.Client{}
var _mu sync.Mutex
func GetMongoClient(opts ...ClientOption) (c *mongo.Client, err error) {
// client options
_opts := &ClientOptions{}
for _, op := range opts {
op(_opts)
}
if _opts.Uri == "" {
_opts.Uri = viper.GetString("mongo.uri")
}
if _opts.Host == "" {
_opts.Host = viper.GetString("mongo.host")
if _opts.Host == "" {
_opts.Host = "localhost"
}
}
if _opts.Port == 0 {
_opts.Port = viper.GetInt("mongo.port")
if _opts.Port == 0 {
_opts.Port = 27017
}
}
if _opts.Db == "" {
_opts.Db = viper.GetString("mongo.db")
if _opts.Db == "" {
_opts.Db = "admin"
}
}
if len(_opts.Hosts) == 0 {
_opts.Hosts = viper.GetStringSlice("mongo.hosts")
}
if _opts.Username == "" {
_opts.Username = viper.GetString("mongo.username")
}
if _opts.Password == "" {
_opts.Password = viper.GetString("mongo.password")
}
if _opts.AuthSource == "" {
_opts.AuthSource = viper.GetString("mongo.authSource")
if _opts.AuthSource == "" {
_opts.AuthSource = "admin"
}
}
if _opts.AuthMechanism == "" {
_opts.AuthMechanism = viper.GetString("mongo.authMechanism")
}
if _opts.AuthMechanismProperties == nil {
_opts.AuthMechanismProperties = viper.GetStringMapString("mongo.authMechanismProperties")
}
// client options key json string
_optsKeyBytes, err := json.Marshal(_opts)
if err != nil {
return nil, err
}
_optsKey := string(_optsKeyBytes)
// attempt to get client by client options
c, ok := _clientMap[_optsKey]
if ok {
return c, nil
}
// create new mongo client
c, err = newMongoClient(_opts)
if err != nil {
logger.Errorf("create mongo client error: %v", err)
return nil, err
}
// add to map
_mu.Lock()
_clientMap[_optsKey] = c
_mu.Unlock()
return c, nil
}
func newMongoClient(_opts *ClientOptions) (c *mongo.Client, err error) {
// mongo client options
mongoOpts := &options.ClientOptions{
AppName: &AppName,
}
if _opts.Uri != "" {
// uri is set
mongoOpts.ApplyURI(_opts.Uri)
} else {
// uri is unset
if _opts.Username != "" && _opts.Password != "" {
// username and password are set
mongoOpts.SetAuth(options.Credential{
AuthMechanism: _opts.AuthMechanism,
AuthMechanismProperties: _opts.AuthMechanismProperties,
AuthSource: _opts.AuthSource,
Username: _opts.Username,
Password: _opts.Password,
PasswordSet: true,
})
}
if len(_opts.Hosts) > 0 {
// hosts are set
mongoOpts.SetHosts(_opts.Hosts)
} else {
// hosts are unset
mongoOpts.ApplyURI(fmt.Sprintf("mongodb://%s:%d/%s", _opts.Host, _opts.Port, _opts.Db))
}
}
// attempt to connect with retry
op := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
logger.Infof("connecting to mongo")
c, err = mongo.Connect(ctx, mongoOpts)
if err != nil {
return err
}
err = c.Ping(ctx, nil)
if err != nil {
logger.Errorf("ping mongo error: %v", err)
return err
}
logger.Infof("connected to mongo")
return nil
}
b := backoff.NewExponentialBackOff()
n := func(err error, duration time.Duration) {
logger.Errorf("connect to mongo error: %v. retrying in %.1fs", err, duration.Seconds())
}
err = backoff.RetryNotify(op, b, n)
if err != nil {
logger.Errorf("connect to mongo error: %v", err)
return nil, err
}
return c, nil
}