feat: added modules

This commit is contained in:
Marvin Zhang
2024-06-14 15:42:50 +08:00
parent f1833fed21
commit 0b67fd9ece
626 changed files with 60104 additions and 0 deletions

17
core/utils/args.go Normal file
View File

@@ -0,0 +1,17 @@
package utils
import "github.com/crawlab-team/crawlab/core/interfaces"
func GetUserFromArgs(args ...interface{}) (u interfaces.User) {
for _, arg := range args {
switch arg.(type) {
case interfaces.User:
var ok bool
u, ok = arg.(interfaces.User)
if ok {
return u
}
}
}
return nil
}

46
core/utils/array.go Normal file
View File

@@ -0,0 +1,46 @@
package utils
import (
"errors"
"math/rand"
"reflect"
"time"
)
func StringArrayContains(arr []string, str string) bool {
for _, s := range arr {
if s == str {
return true
}
}
return false
}
func GetArrayItems(array interface{}) (res []interface{}, err error) {
switch reflect.TypeOf(array).Kind() {
case reflect.Slice, reflect.Array:
s := reflect.ValueOf(array)
for i := 0; i < s.Len(); i++ {
obj, ok := s.Index(i).Interface().(interface{})
if !ok {
return nil, errors.New("invalid type")
}
res = append(res, obj)
}
default:
return nil, errors.New("invalid type")
}
return res, nil
}
func ShuffleArray(slice []interface{}) (err error) {
r := rand.New(rand.NewSource(time.Now().Unix()))
for len(slice) > 0 {
n := len(slice)
randIndex := r.Intn(n)
slice[n-1], slice[randIndex] = slice[randIndex], slice[n-1]
slice = slice[:n-1]
}
return nil
}

15
core/utils/backoff.go Normal file
View File

@@ -0,0 +1,15 @@
package utils
import (
"github.com/apex/log"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/go-trace"
"time"
)
func BackoffErrorNotify(prefix string) backoff.Notify {
return func(err error, duration time.Duration) {
log.Errorf("%s error: %v. reattempt in %.1f seconds...", prefix, err, duration.Seconds())
trace.PrintError(err)
}
}

View File

@@ -0,0 +1,99 @@
package binders
import (
"github.com/crawlab-team/crawlab/core/errors"
"github.com/crawlab-team/crawlab/core/interfaces"
)
func NewColNameBinder(id interfaces.ModelId) (b *ColNameBinder) {
return &ColNameBinder{id: id}
}
type ColNameBinder struct {
id interfaces.ModelId
}
func (b *ColNameBinder) Bind() (res interface{}, err error) {
switch b.id {
// system models
case interfaces.ModelIdArtifact:
return interfaces.ModelColNameArtifact, nil
case interfaces.ModelIdTag:
return interfaces.ModelColNameTag, nil
// operation models
case interfaces.ModelIdNode:
return interfaces.ModelColNameNode, nil
case interfaces.ModelIdProject:
return interfaces.ModelColNameProject, nil
case interfaces.ModelIdSpider:
return interfaces.ModelColNameSpider, nil
case interfaces.ModelIdTask:
return interfaces.ModelColNameTask, nil
case interfaces.ModelIdJob:
return interfaces.ModelColNameJob, nil
case interfaces.ModelIdSchedule:
return interfaces.ModelColNameSchedule, nil
case interfaces.ModelIdUser:
return interfaces.ModelColNameUser, nil
case interfaces.ModelIdSetting:
return interfaces.ModelColNameSetting, nil
case interfaces.ModelIdToken:
return interfaces.ModelColNameToken, nil
case interfaces.ModelIdVariable:
return interfaces.ModelColNameVariable, nil
case interfaces.ModelIdTaskQueue:
return interfaces.ModelColNameTaskQueue, nil
case interfaces.ModelIdTaskStat:
return interfaces.ModelColNameTaskStat, nil
case interfaces.ModelIdSpiderStat:
return interfaces.ModelColNameSpiderStat, nil
case interfaces.ModelIdDataSource:
return interfaces.ModelColNameDataSource, nil
case interfaces.ModelIdDataCollection:
return interfaces.ModelColNameDataCollection, nil
case interfaces.ModelIdPassword:
return interfaces.ModelColNamePasswords, nil
case interfaces.ModelIdExtraValue:
return interfaces.ModelColNameExtraValues, nil
case interfaces.ModelIdGit:
return interfaces.ModelColNameGit, nil
case interfaces.ModelIdRole:
return interfaces.ModelColNameRole, nil
case interfaces.ModelIdUserRole:
return interfaces.ModelColNameUserRole, nil
case interfaces.ModelIdPermission:
return interfaces.ModelColNamePermission, nil
case interfaces.ModelIdRolePermission:
return interfaces.ModelColNameRolePermission, nil
case interfaces.ModelIdEnvironment:
return interfaces.ModelColNameEnvironment, nil
case interfaces.ModelIdDependencySetting:
return interfaces.ModelColNameDependencySetting, nil
// invalid
default:
return res, errors.ErrorModelNotImplemented
}
}
func (b *ColNameBinder) MustBind() (res interface{}) {
res, err := b.Bind()
if err != nil {
panic(err)
}
return res
}
func (b *ColNameBinder) BindString() (res string, err error) {
res_, err := b.Bind()
if err != nil {
return "", err
}
res = res_.(string)
return res, nil
}
func (b *ColNameBinder) MustBindString() (res string) {
return b.MustBind().(string)
}

12
core/utils/bool.go Normal file
View File

@@ -0,0 +1,12 @@
package utils
import "github.com/spf13/viper"
func EnvIsTrue(key string, defaultOk bool) bool {
isTrueBool := viper.GetBool(key)
isTrueString := viper.GetString(key)
if isTrueString == "" {
return defaultOk
}
return isTrueBool || isTrueString == "Y"
}

125
core/utils/bson.go Normal file
View File

