feat: added modules

This commit is contained in:
Marvin Zhang
2024-06-14 15:59:48 +08:00
parent 0b67fd9ece
commit dc21bce11f
138 changed files with 3231 additions and 120 deletions

9
db/.editorconfig Normal file
View File

@@ -0,0 +1,9 @@
root = true
[*]
charset = utf-8
end_of_line = lf
indent_size = 4
indent_style = tab
insert_final_newline = true
trim_trailing_whitespace = true

5
db/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
.idea
.DS_Store
tmp/
vendor/
go.sum

29
db/LICENSE Normal file
View File

@@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2020, Crawlab Team
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

2
db/README.md Normal file
View File

@@ -0,0 +1,2 @@
# crawlab-db
Backend database module for Crawlab

6
db/errors/base.go Normal file
View File

@@ -0,0 +1,6 @@
package errors
const (
errorPrefixMongo = "mongo"
errorPrefixRedis = "redis"
)

10
db/errors/errors.go Normal file
View File

@@ -0,0 +1,10 @@
package errors
import "errors"
var (
ErrInvalidType = errors.New("invalid type")
ErrMissingValue = errors.New("missing value")
ErrNoCursor = errors.New("no cursor")
ErrAlreadyLocked = errors.New("already locked")
)

15
db/errors/redis.go Normal file
View File

@@ -0,0 +1,15 @@
package errors
import (
"errors"
"fmt"
)
var (
ErrorRedisInvalidType = NewRedisError("invalid type")
ErrorRedisLocked = NewRedisError("locked")
)
func NewRedisError(msg string) (err error) {
return errors.New(fmt.Sprintf("%s: %s", errorPrefixRedis, msg))
}

8
db/generic/base.go Normal file
View File

@@ -0,0 +1,8 @@
package generic
const (
DataSourceTypeMongo = "mongo"
DataSourceTypeMysql = "mysql"
DataSourceTypePostgres = "postgres"
DataSourceTypeElasticSearch = "postgres"
)

15
db/generic/list.go Normal file
View File

@@ -0,0 +1,15 @@
package generic
type ListQueryCondition struct {
Key string
Op string
Value interface{}
}
type ListQuery []ListQueryCondition
type ListOptions struct {
Skip int
Limit int
Sort []ListSort
}

7
db/generic/op.go Normal file
View File

@@ -0,0 +1,7 @@
package generic
type Op string
const (
OpEqual = "eq"
)

13
db/generic/sort.go Normal file
View File

@@ -0,0 +1,13 @@
package generic
type SortDirection string
const (
SortDirectionAsc SortDirection = "asc"
SortDirectionDesc SortDirection = "desc"
)
type ListSort struct {
Key string
Direction SortDirection
}

47
db/go.mod Normal file
View File

@@ -0,0 +1,47 @@
module github.com/crawlab-team/crawlab/db
go 1.22
require (
github.com/apex/log v1.9.0
github.com/cenkalti/backoff/v4 v4.1.0
github.com/crawlab-team/go-trace v0.1.0
github.com/gomodule/redigo v2.0.0+incompatible
github.com/jmoiron/sqlx v1.2.0
github.com/spf13/viper v1.7.1
github.com/stretchr/testify v1.6.1
go.mongodb.org/mongo-driver v1.15.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.4.7 // indirect
github.com/golang/snappy v0.0.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/klauspost/compress v1.13.6 // indirect
github.com/lib/pq v1.1.1 // indirect
github.com/logrusorgru/aurora v0.0.0-20181002194514-a7b3b318ed4e // indirect
github.com/magiconair/properties v1.8.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
github.com/pelletier/go-toml v1.7.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/afero v1.1.2 // indirect
github.com/spf13/cast v1.3.0 // indirect
github.com/spf13/jwalterweatherman v1.0.0 // indirect
github.com/spf13/pflag v1.0.3 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
github.com/ztrue/tracerr v0.3.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/ini.v1 v1.51.0 // indirect
gopkg.in/yaml.v2 v2.2.8 // indirect
gopkg.in/yaml.v3 v3.0.0 // indirect
)

39
db/interfaces.go Normal file
View File

@@ -0,0 +1,39 @@
package db
import "time"
type RedisClient interface {
Ping() (err error)
Keys(pattern string) (values []string, err error)
AllKeys() (values []string, err error)
Get(collection string) (value string, err error)
Set(collection string, value string) (err error)
Del(collection string) (err error)
RPush(collection string, value interface{}) (err error)
LPush(collection string, value interface{}) (err error)
LPop(collection string) (value string, err error)
RPop(collection string) (value string, err error)
LLen(collection string) (count int, err error)
BRPop(collection string, timeout int) (value string, err error)
BLPop(collection string, timeout int) (value string, err error)
HSet(collection string, key string, value string) (err error)
HGet(collection string, key string) (value string, err error)
HDel(collection string, key string) (err error)
HScan(collection string) (results map[string]string, err error)
HKeys(collection string) (results []string, err error)
ZAdd(collection string, score float32, value interface{}) (err error)
ZCount(collection string, min string, max string) (count int, err error)
ZCountAll(collection string) (count int, err error)
ZScan(collection string, pattern string, count int) (results []string, err error)
ZPopMax(collection string, count int) (results []string, err error)
ZPopMin(collection string, count int) (results []string, err error)
ZPopMaxOne(collection string) (value string, err error)
ZPopMinOne(collection string) (value string, err error)
BZPopMax(collection string, timeout int) (value string, err error)
BZPopMin(collection string, timeout int) (value string, err error)
Lock(lockKey string) (value int64, err error)
UnLock(lockKey string, value int64)
MemoryStats() (stats map[string]int64, err error)
SetBackoffMaxInterval(interval time.Duration)
SetTimeout(timeout int)
}

147
db/mongo/client.go Normal file
View File

