Files
crawlab/core/ds/es.go

221 lines
5.3 KiB
Go

package ds
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/crawlab-team/crawlab/core/constants"
constants2 "github.com/crawlab-team/crawlab/core/constants"
"github.com/crawlab-team/crawlab/core/entity"
entity2 "github.com/crawlab-team/crawlab/core/entity"
"github.com/crawlab-team/crawlab/core/interfaces"
"github.com/crawlab-team/crawlab/core/models/models"
"github.com/crawlab-team/crawlab/core/models/service"
"github.com/crawlab-team/crawlab/core/utils"
"github.com/crawlab-team/crawlab/db/generic"
"github.com/crawlab-team/crawlab/trace"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/esapi"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson/primitive"
"strings"
"sync"
"time"
)
type ElasticsearchService struct {
// dependencies
modelSvc service.ModelService
// internals
dc *models.DataCollection // models.DataCollection
ds *models.DataSource // models.DataSource
c *elasticsearch.Client // elasticsearch.Client
t time.Time
}
func (svc *ElasticsearchService) Insert(records ...interface{}) (err error) {
// wait group
var wg sync.WaitGroup
wg.Add(len(records))
// iterate records
for _, r := range records {
// async operation
go func(r interface{}) {
switch r.(type) {
case entity.Result:
// convert type to entity.Result
d := r.(entity.Result)
// get document id
id := d.GetValue("id")
var docId string
switch id.(type) {
case string:
docId = id.(string)
}
if docId == "" {
docId = uuid.New().String() // generate new uuid if id is empty
}
// collection
d[constants2.DataCollectionKey] = svc.dc.Name
// index request
req := esapi.IndexRequest{
Index: svc.getIndexName(),
DocumentID: docId,
Body: strings.NewReader(d.String()),
}
// perform request
res, err := req.Do(context.Background(), svc.c)
if err != nil {
trace.PrintError(err)
wg.Done()
return
}
defer res.Body.Close()
if res.IsError() {
trace.PrintError(errors.New(fmt.Sprintf("[ElasticsearchService] [%s] error inserting record: %v", res.Status(), r)))
}
// release
wg.Done()
default:
wg.Done()
return
}
}(r)
}
// wait
wg.Wait()
return nil
}
func (svc *ElasticsearchService) List(query generic.ListQuery, opts *generic.ListOptions) (results []interface{}, err error) {
data, err := svc.getListResponse(query, opts, false)
if err != nil {
return nil, err
}
for _, hit := range data.Hits.Hits {
results = append(results, hit.Source)
}
return results, nil
}
func (svc *ElasticsearchService) Count(query generic.ListQuery) (n int, err error) {
data, err := svc.getListResponse(query, nil, true)
if err != nil {
return n, err
}
return int(data.Hits.Total.Value), nil
}
func (svc *ElasticsearchService) getListResponse(query generic.ListQuery, opts *generic.ListOptions, trackTotalHits bool) (data *entity2.ElasticsearchResponseData, err error) {
if opts == nil {
opts = &generic.ListOptions{}
}
query = append(query, generic.ListQueryCondition{
Key: constants2.DataCollectionKey,
Op: constants2.FilterOpEqual,
Value: svc.dc.Name,
})
res, err := svc.c.Search(
svc.c.Search.WithContext(context.Background()),
svc.c.Search.WithIndex(svc.getIndexName()),
svc.c.Search.WithBody(utils.GetElasticsearchQueryWithOptions(query, opts)),
svc.c.Search.WithTrackTotalHits(trackTotalHits),
)
if err != nil {
return nil, trace.TraceError(err)
}
defer res.Body.Close()
if res.IsError() {
err = errors.New(fmt.Sprintf("[ElasticsearchService] [%s] error listing records: response=%s, query=%v opts=%v", res.Status(), res.String(), query, opts))
trace.PrintError(err)
return nil, err
}
data = &entity2.ElasticsearchResponseData{}
if err := json.NewDecoder(res.Body).Decode(data); err != nil {
return nil, trace.TraceError(err)
}
return data, nil
}
func (svc *ElasticsearchService) getIndexName() (index string) {
if svc.ds.Database == "" {
return svc.dc.Name
} else {
return svc.ds.Name
}
}
func NewDataSourceElasticsearchService(colId primitive.ObjectID, dsId primitive.ObjectID) (svc2 interfaces.ResultService, err error) {
// service
svc := &ElasticsearchService{}
// dependency injection
svc.modelSvc, err = service.GetService()
if err != nil {
return nil, err
}
// data source
if dsId.IsZero() {
svc.ds = &models.DataSource{}
} else {
svc.ds, err = svc.modelSvc.GetDataSourceById(dsId)
if err != nil {
return nil, err
}
}
// data source defaults
if svc.ds.Host == "" {
svc.ds.Host = constants.DefaultHost
}
if svc.ds.Port == 0 {
svc.ds.Port = constants.DefaultElasticsearchPort
}
// data source password
pwd, err := svc.modelSvc.GetPasswordById(svc.ds.Id)
if err == nil {
svc.ds.Password, err = utils.DecryptAES(pwd.Password)
if err != nil {
return nil, err
}
}
// data collection
svc.dc, err = svc.modelSvc.GetDataCollectionById(colId)
if err != nil {
return nil, err
}
// client
svc.c, err = utils.GetElasticsearchClient(svc.ds)
if err != nil {
return nil, err
}
return svc, nil
}
func (svc *ElasticsearchService) Index(fields []string) {
// TODO: implement me
}
func (svc *ElasticsearchService) SetTime(t time.Time) {
svc.t = t
}
func (svc *ElasticsearchService) GetTime() (t time.Time) {
return svc.t
}