351 lines
6.9 KiB
Go
351 lines
6.9 KiB
Go
package eventing
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand/v2"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"AdaptixServer/core/utils/logs"
|
|
)
|
|
|
|
const (
|
|
DefaultWorkerCount = 4
|
|
DefaultQueueSize = 256
|
|
DefaultHookTimeout = 30 * time.Second
|
|
DefaultPreHookTimeout = 5 * time.Second
|
|
)
|
|
|
|
type HookFunc func(event any) error
|
|
|
|
type Hook struct {
|
|
ID string
|
|
Name string
|
|
Phase HookPhase
|
|
Priority int
|
|
Timeout time.Duration
|
|
Handler HookFunc
|
|
}
|
|
|
|
type eventTask struct {
|
|
eventType EventType
|
|
event any
|
|
hook *Hook
|
|
}
|
|
|
|
type EventManager struct {
|
|
hooks map[EventType][]*Hook
|
|
mu sync.RWMutex
|
|
|
|
taskQueue chan eventTask
|
|
workerWg sync.WaitGroup
|
|
shutdownCh chan struct{}
|
|
running bool
|
|
}
|
|
|
|
func NewEventManager() *EventManager {
|
|
em := &EventManager{
|
|
hooks: make(map[EventType][]*Hook),
|
|
taskQueue: make(chan eventTask, DefaultQueueSize),
|
|
shutdownCh: make(chan struct{}),
|
|
running: true,
|
|
}
|
|
|
|
em.startWorkers(DefaultWorkerCount)
|
|
return em
|
|
}
|
|
|
|
func (em *EventManager) startWorkers(count int) {
|
|
for i := 0; i < count; i++ {
|
|
em.workerWg.Add(1)
|
|
go em.worker(i)
|
|
}
|
|
}
|
|
|
|
func (em *EventManager) worker(id int) {
|
|
defer em.workerWg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-em.shutdownCh:
|
|
return
|
|
case task, ok := <-em.taskQueue:
|
|
if !ok {
|
|
return
|
|
}
|
|
_ = em.executeHookWithTimeout(task.hook, task.event, DefaultHookTimeout)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (em *EventManager) executeHookWithTimeout(hook *Hook, event any, defaultTimeout time.Duration) error {
|
|
timeout := hook.Timeout
|
|
if timeout == 0 {
|
|
timeout = defaultTimeout
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
done <- fmt.Errorf("hook %s panicked: %v", hook.Name, r)
|
|
}
|
|
}()
|
|
done <- hook.Handler(event)
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
return err
|
|
case <-ctx.Done():
|
|
logs.Warn("", fmt.Sprintf("Hook %s timed out after %v", hook.Name, timeout))
|
|
return fmt.Errorf("hook %s timed out", hook.Name)
|
|
}
|
|
}
|
|
|
|
func (em *EventManager) Shutdown() {
|
|
em.mu.Lock()
|
|
if !em.running {
|
|
em.mu.Unlock()
|
|
return
|
|
}
|
|
em.running = false
|
|
em.mu.Unlock()
|
|
|
|
close(em.shutdownCh)
|
|
close(em.taskQueue)
|
|
em.workerWg.Wait()
|
|
}
|
|
|
|
func (em *EventManager) Register(eventType EventType, hook *Hook) string {
|
|
em.mu.Lock()
|
|
defer em.mu.Unlock()
|
|
|
|
if hook.ID == "" {
|
|
hook.ID = em.generateUniqueID()
|
|
}
|
|
|
|
em.hooks[eventType] = append(em.hooks[eventType], hook)
|
|
|
|
sort.Slice(em.hooks[eventType], func(i, j int) bool {
|
|
return em.hooks[eventType][i].Priority < em.hooks[eventType][j].Priority
|
|
})
|
|
|
|
return hook.ID
|
|
}
|
|
|
|
func (em *EventManager) generateUniqueID() string {
|
|
for {
|
|
id := fmt.Sprintf("%08x", rand.Uint32())
|
|
if !em.hookIDExists(id) {
|
|
return id
|
|
}
|
|
}
|
|
}
|
|
|
|
func (em *EventManager) hookIDExists(id string) bool {
|
|
for _, hooks := range em.hooks {
|
|
for _, hook := range hooks {
|
|
if hook.ID == id {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (em *EventManager) Unregister(hookID string) bool {
|
|
em.mu.Lock()
|
|
defer em.mu.Unlock()
|
|
|
|
for eventType, hooks := range em.hooks {
|
|
for i, hook := range hooks {
|
|
if hook.ID == hookID {
|
|
em.hooks[eventType] = append(hooks[:i], hooks[i+1:]...)
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (em *EventManager) UnregisterByName(name string) int {
|
|
em.mu.Lock()
|
|
defer em.mu.Unlock()
|
|
|
|
count := 0
|
|
for eventType, hooks := range em.hooks {
|
|
filtered := make([]*Hook, 0, len(hooks))
|
|
for _, hook := range hooks {
|
|
if hook.Name != name {
|
|
filtered = append(filtered, hook)
|
|
} else {
|
|
count++
|
|
}
|
|
}
|
|
em.hooks[eventType] = filtered
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (em *EventManager) Emit(eventType EventType, phase HookPhase, event any) bool {
|
|
em.mu.RLock()
|
|
hooks := make([]*Hook, len(em.hooks[eventType]))
|
|
copy(hooks, em.hooks[eventType])
|
|
em.mu.RUnlock()
|
|
|
|
baseEvent := getBaseEvent(event)
|
|
if baseEvent != nil {
|
|
baseEvent.Type = eventType
|
|
baseEvent.Phase = phase
|
|
}
|
|
|
|
for _, hook := range hooks {
|
|
if hook.Phase != phase {
|
|
continue
|
|
}
|
|
|
|
var err error
|
|
if phase == HookPre {
|
|
err = em.executeHookWithTimeout(hook, event, DefaultPreHookTimeout)
|
|
} else {
|
|
err = em.executeHookWithTimeout(hook, event, DefaultHookTimeout)
|
|
}
|
|
|
|
if err != nil {
|
|
if baseEvent != nil {
|
|
baseEvent.Error = err
|
|
if phase == HookPre {
|
|
baseEvent.Cancelled = true
|
|
return false
|
|
}
|
|
}
|
|
logs.Warn("", fmt.Sprintf("Hook %s error: %v", hook.Name, err))
|
|
}
|
|
|
|
if baseEvent != nil && phase == HookPre && baseEvent.Cancelled {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (em *EventManager) EmitAsync(eventType EventType, event any) {
|
|
em.mu.RLock()
|
|
hooks := make([]*Hook, len(em.hooks[eventType]))
|
|
copy(hooks, em.hooks[eventType])
|
|
running := em.running
|
|
em.mu.RUnlock()
|
|
|
|
if !running {
|
|
return
|
|
}
|
|
|
|
baseEvent := getBaseEvent(event)
|
|
if baseEvent != nil {
|
|
baseEvent.Type = eventType
|
|
baseEvent.Phase = HookPost
|
|
}
|
|
|
|
for _, hook := range hooks {
|
|
if hook.Phase != HookPost {
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case em.taskQueue <- eventTask{
|
|
eventType: eventType,
|
|
event: event,
|
|
hook: hook,
|
|
}:
|
|
default:
|
|
logs.Warn("", fmt.Sprintf("Event queue full, dropping hook %s for event %s", hook.Name, eventType))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (em *EventManager) ListHooks(eventType EventType) []Hook {
|
|
em.mu.RLock()
|
|
defer em.mu.RUnlock()
|
|
|
|
result := make([]Hook, 0, len(em.hooks[eventType]))
|
|
for _, hook := range em.hooks[eventType] {
|
|
result = append(result, *hook)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (em *EventManager) ListAllHooks() map[EventType][]Hook {
|
|
em.mu.RLock()
|
|
defer em.mu.RUnlock()
|
|
|
|
result := make(map[EventType][]Hook)
|
|
for eventType, hooks := range em.hooks {
|
|
list := make([]Hook, 0, len(hooks))
|
|
for _, hook := range hooks {
|
|
list = append(list, *hook)
|
|
}
|
|
result[eventType] = list
|
|
}
|
|
return result
|
|
}
|
|
|
|
func getBaseEvent(event any) *BaseEvent {
|
|
switch e := event.(type) {
|
|
case *EventCredentialsAdd:
|
|
return &e.BaseEvent
|
|
case *EventCredentialsEdit:
|
|
return &e.BaseEvent
|
|
case *EventCredentialsRemove:
|
|
return &e.BaseEvent
|
|
case *EventDataAgentNew:
|
|
return &e.BaseEvent
|
|
case *EventDataAgentCheckin:
|
|
return &e.BaseEvent
|
|
case *EventDataAgentTerminate:
|
|
return &e.BaseEvent
|
|
case *EventDataAgentUpdate:
|
|
return &e.BaseEvent
|
|
case *EventDataTaskCreate:
|
|
return &e.BaseEvent
|
|
case *EventDataTaskStart:
|
|
return &e.BaseEvent
|
|
case *EventDataTaskComplete:
|
|
return &e.BaseEvent
|
|
case *EventDataListenerStart:
|
|
return &e.BaseEvent
|
|
case *EventDataListenerStop:
|
|
return &e.BaseEvent
|
|
case *EventDataDownloadStart:
|
|
return &e.BaseEvent
|
|
case *EventDataDownloadFinish:
|
|
return &e.BaseEvent
|
|
case *EventDataScreenshotAdd:
|
|
return &e.BaseEvent
|
|
case *EventDataTunnelStart:
|
|
return &e.BaseEvent
|
|
case *EventDataTunnelStop:
|
|
return &e.BaseEvent
|
|
case *EventDataClientConnect:
|
|
return &e.BaseEvent
|
|
case *EventDataClientDisconnect:
|
|
return &e.BaseEvent
|
|
case *EventDataTargetAdd:
|
|
return &e.BaseEvent
|
|
case *EventDataTargetRemove:
|
|
return &e.BaseEvent
|
|
case *EventDataPivotCreate:
|
|
return &e.BaseEvent
|
|
case *EventDataPivotRemove:
|
|
return &e.BaseEvent
|
|
default:
|
|
return nil
|
|
}
|
|
}
|