@@ -0,0 +1,147 @@
package mongo
import (
"context"
"encoding/json"
"fmt"
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/go-trace"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"sync"
)
var AppName = "crawlab-db"
var _clientMap = map[string]*mongo.Client{}
var _mu sync.Mutex
func GetMongoClient(opts ...ClientOption) (c *mongo.Client, err error) {
// client options
_opts := &ClientOptions{}
for _, op := range opts {
op(_opts)
}
if _opts.Uri == "" {
_opts.Uri = viper.GetString("mongo.uri")
}
if _opts.Host == "" {
_opts.Host = viper.GetString("mongo.host")
if _opts.Host == "" {
_opts.Host = "localhost"
}
}
if _opts.Port == "" {
_opts.Port = viper.GetString("mongo.port")
if _opts.Port == "" {
_opts.Port = "27017"
}
}
if _opts.Db == "" {
_opts.Db = viper.GetString("mongo.db")
if _opts.Db == "" {
_opts.Db = "admin"
}
}
if len(_opts.Hosts) == 0 {
_opts.Hosts = viper.GetStringSlice("mongo.hosts")
}
if _opts.Username == "" {
_opts.Username = viper.GetString("mongo.username")
}
if _opts.Password == "" {
_opts.Password = viper.GetString("mongo.password")
}
if _opts.AuthSource == "" {
_opts.AuthSource = viper.GetString("mongo.authSource")
if _opts.AuthSource == "" {
_opts.AuthSource = "admin"
}
}
if _opts.AuthMechanism == "" {
_opts.AuthMechanism = viper.GetString("mongo.authMechanism")
}
if _opts.AuthMechanismProperties == nil {
_opts.AuthMechanismProperties = viper.GetStringMapString("mongo.authMechanismProperties")
}
// client options key json string
_optsKeyBytes, err := json.Marshal(_opts)
if err != nil {
return nil, trace.TraceError(err)
}
_optsKey := string(_optsKeyBytes)
// attempt to get client by client options
c, ok := _clientMap[_optsKey]
if ok {
return c, nil
}
// create new mongo client
c, err = newMongoClient(_opts.Context, _opts)
if err != nil {
return nil, err
}
// add to map
_mu.Lock()
_clientMap[_optsKey] = c
_mu.Unlock()
return c, nil
}
func newMongoClient(ctx context.Context, _opts *ClientOptions) (c *mongo.Client, err error) {
// mongo client options
mongoOpts := &options.ClientOptions{
AppName: &AppName,
}
if _opts.Uri != "" {
// uri is set
mongoOpts.ApplyURI(_opts.Uri)
} else {
// uri is unset
// username and password are set
if _opts.Username != "" && _opts.Password != "" {
mongoOpts.SetAuth(options.Credential{
AuthMechanism: _opts.AuthMechanism,
AuthMechanismProperties: _opts.AuthMechanismProperties,
AuthSource: _opts.AuthSource,
Username: _opts.Username,
Password: _opts.Password,
PasswordSet: true,
})
}
if len(_opts.Hosts) > 0 {
// hosts are set
mongoOpts.SetHosts(_opts.Hosts)
} else {
// hosts are unset
mongoOpts.ApplyURI(fmt.Sprintf("mongodb://%s:%s/%s", _opts.Host, _opts.Port, _opts.Db))
}
}
// attempt to connect with retry
bp := backoff.NewExponentialBackOff()
err = backoff.Retry(func() error {
errMsg := fmt.Sprintf("waiting for connect mongo database, after %f seconds try again.", bp.NextBackOff().Seconds())
c, err = mongo.NewClient(mongoOpts)
if err != nil {
log.WithError(err).Warnf(errMsg)
return err
}
if err := c.Connect(ctx); err != nil {
log.WithError(err).Warnf(errMsg)
return err
}
return nil
}, bp)
return c, nil
}

View File

@@ -0,0 +1,79 @@
package mongo
import "context"
type ClientOption func(options *ClientOptions)
type ClientOptions struct {
Context context.Context
Uri string
Host string
Port string
Db string
Hosts []string
Username string
Password string
AuthSource string
AuthMechanism string
AuthMechanismProperties map[string]string
}
func WithContext(ctx context.Context) ClientOption {
return func(options *ClientOptions) {
options.Context = ctx
}
}
func WithUri(value string) ClientOption {
return func(options *ClientOptions) {
options.Uri = value
}
}
func WithHost(value string) ClientOption {
return func(options *ClientOptions) {
options.Host = value
}
}
func WithPort(value string) ClientOption {
return func(options *ClientOptions) {
options.Port = value
}
}
func WithDb(value string) ClientOption {
return func(options *ClientOptions) {
options.Db = value
}
}
func WithHosts(value []string) ClientOption {
return func(options *ClientOptions) {
options.Hosts = value
}
}
func WithUsername(value string) ClientOption {
return func(options *ClientOptions) {
options.Username = value
}
}
func WithPassword(value string) ClientOption {
return func(options *ClientOptions) {
options.Password = value
}
}
func WithAuthSource(value string) ClientOption {
return func(options *ClientOptions) {
options.AuthSource = value
}
}
func WithAuthMechanism(value string) ClientOption {
return func(options *ClientOptions) {
options.AuthMechanism = value
}
}

23
db/mongo/client_test.go Normal file
View File

@@ -0,0 +1,23 @@
package mongo
import (
"github.com/stretchr/testify/require"
"testing"
)
func setupMongoTest() (err error) {
return nil
}
func cleanupMongoTest() {
}
func TestMongoInitMongo(t *testing.T) {
err := setupMongoTest()
require.Nil(t, err)
_, err = GetMongoClient()
require.Nil(t, err)
cleanupMongoTest()
}

301
db/mongo/col.go Normal file
View File