@@ -0,0 +1,125 @@
package utils
import (
"github.com/emirpasic/gods/sets/hashset"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"reflect"
)
func BsonMEqual(v1, v2 bson.M) (ok bool) {
//ok = reflect.DeepEqual(v1, v2)
ok = bsonMEqual(v1, v2)
return ok
}
func bsonMEqual(v1, v2 bson.M) (ok bool) {
// all keys
allKeys := hashset.New()
for key := range v1 {
allKeys.Add(key)
}
for key := range v2 {
allKeys.Add(key)
}
for _, keyRes := range allKeys.Values() {
key := keyRes.(string)
v1Value, ok := v1[key]
if !ok {
return false
}
v2Value, ok := v2[key]
if !ok {
return false
}
mode := 0
var v1ValueBsonM bson.M
var v1ValueBsonA bson.A
switch v1Value.(type) {
case bson.M:
mode = 1
v1ValueBsonM = v1Value.(bson.M)
case bson.A:
mode = 2
v1ValueBsonA = v1Value.(bson.A)
}
var v2ValueBsonM bson.M
var v2ValueBsonA bson.A
switch v2Value.(type) {
case bson.M:
if mode != 1 {
return false
}
v2ValueBsonM = v2Value.(bson.M)
case bson.A:
if mode != 2 {
return false
}
v2ValueBsonA = v2Value.(bson.A)
}
switch mode {
case 0:
if v1Value != v2Value {
return false
}
case 1:
if !bsonMEqual(v1ValueBsonM, v2ValueBsonM) {
return false
}
case 2:
if !reflect.DeepEqual(v1ValueBsonA, v2ValueBsonA) {
return false
}
default:
// not reachable
return false
}
}
return true
}
func NormalizeBsonMObjectId(m bson.M) (res bson.M) {
for k, v := range m {
switch v.(type) {
case string:
oid, err := primitive.ObjectIDFromHex(v.(string))
if err == nil {
m[k] = oid
}
case bson.M:
m[k] = NormalizeBsonMObjectId(v.(bson.M))
}
}
return m
}
func DenormalizeBsonMObjectId(m bson.M) (res bson.M) {
for k, v := range m {
switch v.(type) {
case primitive.ObjectID:
m[k] = v.(primitive.ObjectID).Hex()
case bson.M:
m[k] = NormalizeBsonMObjectId(v.(bson.M))
}
}
return m
}
func NormalizeObjectId(v interface{}) (res interface{}) {
switch v.(type) {
case string:
oid, err := primitive.ObjectIDFromHex(v.(string))
if err != nil {
return v
}
return oid
default:
return v
}
}

57
core/utils/cache.go Normal file
View File

@@ -0,0 +1,57 @@
package utils
import (
"github.com/crawlab-team/crawlab-db/mongo"
"github.com/crawlab-team/crawlab/core/constants"
"go.mongodb.org/mongo-driver/bson"
mongo2 "go.mongodb.org/mongo-driver/mongo"
"time"
)
func GetFromDbCache(key string, getFn func() (string, error)) (res string, err error) {
col := mongo.GetMongoCol(constants.CacheColName)
var d bson.M
if err := col.Find(bson.M{
constants.CacheColKey: key,
}, nil).One(&d); err != nil {
if err != mongo2.ErrNoDocuments {
return "", err
}
// get cache value
res, err = getFn()
if err != nil {
return "", err
}
// save cache
d = bson.M{
constants.CacheColKey: key,
constants.CacheColValue: res,
constants.CacheColTime: time.Now(),
}
if _, err := col.Insert(d); err != nil {
return "", err
}
return res, nil
}
// type conversion
r, ok := d[constants.CacheColValue]
if !ok {
if err := col.Delete(bson.M{constants.CacheColKey: key}); err != nil {
return "", err
}
return GetFromDbCache(key, getFn)
}
res, ok = r.(string)
if !ok {
if err := col.Delete(bson.M{constants.CacheColKey: key}); err != nil {
return "", err
}
return GetFromDbCache(key, getFn)
}
return res, nil
}

40
core/utils/chan.go Normal file
View File

@@ -0,0 +1,40 @@
package utils
import (
"sync"
)
var TaskExecChanMap = NewChanMap()
type ChanMap struct {
m sync.Map
}
func NewChanMap() *ChanMap {
return &ChanMap{m: sync.Map{}}
}
func (cm *ChanMap) Chan(key string) chan string {
if ch, ok := cm.m.Load(key); ok {
return ch.(interface{}).(chan string)
}
ch := make(chan string, 10)
cm.m.Store(key, ch)
return ch
}
func (cm *ChanMap) ChanBlocked(key string) chan string {
if ch, ok := cm.m.Load(key); ok {
return ch.(interface{}).(chan string)
}
ch := make(chan string)
cm.m.Store(key, ch)
return ch
}
func (cm *ChanMap) HasChanKey(key string) bool {
if _, ok := cm.m.Load(key); ok {
return true
}
return false
}

78
core/utils/chan_test.go Normal file
View File

@@ -0,0 +1,78 @@
package utils
import (
. "github.com/smartystreets/goconvey/convey"
"sync"
"testing"
)
func TestNewChanMap(t *testing.T) {
mapTest := sync.Map{}
chanTest := make(chan string)
test := "test"
Convey("Call NewChanMap to generate ChanMap", t, func() {
mapTest.Store("test", chanTest)
chanMapTest := ChanMap{mapTest}
chanMap := NewChanMap()
chanMap.m.Store("test", chanTest)
Convey(test, func() {
v1, ok := chanMap.m.Load("test")
So(ok, ShouldBeTrue)
v2, ok := chanMapTest.m.Load("test")
So(ok, ShouldBeTrue)
So(v1, ShouldResemble, v2)
})
})
}
func TestChan(t *testing.T) {
mapTest := sync.Map{}
chanTest := make(chan string)
mapTest.Store("test", chanTest)
chanMapTest := ChanMap{mapTest}
Convey("Test Chan use exist key", t, func() {
ch1 := chanMapTest.Chan("test")
Convey("ch1 should equal chanTest", func() {
So(ch1, ShouldEqual, chanTest)
})
})
Convey("Test Chan use no-exist key", t, func() {
ch2 := chanMapTest.Chan("test2")
Convey("ch2 should equal chanMapTest.m[test2]", func() {
v, ok := chanMapTest.m.Load("test2")
So(ok, ShouldBeTrue)
So(v, ShouldEqual, ch2)
})
Convey("Cap of chanMapTest.m[test2] should equal 10", func() {
So(10, ShouldEqual, cap(ch2))
})
})
}
func TestChanBlocked(t *testing.T) {
mapTest := sync.Map{}
chanTest := make(chan string)
mapTest.Store("test", chanTest)
chanMapTest := ChanMap{mapTest}
Convey("Test Chan use exist key", t, func() {
ch1 := chanMapTest.ChanBlocked("test")
Convey("ch1 should equal chanTest", func() {
So(ch1, ShouldEqual, chanTest)
})
})
Convey("Test Chan use no-exist key", t, func() {
ch2 := chanMapTest.ChanBlocked("test2")
Convey("ch2 should equal chanMapTest.m[test2]", func() {
v, ok := chanMapTest.m.Load("test2")
So(ok, ShouldBeTrue)
So(v, ShouldEqual, ch2)
})
Convey("Cap of chanMapTest.m[test2] should equal 10", func() {
So(0, ShouldEqual, cap(ch2))
})
})
}

60
core/utils/cockroachdb.go Normal file
View File

