2026-04-06 00:20:51 -05:00

508 lines
11 KiB
Go

package server
import (
"context"
"io"
"net"
"strconv"
"sync"
"sync/atomic"
"github.com/Adaptix-Framework/axc2"
"github.com/gorilla/websocket"
)
const (
TunnelBufferSize = 0x8000
TunnelBufferPoolCap = 256
tunnelIngressQueueDepth = 1024
tunnelIngressHiWM = 80
tunnelIngressLoWM = 20
)
type TunnelManager struct {
ts *Teamserver
tunnels sync.Map // tunnelId string -> *Tunnel
channelIndex sync.Map // "tunnelId:channelId" string -> *ChannelEntry
bufferPool sync.Pool
stats TunnelStats
}
func (tm *TunnelManager) SendTunnelFlowControl(channelId int, pause bool) {
entry, ok := tm.GetChannelByIdOnly(channelId)
if !ok || entry.Tunnel == nil {
return
}
if tm.ts == nil {
return
}
agent, err := tm.ts.getAgent(entry.Tunnel.Data.AgentId)
if err != nil {
return
}
var task adaptix.TaskData
if pause {
task = entry.Tunnel.Callbacks.Pause(channelId)
} else {
task = entry.Tunnel.Callbacks.Resume(channelId)
}
if task.Type == 0 {
return
}
tunnelManageTask(agent, task)
}
func channelKey(tunnelId string, channelId int) string {
return tunnelId + ":" + strconv.Itoa(channelId)
}
type ChannelEntry struct {
TunnelId string
Tunnel *Tunnel
Channel *TunnelChannel
}
type TunnelStats struct {
ActiveTunnels atomic.Int64
ActiveChannels atomic.Int64
TotalBytesSent atomic.Uint64
TotalBytesRecv atomic.Uint64
}
type TunnelChannelSafe struct {
TunnelChannel
mu sync.Mutex
closed atomic.Bool
ctx context.Context
cancel context.CancelFunc
}
func NewTunnelManager(ts *Teamserver) *TunnelManager {
tm := &TunnelManager{
ts: ts,
bufferPool: sync.Pool{
New: func() interface{} {
buf := make([]byte, TunnelBufferSize)
return &buf
},
},
}
return tm
}
func (tm *TunnelManager) GetBuffer() []byte {
return *tm.bufferPool.Get().(*[]byte)
}
func (tm *TunnelManager) PutBuffer(buf []byte) {
if cap(buf) >= TunnelBufferSize {
tm.bufferPool.Put(&buf)
}
}
func (tm *TunnelManager) GetTunnel(tunnelId string) (*Tunnel, bool) {
value, ok := tm.tunnels.Load(tunnelId)
if !ok {
return nil, false
}
return value.(*Tunnel), true
}
func (tm *TunnelManager) PutTunnel(tunnel *Tunnel) {
tm.tunnels.Store(tunnel.Data.TunnelId, tunnel)
tm.stats.ActiveTunnels.Add(1)
}
func (tm *TunnelManager) DeleteTunnel(tunnelId string) (*Tunnel, bool) {
value, ok := tm.tunnels.LoadAndDelete(tunnelId)
if !ok {
return nil, false
}
tm.stats.ActiveTunnels.Add(-1)
return value.(*Tunnel), true
}
func (tm *TunnelManager) TunnelExists(tunnelId string) bool {
_, ok := tm.tunnels.Load(tunnelId)
return ok
}
func (tm *TunnelManager) ForEachTunnel(fn func(tunnelId string, tunnel *Tunnel) bool) {
tm.tunnels.Range(func(key, value interface{}) bool {
return fn(key.(string), value.(*Tunnel))
})
}
func (tm *TunnelManager) RegisterChannel(tunnelId string, tunnel *Tunnel, channel *TunnelChannel) {
entry := &ChannelEntry{
TunnelId: tunnelId,
Tunnel: tunnel,
Channel: channel,
}
key := channelKey(tunnelId, channel.channelId)
tm.channelIndex.Store(key, entry)
tunnel.connections.Put(strconv.Itoa(channel.channelId), channel)
tm.stats.ActiveChannels.Add(1)
}
func (tm *TunnelManager) UnregisterChannel(tunnelId string, channelId int) {
key := channelKey(tunnelId, channelId)
if value, ok := tm.channelIndex.LoadAndDelete(key); ok {
entry := value.(*ChannelEntry)
entry.Tunnel.connections.Delete(strconv.Itoa(channelId))
tm.stats.ActiveChannels.Add(-1)
}
}
func (tm *TunnelManager) GetChannel(tunnelId string, channelId int) (*ChannelEntry, bool) {
key := channelKey(tunnelId, channelId)
value, ok := tm.channelIndex.Load(key)
if !ok {
return nil, false
}
return value.(*ChannelEntry), true
}
func (tm *TunnelManager) GetChannelByIdOnly(channelId int) (*ChannelEntry, bool) {
var result *ChannelEntry
found := false
tm.channelIndex.Range(func(key, value interface{}) bool {
entry := value.(*ChannelEntry)
if entry.Channel.channelId == channelId {
result = entry
found = true
return false
}
return true
})
return result, found
}
func (tm *TunnelManager) ChannelExists(tunnelId string, channelId int) bool {
key := channelKey(tunnelId, channelId)
_, ok := tm.channelIndex.Load(key)
return ok
}
func (tm *TunnelManager) ChannelExistsInTunnel(tunnel *Tunnel, channelId int) bool {
return tunnel.connections.Contains(strconv.Itoa(channelId))
}
func (tm *TunnelManager) CloseChannel(tunnelId string, channelId int) {
entry, ok := tm.GetChannel(tunnelId, channelId)
if !ok {
return
}
tm.closeChannelInternal(entry.Tunnel, entry.Channel)
}
func (tm *TunnelManager) CloseChannelByIdOnly(channelId int, writeOnly bool) {
entry, ok := tm.GetChannelByIdOnly(channelId)
if !ok {
return
}
if writeOnly {
if entry.Channel != nil && entry.Channel.pwTun != nil {
_ = entry.Channel.pwTun.Close()
}
} else {
tm.closeChannelInternal(entry.Tunnel, entry.Channel)
}
}
func (tm *TunnelManager) closeChannelInternal(tunnel *Tunnel, channel *TunnelChannel) {
if channel == nil {
return
}
if channel.conn != nil {
_ = channel.conn.Close()
}
if channel.wsconn != nil {
_ = channel.wsconn.Close()
}
if channel.pwTun != nil {
_ = channel.pwTun.Close()
}
if channel.prTun != nil {
_ = channel.prTun.Close()
}
if channel.pwSrv != nil {
_ = channel.pwSrv.Close()
}
if channel.prSrv != nil {
_ = channel.prSrv.Close()
}
if tunnel != nil {
tm.UnregisterChannel(tunnel.Data.TunnelId, channel.channelId)
}
}
func (tm *TunnelManager) CloseAllChannels(tunnel *Tunnel) {
var channelKeys []string
tm.channelIndex.Range(func(key, value interface{}) bool {
entry := value.(*ChannelEntry)
if entry.TunnelId == tunnel.Data.TunnelId {
channelKeys = append(channelKeys, key.(string))
}
return true
})
for _, k := range channelKeys {
if value, ok := tm.channelIndex.LoadAndDelete(k); ok {
entry := value.(*ChannelEntry)
tm.closeChannelInternal(nil, entry.Channel)
tm.stats.ActiveChannels.Add(-1)
}
}
tunnel.connections.CutMap()
}
func (tm *TunnelManager) WriteToChannel(tunnelId string, channelId int, data []byte) bool {
entry, ok := tm.GetChannel(tunnelId, channelId)
if !ok {
return false
}
if entry.Channel != nil && entry.Channel.ingressChan != nil {
curLen := len(entry.Channel.ingressChan)
capLen := cap(entry.Channel.ingressChan)
if capLen > 0 && curLen > (capLen*tunnelIngressHiWM)/100 {
if entry.Channel.flowPaused.CompareAndSwap(false, true) {
tm.SendTunnelFlowControl(channelId, true)
}
}
select {
case entry.Channel.ingressChan <- data:
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
return true
default:
return false
}
}
if entry.Channel != nil && entry.Channel.pwTun != nil {
_, err := entry.Channel.pwTun.Write(data)
if err == nil {
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
return true
}
}
return false
}
func (tm *TunnelManager) WriteToChannelByIdOnly(channelId int, data []byte) bool {
entry, ok := tm.GetChannelByIdOnly(channelId)
if !ok {
return false
}
if entry.Channel != nil && entry.Channel.ingressChan != nil {
curLen := len(entry.Channel.ingressChan)
capLen := cap(entry.Channel.ingressChan)
if capLen > 0 && curLen > (capLen*tunnelIngressHiWM)/100 {
if entry.Channel.flowPaused.CompareAndSwap(false, true) {
tm.SendTunnelFlowControl(channelId, true)
}
}
select {
case entry.Channel.ingressChan <- data:
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
return true
default:
return false
}
}
if entry.Channel != nil && entry.Channel.pwTun != nil {
_, err := entry.Channel.pwTun.Write(data)
if err == nil {
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
return true
}
}
return false
}
func (tm *TunnelManager) GetChannelPipes(tunnelId string, channelId int) (*io.PipeReader, *io.PipeWriter, error) {
entry, ok := tm.GetChannel(tunnelId, channelId)
if !ok {
return nil, nil, ErrChannelNotFound
}
return entry.Channel.prSrv, entry.Channel.pwTun, nil
}
func (tm *TunnelManager) GetChannelPipesByIdOnly(channelId int) (*io.PipeReader, *io.PipeWriter, error) {
entry, ok := tm.GetChannelByIdOnly(channelId)
if !ok {
return nil, nil, ErrChannelNotFound
}
return entry.Channel.prSrv, entry.Channel.pwTun, nil
}
func (tm *TunnelManager) GetStats() *TunnelStats {
return &tm.stats
}
func (tm *TunnelManager) PauseChannel(channelId int) {
entry, ok := tm.GetChannelByIdOnly(channelId)
if ok && entry.Channel != nil {
entry.Channel.paused.Store(true)
}
}
func (tm *TunnelManager) ResumeChannel(channelId int) {
entry, ok := tm.GetChannelByIdOnly(channelId)
if ok && entry.Channel != nil {
entry.Channel.paused.Store(false)
}
}
func (tm *TunnelManager) ListTunnels() []adaptix.TunnelData {
var tunnels []adaptix.TunnelData
tm.tunnels.Range(func(key, value interface{}) bool {
tunnel := value.(*Tunnel)
tunnels = append(tunnels, tunnel.Data)
return true
})
return tunnels
}
/// SAFE TUNNEL CHANNEL
type SafeTunnelChannel struct {
*TunnelChannel
tm *TunnelManager
mu sync.Mutex
closed atomic.Bool
closing atomic.Bool
ctx context.Context
cancel context.CancelFunc
}
func NewSafeTunnelChannel(tm *TunnelManager, channelId int, conn net.Conn, wsconn *websocket.Conn, protocol string) *SafeTunnelChannel {
ctx, cancel := context.WithCancel(context.Background())
stc := &SafeTunnelChannel{
TunnelChannel: &TunnelChannel{
channelId: channelId,
conn: conn,
wsconn: wsconn,
protocol: protocol,
},
tm: tm,
ctx: ctx,
cancel: cancel,
}
stc.prSrv, stc.pwSrv = io.Pipe()
stc.prTun, stc.pwTun = io.Pipe()
stc.ingressChan = make(chan []byte, tunnelIngressQueueDepth)
go stc.ingressPump()
return stc
}
func (stc *SafeTunnelChannel) ingressPump() {
defer func() {
if stc.pwTun != nil {
_ = stc.pwTun.Close()
}
}()
for data := range stc.ingressChan {
// RESUME
if stc.flowPaused.Load() {
capLen := cap(stc.ingressChan)
if capLen > 0 && len(stc.ingressChan) < (capLen*tunnelIngressLoWM)/100 {
if stc.flowPaused.CompareAndSwap(true, false) {
if stc.tm != nil {
stc.tm.SendTunnelFlowControl(stc.channelId, false)
}
}
}
}
if stc.conn != nil {
if _, err := stc.conn.Write(data); err != nil {
return
}
} else if stc.wsconn != nil {
if err := stc.wsconn.WriteMessage(websocket.BinaryMessage, data); err != nil {
return
}
} else if stc.pwTun != nil {
if _, err := stc.pwTun.Write(data); err != nil {
return
}
}
}
}
func (stc *SafeTunnelChannel) Close() bool {
if stc.closed.Swap(true) {
return false
}
stc.closing.Store(true)
if stc.ingressChan != nil {
close(stc.ingressChan)
stc.ingressChan = nil
}
stc.cancel()
stc.mu.Lock()
defer stc.mu.Unlock()
if stc.conn != nil {
_ = stc.conn.Close()
}
if stc.wsconn != nil {
_ = stc.wsconn.Close()
}
if stc.pwTun != nil {
_ = stc.pwTun.Close()
}
if stc.prTun != nil {
_ = stc.prTun.Close()
}
if stc.pwSrv != nil {
_ = stc.pwSrv.Close()
}
if stc.prSrv != nil {
_ = stc.prSrv.Close()
}
return true
}
func (stc *SafeTunnelChannel) IsClosed() bool {
return stc.closed.Load()
}
func (stc *SafeTunnelChannel) Context() context.Context {
return stc.ctx
}
/// UTILS
var ErrChannelNotFound = errorString("tunnel channel not found")
var ErrTunnelNotFound = errorString("tunnel not found")
var ErrAgentNotFound = errorString("agent not found")
var ErrTunnelAlreadyActive = errorString("tunnel already active")
type errorString string
func (e errorString) Error() string {
return string(e)
}