From a9346a093400fa919725bf3b8dc8cc5372ec4bbc Mon Sep 17 00:00:00 2001 From: yaziming Date: Sun, 1 Sep 2019 17:18:08 +0800 Subject: [PATCH] backend: 1. Mongo dial add 5 seconds connection timeout. 2. Redis uses connection pool mode. 3. Redis pool new connection have 10 seconds write timeout and read timeout and connection timeout. --- backend/database/mongo.go | 3 +- backend/database/pubsub.go | 145 +++++++++++++++++++------------------ backend/database/redis.go | 75 +++++++++---------- backend/services/log.go | 2 +- backend/services/node.go | 59 ++++++++------- backend/services/spider.go | 36 ++++----- backend/services/system.go | 2 +- backend/services/task.go | 2 +- backend/utils/file.go | 2 +- 9 files changed, 166 insertions(+), 160 deletions(-) diff --git a/backend/database/mongo.go b/backend/database/mongo.go index 6b155791..1c2d6433 100644 --- a/backend/database/mongo.go +++ b/backend/database/mongo.go @@ -3,6 +3,7 @@ package database import ( "github.com/globalsign/mgo" "github.com/spf13/viper" + "time" ) var Session *mgo.Session @@ -44,7 +45,7 @@ func InitMongo() error { } else { uri = "mongodb://" + mongoUsername + ":" + mongoPassword + "@" + mongoHost + ":" + mongoPort + "/" + mongoDb + "?authSource=" + mongoAuth } - sess, err := mgo.Dial(uri) + sess, err := mgo.DialWithTimeout(uri, time.Second*5) if err != nil { return err } diff --git a/backend/database/pubsub.go b/backend/database/pubsub.go index 01e52fa1..152dc9a3 100644 --- a/backend/database/pubsub.go +++ b/backend/database/pubsub.go @@ -1,90 +1,97 @@ package database import ( - "errors" + "context" "fmt" "github.com/apex/log" "github.com/gomodule/redigo/redis" + errors2 "github.com/pkg/errors" "time" - "unsafe" ) -type SubscribeCallback func(channel, message string) +type ConsumeFunc func(message redis.Message) error -type Subscriber struct { - client redis.PubSubConn - cbMap map[string]SubscribeCallback -} - -func (c *Subscriber) Connect() { - conn, err := GetRedisConn() - if err != nil { - log.Fatalf("redis dial failed.") - } - - c.client = redis.PubSubConn{Conn: conn} - c.cbMap = make(map[string]SubscribeCallback) - - //retry connect redis 5 times, or panic - index := 0 - go func(i int) { - for { - log.Debug("wait...") - switch res := c.client.Receive().(type) { - case redis.Message: - i = 0 - channel := (*string)(unsafe.Pointer(&res.Channel)) - message := (*string)(unsafe.Pointer(&res.Data)) - c.cbMap[*channel](*channel, *message) - case redis.Subscription: - fmt.Printf("%s: %s %d\n", res.Channel, res.Kind, res.Count) - case error: - log.Error("error handle redis connection...") - - time.Sleep(2 * time.Second) - if i > 5 { - panic(errors.New("redis connection failed too many times, panic")) - } - con, err := GetRedisConn() - i += 1 - if err != nil { - log.Error("redis dial failed") - continue - } - c.client = redis.PubSubConn{Conn: con} - - continue - } - } - }(index) - -} - -func (c *Subscriber) Close() { - err := c.client.Close() +func (r *Redis) Close() { + err := r.pool.Close() if err != nil { log.Errorf("redis close error.") } } +func (r *Redis) subscribe(ctx context.Context, consume ConsumeFunc, channel ...string) error { + psc := redis.PubSubConn{Conn: r.pool.Get()} + if err := psc.Subscribe(redis.Args{}.AddFlat(channel)); err != nil { + return err + } + done := make(chan error, 1) + tick := time.NewTicker(time.Second * 3) + defer tick.Stop() + go func() { + defer func() { _ = psc.Close() }() + for { + switch msg := psc.Receive().(type) { + case error: + done <- fmt.Errorf("redis pubsub receive err: %v", msg) + return + case redis.Message: + fmt.Println(msg) + if err := consume(msg); err != nil { + fmt.Printf("redis pubsub consume message err: %v", err) + continue + } + case redis.Subscription: + fmt.Println(msg) + // + //if msg.Count == 0 { + // // all channels are unsubscribed + // return + //} + } -func (c *Subscriber) Subscribe(channel interface{}, cb SubscribeCallback) { - err := c.client.Subscribe(channel) - if err != nil { - log.Fatalf("redis Subscribe error.") + } + }() + // start a new goroutine to receive message + for { + select { + case <-ctx.Done(): + if err := psc.Unsubscribe(); err != nil { + fmt.Printf("redis pubsub unsubscribe err: %v", err) + } + return nil + case <-tick.C: + //fmt.Printf("ping message \n") + if err := psc.Ping(""); err != nil { + done <- err + } + case err := <-done: + close(done) + return err + } } - c.cbMap[channel.(string)] = cb } +func (r *Redis) Subscribe(ctx context.Context, consume ConsumeFunc, channel ...string) error { + index := 0 + go func() { + for { + err := r.subscribe(ctx, consume, channel...) + fmt.Println(err) -func Publish(channel string, msg string) error { - c, err := GetRedisConn() - if err != nil { - return err - } - - if _, err := c.Do("PUBLISH", channel, msg); err != nil { - return err - } - + if err == nil { + break + } + time.Sleep(5 * time.Second) + index += 1 + fmt.Printf("try reconnect %d times \n", index) + } + }() return nil } +func (r *Redis) Publish(channel, message string) (n int, err error) { + conn := r.pool.Get() + defer func() { _ = conn.Close() }() + n, err = redis.Int(conn.Do("PUBLISH", channel, message)) + if err != nil { + return 0, errors2.Wrapf(err, "redis publish %s %s", channel, message) + } + return +} diff --git a/backend/database/redis.go b/backend/database/redis.go index ffebf776..d159cccd 100644 --- a/backend/database/redis.go +++ b/backend/database/redis.go @@ -1,24 +1,24 @@ package database import ( + "fmt" "github.com/gomodule/redigo/redis" "github.com/spf13/viper" "runtime/debug" + "time" ) -var RedisClient = Redis{} - -type ConsumeFunc func(channel string, message []byte) error +var RedisClient *Redis type Redis struct { + pool *redis.Pool } +func NewRedisClient() *Redis { + return &Redis{pool: NewRedisPool()} +} func (r *Redis) RPush(collection string, value interface{}) error { - c, err := GetRedisConn() - if err != nil { - debug.PrintStack() - return err - } + c := r.pool.Get() defer c.Close() if _, err := c.Do("RPUSH", collection, value); err != nil { @@ -29,11 +29,7 @@ func (r *Redis) RPush(collection string, value interface{}) error { } func (r *Redis) LPop(collection string) (string, error) { - c, err := GetRedisConn() - if err != nil { - debug.PrintStack() - return "", err - } + c := r.pool.Get() defer c.Close() value, err2 := redis.String(c.Do("LPOP", collection)) @@ -44,11 +40,7 @@ func (r *Redis) LPop(collection string) (string, error) { } func (r *Redis) HSet(collection string, key string, value string) error { - c, err := GetRedisConn() - if err != nil { - debug.PrintStack() - return err - } + c := r.pool.Get() defer c.Close() if _, err := c.Do("HSET", collection, key, value); err != nil { @@ -59,11 +51,7 @@ func (r *Redis) HSet(collection string, key string, value string) error { } func (r *Redis) HGet(collection string, key string) (string, error) { - c, err := GetRedisConn() - if err != nil { - debug.PrintStack() - return "", err - } + c := r.pool.Get() defer c.Close() value, err2 := redis.String(c.Do("HGET", collection, key)) @@ -74,11 +62,7 @@ func (r *Redis) HGet(collection string, key string) (string, error) { } func (r *Redis) HDel(collection string, key string) error { - c, err := GetRedisConn() - if err != nil { - debug.PrintStack() - return err - } + c := r.pool.Get() defer c.Close() if _, err := c.Do("HDEL", collection, key); err != nil { @@ -88,11 +72,7 @@ func (r *Redis) HDel(collection string, key string) error { } func (r *Redis) HKeys(collection string) ([]string, error) { - c, err := GetRedisConn() - if err != nil { - debug.PrintStack() - return []string{}, err - } + c := r.pool.Get() defer c.Close() value, err2 := redis.Strings(c.Do("HKeys", collection)) @@ -102,7 +82,7 @@ func (r *Redis) HKeys(collection string) ([]string, error) { return value, nil } -func GetRedisConn() (redis.Conn, error) { +func NewRedisPool() *redis.Pool { var address = viper.GetString("redis.address") var port = viper.GetString("redis.port") var database = viper.GetString("redis.database") @@ -114,14 +94,31 @@ func GetRedisConn() (redis.Conn, error) { } else { url = "redis://x:" + password + "@" + address + ":" + port + "/" + database } - c, err := redis.DialURL(url) - if err != nil { - debug.PrintStack() - return c, err + fmt.Println(url) + return &redis.Pool{ + Dial: func() (conn redis.Conn, e error) { + return redis.DialURL(url, + redis.DialConnectTimeout(time.Second*10), + redis.DialReadTimeout(time.Second*10), + redis.DialWriteTimeout(time.Second*10), + ) + }, + TestOnBorrow: func(c redis.Conn, t time.Time) error { + if time.Since(t) < time.Minute { + return nil + } + _, err := c.Do("PING") + return err + }, + MaxIdle: 10, + MaxActive: 0, + IdleTimeout: 300 * time.Second, + Wait: false, + MaxConnLifetime: 0, } - return c, nil } func InitRedis() error { + RedisClient = NewRedisClient() return nil } diff --git a/backend/services/log.go b/backend/services/log.go index a83926f2..da1d72e1 100644 --- a/backend/services/log.go +++ b/backend/services/log.go @@ -71,7 +71,7 @@ func GetRemoteLog(task model.Task) (logStr string, err error) { // 发布获取日志消息 channel := "nodes:" + task.NodeId.Hex() - if err := database.Publish(channel, string(msgBytes)); err != nil { + if _, err := database.RedisClient.Publish(channel, string(msgBytes)); err != nil { log.Errorf(err.Error()) return "", err } diff --git a/backend/services/node.go b/backend/services/node.go index 9685a1bb..eb24f759 100644 --- a/backend/services/node.go +++ b/backend/services/node.go @@ -1,6 +1,7 @@ package services import ( + "context" "crawlab/constants" "crawlab/database" "crawlab/lib/cron" @@ -10,6 +11,7 @@ import ( "fmt" "github.com/apex/log" "github.com/globalsign/mgo/bson" + "github.com/gomodule/redigo/redis" "github.com/spf13/viper" "runtime/debug" "time" @@ -258,13 +260,12 @@ func UpdateNodeData() { } } -func MasterNodeCallback(channel string, msgStr string) { +func MasterNodeCallback(message redis.Message) (err error) { // 反序列化 var msg NodeMessage - if err := json.Unmarshal([]byte(msgStr), &msg); err != nil { - log.Errorf(err.Error()) - debug.PrintStack() - return + if err := json.Unmarshal(message.Data, &msg); err != nil { + + return err } if msg.Type == constants.MsgTypeGetLog { @@ -281,16 +282,15 @@ func MasterNodeCallback(channel string, msgStr string) { sysInfoBytes, _ := json.Marshal(&msg.SysInfo) ch <- string(sysInfoBytes) } + return nil } -func WorkerNodeCallback(channel string, msgStr string) { +func WorkerNodeCallback(message redis.Message) (err error) { // 反序列化 msg := NodeMessage{} - fmt.Println(msgStr) - if err := json.Unmarshal([]byte(msgStr), &msg); err != nil { - log.Errorf(err.Error()) - debug.PrintStack() - return + if err := json.Unmarshal(message.Data, &msg); err != nil { + + return err } if msg.Type == constants.MsgTypeGetLog { @@ -317,16 +317,14 @@ func WorkerNodeCallback(channel string, msgStr string) { // 序列化 msgSdBytes, err := json.Marshal(&msgSd) if err != nil { - log.Errorf(err.Error()) - debug.PrintStack() - return + return err } // 发布消息给主节点 log.Info("publish get log msg to master") - if err := database.Publish("nodes:master", string(msgSdBytes)); err != nil { - log.Errorf(err.Error()) - return + if _, err := database.RedisClient.Publish("nodes:master", string(msgSdBytes)); err != nil { + + return err } } else if msg.Type == constants.MsgTypeCancelTask { // 取消任务 @@ -336,8 +334,7 @@ func WorkerNodeCallback(channel string, msgStr string) { // 获取环境信息 sysInfo, err := GetLocalSystemInfo() if err != nil { - log.Errorf(err.Error()) - return + return err } msgSd := NodeMessage{ Type: constants.MsgTypeGetSystemInfo, @@ -348,14 +345,14 @@ func WorkerNodeCallback(channel string, msgStr string) { if err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } - fmt.Println(msgSd) - if err := database.Publish("nodes:master", string(msgSdBytes)); err != nil { + if _, err := database.RedisClient.Publish("nodes:master", string(msgSdBytes)); err != nil { log.Errorf(err.Error()) - return + return err } } + return } // 初始化节点服务 @@ -373,25 +370,27 @@ func InitNodeService() error { // 首次更新节点数据(注册到Redis) UpdateNodeData() - // 消息订阅 - var sub database.Subscriber - sub.Connect() - // 获取当前节点 node, err := GetCurrentNode() if err != nil { log.Errorf(err.Error()) return err } - + ctx := context.Background() if IsMaster() { // 如果为主节点,订阅主节点通信频道 channel := "nodes:master" - sub.Subscribe(channel, MasterNodeCallback) + err := database.RedisClient.Subscribe(ctx, MasterNodeCallback, channel) + if err != nil { + return err + } } else { // 若为工作节点,订阅单独指定通信频道 channel := "nodes:" + node.Id.Hex() - sub.Subscribe(channel, WorkerNodeCallback) + err := database.RedisClient.Subscribe(ctx, WorkerNodeCallback, channel) + if err != nil { + return err + } } // 如果为主节点,每30秒刷新所有节点信息 diff --git a/backend/services/spider.go b/backend/services/spider.go index 47c1fa33..61a561a9 100644 --- a/backend/services/spider.go +++ b/backend/services/spider.go @@ -1,6 +1,7 @@ package services import ( + "context" "crawlab/constants" "crawlab/database" "crawlab/lib/cron" @@ -11,6 +12,7 @@ import ( "github.com/apex/log" "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" + "github.com/gomodule/redigo/redis" "github.com/pkg/errors" "github.com/satori/go.uuid" "github.com/spf13/viper" @@ -142,10 +144,7 @@ func ZipSpider(spider model.Spider) (filePath string, err error) { // 临时文件路径 randomId := uuid.NewV4() - if err != nil { - debug.PrintStack() - return "", err - } + filePath = filepath.Join( viper.GetString("other.tmppath"), randomId.String()+".zip", @@ -302,7 +301,7 @@ func PublishSpider(spider model.Spider) (err error) { return } channel := "files:upload" - if err = database.Publish(channel, string(msgStr)); err != nil { + if _, err = database.RedisClient.Publish(channel, string(msgStr)); err != nil { log.Errorf(err.Error()) debug.PrintStack() return @@ -312,16 +311,16 @@ func PublishSpider(spider model.Spider) (err error) { } // 上传爬虫回调 -func OnFileUpload(channel string, msgStr string) { +func OnFileUpload(message redis.Message) (err error) { s, gf := database.GetGridFs("files") defer s.Close() // 反序列化消息 var msg SpiderUploadMessage - if err := json.Unmarshal([]byte(msgStr), &msg); err != nil { + if err := json.Unmarshal(message.Data, &msg); err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } // 从GridFS获取该文件 @@ -329,7 +328,7 @@ func OnFileUpload(channel string, msgStr string) { if err != nil { log.Errorf("open file id: " + msg.FileId + ", spider id:" + msg.SpiderId + ", error: " + err.Error()) debug.PrintStack() - return + return err } defer f.Close() @@ -342,7 +341,7 @@ func OnFileUpload(channel string, msgStr string) { if err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } defer tmpFile.Close() @@ -350,7 +349,7 @@ func OnFileUpload(channel string, msgStr string) { if _, err := io.Copy(tmpFile, f); err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } // 解压缩临时文件到目标文件夹 @@ -361,22 +360,23 @@ func OnFileUpload(channel string, msgStr string) { if err := utils.DeCompress(tmpFile, dstPath); err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } // 关闭临时文件 if err := tmpFile.Close(); err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } // 删除临时文件 if err := os.Remove(tmpFilePath); err != nil { log.Errorf(err.Error()) debug.PrintStack() - return + return err } + return nil } // 启动爬虫服务 @@ -401,9 +401,11 @@ func InitSpiderService() error { // 订阅文件上传 channel := "files:upload" - var sub database.Subscriber - sub.Connect() - sub.Subscribe(channel, OnFileUpload) + + //sub.Connect() + ctx := context.Background() + return database.RedisClient.Subscribe(ctx, OnFileUpload, channel) + } // 启动定时任务 diff --git a/backend/services/system.go b/backend/services/system.go index 5f50dec9..ff177aa0 100644 --- a/backend/services/system.go +++ b/backend/services/system.go @@ -112,7 +112,7 @@ func GetRemoteSystemInfo(id string) (sysInfo model.SystemInfo, err error) { // 序列化 msgBytes, _ := json.Marshal(&msg) - if err := database.Publish("nodes:"+id, string(msgBytes)); err != nil { + if _, err := database.RedisClient.Publish("nodes:"+id, string(msgBytes)); err != nil { return model.SystemInfo{}, err } diff --git a/backend/services/task.go b/backend/services/task.go index 1b0a5676..6ba6b257 100644 --- a/backend/services/task.go +++ b/backend/services/task.go @@ -466,7 +466,7 @@ func CancelTask(id string) (err error) { } // 发布消息 - if err := database.Publish("nodes:"+task.NodeId.Hex(), string(msgBytes)); err != nil { + if _, err := database.RedisClient.Publish("nodes:"+task.NodeId.Hex(), string(msgBytes)); err != nil { return err } } diff --git a/backend/utils/file.go b/backend/utils/file.go index 9a4300a1..dda73c13 100644 --- a/backend/utils/file.go +++ b/backend/utils/file.go @@ -179,11 +179,11 @@ func _Compress(file *os.File, prefix string, zw *zip.Writer) error { } } else { header, err := zip.FileInfoHeader(info) - header.Name = prefix + "/" + header.Name if err != nil { debug.PrintStack() return err } + header.Name = prefix + "/" + header.Name writer, err := zw.CreateHeader(header) if err != nil { debug.PrintStack()