mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
feat: added modules
This commit is contained in:
9
db/.editorconfig
Normal file
9
db/.editorconfig
Normal file
@@ -0,0 +1,9 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
indent_size = 4
|
||||
indent_style = tab
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
5
db/.gitignore
vendored
Normal file
5
db/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
.idea
|
||||
.DS_Store
|
||||
tmp/
|
||||
vendor/
|
||||
go.sum
|
||||
29
db/LICENSE
Normal file
29
db/LICENSE
Normal file
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2020, Crawlab Team
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
2
db/README.md
Normal file
2
db/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# crawlab-db
|
||||
Backend database module for Crawlab
|
||||
6
db/errors/base.go
Normal file
6
db/errors/base.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package errors
|
||||
|
||||
const (
|
||||
errorPrefixMongo = "mongo"
|
||||
errorPrefixRedis = "redis"
|
||||
)
|
||||
10
db/errors/errors.go
Normal file
10
db/errors/errors.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package errors
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInvalidType = errors.New("invalid type")
|
||||
ErrMissingValue = errors.New("missing value")
|
||||
ErrNoCursor = errors.New("no cursor")
|
||||
ErrAlreadyLocked = errors.New("already locked")
|
||||
)
|
||||
15
db/errors/redis.go
Normal file
15
db/errors/redis.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorRedisInvalidType = NewRedisError("invalid type")
|
||||
ErrorRedisLocked = NewRedisError("locked")
|
||||
)
|
||||
|
||||
func NewRedisError(msg string) (err error) {
|
||||
return errors.New(fmt.Sprintf("%s: %s", errorPrefixRedis, msg))
|
||||
}
|
||||
8
db/generic/base.go
Normal file
8
db/generic/base.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package generic
|
||||
|
||||
const (
|
||||
DataSourceTypeMongo = "mongo"
|
||||
DataSourceTypeMysql = "mysql"
|
||||
DataSourceTypePostgres = "postgres"
|
||||
DataSourceTypeElasticSearch = "postgres"
|
||||
)
|
||||
15
db/generic/list.go
Normal file
15
db/generic/list.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package generic
|
||||
|
||||
type ListQueryCondition struct {
|
||||
Key string
|
||||
Op string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
type ListQuery []ListQueryCondition
|
||||
|
||||
type ListOptions struct {
|
||||
Skip int
|
||||
Limit int
|
||||
Sort []ListSort
|
||||
}
|
||||
7
db/generic/op.go
Normal file
7
db/generic/op.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package generic
|
||||
|
||||
type Op string
|
||||
|
||||
const (
|
||||
OpEqual = "eq"
|
||||
)
|
||||
13
db/generic/sort.go
Normal file
13
db/generic/sort.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package generic
|
||||
|
||||
type SortDirection string
|
||||
|
||||
const (
|
||||
SortDirectionAsc SortDirection = "asc"
|
||||
SortDirectionDesc SortDirection = "desc"
|
||||
)
|
||||
|
||||
type ListSort struct {
|
||||
Key string
|
||||
Direction SortDirection
|
||||
}
|
||||
47
db/go.mod
Normal file
47
db/go.mod
Normal file
@@ -0,0 +1,47 @@
|
||||
module github.com/crawlab-team/crawlab/db
|
||||
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/apex/log v1.9.0
|
||||
github.com/cenkalti/backoff/v4 v4.1.0
|
||||
github.com/crawlab-team/go-trace v0.1.0
|
||||
github.com/gomodule/redigo v2.0.0+incompatible
|
||||
github.com/jmoiron/sqlx v1.2.0
|
||||
github.com/spf13/viper v1.7.1
|
||||
github.com/stretchr/testify v1.6.1
|
||||
go.mongodb.org/mongo-driver v1.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.4.7 // indirect
|
||||
github.com/golang/snappy v0.0.1 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/klauspost/compress v1.13.6 // indirect
|
||||
github.com/lib/pq v1.1.1 // indirect
|
||||
github.com/logrusorgru/aurora v0.0.0-20181002194514-a7b3b318ed4e // indirect
|
||||
github.com/magiconair/properties v1.8.1 // indirect
|
||||
github.com/mitchellh/mapstructure v1.1.2 // indirect
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
|
||||
github.com/pelletier/go-toml v1.7.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/spf13/afero v1.1.2 // indirect
|
||||
github.com/spf13/cast v1.3.0 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.0.0 // indirect
|
||||
github.com/spf13/pflag v1.0.3 // indirect
|
||||
github.com/subosito/gotenv v1.2.0 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.1.2 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
|
||||
github.com/ztrue/tracerr v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/ini.v1 v1.51.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.8 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0 // indirect
|
||||
)
|
||||
39
db/interfaces.go
Normal file
39
db/interfaces.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package db
|
||||
|
||||
import "time"
|
||||
|
||||
type RedisClient interface {
|
||||
Ping() (err error)
|
||||
Keys(pattern string) (values []string, err error)
|
||||
AllKeys() (values []string, err error)
|
||||
Get(collection string) (value string, err error)
|
||||
Set(collection string, value string) (err error)
|
||||
Del(collection string) (err error)
|
||||
RPush(collection string, value interface{}) (err error)
|
||||
LPush(collection string, value interface{}) (err error)
|
||||
LPop(collection string) (value string, err error)
|
||||
RPop(collection string) (value string, err error)
|
||||
LLen(collection string) (count int, err error)
|
||||
BRPop(collection string, timeout int) (value string, err error)
|
||||
BLPop(collection string, timeout int) (value string, err error)
|
||||
HSet(collection string, key string, value string) (err error)
|
||||
HGet(collection string, key string) (value string, err error)
|
||||
HDel(collection string, key string) (err error)
|
||||
HScan(collection string) (results map[string]string, err error)
|
||||
HKeys(collection string) (results []string, err error)
|
||||
ZAdd(collection string, score float32, value interface{}) (err error)
|
||||
ZCount(collection string, min string, max string) (count int, err error)
|
||||
ZCountAll(collection string) (count int, err error)
|
||||
ZScan(collection string, pattern string, count int) (results []string, err error)
|
||||
ZPopMax(collection string, count int) (results []string, err error)
|
||||
ZPopMin(collection string, count int) (results []string, err error)
|
||||
ZPopMaxOne(collection string) (value string, err error)
|
||||
ZPopMinOne(collection string) (value string, err error)
|
||||
BZPopMax(collection string, timeout int) (value string, err error)
|
||||
BZPopMin(collection string, timeout int) (value string, err error)
|
||||
Lock(lockKey string) (value int64, err error)
|
||||
UnLock(lockKey string, value int64)
|
||||
MemoryStats() (stats map[string]int64, err error)
|
||||
SetBackoffMaxInterval(interval time.Duration)
|
||||
SetTimeout(timeout int)
|
||||
}
|
||||
147
db/mongo/client.go
Normal file
147
db/mongo/client.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/apex/log"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/crawlab-team/go-trace"
|
||||
"github.com/spf13/viper"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var AppName = "crawlab-db"
|
||||
|
||||
var _clientMap = map[string]*mongo.Client{}
|
||||
var _mu sync.Mutex
|
||||
|
||||
func GetMongoClient(opts ...ClientOption) (c *mongo.Client, err error) {
|
||||
// client options
|
||||
_opts := &ClientOptions{}
|
||||
for _, op := range opts {
|
||||
op(_opts)
|
||||
}
|
||||
if _opts.Uri == "" {
|
||||
_opts.Uri = viper.GetString("mongo.uri")
|
||||
}
|
||||
if _opts.Host == "" {
|
||||
_opts.Host = viper.GetString("mongo.host")
|
||||
if _opts.Host == "" {
|
||||
_opts.Host = "localhost"
|
||||
}
|
||||
}
|
||||
if _opts.Port == "" {
|
||||
_opts.Port = viper.GetString("mongo.port")
|
||||
if _opts.Port == "" {
|
||||
_opts.Port = "27017"
|
||||
}
|
||||
}
|
||||
if _opts.Db == "" {
|
||||
_opts.Db = viper.GetString("mongo.db")
|
||||
if _opts.Db == "" {
|
||||
_opts.Db = "admin"
|
||||
}
|
||||
}
|
||||
if len(_opts.Hosts) == 0 {
|
||||
_opts.Hosts = viper.GetStringSlice("mongo.hosts")
|
||||
}
|
||||
if _opts.Username == "" {
|
||||
_opts.Username = viper.GetString("mongo.username")
|
||||
}
|
||||
if _opts.Password == "" {
|
||||
_opts.Password = viper.GetString("mongo.password")
|
||||
}
|
||||
if _opts.AuthSource == "" {
|
||||
_opts.AuthSource = viper.GetString("mongo.authSource")
|
||||
if _opts.AuthSource == "" {
|
||||
_opts.AuthSource = "admin"
|
||||
}
|
||||
}
|
||||
if _opts.AuthMechanism == "" {
|
||||
_opts.AuthMechanism = viper.GetString("mongo.authMechanism")
|
||||
}
|
||||
if _opts.AuthMechanismProperties == nil {
|
||||
_opts.AuthMechanismProperties = viper.GetStringMapString("mongo.authMechanismProperties")
|
||||
}
|
||||
|
||||
// client options key json string
|
||||
_optsKeyBytes, err := json.Marshal(_opts)
|
||||
if err != nil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
_optsKey := string(_optsKeyBytes)
|
||||
|
||||
// attempt to get client by client options
|
||||
c, ok := _clientMap[_optsKey]
|
||||
if ok {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// create new mongo client
|
||||
c, err = newMongoClient(_opts.Context, _opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// add to map
|
||||
_mu.Lock()
|
||||
_clientMap[_optsKey] = c
|
||||
_mu.Unlock()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func newMongoClient(ctx context.Context, _opts *ClientOptions) (c *mongo.Client, err error) {
|
||||
// mongo client options
|
||||
mongoOpts := &options.ClientOptions{
|
||||
AppName: &AppName,
|
||||
}
|
||||
|
||||
if _opts.Uri != "" {
|
||||
// uri is set
|
||||
mongoOpts.ApplyURI(_opts.Uri)
|
||||
} else {
|
||||
// uri is unset
|
||||
|
||||
// username and password are set
|
||||
if _opts.Username != "" && _opts.Password != "" {
|
||||
mongoOpts.SetAuth(options.Credential{
|
||||
AuthMechanism: _opts.AuthMechanism,
|
||||
AuthMechanismProperties: _opts.AuthMechanismProperties,
|
||||
AuthSource: _opts.AuthSource,
|
||||
Username: _opts.Username,
|
||||
Password: _opts.Password,
|
||||
PasswordSet: true,
|
||||
})
|
||||
}
|
||||
|
||||
if len(_opts.Hosts) > 0 {
|
||||
// hosts are set
|
||||
mongoOpts.SetHosts(_opts.Hosts)
|
||||
} else {
|
||||
// hosts are unset
|
||||
mongoOpts.ApplyURI(fmt.Sprintf("mongodb://%s:%s/%s", _opts.Host, _opts.Port, _opts.Db))
|
||||
}
|
||||
}
|
||||
|
||||
// attempt to connect with retry
|
||||
bp := backoff.NewExponentialBackOff()
|
||||
err = backoff.Retry(func() error {
|
||||
errMsg := fmt.Sprintf("waiting for connect mongo database, after %f seconds try again.", bp.NextBackOff().Seconds())
|
||||
c, err = mongo.NewClient(mongoOpts)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf(errMsg)
|
||||
return err
|
||||
}
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
log.WithError(err).Warnf(errMsg)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, bp)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
79
db/mongo/client_options.go
Normal file
79
db/mongo/client_options.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package mongo
|
||||
|
||||
import "context"
|
||||
|
||||
type ClientOption func(options *ClientOptions)
|
||||
|
||||
type ClientOptions struct {
|
||||
Context context.Context
|
||||
Uri string
|
||||
Host string
|
||||
Port string
|
||||
Db string
|
||||
Hosts []string
|
||||
Username string
|
||||
Password string
|
||||
AuthSource string
|
||||
AuthMechanism string
|
||||
AuthMechanismProperties map[string]string
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithUri(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Uri = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithHost(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Host = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithPort(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Port = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithDb(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Db = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithHosts(value []string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Hosts = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithUsername(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Username = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithPassword(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Password = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthSource(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.AuthSource = value
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthMechanism(value string) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.AuthMechanism = value
|
||||
}
|
||||
}
|
||||
23
db/mongo/client_test.go
Normal file
23
db/mongo/client_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupMongoTest() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupMongoTest() {
|
||||
}
|
||||
|
||||
func TestMongoInitMongo(t *testing.T) {
|
||||
err := setupMongoTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = GetMongoClient()
|
||||
require.Nil(t, err)
|
||||
|
||||
cleanupMongoTest()
|
||||
}
|
||||
301
db/mongo/col.go
Normal file
301
db/mongo/col.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/crawlab-team/crawlab/db/errors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type ColInterface interface {
|
||||
Insert(doc interface{}) (id primitive.ObjectID, err error)
|
||||
InsertMany(docs []interface{}) (ids []primitive.ObjectID, err error)
|
||||
UpdateId(id primitive.ObjectID, update interface{}) (err error)
|
||||
Update(query bson.M, update interface{}) (err error)
|
||||
UpdateWithOptions(query bson.M, update interface{}, opts *options.UpdateOptions) (err error)
|
||||
ReplaceId(id primitive.ObjectID, doc interface{}) (err error)
|
||||
Replace(query bson.M, doc interface{}) (err error)
|
||||
ReplaceWithOptions(query bson.M, doc interface{}, opts *options.ReplaceOptions) (err error)
|
||||
DeleteId(id primitive.ObjectID) (err error)
|
||||
Delete(query bson.M) (err error)
|
||||
DeleteWithOptions(query bson.M, opts *options.DeleteOptions) (err error)
|
||||
Find(query bson.M, opts *FindOptions) (fr *FindResult)
|
||||
FindId(id primitive.ObjectID) (fr *FindResult)
|
||||
Count(query bson.M) (total int, err error)
|
||||
Aggregate(pipeline mongo.Pipeline, opts *options.AggregateOptions) (fr *FindResult)
|
||||
CreateIndex(indexModel mongo.IndexModel) (err error)
|
||||
CreateIndexes(indexModels []mongo.IndexModel) (err error)
|
||||
MustCreateIndex(indexModel mongo.IndexModel)
|
||||
MustCreateIndexes(indexModels []mongo.IndexModel)
|
||||
DeleteIndex(name string) (err error)
|
||||
DeleteAllIndexes() (err error)
|
||||
ListIndexes() (indexes []map[string]interface{}, err error)
|
||||
GetContext() (ctx context.Context)
|
||||
GetName() (name string)
|
||||
GetCollection() (c *mongo.Collection)
|
||||
}
|
||||
|
||||
type FindOptions struct {
|
||||
Skip int
|
||||
Limit int
|
||||
Sort bson.D
|
||||
}
|
||||
|
||||
type Col struct {
|
||||
ctx context.Context
|
||||
db *mongo.Database
|
||||
c *mongo.Collection
|
||||
}
|
||||
|
||||
func (col *Col) Insert(doc interface{}) (id primitive.ObjectID, err error) {
|
||||
res, err := col.c.InsertOne(col.ctx, doc)
|
||||
if err != nil {
|
||||
return primitive.NilObjectID, trace.TraceError(err)
|
||||
}
|
||||
if id, ok := res.InsertedID.(primitive.ObjectID); ok {
|
||||
return id, nil
|
||||
}
|
||||
return primitive.NilObjectID, trace.TraceError(errors.ErrInvalidType)
|
||||
}
|
||||
|
||||
func (col *Col) InsertMany(docs []interface{}) (ids []primitive.ObjectID, err error) {
|
||||
res, err := col.c.InsertMany(col.ctx, docs)
|
||||
if err != nil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
for _, v := range res.InsertedIDs {
|
||||
switch v.(type) {
|
||||
case primitive.ObjectID:
|
||||
id := v.(primitive.ObjectID)
|
||||
ids = append(ids, id)
|
||||
default:
|
||||
return nil, trace.TraceError(errors.ErrInvalidType)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (col *Col) UpdateId(id primitive.ObjectID, update interface{}) (err error) {
|
||||
_, err = col.c.UpdateOne(col.ctx, bson.M{"_id": id}, update)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) Update(query bson.M, update interface{}) (err error) {
|
||||
return col.UpdateWithOptions(query, update, nil)
|
||||
}
|
||||
|
||||
func (col *Col) UpdateWithOptions(query bson.M, update interface{}, opts *options.UpdateOptions) (err error) {
|
||||
if opts == nil {
|
||||
_, err = col.c.UpdateMany(col.ctx, query, update)
|
||||
} else {
|
||||
_, err = col.c.UpdateMany(col.ctx, query, update, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) ReplaceId(id primitive.ObjectID, doc interface{}) (err error) {
|
||||
return col.Replace(bson.M{"_id": id}, doc)
|
||||
}
|
||||
|
||||
func (col *Col) Replace(query bson.M, doc interface{}) (err error) {
|
||||
return col.ReplaceWithOptions(query, doc, nil)
|
||||
}
|
||||
|
||||
func (col *Col) ReplaceWithOptions(query bson.M, doc interface{}, opts *options.ReplaceOptions) (err error) {
|
||||
if opts == nil {
|
||||
_, err = col.c.ReplaceOne(col.ctx, query, doc)
|
||||
} else {
|
||||
_, err = col.c.ReplaceOne(col.ctx, query, doc, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) DeleteId(id primitive.ObjectID) (err error) {
|
||||
_, err = col.c.DeleteOne(col.ctx, bson.M{"_id": id})
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) Delete(query bson.M) (err error) {
|
||||
return col.DeleteWithOptions(query, nil)
|
||||
}
|
||||
|
||||
func (col *Col) DeleteWithOptions(query bson.M, opts *options.DeleteOptions) (err error) {
|
||||
if opts == nil {
|
||||
_, err = col.c.DeleteMany(col.ctx, query)
|
||||
} else {
|
||||
_, err = col.c.DeleteMany(col.ctx, query, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) Find(query bson.M, opts *FindOptions) (fr *FindResult) {
|
||||
_opts := &options.FindOptions{}
|
||||
if opts != nil {
|
||||
if opts.Skip != 0 {
|
||||
skipInt64 := int64(opts.Skip)
|
||||
_opts.Skip = &skipInt64
|
||||
}
|
||||
if opts.Limit != 0 {
|
||||
limitInt64 := int64(opts.Limit)
|
||||
_opts.Limit = &limitInt64
|
||||
}
|
||||
if opts.Sort != nil {
|
||||
_opts.Sort = opts.Sort
|
||||
}
|
||||
}
|
||||
cur, err := col.c.Find(col.ctx, query, _opts)
|
||||
if err != nil {
|
||||
return &FindResult{
|
||||
col: col,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
fr = &FindResult{
|
||||
col: col,
|
||||
cur: cur,
|
||||
}
|
||||
return fr
|
||||
}
|
||||
|
||||
func (col *Col) FindId(id primitive.ObjectID) (fr *FindResult) {
|
||||
res := col.c.FindOne(col.ctx, bson.M{"_id": id})
|
||||
if res.Err() != nil {
|
||||
return &FindResult{
|
||||
col: col,
|
||||
err: res.Err(),
|
||||
}
|
||||
}
|
||||
fr = &FindResult{
|
||||
col: col,
|
||||
res: res,
|
||||
}
|
||||
return fr
|
||||
}
|
||||
|
||||
func (col *Col) Count(query bson.M) (total int, err error) {
|
||||
totalInt64, err := col.c.CountDocuments(col.ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total = int(totalInt64)
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (col *Col) Aggregate(pipeline mongo.Pipeline, opts *options.AggregateOptions) (fr *FindResult) {
|
||||
cur, err := col.c.Aggregate(col.ctx, pipeline, opts)
|
||||
if err != nil {
|
||||
return &FindResult{
|
||||
col: col,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
if cur.Err() != nil {
|
||||
return &FindResult{
|
||||
col: col,
|
||||
err: cur.Err(),
|
||||
}
|
||||
}
|
||||
fr = &FindResult{
|
||||
col: col,
|
||||
cur: cur,
|
||||
}
|
||||
return fr
|
||||
}
|
||||
|
||||
func (col *Col) CreateIndex(indexModel mongo.IndexModel) (err error) {
|
||||
_, err = col.c.Indexes().CreateOne(col.ctx, indexModel)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) CreateIndexes(indexModels []mongo.IndexModel) (err error) {
|
||||
_, err = col.c.Indexes().CreateMany(col.ctx, indexModels)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) MustCreateIndex(indexModel mongo.IndexModel) {
|
||||
_, _ = col.c.Indexes().CreateOne(col.ctx, indexModel)
|
||||
}
|
||||
|
||||
func (col *Col) MustCreateIndexes(indexModels []mongo.IndexModel) {
|
||||
_, _ = col.c.Indexes().CreateMany(col.ctx, indexModels)
|
||||
}
|
||||
|
||||
func (col *Col) DeleteIndex(name string) (err error) {
|
||||
_, err = col.c.Indexes().DropOne(col.ctx, name)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) DeleteAllIndexes() (err error) {
|
||||
_, err = col.c.Indexes().DropAll(col.ctx)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Col) ListIndexes() (indexes []map[string]interface{}, err error) {
|
||||
cur, err := col.c.Indexes().List(col.ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := cur.All(col.ctx, &indexes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (col *Col) GetContext() (ctx context.Context) {
|
||||
return col.ctx
|
||||
}
|
||||
|
||||
func (col *Col) GetName() (name string) {
|
||||
return col.c.Name()
|
||||
}
|
||||
|
||||
func (col *Col) GetCollection() (c *mongo.Collection) {
|
||||
return col.c
|
||||
}
|
||||
|
||||
func GetMongoCol(colName string) (col *Col) {
|
||||
return GetMongoColWithDb(colName, nil)
|
||||
}
|
||||
|
||||
func GetMongoColWithDb(colName string, db *mongo.Database) (col *Col) {
|
||||
ctx := context.Background()
|
||||
if db == nil {
|
||||
db = GetMongoDb("")
|
||||
}
|
||||
c := db.Collection(colName)
|
||||
col = &Col{
|
||||
ctx: ctx,
|
||||
db: db,
|
||||
c: c,
|
||||
}
|
||||
return col
|
||||
}
|
||||
463
db/mongo/col_test.go
Normal file
463
db/mongo/col_test.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type ColTestObject struct {
|
||||
dbName string
|
||||
colName string
|
||||
col *Col
|
||||
}
|
||||
|
||||
type TestDocument struct {
|
||||
Key string `bson:"key"`
|
||||
Value int `bson:"value"`
|
||||
Tags []string `bson:"tags"`
|
||||
}
|
||||
|
||||
type TestAggregateResult struct {
|
||||
Id string `bson:"_id"`
|
||||
Count int `bson:"count"`
|
||||
Value int `bson:"value"`
|
||||
}
|
||||
|
||||
func setupColTest() (to *ColTestObject, err error) {
|
||||
dbName := "test_db"
|
||||
colName := "test_col"
|
||||
viper.Set("mongo.db", dbName)
|
||||
col := GetMongoCol(colName)
|
||||
if err := col.db.Drop(col.ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ColTestObject{
|
||||
dbName: dbName,
|
||||
colName: colName,
|
||||
col: col,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cleanupColTest(to *ColTestObject) {
|
||||
_ = to.col.db.Drop(to.col.ctx)
|
||||
}
|
||||
|
||||
func TestGetMongoCol(t *testing.T) {
|
||||
colName := "test_col"
|
||||
|
||||
col := GetMongoCol(colName)
|
||||
require.Equal(t, colName, col.c.Name())
|
||||
}
|
||||
|
||||
func TestGetMongoColWithDb(t *testing.T) {
|
||||
dbName := "test_db"
|
||||
colName := "test_col"
|
||||
|
||||
col := GetMongoColWithDb(colName, dbName)
|
||||
require.Equal(t, colName, col.c.Name())
|
||||
require.Equal(t, dbName, col.db.Name())
|
||||
}
|
||||
|
||||
func TestCol_Insert(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
id, err := to.col.Insert(bson.M{"key": "value"})
|
||||
require.Nil(t, err)
|
||||
require.IsType(t, primitive.ObjectID{}, id)
|
||||
|
||||
var doc map[string]string
|
||||
err = to.col.FindId(id).One(&doc)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, doc["key"], "value")
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_InsertMany(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
n := 10
|
||||
var docs []interface{}
|
||||
for i := 0; i < n; i++ {
|
||||
docs = append(docs, bson.M{"key": fmt.Sprintf("value-%d", i)})
|
||||
}
|
||||
ids, err := to.col.InsertMany(docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n, len(ids))
|
||||
|
||||
var resDocs []map[string]string
|
||||
err = to.col.Find(nil, &FindOptions{Sort: bson.D{{"_id", 1}}}).All(&resDocs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n, len(resDocs))
|
||||
for i, doc := range resDocs {
|
||||
require.Equal(t, fmt.Sprintf("value-%d", i), doc["key"])
|
||||
}
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_UpdateId(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
id, err := to.col.Insert(bson.M{"key": "old-value"})
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.UpdateId(id, bson.M{
|
||||
"$set": bson.M{
|
||||
"key": "new-value",
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
var doc map[string]string
|
||||
err = to.col.FindId(id).One(&doc)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "new-value", doc["key"])
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_Update(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
n := 10
|
||||
var docs []interface{}
|
||||
for i := 0; i < n; i++ {
|
||||
docs = append(docs, bson.M{"key": fmt.Sprintf("old-value-%d", i)})
|
||||
}
|
||||
|
||||
err = to.col.Update(nil, bson.M{
|
||||
"$set": bson.M{
|
||||
"key": "new-value",
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
var resDocs []map[string]string
|
||||
err = to.col.Find(nil, &FindOptions{Sort: bson.D{{"_id", 1}}}).All(&resDocs)
|
||||
require.Nil(t, err)
|
||||
for _, doc := range resDocs {
|
||||
require.Equal(t, "new-value", doc["key"])
|
||||
}
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_ReplaceId(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
id, err := to.col.Insert(bson.M{"key": "old-value"})
|
||||
require.Nil(t, err)
|
||||
|
||||
var doc map[string]interface{}
|
||||
err = to.col.FindId(id).One(&doc)
|
||||
require.Nil(t, err)
|
||||
doc["key"] = "new-value"
|
||||
|
||||
err = to.col.ReplaceId(id, doc)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.FindId(id).One(doc)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "new-value", doc["key"])
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_Replace(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
id, err := to.col.Insert(bson.M{"key": "old-value"})
|
||||
require.Nil(t, err)
|
||||
|
||||
var doc map[string]interface{}
|
||||
err = to.col.FindId(id).One(&doc)
|
||||
require.Nil(t, err)
|
||||
doc["key"] = "new-value"
|
||||
|
||||
err = to.col.Replace(bson.M{"key": "old-value"}, doc)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.FindId(id).One(&doc)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "new-value", doc["key"])
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_DeleteId(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
id, err := to.col.Insert(bson.M{"key": "value"})
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.DeleteId(id)
|
||||
require.Nil(t, err)
|
||||
|
||||
total, err := to.col.Count(nil)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, total)
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_Delete(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
n := 10
|
||||
var docs []interface{}
|
||||
for i := 0; i < n; i++ {
|
||||
docs = append(docs, bson.M{"key": fmt.Sprintf("value-%d", i)})
|
||||
}
|
||||
ids, err := to.col.InsertMany(docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n, len(ids))
|
||||
|
||||
err = to.col.Delete(bson.M{"key": "value-0"})
|
||||
require.Nil(t, err)
|
||||
|
||||
total, err := to.col.Count(nil)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n-1, total)
|
||||
|
||||
err = to.col.Delete(nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
total, err = to.col.Count(nil)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, total)
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_FindId(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
id, err := to.col.Insert(bson.M{"key": "value"})
|
||||
require.Nil(t, err)
|
||||
|
||||
var doc map[string]string
|
||||
err = to.col.FindId(id).One(&doc)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "value", doc["key"])
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_Find(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
n := 10
|
||||
var docs []interface{}
|
||||
for i := 0; i < n; i++ {
|
||||
docs = append(docs, TestDocument{
|
||||
Key: fmt.Sprintf("value-%d", i),
|
||||
Tags: []string{"test tag"},
|
||||
})
|
||||
}
|
||||
ids, err := to.col.InsertMany(docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n, len(ids))
|
||||
|
||||
err = to.col.Find(nil, nil).All(&docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n, len(docs))
|
||||
|
||||
err = to.col.Find(bson.M{"key": bson.M{"$gte": fmt.Sprintf("value-%d", 5)}}, nil).All(&docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n-5, len(docs))
|
||||
|
||||
err = to.col.Find(nil, &FindOptions{
|
||||
Skip: 5,
|
||||
}).All(&docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n-5, len(docs))
|
||||
|
||||
err = to.col.Find(nil, &FindOptions{
|
||||
Limit: 5,
|
||||
}).All(&docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, len(docs))
|
||||
|
||||
var resDocs []TestDocument
|
||||
err = to.col.Find(nil, &FindOptions{
|
||||
Sort: bson.D{{"key", 1}},
|
||||
}).All(&resDocs)
|
||||
require.Nil(t, err)
|
||||
require.Greater(t, len(resDocs), 0)
|
||||
require.Equal(t, "value-0", resDocs[0].Key)
|
||||
|
||||
err = to.col.Find(nil, &FindOptions{
|
||||
Sort: bson.D{{"key", -1}},
|
||||
}).All(&resDocs)
|
||||
require.Nil(t, err)
|
||||
require.Greater(t, len(resDocs), 0)
|
||||
require.Equal(t, fmt.Sprintf("value-%d", n-1), resDocs[0].Key)
|
||||
|
||||
var resDocs2 []TestDocument
|
||||
err = to.col.Find(bson.M{"tags": bson.M{"$in": []string{"test tag"}}}, nil).All(&resDocs2)
|
||||
require.Nil(t, err)
|
||||
require.Greater(t, len(resDocs2), 0)
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_CreateIndex(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.CreateIndex(mongo.IndexModel{
|
||||
Keys: bson.D{{"key", 1}},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
indexes, err := to.col.ListIndexes()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(indexes))
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_Aggregate(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
n := 10
|
||||
v := 2
|
||||
var docs []interface{}
|
||||
for i := 0; i < n; i++ {
|
||||
docs = append(docs, TestDocument{
|
||||
Key: fmt.Sprintf("%d", i%2),
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
ids, err := to.col.InsertMany(docs)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, n, len(ids))
|
||||
|
||||
pipeline := mongo.Pipeline{
|
||||
{
|
||||
{
|
||||
"$group",
|
||||
bson.D{
|
||||
{"_id", "$key"},
|
||||
{
|
||||
"count",
|
||||
bson.D{{"$sum", 1}},
|
||||
},
|
||||
{
|
||||
"value",
|
||||
bson.D{{"$sum", "$value"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
"$sort",
|
||||
bson.D{{"_id", 1}},
|
||||
},
|
||||
},
|
||||
}
|
||||
var results []TestAggregateResult
|
||||
err = to.col.Aggregate(pipeline, nil).All(&results)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(results))
|
||||
|
||||
for i, r := range results {
|
||||
require.Equal(t, strconv.Itoa(i), r.Id)
|
||||
require.Equal(t, n/2, r.Count)
|
||||
require.Equal(t, n*v/2, r.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCol_CreateIndexes(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.CreateIndexes([]mongo.IndexModel{
|
||||
{
|
||||
Keys: bson.D{{"key", 1}},
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{"empty-key", 1}},
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
indexes, err := to.col.ListIndexes()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(indexes))
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_DeleteIndex(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.CreateIndex(mongo.IndexModel{
|
||||
Keys: bson.D{{"key", 1}},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
indexes, err := to.col.ListIndexes()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(indexes))
|
||||
for _, index := range indexes {
|
||||
name, ok := index["name"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
if name != "_id_" {
|
||||
err = to.col.DeleteIndex(name)
|
||||
require.Nil(t, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
indexes, err = to.col.ListIndexes()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(indexes))
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
|
||||
func TestCol_DeleteIndexes(t *testing.T) {
|
||||
to, err := setupColTest()
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.CreateIndexes([]mongo.IndexModel{
|
||||
{
|
||||
Keys: bson.D{{"key", 1}},
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{"empty-key", 1}},
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
err = to.col.DeleteAllIndexes()
|
||||
require.Nil(t, err)
|
||||
|
||||
indexes, err := to.col.ListIndexes()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(indexes))
|
||||
|
||||
cleanupColTest(to)
|
||||
}
|
||||
35
db/mongo/db.go
Normal file
35
db/mongo/db.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/go-trace"
|
||||
"github.com/spf13/viper"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func GetMongoDb(dbName string, opts ...DbOption) (db *mongo.Database) {
|
||||
if dbName == "" {
|
||||
dbName = viper.GetString("mongo.db")
|
||||
}
|
||||
if dbName == "" {
|
||||
dbName = "test"
|
||||
}
|
||||
|
||||
_opts := &DbOptions{}
|
||||
for _, op := range opts {
|
||||
op(_opts)
|
||||
}
|
||||
|
||||
var c *mongo.Client
|
||||
if _opts.client == nil {
|
||||
var err error
|
||||
c, err = GetMongoClient()
|
||||
if err != nil {
|
||||
trace.PrintError(err)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
c = _opts.client
|
||||
}
|
||||
|
||||
return c.Database(dbName, nil)
|
||||
}
|
||||
15
db/mongo/db_options.go
Normal file
15
db/mongo/db_options.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package mongo
|
||||
|
||||
import "go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
type DbOption func(options *DbOptions)
|
||||
|
||||
type DbOptions struct {
|
||||
client *mongo.Client
|
||||
}
|
||||
|
||||
func WithDbClient(c *mongo.Client) DbOption {
|
||||
return func(options *DbOptions) {
|
||||
options.client = c
|
||||
}
|
||||
}
|
||||
17
db/mongo/db_test.go
Normal file
17
db/mongo/db_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMongoGetDb(t *testing.T) {
|
||||
dbName := "test_db"
|
||||
viper.Set("mongo.db", dbName)
|
||||
err := InitMongo()
|
||||
require.Nil(t, err)
|
||||
|
||||
db := GetMongoDb("")
|
||||
require.Equal(t, dbName, db.Name())
|
||||
}
|
||||
82
db/mongo/result.go
Normal file
82
db/mongo/result.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/crawlab-team/crawlab/db/errors"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type FindResultInterface interface {
|
||||
One(val interface{}) (err error)
|
||||
All(val interface{}) (err error)
|
||||
GetCol() (col *Col)
|
||||
GetSingleResult() (res *mongo.SingleResult)
|
||||
GetCursor() (cur *mongo.Cursor)
|
||||
GetError() (err error)
|
||||
}
|
||||
|
||||
func NewFindResult() (fr *FindResult) {
|
||||
return &FindResult{}
|
||||
}
|
||||
|
||||
func NewFindResultWithError(err error) (fr *FindResult) {
|
||||
return &FindResult{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
type FindResult struct {
|
||||
col *Col
|
||||
res *mongo.SingleResult
|
||||
cur *mongo.Cursor
|
||||
err error
|
||||
}
|
||||
|
||||
func (fr *FindResult) GetError() (err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (fr *FindResult) One(val interface{}) (err error) {
|
||||
if fr.err != nil {
|
||||
return fr.err
|
||||
}
|
||||
if fr.cur != nil {
|
||||
if !fr.cur.TryNext(fr.col.ctx) {
|
||||
return mongo.ErrNoDocuments
|
||||
}
|
||||
return fr.cur.Decode(val)
|
||||
}
|
||||
return fr.res.Decode(val)
|
||||
}
|
||||
|
||||
func (fr *FindResult) All(val interface{}) (err error) {
|
||||
if fr.err != nil {
|
||||
return fr.err
|
||||
}
|
||||
var ctx context.Context
|
||||
if fr.col == nil {
|
||||
ctx = context.Background()
|
||||
} else {
|
||||
ctx = fr.col.ctx
|
||||
}
|
||||
if fr.cur == nil {
|
||||
return errors.ErrNoCursor
|
||||
}
|
||||
if !fr.cur.TryNext(ctx) {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fr.cur.All(ctx, val)
|
||||
}
|
||||
|
||||
func (fr *FindResult) GetCol() (col *Col) {
|
||||
return fr.col
|
||||
}
|
||||
|
||||
func (fr *FindResult) GetSingleResult() (res *mongo.SingleResult) {
|
||||
return fr.res
|
||||
}
|
||||
|
||||
func (fr *FindResult) GetCursor() (cur *mongo.Cursor) {
|
||||
return fr.cur
|
||||
}
|
||||
45
db/mongo/transaction.go
Normal file
45
db/mongo/transaction.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/crawlab-team/go-trace"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func RunTransaction(fn func(mongo.SessionContext) error) (err error) {
|
||||
return RunTransactionWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
func RunTransactionWithContext(ctx context.Context, fn func(mongo.SessionContext) error) (err error) {
|
||||
// default client
|
||||
c, err := GetMongoClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// start session
|
||||
s, err := c.StartSession()
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
// start transaction
|
||||
if err := s.StartTransaction(); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
// perform operation
|
||||
if err := mongo.WithSession(ctx, s, func(sc mongo.SessionContext) error {
|
||||
if err := fn(sc); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
if err = s.CommitTransaction(sc); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
549
db/redis/client.go
Normal file
549
db/redis/client.go
Normal file
@@ -0,0 +1,549 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/db"
|
||||
"github.com/crawlab-team/crawlab/db/errors"
|
||||
"github.com/crawlab-team/crawlab/db/utils"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
// settings
|
||||
backoffMaxInterval time.Duration
|
||||
timeout int
|
||||
|
||||
// internals
|
||||
pool *redis.Pool
|
||||
}
|
||||
|
||||
func (client *Client) Ping() error {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
if _, err := redis.String(c.Do("PING")); err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) Keys(pattern string) (values []string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err = redis.Strings(c.Do("KEYS", pattern))
|
||||
if err != nil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func (client *Client) AllKeys() (values []string, err error) {
|
||||
return client.Keys("*")
|
||||
}
|
||||
|
||||
func (client *Client) Get(collection string) (value string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
value, err = redis.String(c.Do("GET", collection))
|
||||
if err != nil {
|
||||
return "", trace.TraceError(err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (client *Client) Set(collection string, value string) (err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
value, err = redis.String(c.Do("SET", collection, value))
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) Del(collection string) error {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
if _, err := c.Do("DEL", collection); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) RPush(collection string, value interface{}) error {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
if _, err := c.Do("RPUSH", collection, value); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) LPush(collection string, value interface{}) error {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
if _, err := c.Do("LPUSH", collection, value); err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) LPop(collection string) (string, error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
value, err := redis.String(c.Do("LPOP", collection))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return value, trace.TraceError(err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (client *Client) RPop(collection string) (string, error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
value, err := redis.String(c.Do("RPOP", collection))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return value, trace.TraceError(err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (client *Client) LLen(collection string) (int, error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
value, err := redis.Int(c.Do("LLEN", collection))
|
||||
if err != nil {
|
||||
return 0, trace.TraceError(err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (client *Client) BRPop(collection string, timeout int) (value string, err error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 60
|
||||
}
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err := redis.Strings(c.Do("BRPOP", collection, timeout))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return value, trace.TraceError(err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
return values[1], nil
|
||||
}
|
||||
|
||||
func (client *Client) BLPop(collection string, timeout int) (value string, err error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 60
|
||||
}
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err := redis.Strings(c.Do("BLPOP", collection, timeout))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return value, trace.TraceError(err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
return values[1], nil
|
||||
}
|
||||
|
||||
func (client *Client) HSet(collection string, key string, value string) error {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
if _, err := c.Do("HSET", collection, key, value); err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) HGet(collection string, key string) (string, error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
value, err := redis.String(c.Do("HGET", collection, key))
|
||||
if err != nil && err != redis.ErrNil {
|
||||
if err != redis.ErrNil {
|
||||
return value, trace.TraceError(err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (client *Client) HDel(collection string, key string) error {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
if _, err := c.Do("HDEL", collection, key); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) HScan(collection string) (results map[string]string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
var (
|
||||
cursor int64
|
||||
items []string
|
||||
)
|
||||
|
||||
results = map[string]string{}
|
||||
|
||||
for {
|
||||
values, err := redis.Values(c.Do("HSCAN", collection, cursor))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values, err = redis.Scan(values, &cursor, &items)
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
for i := 0; i < len(items); i += 2 {
|
||||
key := items[i]
|
||||
value := items[i+1]
|
||||
results[key] = value
|
||||
}
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (client *Client) HKeys(collection string) (results []string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
results, err = redis.Strings(c.Do("HKEYS", collection))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return results, trace.TraceError(err)
|
||||
}
|
||||
return results, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (client *Client) ZAdd(collection string, score float32, value interface{}) (err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
if _, err := c.Do("ZADD", collection, score, value); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) ZCount(collection string, min string, max string) (count int, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
count, err = redis.Int(c.Do("ZCOUNT", collection, min, max))
|
||||
if err != nil {
|
||||
return 0, trace.TraceError(err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (client *Client) ZCountAll(collection string) (count int, err error) {
|
||||
return client.ZCount(collection, "-inf", "+inf")
|
||||
}
|
||||
|
||||
func (client *Client) ZScan(collection string, pattern string, count int) (values []string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err = redis.Strings(c.Do("ZSCAN", collection, 0, pattern, count))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func (client *Client) ZPopMax(collection string, count int) (results []string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
results = []string{}
|
||||
|
||||
values, err := redis.Strings(c.Do("ZPOPMAX", collection, count))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
v := values[i]
|
||||
results = append(results, v)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (client *Client) ZPopMin(collection string, count int) (results []string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
results = []string{}
|
||||
|
||||
values, err := redis.Strings(c.Do("ZPOPMIN", collection, count))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return nil, trace.TraceError(err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
v := values[i]
|
||||
results = append(results, v)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (client *Client) ZPopMaxOne(collection string) (value string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err := client.ZPopMax(collection, 1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if values == nil || len(values) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
func (client *Client) ZPopMinOne(collection string) (value string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err := client.ZPopMin(collection, 1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if values == nil || len(values) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
func (client *Client) BZPopMax(collection string, timeout int) (value string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err := redis.Strings(c.Do("BZPOPMAX", collection, timeout))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return "", trace.TraceError(err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if len(values) < 3 {
|
||||
return "", trace.TraceError(errors.ErrorRedisInvalidType)
|
||||
}
|
||||
return values[1], nil
|
||||
}
|
||||
|
||||
func (client *Client) BZPopMin(collection string, timeout int) (value string, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
|
||||
values, err := redis.Strings(c.Do("BZPOPMIN", collection, timeout))
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return "", trace.TraceError(err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if len(values) < 3 {
|
||||
return "", trace.TraceError(errors.ErrorRedisInvalidType)
|
||||
}
|
||||
return values[1], nil
|
||||
}
|
||||
|
||||
func (client *Client) Lock(lockKey string) (value int64, err error) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
lockKey = client.getLockKey(lockKey)
|
||||
|
||||
ts := time.Now().Unix()
|
||||
ok, err := c.Do("SET", lockKey, ts, "NX", "PX", 30000)
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return value, trace.TraceError(err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
if ok == nil {
|
||||
return 0, trace.TraceError(errors.ErrorRedisLocked)
|
||||
}
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (client *Client) UnLock(lockKey string, value int64) {
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
lockKey = client.getLockKey(lockKey)
|
||||
|
||||
getValue, err := redis.Int64(c.Do("GET", lockKey))
|
||||
if err != nil {
|
||||
log.Errorf("get lockKey error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if getValue != value {
|
||||
log.Errorf("the lockKey value diff: %d, %d", value, getValue)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := redis.Int64(c.Do("DEL", lockKey))
|
||||
if err != nil {
|
||||
log.Errorf("unlock failed, error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if v == 0 {
|
||||
log.Errorf("unlock failed: key=%s", lockKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) MemoryStats() (stats map[string]int64, err error) {
|
||||
stats = map[string]int64{}
|
||||
c := client.pool.Get()
|
||||
defer utils.Close(c)
|
||||
values, err := redis.Values(c.Do("MEMORY", "STATS"))
|
||||
for i, v := range values {
|
||||
t := reflect.TypeOf(v)
|
||||
if t.Kind() == reflect.Slice {
|
||||
vc, _ := redis.String(v, err)
|
||||
if utils.ContainsString(MemoryStatsMetrics, vc) {
|
||||
stats[vc], _ = redis.Int64(values[i+1], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err != redis.ErrNil {
|
||||
return stats, trace.TraceError(err)
|
||||
}
|
||||
return stats, err
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (client *Client) SetBackoffMaxInterval(interval time.Duration) {
|
||||
client.backoffMaxInterval = interval
|
||||
}
|
||||
|
||||
func (client *Client) SetTimeout(timeout int) {
|
||||
client.timeout = timeout
|
||||
}
|
||||
|
||||
func (client *Client) init() (err error) {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxInterval = client.backoffMaxInterval
|
||||
if err := backoff.Retry(func() error {
|
||||
err := client.Ping()
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("waiting for redis pool active connection. will after %f seconds try again.", b.NextBackOff().Seconds())
|
||||
}
|
||||
return nil
|
||||
}, b); err != nil {
|
||||
return trace.TraceError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) getLockKey(lockKey string) string {
|
||||
lockKey = strings.ReplaceAll(lockKey, ":", "-")
|
||||
return "nodes:lock:" + lockKey
|
||||
}
|
||||
|
||||
func (client *Client) getTimeout(timeout int) (res int) {
|
||||
if timeout == 0 {
|
||||
return client.timeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
var client db.RedisClient
|
||||
|
||||
func NewRedisClient(opts ...Option) (client *Client, err error) {
|
||||
// client
|
||||
client = &Client{
|
||||
backoffMaxInterval: 20 * time.Second,
|
||||
pool: NewRedisPool(),
|
||||
}
|
||||
|
||||
// apply options
|
||||
for _, opt := range opts {
|
||||
opt(client)
|
||||
}
|
||||
|
||||
// init
|
||||
if err := client.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func GetRedisClient() (c db.RedisClient, err error) {
|
||||
if client != nil {
|
||||
return client, nil
|
||||
}
|
||||
c, err = NewRedisClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
10
db/redis/constants.go
Normal file
10
db/redis/constants.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package redis
|
||||
|
||||
var MemoryStatsMetrics = []string{
|
||||
"peak.allocated",
|
||||
"total.allocated",
|
||||
"startup.allocated",
|
||||
"overhead.total",
|
||||
"keys.count",
|
||||
"dataset.bytes",
|
||||
}
|
||||
20
db/redis/options.go
Normal file
20
db/redis/options.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/crawlab/db"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Option func(c db.RedisClient)
|
||||
|
||||
func WithBackoffMaxInterval(interval time.Duration) Option {
|
||||
return func(c db.RedisClient) {
|
||||
c.SetBackoffMaxInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
func WithTimeout(timeout int) Option {
|
||||
return func(c db.RedisClient) {
|
||||
c.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
54
db/redis/pool.go
Normal file
54
db/redis/pool.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/go-trace"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/spf13/viper"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewRedisPool() *redis.Pool {
|
||||
var address = viper.GetString("redis.address")
|
||||
var port = viper.GetString("redis.port")
|
||||
var database = viper.GetString("redis.database")
|
||||
var password = viper.GetString("redis.password")
|
||||
|
||||
// normalize params
|
||||
if address == "" {
|
||||
address = "localhost"
|
||||
}
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
if database == "" {
|
||||
database = "1"
|
||||
}
|
||||
|
||||
var url string
|
||||
if password == "" {
|
||||
url = "redis://" + address + ":" + port + "/" + database
|
||||
} else {
|
||||
url = "redis://x:" + password + "@" + address + ":" + port + "/" + database
|
||||
}
|
||||
return &redis.Pool{
|
||||
Dial: func() (conn redis.Conn, e error) {
|
||||
return redis.DialURL(url,
|
||||
redis.DialConnectTimeout(time.Second*10),
|
||||
redis.DialReadTimeout(time.Second*600),
|
||||
redis.DialWriteTimeout(time.Second*10),
|
||||
)
|
||||
},
|
||||
TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
if time.Since(t) < time.Minute {
|
||||
return nil
|
||||
}
|
||||
_, err := c.Do("PING")
|
||||
return trace.TraceError(err)
|
||||
},
|
||||
MaxIdle: 10,
|
||||
MaxActive: 0,
|
||||
IdleTimeout: 300 * time.Second,
|
||||
Wait: false,
|
||||
MaxConnLifetime: 0,
|
||||
}
|
||||
}
|
||||
91
db/redis/test/base.go
Normal file
91
db/redis/test/base.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/crawlab/db"
|
||||
"github.com/crawlab-team/crawlab/db/redis"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
T, err = NewTest()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Test struct {
|
||||
client db.RedisClient
|
||||
TestCollection string
|
||||
TestMessage string
|
||||
TestMessages []string
|
||||
TestMessagesMap map[string]string
|
||||
TestKeysAlpha []string
|
||||
TestKeysBeta []string
|
||||
TestLockKey string
|
||||
}
|
||||
|
||||
func (t *Test) Setup(t2 *testing.T) {
|
||||
t2.Cleanup(t.Cleanup)
|
||||
}
|
||||
|
||||
func (t *Test) Cleanup() {
|
||||
keys, _ := t.client.AllKeys()
|
||||
for _, key := range keys {
|
||||
_ = t.client.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
var T *Test
|
||||
|
||||
func NewTest() (t *Test, err error) {
|
||||
// test
|
||||
t = &Test{}
|
||||
|
||||
// client
|
||||
t.client, err = redis.GetRedisClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// test collection
|
||||
t.TestCollection = "test_collection"
|
||||
|
||||
// test message
|
||||
t.TestMessage = "this is a test message"
|
||||
|
||||
// test messages
|
||||
t.TestMessages = []string{
|
||||
"test message 1",
|
||||
"test message 2",
|
||||
"test message 3",
|
||||
}
|
||||
|
||||
// test messages map
|
||||
t.TestMessagesMap = map[string]string{
|
||||
"test key 1": "test value 1",
|
||||
"test key 2": "test value 2",
|
||||
"test key 3": "test value 3",
|
||||
}
|
||||
|
||||
// test keys alpha
|
||||
t.TestKeysAlpha = []string{
|
||||
"test key alpha 1",
|
||||
"test key alpha 2",
|
||||
"test key alpha 3",
|
||||
}
|
||||
|
||||
// test keys beta
|
||||
t.TestKeysBeta = []string{
|
||||
"test key beta 1",
|
||||
"test key beta 2",
|
||||
"test key beta 3",
|
||||
"test key beta 4",
|
||||
"test key beta 5",
|
||||
}
|
||||
|
||||
// test lock key
|
||||
t.TestLockKey = "test lock key"
|
||||
|
||||
return t, nil
|
||||
}
|
||||
273
db/redis/test/client_test.go
Normal file
273
db/redis/test/client_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"github.com/crawlab-team/crawlab/db/redis"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRedisClient_Ping(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
err = T.client.Ping()
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestRedisClient_Get_Set(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
err = T.client.Set(T.TestCollection, T.TestMessage)
|
||||
require.Nil(t, err)
|
||||
|
||||
value, err := T.client.Get(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessage, value)
|
||||
}
|
||||
|
||||
func TestRedisClient_Keys_AllKeys(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for _, key := range T.TestKeysAlpha {
|
||||
err = T.client.Set(key, key)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
for _, key := range T.TestKeysBeta {
|
||||
err = T.client.Set(key, key)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
keys, err := T.client.Keys("*alpha*")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, keys, len(T.TestKeysAlpha))
|
||||
|
||||
keys, err = T.client.Keys("*beta*")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, keys, len(T.TestKeysBeta))
|
||||
|
||||
keys, err = T.client.AllKeys()
|
||||
require.Nil(t, err)
|
||||
require.Len(t, keys, len(T.TestKeysAlpha)+len(T.TestKeysBeta))
|
||||
}
|
||||
|
||||
func TestRedisClient_RPush_LPop_LLen(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for _, msg := range T.TestMessages {
|
||||
err = T.client.RPush(T.TestCollection, msg)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
n, err := T.client.LLen(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, len(T.TestMessages), n)
|
||||
|
||||
value, err := T.client.LPop(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessages[0], value)
|
||||
}
|
||||
|
||||
func TestRedisClient_LPush_RPop(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for _, msg := range T.TestMessages {
|
||||
err = T.client.LPush(T.TestCollection, msg)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
n, err := T.client.LLen(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, len(T.TestMessages), n)
|
||||
|
||||
value, err := T.client.RPop(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessages[0], value)
|
||||
}
|
||||
|
||||
func TestRedisClient_BRPop(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
isErr := true
|
||||
go func(t *testing.T) {
|
||||
value, err := T.client.BRPop(T.TestCollection, 0)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessage, value)
|
||||
isErr = false
|
||||
}(t)
|
||||
|
||||
err = T.client.LPush(T.TestCollection, T.TestMessage)
|
||||
require.Nil(t, err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
require.False(t, isErr)
|
||||
}
|
||||
|
||||
func TestRedisClient_BLPop(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
isErr := true
|
||||
go func(t *testing.T) {
|
||||
value, err := T.client.BLPop(T.TestCollection, 0)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessage, value)
|
||||
isErr = false
|
||||
}(t)
|
||||
|
||||
err = T.client.RPush(T.TestCollection, T.TestMessage)
|
||||
require.Nil(t, err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
require.False(t, isErr)
|
||||
}
|
||||
|
||||
func TestRedisClient_HSet_HGet_HDel(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for k, v := range T.TestMessagesMap {
|
||||
err = T.client.HSet(T.TestCollection, k, v)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
for k, v := range T.TestMessagesMap {
|
||||
vr, err := T.client.HGet(T.TestCollection, k)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, v, vr)
|
||||
}
|
||||
|
||||
for k := range T.TestMessagesMap {
|
||||
err = T.client.HDel(T.TestCollection, k)
|
||||
require.Nil(t, err)
|
||||
|
||||
v, err := T.client.HGet(T.TestCollection, k)
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisClient_HScan(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for k, v := range T.TestMessagesMap {
|
||||
err = T.client.HSet(T.TestCollection, k, v)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
results, err := T.client.HScan(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
|
||||
for k, vr := range results {
|
||||
v, ok := T.TestMessagesMap[k]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, v, vr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisClient_HKeys(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for k, v := range T.TestMessagesMap {
|
||||
err = T.client.HSet(T.TestCollection, k, v)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
keys, err := T.client.HKeys(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, k := range keys {
|
||||
_, ok := T.TestMessagesMap[k]
|
||||
require.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisClient_ZAdd_ZCount_ZCountAll_ZPopMax_ZPopMin(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
for i, v := range T.TestMessages {
|
||||
score := float32(i)
|
||||
err = T.client.ZAdd(T.TestCollection, score, v)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
count, err := T.client.ZCountAll(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, len(T.TestMessages), count)
|
||||
|
||||
value, err := T.client.ZPopMaxOne(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessages[len(T.TestMessages)-1], value)
|
||||
|
||||
value, err = T.client.ZPopMinOne(T.TestCollection)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessages[0], value)
|
||||
}
|
||||
|
||||
func TestRedisClient_BZPopMax_BZPopMin(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
isErr := true
|
||||
go func(t *testing.T) {
|
||||
value, err := T.client.BZPopMax(T.TestCollection, 0)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessage, value)
|
||||
isErr = false
|
||||
}(t)
|
||||
|
||||
err = T.client.ZAdd(T.TestCollection, 1, T.TestMessage)
|
||||
require.Nil(t, err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
require.False(t, isErr)
|
||||
|
||||
isErr = true
|
||||
go func(t *testing.T) {
|
||||
value, err := T.client.BZPopMin(T.TestCollection, 0)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, T.TestMessage, value)
|
||||
isErr = false
|
||||
}(t)
|
||||
|
||||
err = T.client.ZAdd(T.TestCollection, 1, T.TestMessage)
|
||||
require.Nil(t, err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
require.False(t, isErr)
|
||||
}
|
||||
|
||||
func TestRedisClient_Lock_Unlock(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
ts, err := T.client.Lock(T.TestLockKey)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = T.client.Lock(T.TestLockKey)
|
||||
require.NotNil(t, err)
|
||||
|
||||
T.client.UnLock(T.TestLockKey, ts)
|
||||
|
||||
ts, err = T.client.Lock(T.TestLockKey)
|
||||
require.Nil(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestRedisClient_MemoryStats(t *testing.T) {
|
||||
var err error
|
||||
T.Setup(t)
|
||||
|
||||
stats, err := T.client.MemoryStats()
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, k := range redis.MemoryStatsMetrics {
|
||||
v, ok := stats[k]
|
||||
require.True(t, ok)
|
||||
require.Greater(t, v, int64(-1))
|
||||
}
|
||||
}
|
||||
36
db/sql/sql.go
Normal file
36
db/sql/sql.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/crawlab-team/go-trace"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func GetSqlDatabaseConnectionString(dataSourceType string, host string, port string, username string, password string, database string) (connStr string, err error) {
|
||||
if dataSourceType == "mysql" {
|
||||
connStr = fmt.Sprintf("%s:%s@(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local", username, password, host, port, database)
|
||||
} else if dataSourceType == "postgres" {
|
||||
connStr = fmt.Sprintf("host=%s port=%s user=%s dbname=%s password=%s sslmode=%s", host, port, username, database, password, "disable")
|
||||
} else {
|
||||
err = errors.New(dataSourceType + " is not implemented")
|
||||
return connStr, trace.TraceError(err)
|
||||
}
|
||||
return connStr, nil
|
||||
}
|
||||
|
||||
func GetSqlConn(dataSourceType string, host string, port string, username string, password string, database string) (db *sqlx.DB, err error) {
|
||||
// get database connection string
|
||||
connStr, err := GetSqlDatabaseConnectionString(dataSourceType, host, port, username, password, database)
|
||||
if err != nil {
|
||||
return db, trace.TraceError(err)
|
||||
}
|
||||
|
||||
// get database instance
|
||||
db, err = sqlx.Open(dataSourceType, connStr)
|
||||
if err != nil {
|
||||
return db, trace.TraceError(err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
19
db/utils/utils.go
Normal file
19
db/utils/utils.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import "io"
|
||||
|
||||
func Close(c io.Closer) {
|
||||
err := c.Close()
|
||||
if err != nil {
|
||||
//log.WithError(err).Error("关闭资源文件失败。")
|
||||
}
|
||||
}
|
||||
|
||||
func ContainsString(list []string, item string) bool {
|
||||
for _, d := range list {
|
||||
if d == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user