mirror of
https://github.com/crawlab-team/crawlab.git
synced 2026-01-21 17:21:09 +01:00
refactor: optimized code
This commit is contained in:
@@ -21,17 +21,26 @@ func init() {
|
||||
|
||||
type Api struct {
|
||||
// internals
|
||||
app *gin.Engine
|
||||
ln net.Listener
|
||||
srv *http.Server
|
||||
app *gin.Engine
|
||||
ln net.Listener
|
||||
srv *http.Server
|
||||
initialized bool
|
||||
}
|
||||
|
||||
func (app *Api) Init() {
|
||||
// skip if already initialized
|
||||
if app.initialized {
|
||||
return
|
||||
}
|
||||
|
||||
// initialize middlewares
|
||||
_ = app.initModuleWithApp("middlewares", middlewares.InitMiddlewares)
|
||||
|
||||
// initialize routes
|
||||
_ = app.initModuleWithApp("routes", controllers.InitRoutes)
|
||||
|
||||
// set initialized
|
||||
app.initialized = true
|
||||
}
|
||||
|
||||
func (app *Api) Start() {
|
||||
@@ -90,9 +99,11 @@ func (app *Api) initModuleWithApp(name string, fn func(app *gin.Engine) error) (
|
||||
}
|
||||
|
||||
func newApi() *Api {
|
||||
return &Api{
|
||||
api := &Api{
|
||||
app: gin.New(),
|
||||
}
|
||||
api.Init()
|
||||
return api
|
||||
}
|
||||
|
||||
var api *Api
|
||||
|
||||
6
core/constants/ipc.go
Normal file
6
core/constants/ipc.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
IPCMessageData = "data" // IPCMessageData is the message type identifier for data messages
|
||||
IPCMessageLog = "log" // IPCMessageLog is the message type identifier for log messages
|
||||
)
|
||||
8
core/entity/ipc.go
Normal file
8
core/entity/ipc.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package entity
|
||||
|
||||
// IPCMessage defines the structure for messages exchanged between parent and child processes
|
||||
type IPCMessage struct {
|
||||
Type string `json:"type"` // message type identifier
|
||||
Payload interface{} `json:"payload"` // message content
|
||||
IPC bool `json:"ipc"` // Add this field to explicitly mark IPC messages
|
||||
}
|
||||
@@ -4,6 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/crawlab-team/crawlab/core/constants"
|
||||
"github.com/crawlab-team/crawlab/core/interfaces"
|
||||
@@ -19,9 +23,6 @@ import (
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
mongo2 "go.mongodb.org/mongo-driver/mongo"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var taskServiceMutex = sync.Mutex{}
|
||||
@@ -67,38 +68,66 @@ func (svr TaskServiceServer) Subscribe(req *grpc.TaskServiceSubscribeRequest, st
|
||||
}
|
||||
|
||||
// Connect to task stream when a task runner in a node starts
|
||||
// Connect handles the bidirectional streaming connection from task runners in nodes.
|
||||
// It receives messages containing task data and logs, processes them, and handles any errors.
|
||||
func (svr TaskServiceServer) Connect(stream grpc.TaskService_ConnectServer) (err error) {
|
||||
// spider id and task id to track which spider/task this connection belongs to
|
||||
var spiderId primitive.ObjectID
|
||||
var taskId primitive.ObjectID
|
||||
|
||||
// continuously receive messages from the stream
|
||||
for {
|
||||
// receive next message from stream
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
// stream has ended normally
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
// handle graceful context cancellation
|
||||
if strings.HasSuffix(err.Error(), "context canceled") {
|
||||
return nil
|
||||
}
|
||||
// log other stream receive errors and continue
|
||||
log.Errorf("error receiving stream message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// validate task id
|
||||
taskId, err := primitive.ObjectIDFromHex(msg.TaskId)
|
||||
if err != nil {
|
||||
log.Errorf("invalid task id: %s", msg.TaskId)
|
||||
continue
|
||||
// validate and parse the task ID from the message if not already set
|
||||
if taskId.IsZero() {
|
||||
taskId, err = primitive.ObjectIDFromHex(msg.TaskId)
|
||||
if err != nil {
|
||||
log.Errorf("invalid task id: %s", msg.TaskId)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// get spider id if not already set
|
||||
// this only needs to be done once per connection
|
||||
if spiderId.IsZero() {
|
||||
t, err := service.NewModelService[models.Task]().GetById(taskId)
|
||||
if err != nil {
|
||||
log.Errorf("error getting spider[%s]: %v", taskId.Hex(), err)
|
||||
continue
|
||||
}
|
||||
spiderId = t.SpiderId
|
||||
}
|
||||
|
||||
// handle different message types based on the code
|
||||
switch msg.Code {
|
||||
case grpc.TaskServiceConnectCode_INSERT_DATA:
|
||||
err = svr.handleInsertData(taskId, msg)
|
||||
// handle scraped data insertion
|
||||
err = svr.handleInsertData(taskId, spiderId, msg)
|
||||
case grpc.TaskServiceConnectCode_INSERT_LOGS:
|
||||
// handle task log insertion
|
||||
err = svr.handleInsertLogs(taskId, msg)
|
||||
default:
|
||||
err = errors.New("invalid stream message code")
|
||||
// invalid message code received
|
||||
log.Errorf("invalid stream message code: %d", msg.Code)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
// log any errors from handlers
|
||||
log.Errorf("grpc error[%d]: %v", msg.Code, err)
|
||||
}
|
||||
}
|
||||
@@ -170,7 +199,7 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
|
||||
taskId, err := primitive.ObjectIDFromHex(request.TaskId)
|
||||
if err != nil {
|
||||
log.Errorf("invalid task id: %s", request.TaskId)
|
||||
return nil, trace.TraceError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// arguments
|
||||
@@ -179,31 +208,32 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
|
||||
// task
|
||||
task, err := service.NewModelService[models.Task]().GetById(taskId)
|
||||
if err != nil {
|
||||
log.Errorf("task not found: %s", request.TaskId)
|
||||
return nil, trace.TraceError(err)
|
||||
log.Errorf("error getting task[%s]: %v", request.TaskId, err)
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, task)
|
||||
|
||||
// task stat
|
||||
taskStat, err := service.NewModelService[models.TaskStat]().GetById(task.Id)
|
||||
if err != nil {
|
||||
log.Errorf("task stat not found for task: %s", request.TaskId)
|
||||
return nil, trace.TraceError(err)
|
||||
log.Errorf("error getting task stat for task[%s]: %v", request.TaskId, err)
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, taskStat)
|
||||
|
||||
// spider
|
||||
spider, err := service.NewModelService[models.Spider]().GetById(task.SpiderId)
|
||||
if err != nil {
|
||||
log.Errorf("spider not found for task: %s", request.TaskId)
|
||||
return nil, trace.TraceError(err)
|
||||
log.Errorf("error getting spider[%s]: %v", task.SpiderId.Hex(), err)
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, spider)
|
||||
|
||||
// node
|
||||
node, err := service.NewModelService[models.Node]().GetById(task.NodeId)
|
||||
if err != nil {
|
||||
return nil, trace.TraceError(err)
|
||||
log.Errorf("error getting node[%s]: %v", task.NodeId.Hex(), err)
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, node)
|
||||
|
||||
@@ -212,8 +242,8 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
|
||||
if !task.ScheduleId.IsZero() {
|
||||
schedule, err = service.NewModelService[models.Schedule]().GetById(task.ScheduleId)
|
||||
if err != nil {
|
||||
log.Errorf("schedule not found for task: %s", request.TaskId)
|
||||
return nil, trace.TraceError(err)
|
||||
log.Errorf("error getting schedule[%s]: %v", task.ScheduleId.Hex(), err)
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, schedule)
|
||||
}
|
||||
@@ -226,7 +256,8 @@ func (svr TaskServiceServer) SendNotification(_ context.Context, request *grpc.T
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, trace.TraceError(err)
|
||||
log.Errorf("error getting notification settings: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// notification service
|
||||
@@ -268,11 +299,16 @@ func (svr TaskServiceServer) GetSubscribeStream(taskId primitive.ObjectID) (stre
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
func (svr TaskServiceServer) handleInsertData(taskId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
|
||||
func (svr TaskServiceServer) handleInsertData(taskId, spiderId primitive.ObjectID, msg *grpc.TaskServiceConnectRequest) (err error) {
|
||||
var records []map[string]interface{}
|
||||
err = json.Unmarshal(msg.Data, &records)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
log.Errorf("error unmarshalling data: %v", err)
|
||||
return err
|
||||
}
|
||||
for i := range records {
|
||||
records[i][constants.TaskKey] = taskId
|
||||
records[i][constants.SpiderKey] = spiderId
|
||||
}
|
||||
return svr.statsSvc.InsertData(taskId, records...)
|
||||
}
|
||||
@@ -281,7 +317,8 @@ func (svr TaskServiceServer) handleInsertLogs(taskId primitive.ObjectID, msg *gr
|
||||
var logs []string
|
||||
err = json.Unmarshal(msg.Data, &logs)
|
||||
if err != nil {
|
||||
return trace.TraceError(err)
|
||||
log.Errorf("error unmarshalling logs: %v", err)
|
||||
return err
|
||||
}
|
||||
return svr.statsSvc.InsertLogs(taskId, logs...)
|
||||
}
|
||||
|
||||
@@ -62,10 +62,10 @@ type Runner struct {
|
||||
logBatchSize int // number of log lines to batch before sending
|
||||
|
||||
// IPC (Inter-Process Communication)
|
||||
stdinPipe io.WriteCloser // pipe for writing to child process
|
||||
stdoutPipe io.ReadCloser // pipe for reading from child process
|
||||
ipcChan chan IPCMessage // channel for sending IPC messages
|
||||
ipcHandler func(IPCMessage) // callback for handling received IPC messages
|
||||
stdinPipe io.WriteCloser // pipe for writing to child process
|
||||
stdoutPipe io.ReadCloser // pipe for reading from child process
|
||||
ipcChan chan entity.IPCMessage // channel for sending IPC messages
|
||||
ipcHandler func(entity.IPCMessage) // callback for handling received IPC messages
|
||||
|
||||
// goroutine management
|
||||
ctx context.Context // context for controlling goroutine lifecycle
|
||||
@@ -74,18 +74,6 @@ type Runner struct {
|
||||
wg sync.WaitGroup // wait group for goroutine synchronization
|
||||
}
|
||||
|
||||
const (
|
||||
IPCMessageData = "data" // IPCMessageData is the message type identifier for data messages
|
||||
IPCMessageLog = "log" // IPCMessageLog is the message type identifier for log messages
|
||||
)
|
||||
|
||||
// IPCMessage defines the structure for messages exchanged between parent and child processes
|
||||
type IPCMessage struct {
|
||||
Type string `json:"type"` // message type identifier
|
||||
Payload interface{} `json:"payload"` // message content
|
||||
IPC bool `json:"ipc"` // Add this field to explicitly mark IPC messages
|
||||
}
|
||||
|
||||
// Init initializes the task runner by updating the task status and establishing gRPC connections
|
||||
func (r *Runner) Init() (err error) {
|
||||
// update task
|
||||
@@ -162,10 +150,11 @@ func (r *Runner) Run() (err error) {
|
||||
|
||||
// Ensure cleanup when Run() exits
|
||||
defer func() {
|
||||
r.cancel() // Cancel context to stop all goroutines
|
||||
r.wg.Wait() // Wait for all goroutines to finish
|
||||
close(r.done) // Signal that everything is done
|
||||
close(r.ipcChan) // Close IPC channel
|
||||
_ = r.conn.CloseSend() // Close gRPC connection
|
||||
r.cancel() // Cancel context to stop all goroutines
|
||||
r.wg.Wait() // Wait for all goroutines to finish
|
||||
close(r.done) // Signal that everything is done
|
||||
close(r.ipcChan) // Close IPC channel
|
||||
}()
|
||||
|
||||
// wait for process to finish
|
||||
@@ -258,7 +247,7 @@ func (r *Runner) configureCmd() (err error) {
|
||||
}
|
||||
|
||||
// Initialize IPC channel
|
||||
r.ipcChan = make(chan IPCMessage)
|
||||
r.ipcChan = make(chan entity.IPCMessage)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -782,7 +771,7 @@ func (r *Runner) handleIPC() {
|
||||
// msgType: type of message being sent
|
||||
// payload: data to be sent to the child process
|
||||
func (r *Runner) SendToChild(msgType string, payload interface{}) {
|
||||
r.ipcChan <- IPCMessage{
|
||||
r.ipcChan <- entity.IPCMessage{
|
||||
Type: msgType,
|
||||
Payload: payload,
|
||||
IPC: true, // Explicitly mark as IPC message
|
||||
@@ -790,7 +779,7 @@ func (r *Runner) SendToChild(msgType string, payload interface{}) {
|
||||
}
|
||||
|
||||
// SetIPCHandler sets the handler for incoming IPC messages
|
||||
func (r *Runner) SetIPCHandler(handler func(IPCMessage)) {
|
||||
func (r *Runner) SetIPCHandler(handler func(entity.IPCMessage)) {
|
||||
r.ipcHandler = handler
|
||||
}
|
||||
|
||||
@@ -811,7 +800,7 @@ func (r *Runner) startIPCReader() {
|
||||
}
|
||||
line := scanner.Text()
|
||||
|
||||
var ipcMsg IPCMessage
|
||||
var ipcMsg entity.IPCMessage
|
||||
err := json.Unmarshal([]byte(line), &ipcMsg)
|
||||
if err == nil && ipcMsg.IPC {
|
||||
// Only handle as IPC if it's valid JSON AND has IPC flag set
|
||||
@@ -819,7 +808,7 @@ func (r *Runner) startIPCReader() {
|
||||
r.ipcHandler(ipcMsg)
|
||||
} else {
|
||||
// Default handler (insert data)
|
||||
if ipcMsg.Type == "" || ipcMsg.Type == IPCMessageData {
|
||||
if ipcMsg.Type == "" || ipcMsg.Type == constants.IPCMessageData {
|
||||
r.handleIPCInsertDataMessage(ipcMsg)
|
||||
} else {
|
||||
log.Warnf("no IPC handler set for message: %v", ipcMsg)
|
||||
@@ -834,7 +823,7 @@ func (r *Runner) startIPCReader() {
|
||||
}
|
||||
|
||||
// handleIPCInsertDataMessage converts the IPC message payload to JSON and sends it to the master node
|
||||
func (r *Runner) handleIPCInsertDataMessage(ipcMsg IPCMessage) {
|
||||
func (r *Runner) handleIPCInsertDataMessage(ipcMsg entity.IPCMessage) {
|
||||
// Validate message
|
||||
if ipcMsg.Payload == nil {
|
||||
log.Errorf("empty payload in IPC message")
|
||||
|
||||
@@ -63,11 +63,11 @@ func (svc *Service) InsertData(taskId primitive.ObjectID, records ...map[string]
|
||||
count++
|
||||
}
|
||||
} else {
|
||||
var records2 []interface{}
|
||||
var recordsToInsert []interface{}
|
||||
for _, record := range records {
|
||||
records2 = append(records2, svc.normalizeRecord(item, record))
|
||||
recordsToInsert = append(recordsToInsert, svc.normalizeRecord(item, record))
|
||||
}
|
||||
_, err = mongo.GetMongoCol(tableName).InsertMany(records2)
|
||||
_, err = mongo.GetMongoCol(tableName).InsertMany(recordsToInsert)
|
||||
if err != nil {
|
||||
log2.Errorf("failed to insert data: %v", err)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user