@@ -0,0 +1,301 @@
package mongo
import (
"context"
"github.com/crawlab-team/crawlab/db/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type ColInterface interface {
Insert(doc interface{}) (id primitive.ObjectID, err error)
InsertMany(docs []interface{}) (ids []primitive.ObjectID, err error)
UpdateId(id primitive.ObjectID, update interface{}) (err error)
Update(query bson.M, update interface{}) (err error)
UpdateWithOptions(query bson.M, update interface{}, opts *options.UpdateOptions) (err error)
ReplaceId(id primitive.ObjectID, doc interface{}) (err error)
Replace(query bson.M, doc interface{}) (err error)
ReplaceWithOptions(query bson.M, doc interface{}, opts *options.ReplaceOptions) (err error)
DeleteId(id primitive.ObjectID) (err error)
Delete(query bson.M) (err error)
DeleteWithOptions(query bson.M, opts *options.DeleteOptions) (err error)
Find(query bson.M, opts *FindOptions) (fr *FindResult)
FindId(id primitive.ObjectID) (fr *FindResult)
Count(query bson.M) (total int, err error)
Aggregate(pipeline mongo.Pipeline, opts *options.AggregateOptions) (fr *FindResult)
CreateIndex(indexModel mongo.IndexModel) (err error)
CreateIndexes(indexModels []mongo.IndexModel) (err error)
MustCreateIndex(indexModel mongo.IndexModel)
MustCreateIndexes(indexModels []mongo.IndexModel)
DeleteIndex(name string) (err error)
DeleteAllIndexes() (err error)
ListIndexes() (indexes []map[string]interface{}, err error)
GetContext() (ctx context.Context)
GetName() (name string)
GetCollection() (c *mongo.Collection)
}
type FindOptions struct {
Skip int
Limit int
Sort bson.D
}
type Col struct {
ctx context.Context
db *mongo.Database
c *mongo.Collection
}
func (col *Col) Insert(doc interface{}) (id primitive.ObjectID, err error) {
res, err := col.c.InsertOne(col.ctx, doc)
if err != nil {
return primitive.NilObjectID, trace.TraceError(err)
}
if id, ok := res.InsertedID.(primitive.ObjectID); ok {
return id, nil
}
return primitive.NilObjectID, trace.TraceError(errors.ErrInvalidType)
}
func (col *Col) InsertMany(docs []interface{}) (ids []primitive.ObjectID, err error) {
res, err := col.c.InsertMany(col.ctx, docs)
if err != nil {
return nil, trace.TraceError(err)
}
for _, v := range res.InsertedIDs {
switch v.(type) {
case primitive.ObjectID:
id := v.(primitive.ObjectID)
ids = append(ids, id)
default:
return nil, trace.TraceError(errors.ErrInvalidType)
}
}
return ids, nil
}
func (col *Col) UpdateId(id primitive.ObjectID, update interface{}) (err error) {
_, err = col.c.UpdateOne(col.ctx, bson.M{"_id": id}, update)
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) Update(query bson.M, update interface{}) (err error) {
return col.UpdateWithOptions(query, update, nil)
}
func (col *Col) UpdateWithOptions(query bson.M, update interface{}, opts *options.UpdateOptions) (err error) {
if opts == nil {
_, err = col.c.UpdateMany(col.ctx, query, update)
} else {
_, err = col.c.UpdateMany(col.ctx, query, update, opts)
}
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) ReplaceId(id primitive.ObjectID, doc interface{}) (err error) {
return col.Replace(bson.M{"_id": id}, doc)
}
func (col *Col) Replace(query bson.M, doc interface{}) (err error) {
return col.ReplaceWithOptions(query, doc, nil)
}
func (col *Col) ReplaceWithOptions(query bson.M, doc interface{}, opts *options.ReplaceOptions) (err error) {
if opts == nil {
_, err = col.c.ReplaceOne(col.ctx, query, doc)
} else {
_, err = col.c.ReplaceOne(col.ctx, query, doc, opts)
}
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) DeleteId(id primitive.ObjectID) (err error) {
_, err = col.c.DeleteOne(col.ctx, bson.M{"_id": id})
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) Delete(query bson.M) (err error) {
return col.DeleteWithOptions(query, nil)
}
func (col *Col) DeleteWithOptions(query bson.M, opts *options.DeleteOptions) (err error) {
if opts == nil {
_, err = col.c.DeleteMany(col.ctx, query)
} else {
_, err = col.c.DeleteMany(col.ctx, query, opts)
}
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) Find(query bson.M, opts *FindOptions) (fr *FindResult) {
_opts := &options.FindOptions{}
if opts != nil {
if opts.Skip != 0 {
skipInt64 := int64(opts.Skip)
_opts.Skip = &skipInt64
}
if opts.Limit != 0 {
limitInt64 := int64(opts.Limit)
_opts.Limit = &limitInt64
}
if opts.Sort != nil {
_opts.Sort = opts.Sort
}
}
cur, err := col.c.Find(col.ctx, query, _opts)
if err != nil {
return &FindResult{
col: col,
err: err,
}
}
fr = &FindResult{
col: col,
cur: cur,
}
return fr
}
func (col *Col) FindId(id primitive.ObjectID) (fr *FindResult) {
res := col.c.FindOne(col.ctx, bson.M{"_id": id})
if res.Err() != nil {
return &FindResult{
col: col,
err: res.Err(),
}
}
fr = &FindResult{
col: col,
res: res,
}
return fr
}
func (col *Col) Count(query bson.M) (total int, err error) {
totalInt64, err := col.c.CountDocuments(col.ctx, query)
if err != nil {
return 0, err
}
total = int(totalInt64)
return total, nil
}
func (col *Col) Aggregate(pipeline mongo.Pipeline, opts *options.AggregateOptions) (fr *FindResult) {
cur, err := col.c.Aggregate(col.ctx, pipeline, opts)
if err != nil {
return &FindResult{
col: col,
err: err,
}
}
if cur.Err() != nil {
return &FindResult{
col: col,
err: cur.Err(),
}
}
fr = &FindResult{
col: col,
cur: cur,
}
return fr
}
func (col *Col) CreateIndex(indexModel mongo.IndexModel) (err error) {
_, err = col.c.Indexes().CreateOne(col.ctx, indexModel)
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) CreateIndexes(indexModels []mongo.IndexModel) (err error) {
_, err = col.c.Indexes().CreateMany(col.ctx, indexModels)
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) MustCreateIndex(indexModel mongo.IndexModel) {
_, _ = col.c.Indexes().CreateOne(col.ctx, indexModel)
}
func (col *Col) MustCreateIndexes(indexModels []mongo.IndexModel) {
_, _ = col.c.Indexes().CreateMany(col.ctx, indexModels)
}
func (col *Col) DeleteIndex(name string) (err error) {
_, err = col.c.Indexes().DropOne(col.ctx, name)
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) DeleteAllIndexes() (err error) {
_, err = col.c.Indexes().DropAll(col.ctx)
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (col *Col) ListIndexes() (indexes []map[string]interface{}, err error) {
cur, err := col.c.Indexes().List(col.ctx)
if err != nil {
return nil, err
}
if err := cur.All(col.ctx, &indexes); err != nil {
return nil, err
}
return indexes, nil
}
func (col *Col) GetContext() (ctx context.Context) {
return col.ctx
}
func (col *Col) GetName() (name string) {
return col.c.Name()
}
func (col *Col) GetCollection() (c *mongo.Collection) {
return col.c
}
func GetMongoCol(colName string) (col *Col) {
return GetMongoColWithDb(colName, nil)
}
func GetMongoColWithDb(colName string, db *mongo.Database) (col *Col) {
ctx := context.Background()
if db == nil {
db = GetMongoDb("")
}
c := db.Collection(colName)
col = &Col{
ctx: ctx,
db: db,
c: c,
}
return col
}

463
db/mongo/col_test.go Normal file
View File

@@ -0,0 +1,463 @@
package mongo
import (
"fmt"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"strconv"
"testing"
)
type ColTestObject struct {
dbName string
colName string
col *Col
}
type TestDocument struct {
Key string `bson:"key"`
Value int `bson:"value"`
Tags []string `bson:"tags"`
}
type TestAggregateResult struct {
Id string `bson:"_id"`
Count int `bson:"count"`
Value int `bson:"value"`
}
func setupColTest() (to *ColTestObject, err error) {
dbName := "test_db"
colName := "test_col"
viper.Set("mongo.db", dbName)
col := GetMongoCol(colName)
if err := col.db.Drop(col.ctx); err != nil {
return nil, err
}
return &ColTestObject{
dbName: dbName,
colName: colName,
col: col,
}, nil
}
func cleanupColTest(to *ColTestObject) {
_ = to.col.db.Drop(to.col.ctx)
}
func TestGetMongoCol(t *testing.T) {
colName := "test_col"
col := GetMongoCol(colName)
require.Equal(t, colName, col.c.Name())
}
func TestGetMongoColWithDb(t *testing.T) {
dbName := "test_db"
colName := "test_col"
col := GetMongoColWithDb(colName, dbName)
require.Equal(t, colName, col.c.Name())
require.Equal(t, dbName, col.db.Name())
}
func TestCol_Insert(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
id, err := to.col.Insert(bson.M{"key": "value"})
require.Nil(t, err)
require.IsType(t, primitive.ObjectID{}, id)
var doc map[string]string
err = to.col.FindId(id).One(&doc)
require.Nil(t, err)
require.Equal(t, doc["key"], "value")
cleanupColTest(to)
}
func TestCol_InsertMany(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
n := 10
var docs []interface{}
for i := 0; i < n; i++ {
docs = append(docs, bson.M{"key": fmt.Sprintf("value-%d", i)})
}
ids, err := to.col.InsertMany(docs)
require.Nil(t, err)
require.Equal(t, n, len(ids))
var resDocs []map[string]string
err = to.col.Find(nil, &FindOptions{Sort: bson.D{{"_id", 1}}}).All(&resDocs)
require.Nil(t, err)
require.Equal(t, n, len(resDocs))
for i, doc := range resDocs {
require.Equal(t, fmt.Sprintf("value-%d", i), doc["key"])
}
cleanupColTest(to)
}
func TestCol_UpdateId(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
id, err := to.col.Insert(bson.M{"key": "old-value"})
require.Nil(t, err)
err = to.col.UpdateId(id, bson.M{
"$set": bson.M{
"key": "new-value",
},
})
require.Nil(t, err)
var doc map[string]string
err = to.col.FindId(id).One(&doc)
require.Nil(t, err)
require.Equal(t, "new-value", doc["key"])
cleanupColTest(to)
}
func TestCol_Update(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
n := 10
var docs []interface{}
for i := 0; i < n; i++ {
docs = append(docs, bson.M{"key": fmt.Sprintf("old-value-%d", i)})
}
err = to.col.Update(nil, bson.M{
"$set": bson.M{
"key": "new-value",
},
})
require.Nil(t, err)
var resDocs []map[string]string
err = to.col.Find(nil, &FindOptions{Sort: bson.D{{"_id", 1}}}).All(&resDocs)
require.Nil(t, err)
for _, doc := range resDocs {
require.Equal(t, "new-value", doc["key"])
}
cleanupColTest(to)
}
func TestCol_ReplaceId(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
id, err := to.col.Insert(bson.M{"key": "old-value"})
require.Nil(t, err)
var doc map[string]interface{}
err = to.col.FindId(id).One(&doc)
require.Nil(t, err)
doc["key"] = "new-value"
err = to.col.ReplaceId(id, doc)
require.Nil(t, err)
err = to.col.FindId(id).One(doc)
require.Nil(t, err)
require.Equal(t, "new-value", doc["key"])
cleanupColTest(to)
}
func TestCol_Replace(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
id, err := to.col.Insert(bson.M{"key": "old-value"})
require.Nil(t, err)
var doc map[string]interface{}
err = to.col.FindId(id).One(&doc)
require.Nil(t, err)
doc["key"] = "new-value"
err = to.col.Replace(bson.M{"key": "old-value"}, doc)
require.Nil(t, err)
err = to.col.FindId(id).One(&doc)
require.Nil(t, err)
require.Equal(t, "new-value", doc["key"])
cleanupColTest(to)
}
func TestCol_DeleteId(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
id, err := to.col.Insert(bson.M{"key": "value"})
require.Nil(t, err)
err = to.col.DeleteId(id)
require.Nil(t, err)
total, err := to.col.Count(nil)
require.Nil(t, err)
require.Equal(t, 0, total)
cleanupColTest(to)
}
func TestCol_Delete(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
n := 10
var docs []interface{}
for i := 0; i < n; i++ {
docs = append(docs, bson.M{"key": fmt.Sprintf("value-%d", i)})
}
ids, err := to.col.InsertMany(docs)
require.Nil(t, err)
require.Equal(t, n, len(ids))
err = to.col.Delete(bson.M{"key": "value-0"})
require.Nil(t, err)
total, err := to.col.Count(nil)
require.Nil(t, err)
require.Equal(t, n-1, total)
err = to.col.Delete(nil)
require.Nil(t, err)
total, err = to.col.Count(nil)
require.Nil(t, err)
require.Equal(t, 0, total)
cleanupColTest(to)
}
func TestCol_FindId(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
id, err := to.col.Insert(bson.M{"key": "value"})
require.Nil(t, err)
var doc map[string]string
err = to.col.FindId(id).One(&doc)
require.Nil(t, err)
require.Equal(t, "value", doc["key"])
cleanupColTest(to)
}
func TestCol_Find(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
n := 10
var docs []interface{}
for i := 0; i < n; i++ {
docs = append(docs, TestDocument{
Key: fmt.Sprintf("value-%d", i),
Tags: []string{"test tag"},
})
}
ids, err := to.col.InsertMany(docs)
require.Nil(t, err)
require.Equal(t, n, len(ids))
err = to.col.Find(nil, nil).All(&docs)
require.Nil(t, err)
require.Equal(t, n, len(docs))
err = to.col.Find(bson.M{"key": bson.M{"$gte": fmt.Sprintf("value-%d", 5)}}, nil).All(&docs)
require.Nil(t, err)
require.Equal(t, n-5, len(docs))
err = to.col.Find(nil, &FindOptions{
Skip: 5,
}).All(&docs)
require.Nil(t, err)
require.Equal(t, n-5, len(docs))
err = to.col.Find(nil, &FindOptions{
Limit: 5,
}).All(&docs)
require.Nil(t, err)
require.Equal(t, 5, len(docs))
var resDocs []TestDocument
err = to.col.Find(nil, &FindOptions{
Sort: bson.D{{"key", 1}},
}).All(&resDocs)
require.Nil(t, err)
require.Greater(t, len(resDocs), 0)
require.Equal(t, "value-0", resDocs[0].Key)
err = to.col.Find(nil, &FindOptions{
Sort: bson.D{{"key", -1}},
}).All(&resDocs)
require.Nil(t, err)
require.Greater(t, len(resDocs), 0)
require.Equal(t, fmt.Sprintf("value-%d", n-1), resDocs[0].Key)
var resDocs2 []TestDocument
err = to.col.Find(bson.M{"tags": bson.M{"$in": []string{"test tag"}}}, nil).All(&resDocs2)
require.Nil(t, err)
require.Greater(t, len(resDocs2), 0)
cleanupColTest(to)
}
func TestCol_CreateIndex(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
err = to.col.CreateIndex(mongo.IndexModel{
Keys: bson.D{{"key", 1}},
})
require.Nil(t, err)
indexes, err := to.col.ListIndexes()
require.Nil(t, err)
require.Equal(t, 2, len(indexes))
cleanupColTest(to)
}
func TestCol_Aggregate(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
n := 10
v := 2
var docs []interface{}
for i := 0; i < n; i++ {
docs = append(docs, TestDocument{
Key: fmt.Sprintf("%d", i%2),
Value: v,
})
}
ids, err := to.col.InsertMany(docs)
require.Nil(t, err)
require.Equal(t, n, len(ids))
pipeline := mongo.Pipeline{
{
{
"$group",
bson.D{
{"_id", "$key"},
{
"count",
bson.D{{"$sum", 1}},
},
{
"value",
bson.D{{"$sum", "$value"}},
},
},
},
},
{
{
"$sort",
bson.D{{"_id", 1}},
},
},
}
var results []TestAggregateResult
err = to.col.Aggregate(pipeline, nil).All(&results)
require.Nil(t, err)
require.Equal(t, 2, len(results))
for i, r := range results {
require.Equal(t, strconv.Itoa(i), r.Id)
require.Equal(t, n/2, r.Count)
require.Equal(t, n*v/2, r.Value)
}
}
func TestCol_CreateIndexes(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
err = to.col.CreateIndexes([]mongo.IndexModel{
{
Keys: bson.D{{"key", 1}},
},
{
Keys: bson.D{{"empty-key", 1}},
},
})
require.Nil(t, err)
indexes, err := to.col.ListIndexes()
require.Nil(t, err)
require.Equal(t, 3, len(indexes))
cleanupColTest(to)
}
func TestCol_DeleteIndex(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
err = to.col.CreateIndex(mongo.IndexModel{
Keys: bson.D{{"key", 1}},
})
require.Nil(t, err)
indexes, err := to.col.ListIndexes()
require.Nil(t, err)
require.Equal(t, 2, len(indexes))
for _, index := range indexes {
name, ok := index["name"].(string)
require.True(t, ok)
if name != "_id_" {
err = to.col.DeleteIndex(name)
require.Nil(t, err)
break
}
}
indexes, err = to.col.ListIndexes()
require.Nil(t, err)
require.Equal(t, 1, len(indexes))
cleanupColTest(to)
}
func TestCol_DeleteIndexes(t *testing.T) {
to, err := setupColTest()
require.Nil(t, err)
err = to.col.CreateIndexes([]mongo.IndexModel{
{
Keys: bson.D{{"key", 1}},
},
{
Keys: bson.D{{"empty-key", 1}},
},
})
require.Nil(t, err)
err = to.col.DeleteAllIndexes()
require.Nil(t, err)
indexes, err := to.col.ListIndexes()
require.Nil(t, err)
require.Equal(t, 1, len(indexes))
cleanupColTest(to)
}

35
db/mongo/db.go Normal file
View File

@@ -0,0 +1,35 @@
package mongo
import (
"github.com/crawlab-team/go-trace"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/mongo"
)
func GetMongoDb(dbName string, opts ...DbOption) (db *mongo.Database) {
if dbName == "" {
dbName = viper.GetString("mongo.db")
}
if dbName == "" {
dbName = "test"
}
_opts := &DbOptions{}
for _, op := range opts {
op(_opts)
}
var c *mongo.Client
if _opts.client == nil {
var err error
c, err = GetMongoClient()
if err != nil {
trace.PrintError(err)
return nil
}
} else {
c = _opts.client
}
return c.Database(dbName, nil)
}

15
db/mongo/db_options.go Normal file
View File

@@ -0,0 +1,15 @@
package mongo
import "go.mongodb.org/mongo-driver/mongo"
type DbOption func(options *DbOptions)
type DbOptions struct {
client *mongo.Client
}
func WithDbClient(c *mongo.Client) DbOption {
return func(options *DbOptions) {
options.client = c
}
}

17
db/mongo/db_test.go Normal file
View File

@@ -0,0 +1,17 @@
package mongo
import (
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"testing"
)
func TestMongoGetDb(t *testing.T) {
dbName := "test_db"
viper.Set("mongo.db", dbName)
err := InitMongo()
require.Nil(t, err)
db := GetMongoDb("")
require.Equal(t, dbName, db.Name())
}

82
db/mongo/result.go Normal file
View File

@@ -0,0 +1,82 @@
package mongo
import (
"context"
"github.com/crawlab-team/crawlab/db/errors"
"go.mongodb.org/mongo-driver/mongo"
)
type FindResultInterface interface {
One(val interface{}) (err error)
All(val interface{}) (err error)
GetCol() (col *Col)
GetSingleResult() (res *mongo.SingleResult)
GetCursor() (cur *mongo.Cursor)
GetError() (err error)
}
func NewFindResult() (fr *FindResult) {
return &FindResult{}
}
func NewFindResultWithError(err error) (fr *FindResult) {
return &FindResult{
err: err,
}
}
type FindResult struct {
col *Col
res *mongo.SingleResult
cur *mongo.Cursor
err error
}
func (fr *FindResult) GetError() (err error) {
//TODO implement me
panic("implement me")
}
func (fr *FindResult) One(val interface{}) (err error) {
if fr.err != nil {
return fr.err
}
if fr.cur != nil {
if !fr.cur.TryNext(fr.col.ctx) {
return mongo.ErrNoDocuments
}
return fr.cur.Decode(val)
}
return fr.res.Decode(val)
}
func (fr *FindResult) All(val interface{}) (err error) {
if fr.err != nil {
return fr.err
}
var ctx context.Context
if fr.col == nil {
ctx = context.Background()
} else {
ctx = fr.col.ctx
}
if fr.cur == nil {
return errors.ErrNoCursor
}
if !fr.cur.TryNext(ctx) {
return ctx.Err()
}
return fr.cur.All(ctx, val)
}
func (fr *FindResult) GetCol() (col *Col) {
return fr.col
}
func (fr *FindResult) GetSingleResult() (res *mongo.SingleResult) {
return fr.res
}
func (fr *FindResult) GetCursor() (cur *mongo.Cursor) {
return fr.cur
}

45
db/mongo/transaction.go Normal file
View File

@@ -0,0 +1,45 @@
package mongo
import (
"context"
"github.com/crawlab-team/go-trace"
"go.mongodb.org/mongo-driver/mongo"
)
func RunTransaction(fn func(mongo.SessionContext) error) (err error) {
return RunTransactionWithContext(context.Background(), fn)
}
func RunTransactionWithContext(ctx context.Context, fn func(mongo.SessionContext) error) (err error) {
// default client
c, err := GetMongoClient()
if err != nil {
return err
}
// start session
s, err := c.StartSession()
if err != nil {
return trace.TraceError(err)
}
// start transaction
if err := s.StartTransaction(); err != nil {
return trace.TraceError(err)
}
// perform operation
if err := mongo.WithSession(ctx, s, func(sc mongo.SessionContext) error {
if err := fn(sc); err != nil {
return trace.TraceError(err)
}
if err = s.CommitTransaction(sc); err != nil {
return trace.TraceError(err)
}
return nil
}); err != nil {
return trace.TraceError(err)
}
return nil
}

549
db/redis/client.go Normal file
View File

@@ -0,0 +1,549 @@
package redis
import (
"github.com/apex/log"
"github.com/crawlab-team/crawlab/db"
"github.com/crawlab-team/crawlab/db/errors"
"github.com/crawlab-team/crawlab/db/utils"
"github.com/gomodule/redigo/redis"
"reflect"
"strings"
"time"
)
type Client struct {
// settings
backoffMaxInterval time.Duration
timeout int
// internals
pool *redis.Pool
}
func (client *Client) Ping() error {
c := client.pool.Get()
defer utils.Close(c)
if _, err := redis.String(c.Do("PING")); err != nil {
if err != redis.ErrNil {
return trace.TraceError(err)
}
return err
}
return nil
}
func (client *Client) Keys(pattern string) (values []string, err error) {
c := client.pool.Get()
defer utils.Close(c)
values, err = redis.Strings(c.Do("KEYS", pattern))
if err != nil {
return nil, trace.TraceError(err)
}
return values, nil
}
func (client *Client) AllKeys() (values []string, err error) {
return client.Keys("*")
}
func (client *Client) Get(collection string) (value string, err error) {
c := client.pool.Get()
defer utils.Close(c)
value, err = redis.String(c.Do("GET", collection))
if err != nil {
return "", trace.TraceError(err)
}
return value, nil
}
func (client *Client) Set(collection string, value string) (err error) {
c := client.pool.Get()
defer utils.Close(c)
value, err = redis.String(c.Do("SET", collection, value))
if err != nil {
return trace.TraceError(err)
}
return nil
}
func (client *Client) Del(collection string) error {
c := client.pool.Get()
defer utils.Close(c)
if _, err := c.Do("DEL", collection); err != nil {
return trace.TraceError(err)
}
return nil
}
func (client *Client) RPush(collection string, value interface{}) error {
c := client.pool.Get()
defer utils.Close(c)
if _, err := c.Do("RPUSH", collection, value); err != nil {
return trace.TraceError(err)
}
return nil
}
func (client *Client) LPush(collection string, value interface{}) error {
c := client.pool.Get()
defer utils.Close(c)
if _, err := c.Do("LPUSH", collection, value); err != nil {
if err != redis.ErrNil {
return trace.TraceError(err)
}
return err
}
return nil
}
func (client *Client) LPop(collection string) (string, error) {
c := client.pool.Get()
defer utils.Close(c)
value, err := redis.String(c.Do("LPOP", collection))
if err != nil {
if err != redis.ErrNil {
return value, trace.TraceError(err)
}
return value, err
}
return value, nil
}
func (client *Client) RPop(collection string) (string, error) {
c := client.pool.Get()
defer utils.Close(c)
value, err := redis.String(c.Do("RPOP", collection))
if err != nil {
if err != redis.ErrNil {
return value, trace.TraceError(err)
}
return value, err
}
return value, nil
}
func (client *Client) LLen(collection string) (int, error) {
c := client.pool.Get()
defer utils.Close(c)
value, err := redis.Int(c.Do("LLEN", collection))
if err != nil {
return 0, trace.TraceError(err)
}
return value, nil
}
func (client *Client) BRPop(collection string, timeout int) (value string, err error) {
if timeout <= 0 {
timeout = 60
}
c := client.pool.Get()
defer utils.Close(c)
values, err := redis.Strings(c.Do("BRPOP", collection, timeout))
if err != nil {
if err != redis.ErrNil {
return value, trace.TraceError(err)
}
return value, err
}
return values[1], nil
}
func (client *Client) BLPop(collection string, timeout int) (value string, err error) {
if timeout <= 0 {
timeout = 60
}
c := client.pool.Get()
defer utils.Close(c)
values, err := redis.Strings(c.Do("BLPOP", collection, timeout))
if err != nil {
if err != redis.ErrNil {
return value, trace.TraceError(err)
}
return value, err
}
return values[1], nil
}
func (client *Client) HSet(collection string, key string, value string) error {
c := client.pool.Get()
defer utils.Close(c)
if _, err := c.Do("HSET", collection, key, value); err != nil {
if err != redis.ErrNil {
return trace.TraceError(err)
}
return err
}
return nil
}
func (client *Client) HGet(collection string, key string) (string, error) {
c := client.pool.Get()
defer utils.Close(c)
value, err := redis.String(c.Do("HGET", collection, key))
if err != nil && err != redis.ErrNil {
if err != redis.ErrNil {
return value, trace.TraceError(err)
}
return value, err
}
return value, nil
}
func (client *Client) HDel(collection string, key string) error {
c := client.pool.Get()
defer utils.Close(c)
if _, err := c.Do("HDEL", collection, key); err != nil {
return trace.TraceError(err)
}
return nil
}
func (client *Client) HScan(collection string) (results map[string]string, err error) {
c := client.pool.Get()
defer utils.Close(c)
var (
cursor int64
items []string
)
results = map[string]string{}
for {
values, err := redis.Values(c.Do("HSCAN", collection, cursor))
if err != nil {
if err != redis.ErrNil {
return nil, trace.TraceError(err)
}
return nil, err
}
values, err = redis.Scan(values, &cursor, &items)
if err != nil {
if err != redis.ErrNil {
return nil, trace.TraceError(err)
}
return nil, err
}
for i := 0; i < len(items); i += 2 {
key := items[i]
value := items[i+1]
results[key] = value
}
if cursor == 0 {
break
}
}
return results, nil
}
func (client *Client) HKeys(collection string) (results []string, err error) {
c := client.pool.Get()
defer utils.Close(c)
results, err = redis.Strings(c.Do("HKEYS", collection))
if err != nil {
if err != redis.ErrNil {
return results, trace.TraceError(err)
}
return results, err
}
return results, nil
}
func (client *Client) ZAdd(collection string, score float32, value interface{}) (err error) {
c := client.pool.Get()
defer utils.Close(c)
if _, err := c.Do("ZADD", collection, score, value); err != nil {
return trace.TraceError(err)
}
return nil
}
func (client *Client) ZCount(collection string, min string, max string) (count int, err error) {
c := client.pool.Get()
defer utils.Close(c)
count, err = redis.Int(c.Do("ZCOUNT", collection, min, max))
if err != nil {
return 0, trace.TraceError(err)
}
return count, nil
}
func (client *Client) ZCountAll(collection string) (count int, err error) {
return client.ZCount(collection, "-inf", "+inf")
}
func (client *Client) ZScan(collection string, pattern string, count int) (values []string, err error) {
c := client.pool.Get()
defer utils.Close(c)
values, err = redis.Strings(c.Do("ZSCAN", collection, 0, pattern, count))
if err != nil {
if err != redis.ErrNil {
return nil, trace.TraceError(err)
}
return nil, err
}
return values, nil
}
func (client *Client) ZPopMax(collection string, count int) (results []string, err error) {
c := client.pool.Get()
defer utils.Close(c)
results = []string{}
values, err := redis.Strings(c.Do("ZPOPMAX", collection, count))
if err != nil {
if err != redis.ErrNil {
return nil, trace.TraceError(err)
}
return nil, err
}
for i := 0; i < len(values); i += 2 {
v := values[i]
results = append(results, v)
}
return results, nil
}
func (client *Client) ZPopMin(collection string, count int) (results []string, err error) {
c := client.pool.Get()
defer utils.Close(c)
results = []string{}
values, err := redis.Strings(c.Do("ZPOPMIN", collection, count))
if err != nil {
if err != redis.ErrNil {
return nil, trace.TraceError(err)
}
return nil, err
}
for i := 0; i < len(values); i += 2 {
v := values[i]
results = append(results, v)
}
return results, nil
}
func (client *Client) ZPopMaxOne(collection string) (value string, err error) {
c := client.pool.Get()
defer utils.Close(c)
values, err := client.ZPopMax(collection, 1)
if err != nil {
return "", err
}
if values == nil || len(values) == 0 {
return "", nil
}
return values[0], nil
}
func (client *Client) ZPopMinOne(collection string) (value string, err error) {
c := client.pool.Get()
defer utils.Close(c)
values, err := client.ZPopMin(collection, 1)
if err != nil {
return "", err
}
if values == nil || len(values) == 0 {
return "", nil
}
return values[0], nil
}
func (client *Client) BZPopMax(collection string, timeout int) (value string, err error) {
c := client.pool.Get()
defer utils.Close(c)
values, err := redis.Strings(c.Do("BZPOPMAX", collection, timeout))
if err != nil {
if err != redis.ErrNil {
return "", trace.TraceError(err)
}
return "", err
}
if len(values) < 3 {
return "", trace.TraceError(errors.ErrorRedisInvalidType)
}
return values[1], nil
}
func (client *Client) BZPopMin(collection string, timeout int) (value string, err error) {
c := client.pool.Get()
defer utils.Close(c)
values, err := redis.Strings(c.Do("BZPOPMIN", collection, timeout))
if err != nil {
if err != redis.ErrNil {
return "", trace.TraceError(err)
}
return "", err
}
if len(values) < 3 {
return "", trace.TraceError(errors.ErrorRedisInvalidType)
}
return values[1], nil
}
func (client *Client) Lock(lockKey string) (value int64, err error) {
c := client.pool.Get()
defer utils.Close(c)
lockKey = client.getLockKey(lockKey)
ts := time.Now().Unix()
ok, err := c.Do("SET", lockKey, ts, "NX", "PX", 30000)
if err != nil {
if err != redis.ErrNil {
return value, trace.TraceError(err)
}
return value, err
}
if ok == nil {
return 0, trace.TraceError(errors.ErrorRedisLocked)
}
return ts, nil
}
func (client *Client) UnLock(lockKey string, value int64) {
c := client.pool.Get()
defer utils.Close(c)
lockKey = client.getLockKey(lockKey)
getValue, err := redis.Int64(c.Do("GET", lockKey))
if err != nil {
log.Errorf("get lockKey error: %s", err.Error())
return
}
if getValue != value {
log.Errorf("the lockKey value diff: %d, %d", value, getValue)
return
}
v, err := redis.Int64(c.Do("DEL", lockKey))
if err != nil {
log.Errorf("unlock failed, error: %s", err.Error())
return
}
if v == 0 {
log.Errorf("unlock failed: key=%s", lockKey)
return
}
}
func (client *Client) MemoryStats() (stats map[string]int64, err error) {
stats = map[string]int64{}
c := client.pool.Get()
defer utils.Close(c)
values, err := redis.Values(c.Do("MEMORY", "STATS"))
for i, v := range values {
t := reflect.TypeOf(v)
if t.Kind() == reflect.Slice {
vc, _ := redis.String(v, err)
if utils.ContainsString(MemoryStatsMetrics, vc) {
stats[vc], _ = redis.Int64(values[i+1], err)
}
}
}
if err != nil {
if err != redis.ErrNil {
return stats, trace.TraceError(err)
}
return stats, err
}
return stats, nil
}
func (client *Client) SetBackoffMaxInterval(interval time.Duration) {
client.backoffMaxInterval = interval
}
func (client *Client) SetTimeout(timeout int) {
client.timeout = timeout
}
func (client *Client) init() (err error) {
b := backoff.NewExponentialBackOff()
b.MaxInterval = client.backoffMaxInterval
if err := backoff.Retry(func() error {
err := client.Ping()
if err != nil {
log.WithError(err).Warnf("waiting for redis pool active connection. will after %f seconds try again.", b.NextBackOff().Seconds())
}
return nil
}, b); err != nil {
return trace.TraceError(err)
}
return nil
}
func (client *Client) getLockKey(lockKey string) string {
lockKey = strings.ReplaceAll(lockKey, ":", "-")
return "nodes:lock:" + lockKey
}
func (client *Client) getTimeout(timeout int) (res int) {
if timeout == 0 {
return client.timeout
}
return timeout
}
var client db.RedisClient
func NewRedisClient(opts ...Option) (client *Client, err error) {
// client
client = &Client{
backoffMaxInterval: 20 * time.Second,
pool: NewRedisPool(),
}
// apply options
for _, opt := range opts {
opt(client)
}
// init
if err := client.init(); err != nil {
return nil, err
}
return client, nil
}
func GetRedisClient() (c db.RedisClient, err error) {
if client != nil {
return client, nil
}
c, err = NewRedisClient()
if err != nil {
return nil, err
}
return c, nil
}

10
db/redis/constants.go Normal file
View File

@@ -0,0 +1,10 @@
package redis
var MemoryStatsMetrics = []string{
"peak.allocated",
"total.allocated",
"startup.allocated",
"overhead.total",
"keys.count",
"dataset.bytes",
}

20
db/redis/options.go Normal file
View File

@@ -0,0 +1,20 @@
package redis
import (
"github.com/crawlab-team/crawlab/db"
"time"
)
type Option func(c db.RedisClient)
func WithBackoffMaxInterval(interval time.Duration) Option {
return func(c db.RedisClient) {
c.SetBackoffMaxInterval(interval)
}
}
func WithTimeout(timeout int) Option {
return func(c db.RedisClient) {
c.SetTimeout(timeout)
}
}

54
db/redis/pool.go Normal file
View File

@@ -0,0 +1,54 @@
package redis
import (
"github.com/crawlab-team/go-trace"
"github.com/gomodule/redigo/redis"
"github.com/spf13/viper"
"time"
)
func NewRedisPool() *redis.Pool {
var address = viper.GetString("redis.address")
var port = viper.GetString("redis.port")
var database = viper.GetString("redis.database")
var password = viper.GetString("redis.password")
// normalize params
if address == "" {
address = "localhost"
}
if port == "" {
port = "6379"
}
if database == "" {
database = "1"
}
var url string
if password == "" {
url = "redis://" + address + ":" + port + "/" + database
} else {
url = "redis://x:" + password + "@" + address + ":" + port + "/" + database
}
return &redis.Pool{
Dial: func() (conn redis.Conn, e error) {
return redis.DialURL(url,
redis.DialConnectTimeout(time.Second*10),
redis.DialReadTimeout(time.Second*600),
redis.DialWriteTimeout(time.Second*10),
)
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
if time.Since(t) < time.Minute {
return nil
}
_, err := c.Do("PING")
return trace.TraceError(err)
},
MaxIdle: 10,
MaxActive: 0,
IdleTimeout: 300 * time.Second,
Wait: false,
MaxConnLifetime: 0,
}
}

91
db/redis/test/base.go Normal file
View File

@@ -0,0 +1,91 @@
package test
import (
"github.com/crawlab-team/crawlab/db"
"github.com/crawlab-team/crawlab/db/redis"
"testing"
)
func init() {
var err error
T, err = NewTest()
if err != nil {
panic(err)
}
}
type Test struct {
client db.RedisClient
TestCollection string
TestMessage string
TestMessages []string
TestMessagesMap map[string]string
TestKeysAlpha []string
TestKeysBeta []string
TestLockKey string
}
func (t *Test) Setup(t2 *testing.T) {
t2.Cleanup(t.Cleanup)
}
func (t *Test) Cleanup() {
keys, _ := t.client.AllKeys()
for _, key := range keys {
_ = t.client.Del(key)
}
}
var T *Test
func NewTest() (t *Test, err error) {
// test
t = &Test{}
// client
t.client, err = redis.GetRedisClient()
if err != nil {
return nil, err
}
// test collection
t.TestCollection = "test_collection"
// test message
t.TestMessage = "this is a test message"
// test messages
t.TestMessages = []string{
"test message 1",
"test message 2",
"test message 3",
}
// test messages map
t.TestMessagesMap = map[string]string{
"test key 1": "test value 1",
"test key 2": "test value 2",
"test key 3": "test value 3",
}
// test keys alpha
t.TestKeysAlpha = []string{
"test key alpha 1",
"test key alpha 2",
"test key alpha 3",
}
// test keys beta
t.TestKeysBeta = []string{
"test key beta 1",
"test key beta 2",
"test key beta 3",
"test key beta 4",
"test key beta 5",
}
// test lock key
t.TestLockKey = "test lock key"
return t, nil
}

View File

@@ -0,0 +1,273 @@
package test
import (
"github.com/crawlab-team/crawlab/db/redis"
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestRedisClient_Ping(t *testing.T) {
var err error
T.Setup(t)
err = T.client.Ping()
require.Nil(t, err)
}
func TestRedisClient_Get_Set(t *testing.T) {
var err error
T.Setup(t)
err = T.client.Set(T.TestCollection, T.TestMessage)
require.Nil(t, err)
value, err := T.client.Get(T.TestCollection)
require.Nil(t, err)
require.Equal(t, T.TestMessage, value)
}
func TestRedisClient_Keys_AllKeys(t *testing.T) {
var err error
T.Setup(t)
for _, key := range T.TestKeysAlpha {
err = T.client.Set(key, key)
require.Nil(t, err)
}
for _, key := range T.TestKeysBeta {
err = T.client.Set(key, key)
require.Nil(t, err)
}
keys, err := T.client.Keys("*alpha*")
require.Nil(t, err)
require.Len(t, keys, len(T.TestKeysAlpha))
keys, err = T.client.Keys("*beta*")
require.Nil(t, err)
require.Len(t, keys, len(T.TestKeysBeta))
keys, err = T.client.AllKeys()
require.Nil(t, err)
require.Len(t, keys, len(T.TestKeysAlpha)+len(T.TestKeysBeta))
}
func TestRedisClient_RPush_LPop_LLen(t *testing.T) {
var err error
T.Setup(t)
for _, msg := range T.TestMessages {
err = T.client.RPush(T.TestCollection, msg)
require.Nil(t, err)
}
n, err := T.client.LLen(T.TestCollection)
require.Nil(t, err)
require.Equal(t, len(T.TestMessages), n)
value, err := T.client.LPop(T.TestCollection)
require.Nil(t, err)
require.Equal(t, T.TestMessages[0], value)
}
func TestRedisClient_LPush_RPop(t *testing.T) {
var err error
T.Setup(t)
for _, msg := range T.TestMessages {
err = T.client.LPush(T.TestCollection, msg)
require.Nil(t, err)
}
n, err := T.client.LLen(T.TestCollection)
require.Nil(t, err)
require.Equal(t, len(T.TestMessages), n)
value, err := T.client.RPop(T.TestCollection)
require.Nil(t, err)
require.Equal(t, T.TestMessages[0], value)
}
func TestRedisClient_BRPop(t *testing.T) {
var err error
T.Setup(t)
isErr := true
go func(t *testing.T) {
value, err := T.client.BRPop(T.TestCollection, 0)
require.Nil(t, err)
require.Equal(t, T.TestMessage, value)
isErr = false
}(t)
err = T.client.LPush(T.TestCollection, T.TestMessage)
require.Nil(t, err)
time.Sleep(500 * time.Millisecond)
require.False(t, isErr)
}
func TestRedisClient_BLPop(t *testing.T) {
var err error
T.Setup(t)
isErr := true
go func(t *testing.T) {
value, err := T.client.BLPop(T.TestCollection, 0)
require.Nil(t, err)
require.Equal(t, T.TestMessage, value)
isErr = false
}(t)
err = T.client.RPush(T.TestCollection, T.TestMessage)
require.Nil(t, err)
time.Sleep(500 * time.Millisecond)
require.False(t, isErr)
}
func TestRedisClient_HSet_HGet_HDel(t *testing.T) {
var err error
T.Setup(t)
for k, v := range T.TestMessagesMap {
err = T.client.HSet(T.TestCollection, k, v)
require.Nil(t, err)
}
for k, v := range T.TestMessagesMap {
vr, err := T.client.HGet(T.TestCollection, k)
require.Nil(t, err)
require.Equal(t, v, vr)
}
for k := range T.TestMessagesMap {
err = T.client.HDel(T.TestCollection, k)
require.Nil(t, err)
v, err := T.client.HGet(T.TestCollection, k)
require.Nil(t, err)
require.Empty(t, v)
}
}
func TestRedisClient_HScan(t *testing.T) {
var err error
T.Setup(t)
for k, v := range T.TestMessagesMap {
err = T.client.HSet(T.TestCollection, k, v)
require.Nil(t, err)
}
results, err := T.client.HScan(T.TestCollection)
require.Nil(t, err)
for k, vr := range results {
v, ok := T.TestMessagesMap[k]
require.True(t, ok)
require.Equal(t, v, vr)
}
}
func TestRedisClient_HKeys(t *testing.T) {
var err error
T.Setup(t)
for k, v := range T.TestMessagesMap {
err = T.client.HSet(T.TestCollection, k, v)
require.Nil(t, err)
}
keys, err := T.client.HKeys(T.TestCollection)
require.Nil(t, err)
for _, k := range keys {
_, ok := T.TestMessagesMap[k]
require.True(t, ok)
}
}
func TestRedisClient_ZAdd_ZCount_ZCountAll_ZPopMax_ZPopMin(t *testing.T) {
var err error
T.Setup(t)
for i, v := range T.TestMessages {
score := float32(i)
err = T.client.ZAdd(T.TestCollection, score, v)
require.Nil(t, err)
}
count, err := T.client.ZCountAll(T.TestCollection)
require.Nil(t, err)
require.Equal(t, len(T.TestMessages), count)
value, err := T.client.ZPopMaxOne(T.TestCollection)
require.Nil(t, err)
require.Equal(t, T.TestMessages[len(T.TestMessages)-1], value)
value, err = T.client.ZPopMinOne(T.TestCollection)
require.Nil(t, err)
require.Equal(t, T.TestMessages[0], value)
}
func TestRedisClient_BZPopMax_BZPopMin(t *testing.T) {
var err error
T.Setup(t)
isErr := true
go func(t *testing.T) {
value, err := T.client.BZPopMax(T.TestCollection, 0)
require.Nil(t, err)
require.Equal(t, T.TestMessage, value)
isErr = false
}(t)
err = T.client.ZAdd(T.TestCollection, 1, T.TestMessage)
require.Nil(t, err)
time.Sleep(500 * time.Millisecond)
require.False(t, isErr)
isErr = true
go func(t *testing.T) {
value, err := T.client.BZPopMin(T.TestCollection, 0)
require.Nil(t, err)
require.Equal(t, T.TestMessage, value)
isErr = false
}(t)
err = T.client.ZAdd(T.TestCollection, 1, T.TestMessage)
require.Nil(t, err)
time.Sleep(500 * time.Millisecond)
require.False(t, isErr)
}
func TestRedisClient_Lock_Unlock(t *testing.T) {
var err error
T.Setup(t)
ts, err := T.client.Lock(T.TestLockKey)
require.Nil(t, err)
_, err = T.client.Lock(T.TestLockKey)
require.NotNil(t, err)
T.client.UnLock(T.TestLockKey, ts)
ts, err = T.client.Lock(T.TestLockKey)
require.Nil(t, err)
}
func TestRedisClient_MemoryStats(t *testing.T) {
var err error
T.Setup(t)
stats, err := T.client.MemoryStats()
require.Nil(t, err)
for _, k := range redis.MemoryStatsMetrics {
v, ok := stats[k]
require.True(t, ok)
require.Greater(t, v, int64(-1))
}
}

36
db/sql/sql.go Normal file
View File

@@ -0,0 +1,36 @@
package sql
import (
"errors"
"fmt"
"github.com/crawlab-team/go-trace"
"github.com/jmoiron/sqlx"
)
func GetSqlDatabaseConnectionString(dataSourceType string, host string, port string, username string, password string, database string) (connStr string, err error) {
if dataSourceType == "mysql" {
connStr = fmt.Sprintf("%s:%s@(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local", username, password, host, port, database)
} else if dataSourceType == "postgres" {
connStr = fmt.Sprintf("host=%s port=%s user=%s dbname=%s password=%s sslmode=%s", host, port, username, database, password, "disable")
} else {
err = errors.New(dataSourceType + " is not implemented")
return connStr, trace.TraceError(err)
}
return connStr, nil
}
func GetSqlConn(dataSourceType string, host string, port string, username string, password string, database string) (db *sqlx.DB, err error) {
// get database connection string
connStr, err := GetSqlDatabaseConnectionString(dataSourceType, host, port, username, password, database)
if err != nil {
return db, trace.TraceError(err)
}
// get database instance
db, err = sqlx.Open(dataSourceType, connStr)
if err != nil {
return db, trace.TraceError(err)
}
return db, nil
}

19
db/utils/utils.go Normal file
View File

@@ -0,0 +1,19 @@
package utils
import "io"
func Close(c io.Closer) {
err := c.Close()
if err != nil {
//log.WithError(err).Error("关闭资源文件失败。")
}
}
func ContainsString(list []string, item string) bool {
for _, d := range list {
if d == item {
return true
}
}
return false
}