@@ -0,0 +1,60 @@
package utils
import (
"context"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/mssql"
"time"
)
func GetCockroachdbSession(ds *models.DataSource) (s db.Session, err error) {
return getCockroachdbSession(context.Background(), ds)
}
func GetCockroachdbSessionWithTimeout(ds *models.DataSource, timeout time.Duration) (s db.Session, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getCockroachdbSession(ctx, ds)
}
func getCockroachdbSession(ctx context.Context, ds *models.DataSource) (s db.Session, err error) {
// normalize settings
host := ds.Host
port := ds.Port
if ds.Host == "" {
host = constants.DefaultHost
}
if ds.Port == "" {
port = constants.DefaultCockroachdbPort
}
// connect settings
settings := mssql.ConnectionURL{
User: ds.Username,
Password: ds.Password,
Database: ds.Database,
Host: fmt.Sprintf("%s:%s", host, port),
Options: nil,
}
// session
done := make(chan struct{})
go func() {
s, err = mssql.Open(settings)
close(done)
}()
// wait for done
select {
case <-ctx.Done():
if ctx.Err() != nil {
err = ctx.Err()
}
case <-done:
}
return s, err
}

176
core/utils/cron.go Normal file
View File

@@ -0,0 +1,176 @@
package utils
import (
"fmt"
"math"
"strconv"
"strings"
)
// cronBounds provides a range of acceptable values (plus a map of name to value).
type cronBounds struct {
min, max uint
names map[string]uint
}
type cronUtils struct {
// The cronBounds for each field.
seconds cronBounds
minutes cronBounds
hours cronBounds
dom cronBounds
months cronBounds
dow cronBounds
// Set the top bit if a star was included in the expression.
starBit uint64
}
// getRange returns the bits indicated by the given expression:
// number | number "-" number [ "/" number ]
// or error parsing range.
func (u *cronUtils) getRange(expr string, r cronBounds) (uint64, error) {
var (
start, end, step uint
rangeAndStep = strings.Split(expr, "/")
lowAndHigh = strings.Split(rangeAndStep[0], "-")
singleDigit = len(lowAndHigh) == 1
err error
)
var extra uint64
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
start = r.min
end = r.max
extra = CronUtils.starBit
} else {
start, err = u.parseIntOrName(lowAndHigh[0], r.names)
if err != nil {
return 0, err
}
switch len(lowAndHigh) {
case 1:
end = start
case 2:
end, err = u.parseIntOrName(lowAndHigh[1], r.names)
if err != nil {
return 0, err
}
default:
return 0, fmt.Errorf("too many hyphens: %s", expr)
}
}
switch len(rangeAndStep) {
case 1:
step = 1
case 2:
step, err = u.mustParseInt(rangeAndStep[1])
if err != nil {
return 0, err
}
// Special handling: "N/step" means "N-max/step".
if singleDigit {
end = r.max
}
if step > 1 {
extra = 0
}
default:
return 0, fmt.Errorf("too many slashes: %s", expr)
}
if start < r.min {
return 0, fmt.Errorf("beginning of range (%d) below minimum (%d): %s", start, r.min, expr)
}
if end > r.max {
return 0, fmt.Errorf("end of range (%d) above maximum (%d): %s", end, r.max, expr)
}
if start > end {
return 0, fmt.Errorf("beginning of range (%d) beyond end of range (%d): %s", start, end, expr)
}
if step == 0 {
return 0, fmt.Errorf("step of range should be a positive number: %s", expr)
}
return u.getBits(start, end, step) | extra, nil
}
// parseIntOrName returns the (possibly-named) integer contained in expr.
func (u *cronUtils) parseIntOrName(expr string, names map[string]uint) (uint, error) {
if names != nil {
if namedInt, ok := names[strings.ToLower(expr)]; ok {
return namedInt, nil
}
}
return u.mustParseInt(expr)
}
// mustParseInt parses the given expression as an int or returns an error.
func (u *cronUtils) mustParseInt(expr string) (uint, error) {
num, err := strconv.Atoi(expr)
if err != nil {
return 0, fmt.Errorf("failed to parse int from %s: %s", expr, err)
}
if num < 0 {
return 0, fmt.Errorf("negative number (%d) not allowed: %s", num, expr)
}
return uint(num), nil
}
// getBits sets all bits in the range [min, max], modulo the given step size.
func (u *cronUtils) getBits(min, max, step uint) uint64 {
var bits uint64
// If step is 1, use shifts.
if step == 1 {
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
}
// Else, use a simple loop.
for i := min; i <= max; i += step {
bits |= 1 << i
}
return bits
}
// all returns all bits within the given cronBounds. (plus the star bit)
func (u *cronUtils) all(r cronBounds) uint64 {
return u.getBits(r.min, r.max, 1) | CronUtils.starBit
}
var CronUtils = cronUtils{
// The cronBounds for each field.
seconds: cronBounds{0, 59, nil},
minutes: cronBounds{0, 59, nil},
hours: cronBounds{0, 23, nil},
dom: cronBounds{1, 31, nil},
months: cronBounds{1, 12, map[string]uint{
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"may": 5,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}},
dow: cronBounds{0, 6, map[string]uint{
"sun": 0,
"mon": 1,
"tue": 2,
"wed": 3,
"thu": 4,
"fri": 5,
"sat": 6,
}},
// Set the top bit if a star was included in the expression.
starBit: 1 << 63,
}

18
core/utils/debug.go Normal file
View File

@@ -0,0 +1,18 @@
package utils
import (
"fmt"
"github.com/spf13/viper"
"time"
)
func IsDebug() bool {
return viper.GetBool("debug")
}
func LogDebug(msg string) {
if !IsDebug() {
return
}
fmt.Println(fmt.Sprintf("[DEBUG] %s: %s", time.Now().Format("2006-01-02 15:04:05"), msg))
}

57
core/utils/demo.go Normal file
View File

@@ -0,0 +1,57 @@
package utils
import (
"fmt"
"github.com/crawlab-team/crawlab-db/mongo"
"github.com/crawlab-team/crawlab/core/sys_exec"
"github.com/crawlab-team/go-trace"
"github.com/spf13/viper"
)
func GetApiAddress() (res string) {
apiAddress := viper.GetString("api.address")
if apiAddress == "" {
return "http://localhost:8000"
}
return apiAddress
}
func IsDemo() (ok bool) {
return EnvIsTrue("demo", true)
}
func InitializedDemo() (ok bool) {
col := mongo.GetMongoCol("users")
n, err := col.Count(nil)
if err != nil {
return true
}
return n > 0
}
func ImportDemo() (err error) {
cmdStr := fmt.Sprintf("crawlab-cli login -a %s && crawlab-demo import", GetApiAddress())
cmd := sys_exec.BuildCmd(cmdStr)
if err := cmd.Run(); err != nil {
trace.PrintError(err)
}
return nil
}
func ReimportDemo() (err error) {
cmdStr := fmt.Sprintf("crawlab-cli login -a %s && crawlab-demo reimport", GetApiAddress())
cmd := sys_exec.BuildCmd(cmdStr)
if err := cmd.Run(); err != nil {
trace.PrintError(err)
}
return nil
}
func CleanupDemo() (err error) {
cmdStr := fmt.Sprintf("crawlab-cli login -a %s && crawlab-demo reimport", GetApiAddress())
cmd := sys_exec.BuildCmd(cmdStr)
if err := cmd.Run(); err != nil {
trace.PrintError(err)
}
return nil
}

18
core/utils/di.go Normal file
View File

@@ -0,0 +1,18 @@
package utils
import (
"github.com/crawlab-team/go-trace"
"github.com/spf13/viper"
"go.uber.org/dig"
"os"
)
func VisualizeContainer(c *dig.Container) (err error) {
if !viper.GetBool("debug.di.visualize") {
return nil
}
if err := dig.Visualize(c, os.Stdout); err != nil {
return trace.TraceError(err)
}
return nil
}

5
core/utils/docker.go Normal file
View File

@@ -0,0 +1,5 @@
package utils
func IsDocker() (ok bool) {
return EnvIsTrue("docker", false)
}

80
core/utils/encrypt.go Normal file
View File

@@ -0,0 +1,80 @@
package utils
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
"io"
)
func GetSecretKey() string {
return constants.DefaultEncryptServerKey
}
func GetSecretKeyBytes() []byte {
return []byte(GetSecretKey())
}
func ComputeHmacSha256(message string, secret string) string {
key := []byte(secret)
h := hmac.New(sha256.New, key)
h.Write([]byte(message))
sha := hex.EncodeToString(h.Sum(nil))
return base64.StdEncoding.EncodeToString([]byte(sha))
}
func EncryptMd5(str string) string {
w := md5.New()
_, _ = io.WriteString(w, str)
md5str := fmt.Sprintf("%x", w.Sum(nil))
return md5str
}
func padding(src []byte, blockSize int) []byte {
padNum := blockSize - len(src)%blockSize
pad := bytes.Repeat([]byte{byte(padNum)}, padNum)
return append(src, pad...)
}
func unPadding(src []byte) []byte {
n := len(src)
unPadNum := int(src[n-1])
return src[:n-unPadNum]
}
func EncryptAES(src string) (res string, err error) {
srcBytes := []byte(src)
key := GetSecretKeyBytes()
block, err := aes.NewCipher(key)
if err != nil {
return res, err
}
srcBytes = padding(srcBytes, block.BlockSize())
blockMode := cipher.NewCBCEncrypter(block, key)
blockMode.CryptBlocks(srcBytes, srcBytes)
res = hex.EncodeToString(srcBytes)
return res, nil
}
func DecryptAES(src string) (res string, err error) {
srcBytes, err := hex.DecodeString(src)
if err != nil {
return res, err
}
key := GetSecretKeyBytes()
block, err := aes.NewCipher(key)
if err != nil {
return res, err
}
blockMode := cipher.NewCBCDecrypter(block, key)
blockMode.CryptBlocks(srcBytes, srcBytes)
res = string(unPadding(srcBytes))
return res, nil
}

View File

@@ -0,0 +1,20 @@
package utils
import (
"fmt"
"github.com/stretchr/testify/require"
"testing"
)
func TestEncryptAesPassword(t *testing.T) {
plainText := "crawlab"
encryptedText, err := EncryptAES(plainText)
require.Nil(t, err)
decryptedText, err := DecryptAES(encryptedText)
require.Nil(t, err)
fmt.Println(fmt.Sprintf("plainText: %s", plainText))
fmt.Println(fmt.Sprintf("encryptedText: %s", encryptedText))
fmt.Println(fmt.Sprintf("decryptedText: %s", decryptedText))
require.Equal(t, decryptedText, plainText)
require.NotEqual(t, decryptedText, encryptedText)
}

159
core/utils/es.go Normal file
View File

@@ -0,0 +1,159 @@
package utils
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/cenkalti/backoff/v4"
"github.com/crawlab-team/crawlab-db/generic"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/go-trace"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esapi"
"go.mongodb.org/mongo-driver/bson/primitive"
"time"
)
func GetElasticsearchClient(ds *models.DataSource) (c *elasticsearch.Client, err error) {
return getElasticsearchClient(context.Background(), ds)
}
func GetElasticsearchClientWithTimeout(ds *models.DataSource, timeout time.Duration) (c *elasticsearch.Client, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getElasticsearchClient(ctx, ds)
}
func getElasticsearchClient(ctx context.Context, ds *models.DataSource) (c *elasticsearch.Client, err error) {
// normalize settings
host := ds.Host
port := ds.Port
if ds.Host == "" {
host = constants.DefaultHost
}
if ds.Port == "" {
port = constants.DefaultElasticsearchPort
}
// es hosts
addresses := []string{
fmt.Sprintf("http://%s:%s", host, port),
}
// retry backoff
rb := backoff.NewExponentialBackOff()
// es client options
cfg := elasticsearch.Config{
Addresses: addresses,
Username: ds.Username,
Password: ds.Password,
//CloudID: "",
//APIKey: "",
//ServiceToken: "",
//CertificateFingerprint: "",
//Header: nil,
//CACert: nil,
//RetryOnStatus: nil,
//DisableRetry: false,
//EnableRetryOnTimeout: false,
//MaxRetries: 0,
//CompressRequestBody: false,
//DiscoverNodesOnStart: false,
//DiscoverNodesInterval: 0,
//EnableMetrics: false,
//EnableDebugLogger: false,
//EnableCompatibilityMode: false,
//DisableMetaHeader: false,
//UseResponseCheckOnly: false,
RetryBackoff: func(i int) time.Duration {
if i == 1 {
rb.Reset()
}
return rb.NextBackOff()
},
//Transport: nil,
//Logger: nil,
//Selector: nil,
//ConnectionPoolFunc: nil,
}
// es client
done := make(chan struct{})
go func() {
c, err = elasticsearch.NewClient(cfg)
if err != nil {
return
}
var res *esapi.Response
res, err = c.Info()
fmt.Println(res)
close(done)
}()
// wait for done
select {
case <-ctx.Done():
if ctx.Err() != nil {
err = ctx.Err()
}
case <-done:
}
return c, err
}
func GetElasticsearchQuery(query generic.ListQuery) (buf *bytes.Buffer) {
q := map[string]interface{}{}
if len(query) > 0 {
match := getElasticsearchQueryMatch(query)
q["query"] = map[string]interface{}{
"match": match,
}
}
buf = &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(q); err != nil {
trace.PrintError(err)
}
return buf
}
func GetElasticsearchQueryWithOptions(query generic.ListQuery, opts *generic.ListOptions) (buf *bytes.Buffer) {
q := map[string]interface{}{
"size": opts.Limit,
"from": opts.Skip,
// TODO: sort
}
if len(query) > 0 {
match := getElasticsearchQueryMatch(query)
q["query"] = map[string]interface{}{
"match": match,
}
}
buf = &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(q); err != nil {
trace.PrintError(err)
}
return buf
}
func getElasticsearchQueryMatch(query generic.ListQuery) (match map[string]interface{}) {
match = map[string]interface{}{}
for _, c := range query {
switch c.Value.(type) {
case primitive.ObjectID:
c.Value = c.Value.(primitive.ObjectID).Hex()
}
switch c.Op {
case generic.OpEqual:
match[c.Key] = c.Value
default:
match[c.Key] = map[string]interface{}{
c.Op: c.Value,
}
}
}
return match
}

382
core/utils/file.go Normal file
View File

@@ -0,0 +1,382 @@
package utils
import (
"archive/zip"
"crypto/md5"
"encoding/hex"
"fmt"
"github.com/apex/log"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"runtime/debug"
)
func OpenFile(fileName string) *os.File {
file, err := os.OpenFile(fileName, os.O_CREATE|os.O_RDWR, os.ModePerm)
if err != nil {
log.Errorf("create file error: %s, file_name: %s", err.Error(), fileName)
debug.PrintStack()
return nil
}
return file
}
func Exists(path string) bool {
_, err := os.Stat(path) //os.Stat获取文件信息
if err != nil {
return os.IsExist(err)
}
return true
}
func IsDir(path string) bool {
s, err := os.Stat(path)
if err != nil {
return false
}
return s.IsDir()
}
// ListDir Add: 增加error类型作为第二返回值
// 在其他函数如 /task/log/file_driver.go中的 *FileLogDriver.cleanup()函数调用时
// 可以通过判断err是否为nil来判断是否有错误发生
func ListDir(path string) ([]fs.FileInfo, error) {
list, err := os.ReadDir(path)
if err != nil {
log.Errorf(err.Error())
debug.PrintStack()
return nil, err
}
var res []fs.FileInfo
for _, item := range list {
info, err := item.Info()
if err != nil {
log.Errorf(err.Error())
debug.PrintStack()
return nil, err
}
res = append(res, info)
}
return res, nil
}
func DeCompress(srcFile *os.File, dstPath string) error {
// 如果保存路径不存在,创建一个
if !Exists(dstPath) {
if err := os.MkdirAll(dstPath, os.ModePerm); err != nil {
debug.PrintStack()
return err
}
}
// 读取zip文件
zipFile, err := zip.OpenReader(srcFile.Name())
if err != nil {
log.Errorf("Unzip File Error" + err.Error())
debug.PrintStack()
return err
}
defer Close(zipFile)
// 遍历zip内所有文件和目录
for _, innerFile := range zipFile.File {
// 获取该文件数据
info := innerFile.FileInfo()
// 如果是目录,则创建一个
if info.IsDir() {
err = os.MkdirAll(filepath.Join(dstPath, innerFile.Name), os.ModeDir|os.ModePerm)
if err != nil {
log.Errorf("Unzip File Error : " + err.Error())
debug.PrintStack()
return err
}
continue
}
// 如果文件目录不存在,则创建一个
dirPath := filepath.Join(dstPath, filepath.Dir(innerFile.Name))
if !Exists(dirPath) {
if err = os.MkdirAll(dirPath, os.ModeDir|os.ModePerm); err != nil {
log.Errorf("Unzip File Error : " + err.Error())
debug.PrintStack()
return err
}
}
// 打开该文件
srcFile, err := innerFile.Open()
if err != nil {
log.Errorf("Unzip File Error : " + err.Error())
debug.PrintStack()
continue
}
// 创建新文件
newFilePath := filepath.Join(dstPath, innerFile.Name)
newFile, err := os.OpenFile(newFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
log.Errorf("Unzip File Error : " + err.Error())
debug.PrintStack()
continue
}
// 拷贝该文件到新文件中
if _, err := io.Copy(newFile, srcFile); err != nil {
debug.PrintStack()
return err
}
// 关闭该文件
if err := srcFile.Close(); err != nil {
debug.PrintStack()
return err
}
// 关闭新文件
if err := newFile.Close(); err != nil {
debug.PrintStack()
return err
}
}
return nil
}
// Compress 压缩文件
// files 文件数组可以是不同dir下的文件或者文件夹
// dest 压缩文件存放地址
func Compress(files []*os.File, dest string) error {
d, _ := os.Create(dest)
defer Close(d)
w := zip.NewWriter(d)
defer Close(w)
for _, file := range files {
if err := _Compress(file, "", w); err != nil {
return err
}
}
return nil
}
func _Compress(file *os.File, prefix string, zw *zip.Writer) error {
info, err := file.Stat()
if err != nil {
debug.PrintStack()
return err
}
if info.IsDir() {
prefix = prefix + "/" + info.Name()
fileInfos, err := file.Readdir(-1)
if err != nil {
debug.PrintStack()
return err
}
for _, fi := range fileInfos {
f, err := os.Open(file.Name() + "/" + fi.Name())
if err != nil {
debug.PrintStack()
return err
}
err = _Compress(f, prefix, zw)
if err != nil {
debug.PrintStack()
return err
}
}
} else {
header, err := zip.FileInfoHeader(info)
if err != nil {
debug.PrintStack()
return err
}
header.Name = prefix + "/" + header.Name
writer, err := zw.CreateHeader(header)
if err != nil {
debug.PrintStack()
return err
}
_, err = io.Copy(writer, file)
Close(file)
if err != nil {
debug.PrintStack()
return err
}
}
return nil
}
func TrimFileData(data []byte) (res []byte) {
if string(data) == constants.EmptyFileData {
return res
}
return data
}
func ZipDirectory(dir, zipfile string) error {
zipFile, err := os.Create(zipfile)
if err != nil {
return err
}
defer zipFile.Close()
zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close()
baseDir := filepath.Dir(dir)
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(baseDir, path)
if err != nil {
return err
}
zipFile, err := zipWriter.Create(relPath)
if err != nil {
return err
}
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(zipFile, file)
if err != nil {
return err
}
return nil
})
return err
}
// CopyFile File copies a single file from src to dst
func CopyFile(src, dst string) error {
var err error
var srcFd *os.File
var dstFd *os.File
var srcInfo os.FileInfo
if srcFd, err = os.Open(src); err != nil {
return err
}
defer srcFd.Close()
if dstFd, err = os.Create(dst); err != nil {
return err
}
defer dstFd.Close()
if _, err = io.Copy(dstFd, srcFd); err != nil {
return err
}
if srcInfo, err = os.Stat(src); err != nil {
return err
}
return os.Chmod(dst, srcInfo.Mode())
}
// CopyDir Dir copies a whole directory recursively
func CopyDir(src string, dst string) error {
var err error
var fds []os.DirEntry
var srcInfo os.FileInfo
if srcInfo, err = os.Stat(src); err != nil {
return err
}
if err = os.MkdirAll(dst, srcInfo.Mode()); err != nil {
return err
}
if fds, err = os.ReadDir(src); err != nil {
return err
}
for _, fd := range fds {
srcfp := path.Join(src, fd.Name())
dstfp := path.Join(dst, fd.Name())
if fd.IsDir() {
if err = CopyDir(srcfp, dstfp); err != nil {
fmt.Println(err)
}
} else {
if err = CopyFile(srcfp, dstfp); err != nil {
fmt.Println(err)
}
}
}
return nil
}
func GetFileHash(filePath string) (res string, err error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := md5.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
func ScanDirectory(dir string) (res map[string]entity.FsFileInfo, err error) {
files := make(map[string]entity.FsFileInfo)
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
hash, err := GetFileHash(path)
if err != nil {
return err
}
relPath, err := filepath.Rel(dir, path)
if err != nil {
return err
}
files[relPath] = entity.FsFileInfo{
Name: info.Name(),
Path: relPath,
FullPath: path,
Extension: filepath.Ext(path),
FileSize: info.Size(),
ModTime: info.ModTime(),
Mode: info.Mode(),
Hash: hash,
}
return nil
})
if err != nil {
return nil, err
}
return files, nil
}

129
core/utils/file_test.go Normal file
View File

@@ -0,0 +1,129 @@
package utils
import (
"archive/zip"
. "github.com/smartystreets/goconvey/convey"
"io"
"log"
"os"
"runtime/debug"
"testing"
)
func TestExists(t *testing.T) {
var pathString = "../config"
var wrongPathString = "test"
Convey("Test path or file is Exists or not", t, func() {
res := Exists(pathString)
Convey("The result should be true", func() {
So(res, ShouldEqual, true)
})
wrongRes := Exists(wrongPathString)
Convey("The result should be false", func() {
So(wrongRes, ShouldEqual, false)
})
})
}
func TestIsDir(t *testing.T) {
var pathString = "../config"
var fileString = "../config/config.go"
var wrongString = "test"
Convey("Test path is folder or not", t, func() {
res := IsDir(pathString)
So(res, ShouldEqual, true)
fileRes := IsDir(fileString)
So(fileRes, ShouldEqual, false)
wrongRes := IsDir(wrongString)
So(wrongRes, ShouldEqual, false)
})
}
func TestCompress(t *testing.T) {
err := os.Mkdir("testCompress", os.ModePerm)
if err != nil {
t.Error("create testCompress failed")
}
var pathString = "testCompress"
var files []*os.File
var disPath = "testCompress"
file, err := os.Open(pathString)
if err != nil {
t.Error("open source path failed")
}
files = append(files, file)
Convey("Verify dispath is valid path", t, func() {
er := Compress(files, disPath)
Convey("err should be nil", func() {
So(er, ShouldEqual, nil)
})
})
_ = os.RemoveAll("testCompress")
}
func Zip(zipFile string, fileList []string) error {
// 创建 zip 包文件
fw, err := os.Create(zipFile)
if err != nil {
log.Fatal()
}
defer Close(fw)
// 实例化新的 zip.Writer
zw := zip.NewWriter(fw)
defer Close(zw)
for _, fileName := range fileList {
fr, err := os.Open(fileName)
if err != nil {
return err
}
fi, err := fr.Stat()
if err != nil {
return err
}
// 写入文件的头信息
fh, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}
w, err := zw.CreateHeader(fh)
if err != nil {
return err
}
// 写入文件内容
_, err = io.Copy(w, fr)
if err != nil {
return err
}
}
return nil
}
func TestDeCompress(t *testing.T) {
err := os.Mkdir("testDeCompress", os.ModePerm)
if err != nil {
t.Error(err)
}
err = Zip("demo.zip", []string{})
if err != nil {
t.Error("create zip file failed")
}
tmpFile, err := os.OpenFile("demo.zip", os.O_RDONLY, 0777)
if err != nil {
debug.PrintStack()
t.Error("open demo.zip failed")
}
var dstPath = "./testDeCompress"
Convey("Test DeCopmress func", t, func() {
err := DeCompress(tmpFile, dstPath)
So(err, ShouldEqual, nil)
})
_ = os.RemoveAll("testDeCompress")
_ = os.Remove("demo.zip")
}

49
core/utils/filter.go Normal file
View File

@@ -0,0 +1,49 @@
package utils
import (
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/errors"
"github.com/crawlab-team/crawlab/core/interfaces"
"go.mongodb.org/mongo-driver/bson"
)
// FilterToQuery Translate entity.Filter to bson.M
func FilterToQuery(f interfaces.Filter) (q bson.M, err error) {
if f == nil || f.IsNil() {
return nil, nil
}
q = bson.M{}
for _, cond := range f.GetConditions() {
key := cond.GetKey()
op := cond.GetOp()
value := cond.GetValue()
switch op {
case constants.FilterOpNotSet:
// do nothing
case constants.FilterOpEqual:
q[key] = cond.GetValue()
case constants.FilterOpNotEqual:
q[key] = bson.M{"$ne": value}
case constants.FilterOpContains, constants.FilterOpRegex, constants.FilterOpSearch:
q[key] = bson.M{"$regex": value, "$options": "i"}
case constants.FilterOpNotContains:
q[key] = bson.M{"$not": bson.M{"$regex": value}}
case constants.FilterOpIn:
q[key] = bson.M{"$in": value}
case constants.FilterOpNotIn:
q[key] = bson.M{"$nin": value}
case constants.FilterOpGreaterThan:
q[key] = bson.M{"$gt": value}
case constants.FilterOpGreaterThanEqual:
q[key] = bson.M{"$gte": value}
case constants.FilterOpLessThan:
q[key] = bson.M{"$lt": value}
case constants.FilterOpLessThanEqual:
q[key] = bson.M{"$lte": value}
default:
return nil, errors.ErrorFilterInvalidOperation
}
}
return q, nil
}

36
core/utils/git.go Normal file
View File

@@ -0,0 +1,36 @@
package utils
import (
vcs "github.com/crawlab-team/crawlab-vcs"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/models"
)
func InitGitClientAuth(g interfaces.Git, gitClient *vcs.GitClient) {
// set auth
switch g.GetAuthType() {
case constants.GitAuthTypeHttp:
gitClient.SetAuthType(vcs.GitAuthTypeHTTP)
gitClient.SetUsername(g.GetUsername())
gitClient.SetPassword(g.GetPassword())
case constants.GitAuthTypeSsh:
gitClient.SetAuthType(vcs.GitAuthTypeSSH)
gitClient.SetUsername(g.GetUsername())
gitClient.SetPrivateKey(g.GetPassword())
}
}
func InitGitClientAuthV2(g *models.GitV2, gitClient *vcs.GitClient) {
// set auth
switch g.AuthType {
case constants.GitAuthTypeHttp:
gitClient.SetAuthType(vcs.GitAuthTypeHTTP)
gitClient.SetUsername(g.Username)
gitClient.SetPassword(g.Password)
case constants.GitAuthTypeSsh:
gitClient.SetAuthType(vcs.GitAuthTypeSSH)
gitClient.SetUsername(g.Username)
gitClient.SetPrivateKey(g.Password)
}
}

36
core/utils/helpers.go Normal file
View File

@@ -0,0 +1,36 @@
package utils
import (
"github.com/crawlab-team/go-trace"
"io"
"reflect"
"unsafe"
)
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
func Close(c io.Closer) {
err := c.Close()
if err != nil {
trace.PrintError(err)
}
}
func Contains(array interface{}, val interface{}) (fla bool) {
fla = false
switch reflect.TypeOf(array).Kind() {
case reflect.Slice:
{
s := reflect.ValueOf(array)
for i := 0; i < s.Len(); i++ {
if reflect.DeepEqual(val, s.Index(i).Interface()) {
fla = true
return
}
}
}
}
return
}

32
core/utils/http.go Normal file
View File

@@ -0,0 +1,32 @@
package utils
import (
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/go-trace"
"github.com/gin-gonic/gin"
"net/http"
)
func handleError(statusCode int, c *gin.Context, err error, print bool) {
if print {
trace.PrintError(err)
}
c.AbortWithStatusJSON(statusCode, entity.Response{
Status: constants.HttpResponseStatusOk,
Message: constants.HttpResponseMessageError,
Error: err.Error(),
})
}
func HandleError(statusCode int, c *gin.Context, err error) {
handleError(statusCode, c, err, true)
}
func HandleErrorUnauthorized(c *gin.Context, err error) {
HandleError(http.StatusUnauthorized, c, err)
}
func HandleErrorInternalServerError(c *gin.Context, err error) {
HandleError(http.StatusInternalServerError, c, err)
}

30
core/utils/init.go Normal file
View File

@@ -0,0 +1,30 @@
package utils
import (
"github.com/crawlab-team/crawlab/core/interfaces"
"sync"
)
var moduleInitializedMap = sync.Map{}
func InitModule(id interfaces.ModuleId, fn func() error) (err error) {
res, ok := moduleInitializedMap.Load(id)
if ok {
initialized, _ := res.(bool)
if initialized {
return nil
}
}
if err := fn(); err != nil {
return err
}
moduleInitializedMap.Store(id, true)
return nil
}
func ForceInitModule(fn func() error) (err error) {
return fn()
}

12
core/utils/json.go Normal file
View File

@@ -0,0 +1,12 @@
package utils
import "encoding/json"
func JsonToBytes(d interface{}) (bytes []byte, err error) {
switch d.(type) {
case []byte:
return d.([]byte), nil
default:
return json.Marshal(d)
}
}

41
core/utils/kafka.go Normal file
View File

@@ -0,0 +1,41 @@
package utils
import (
"context"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/segmentio/kafka-go"
"time"
)
func GetKafkaConnection(ds *models.DataSource) (c *kafka.Conn, err error) {
return getKafkaConnection(context.Background(), ds)
}
func GetKafkaConnectionWithTimeout(ds *models.DataSource, timeout time.Duration) (c *kafka.Conn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getKafkaConnection(ctx, ds)
}
func getKafkaConnection(ctx context.Context, ds *models.DataSource) (c *kafka.Conn, err error) {
// normalize settings
host := ds.Host
port := ds.Port
if ds.Host == "" {
host = constants.DefaultHost
}
if ds.Port == "" {
port = constants.DefaultKafkaPort
}
// kafka connection address
network := "tcp"
address := fmt.Sprintf("%s:%s", host, port)
topic := ds.Database
partition := 0 // TODO: parameterize
// kafka connection
return kafka.DialLeader(ctx, network, address, topic, partition)
}

94
core/utils/mongo.go Normal file
View File

@@ -0,0 +1,94 @@
package utils
import (
"context"
"github.com/crawlab-team/crawlab-db/generic"
"github.com/crawlab-team/crawlab-db/mongo"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"go.mongodb.org/mongo-driver/bson"
mongo2 "go.mongodb.org/mongo-driver/mongo"
"time"
)
func GetMongoQuery(query generic.ListQuery) (res bson.M) {
res = bson.M{}
for _, c := range query {
switch c.Op {
case generic.OpEqual:
res[c.Key] = c.Value
default:
res[c.Key] = bson.M{
c.Op: c.Value,
}
}
}
return res
}
func GetMongoOpts(opts *generic.ListOptions) (res *mongo.FindOptions) {
var sort bson.D
for _, s := range opts.Sort {
direction := 1
if s.Direction == generic.SortDirectionAsc {
direction = 1
} else if s.Direction == generic.SortDirectionDesc {
direction = -1
}
sort = append(sort, bson.E{Key: s.Key, Value: direction})
}
return &mongo.FindOptions{
Skip: opts.Skip,
Limit: opts.Limit,
Sort: sort,
}
}
func GetMongoClient(ds *models.DataSource) (c *mongo2.Client, err error) {
return getMongoClient(context.Background(), ds)
}
func GetMongoClientWithTimeout(ds *models.DataSource, timeout time.Duration) (c *mongo2.Client, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getMongoClient(ctx, ds)
}
func getMongoClient(ctx context.Context, ds *models.DataSource) (c *mongo2.Client, err error) {
// normalize settings
if ds.Host == "" {
ds.Host = constants.DefaultHost
}
if ds.Port == "" {
ds.Port = constants.DefaultMongoPort
}
// options
var opts []mongo.ClientOption
opts = append(opts, mongo.WithContext(ctx))
opts = append(opts, mongo.WithUri(ds.Url))
opts = append(opts, mongo.WithHost(ds.Host))
opts = append(opts, mongo.WithPort(ds.Port))
opts = append(opts, mongo.WithDb(ds.Database))
opts = append(opts, mongo.WithUsername(ds.Username))
opts = append(opts, mongo.WithPassword(ds.Password))
opts = append(opts, mongo.WithHosts(ds.Hosts))
// extra
if ds.Extra != nil {
// auth source
authSource, ok := ds.Extra["auth_source"]
if ok {
opts = append(opts, mongo.WithAuthSource(authSource))
}
// auth mechanism
authMechanism, ok := ds.Extra["auth_mechanism"]
if ok {
opts = append(opts, mongo.WithAuthMechanism(authMechanism))
}
}
// client
return mongo.GetMongoClient(opts...)
}

60
core/utils/mssql.go Normal file
View File

@@ -0,0 +1,60 @@
package utils
import (
"context"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/mssql"
"time"
)
func GetMssqlSession(ds *models.DataSource) (s db.Session, err error) {
return getMssqlSession(context.Background(), ds)
}
func GetMssqlSessionWithTimeout(ds *models.DataSource, timeout time.Duration) (s db.Session, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getMssqlSession(ctx, ds)
}
func getMssqlSession(ctx context.Context, ds *models.DataSource) (s db.Session, err error) {
// normalize settings
host := ds.Host
port := ds.Port
if ds.Host == "" {
host = constants.DefaultHost
}
if ds.Port == "" {
port = constants.DefaultMssqlPort
}
// connect settings
settings := mssql.ConnectionURL{
User: ds.Username,
Password: ds.Password,
Database: ds.Database,
Host: fmt.Sprintf("%s:%s", host, port),
Options: nil,
}
// session
done := make(chan struct{})
go func() {
s, err = mssql.Open(settings)
close(done)
}()
// wait for done
select {
case <-ctx.Done():
if ctx.Err() != nil {
err = ctx.Err()
}
case <-done:
}
return s, err
}

60
core/utils/mysql.go Normal file
View File

@@ -0,0 +1,60 @@
package utils
import (
"context"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/mysql"
"time"
)
func GetMysqlSession(ds *models.DataSource) (s db.Session, err error) {
return getMysqlSession(context.Background(), ds)
}
func GetMysqlSessionWithTimeout(ds *models.DataSource, timeout time.Duration) (s db.Session, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getMysqlSession(ctx, ds)
}
func getMysqlSession(ctx context.Context, ds *models.DataSource) (s db.Session, err error) {
// normalize settings
host := ds.Host
port := ds.Port
if ds.Host == "" {
host = constants.DefaultHost
}
if ds.Port == "" {
port = constants.DefaultMysqlPort
}
// connect settings
settings := mysql.ConnectionURL{
User: ds.Username,
Password: ds.Password,
Database: ds.Database,
Host: fmt.Sprintf("%s:%s", host, port),
Options: nil,
}
// session
done := make(chan struct{})
go func() {
s, err = mysql.Open(settings)
close(done)
}()
// wait for done
select {
case <-ctx.Done():
if ctx.Err() != nil {
err = ctx.Err()
}
case <-done:
}
return s, err
}

13
core/utils/node.go Normal file
View File

@@ -0,0 +1,13 @@
package utils
func IsMaster() bool {
return EnvIsTrue("node.master", false)
}
func GetNodeType() string {
if IsMaster() {
return "master"
} else {
return "worker"
}
}

13
core/utils/os.go Normal file
View File

@@ -0,0 +1,13 @@
package utils
import (
"os"
"os/signal"
"syscall"
)
func DefaultWait() {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
}

60
core/utils/postgresql.go Normal file
View File

@@ -0,0 +1,60 @@
package utils
import (
"context"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/postgresql"
"time"
)
func GetPostgresqlSession(ds *models.DataSource) (s db.Session, err error) {
return getPostgresqlSession(context.Background(), ds)
}
func GetPostgresqlSessionWithTimeout(ds *models.DataSource, timeout time.Duration) (s db.Session, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getPostgresqlSession(ctx, ds)
}
func getPostgresqlSession(ctx context.Context, ds *models.DataSource) (s db.Session, err error) {
// normalize settings
host := ds.Host
port := ds.Port
if ds.Host == "" {
host = constants.DefaultHost
}
if ds.Port == "" {
port = constants.DefaultPostgresqlPort
}
// connect settings
settings := postgresql.ConnectionURL{
User: ds.Username,
Password: ds.Password,
Database: ds.Database,
Host: fmt.Sprintf("%s:%s", host, port),
Options: nil,
}
// session
done := make(chan struct{})
go func() {
s, err = postgresql.Open(settings)
close(done)
}()
// wait for done
select {
case <-ctx.Done():
if ctx.Err() != nil {
err = ctx.Err()
}
case <-done:
}
return s, err
}

23
core/utils/result.go Normal file
View File

@@ -0,0 +1,23 @@
package utils
import (
"encoding/json"
"github.com/crawlab-team/crawlab/core/interfaces"
)
func GetResultHash(value interface{}, keys []string) (res string, err error) {
m := make(map[string]interface{})
for _, k := range keys {
_value, ok := value.(interfaces.Result)
if !ok {
continue
}
v := _value.GetValue(k)
m[k] = v
}
data, err := json.Marshal(m)
if err != nil {
return "", err
}
return EncryptMd5(string(data)), nil
}

14
core/utils/rpc.go Normal file
View File

@@ -0,0 +1,14 @@
package utils
import "encoding/json"
// Object 转化为 String
func ObjectToString(params interface{}) string {
bytes, _ := json.Marshal(params)
return BytesToString(bytes)
}
// 获取 RPC 参数
func GetRpcParam(key string, params map[string]string) string {
return params[key]
}

8
core/utils/spider.go Normal file
View File

@@ -0,0 +1,8 @@
package utils
func GetSpiderCol(col string, name string) string {
if col == "" {
return "results_" + name
}
return col
}

27
core/utils/sql.go Normal file
View File

@@ -0,0 +1,27 @@
package utils
import (
"github.com/crawlab-team/crawlab-db/generic"
"github.com/upper/db/v4"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func GetSqlQuery(query generic.ListQuery) (res db.Cond) {
res = db.Cond{}
for _, c := range query {
switch c.Value.(type) {
case primitive.ObjectID:
c.Value = c.Value.(primitive.ObjectID).Hex()
}
switch c.Op {
case generic.OpEqual:
res[c.Key] = c.Value
default:
res[c.Key] = db.Cond{
c.Op: c.Value,
}
}
}
// TODO: sort
return res
}

45
core/utils/sqlite.go Normal file
View File

@@ -0,0 +1,45 @@
package utils
import (
"context"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/sqlite"
"time"
)
func GetSqliteSession(ds *models.DataSource) (s db.Session, err error) {
return getSqliteSession(context.Background(), ds)
}
func GetSqliteSessionWithTimeout(ds *models.DataSource, timeout time.Duration) (s db.Session, err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return getSqliteSession(ctx, ds)
}
func getSqliteSession(ctx context.Context, ds *models.DataSource) (s db.Session, err error) {
// connect settings
settings := sqlite.ConnectionURL{
Database: ds.Database,
Options: nil,
}
// session
done := make(chan struct{})
go func() {
s, err = sqlite.Open(settings)
close(done)
}()
// wait for done
select {
case <-ctx.Done():
if ctx.Err() != nil {
err = ctx.Err()
}
case <-done:
}
return s, err
}

1
core/utils/stats.go Normal file
View File

@@ -0,0 +1 @@
package utils

7
core/utils/system.go Normal file
View File

@@ -0,0 +1,7 @@
package utils
import "github.com/spf13/viper"
func IsPro() bool {
return viper.GetString("info.edition") == "global.edition.pro"
}

13
core/utils/task.go Normal file
View File

@@ -0,0 +1,13 @@
package utils
import "github.com/crawlab-team/crawlab/core/constants"
func IsCancellable(status string) bool {
switch status {
case constants.TaskStatusPending,
constants.TaskStatusRunning:
return true
default:
return false
}
}

18
core/utils/time.go Normal file
View File

@@ -0,0 +1,18 @@
package utils
import (
"time"
)
func GetLocalTime(t time.Time) time.Time {
return t.In(time.Local)
}
func GetTimeString(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
func GetLocalTimeString(t time.Time) string {
t = GetLocalTime(t)
return GetTimeString(t)
}

8
core/utils/uuid.go Normal file
View File

@@ -0,0 +1,8 @@
package utils
import "github.com/google/uuid"
func NewUUIDString() (res string) {
id, _ := uuid.NewUUID()
return id.String()
}