508 lines
11 KiB
Go
508 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/Adaptix-Framework/axc2"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
TunnelBufferSize = 0x8000
|
|
TunnelBufferPoolCap = 256
|
|
|
|
tunnelIngressQueueDepth = 1024
|
|
tunnelIngressHiWM = 80
|
|
tunnelIngressLoWM = 20
|
|
)
|
|
|
|
type TunnelManager struct {
|
|
ts *Teamserver
|
|
|
|
tunnels sync.Map // tunnelId string -> *Tunnel
|
|
channelIndex sync.Map // "tunnelId:channelId" string -> *ChannelEntry
|
|
|
|
bufferPool sync.Pool
|
|
|
|
stats TunnelStats
|
|
}
|
|
|
|
func (tm *TunnelManager) SendTunnelFlowControl(channelId int, pause bool) {
|
|
entry, ok := tm.GetChannelByIdOnly(channelId)
|
|
if !ok || entry.Tunnel == nil {
|
|
return
|
|
}
|
|
if tm.ts == nil {
|
|
return
|
|
}
|
|
agent, err := tm.ts.getAgent(entry.Tunnel.Data.AgentId)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var task adaptix.TaskData
|
|
if pause {
|
|
task = entry.Tunnel.Callbacks.Pause(channelId)
|
|
} else {
|
|
task = entry.Tunnel.Callbacks.Resume(channelId)
|
|
}
|
|
|
|
if task.Type == 0 {
|
|
return
|
|
}
|
|
tunnelManageTask(agent, task)
|
|
}
|
|
|
|
func channelKey(tunnelId string, channelId int) string {
|
|
return tunnelId + ":" + strconv.Itoa(channelId)
|
|
}
|
|
|
|
type ChannelEntry struct {
|
|
TunnelId string
|
|
Tunnel *Tunnel
|
|
Channel *TunnelChannel
|
|
}
|
|
|
|
type TunnelStats struct {
|
|
ActiveTunnels atomic.Int64
|
|
ActiveChannels atomic.Int64
|
|
TotalBytesSent atomic.Uint64
|
|
TotalBytesRecv atomic.Uint64
|
|
}
|
|
|
|
type TunnelChannelSafe struct {
|
|
TunnelChannel
|
|
mu sync.Mutex
|
|
closed atomic.Bool
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func NewTunnelManager(ts *Teamserver) *TunnelManager {
|
|
tm := &TunnelManager{
|
|
ts: ts,
|
|
bufferPool: sync.Pool{
|
|
New: func() interface{} {
|
|
buf := make([]byte, TunnelBufferSize)
|
|
return &buf
|
|
},
|
|
},
|
|
}
|
|
return tm
|
|
}
|
|
|
|
func (tm *TunnelManager) GetBuffer() []byte {
|
|
return *tm.bufferPool.Get().(*[]byte)
|
|
}
|
|
|
|
func (tm *TunnelManager) PutBuffer(buf []byte) {
|
|
if cap(buf) >= TunnelBufferSize {
|
|
tm.bufferPool.Put(&buf)
|
|
}
|
|
}
|
|
|
|
func (tm *TunnelManager) GetTunnel(tunnelId string) (*Tunnel, bool) {
|
|
value, ok := tm.tunnels.Load(tunnelId)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return value.(*Tunnel), true
|
|
}
|
|
|
|
func (tm *TunnelManager) PutTunnel(tunnel *Tunnel) {
|
|
tm.tunnels.Store(tunnel.Data.TunnelId, tunnel)
|
|
tm.stats.ActiveTunnels.Add(1)
|
|
}
|
|
|
|
func (tm *TunnelManager) DeleteTunnel(tunnelId string) (*Tunnel, bool) {
|
|
value, ok := tm.tunnels.LoadAndDelete(tunnelId)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
tm.stats.ActiveTunnels.Add(-1)
|
|
return value.(*Tunnel), true
|
|
}
|
|
|
|
func (tm *TunnelManager) TunnelExists(tunnelId string) bool {
|
|
_, ok := tm.tunnels.Load(tunnelId)
|
|
return ok
|
|
}
|
|
|
|
func (tm *TunnelManager) ForEachTunnel(fn func(tunnelId string, tunnel *Tunnel) bool) {
|
|
tm.tunnels.Range(func(key, value interface{}) bool {
|
|
return fn(key.(string), value.(*Tunnel))
|
|
})
|
|
}
|
|
|
|
func (tm *TunnelManager) RegisterChannel(tunnelId string, tunnel *Tunnel, channel *TunnelChannel) {
|
|
entry := &ChannelEntry{
|
|
TunnelId: tunnelId,
|
|
Tunnel: tunnel,
|
|
Channel: channel,
|
|
}
|
|
key := channelKey(tunnelId, channel.channelId)
|
|
tm.channelIndex.Store(key, entry)
|
|
tunnel.connections.Put(strconv.Itoa(channel.channelId), channel)
|
|
tm.stats.ActiveChannels.Add(1)
|
|
}
|
|
|
|
func (tm *TunnelManager) UnregisterChannel(tunnelId string, channelId int) {
|
|
key := channelKey(tunnelId, channelId)
|
|
if value, ok := tm.channelIndex.LoadAndDelete(key); ok {
|
|
entry := value.(*ChannelEntry)
|
|
entry.Tunnel.connections.Delete(strconv.Itoa(channelId))
|
|
tm.stats.ActiveChannels.Add(-1)
|
|
}
|
|
}
|
|
|
|
func (tm *TunnelManager) GetChannel(tunnelId string, channelId int) (*ChannelEntry, bool) {
|
|
key := channelKey(tunnelId, channelId)
|
|
value, ok := tm.channelIndex.Load(key)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return value.(*ChannelEntry), true
|
|
}
|
|
|
|
func (tm *TunnelManager) GetChannelByIdOnly(channelId int) (*ChannelEntry, bool) {
|
|
var result *ChannelEntry
|
|
found := false
|
|
tm.channelIndex.Range(func(key, value interface{}) bool {
|
|
entry := value.(*ChannelEntry)
|
|
if entry.Channel.channelId == channelId {
|
|
result = entry
|
|
found = true
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return result, found
|
|
}
|
|
|
|
func (tm *TunnelManager) ChannelExists(tunnelId string, channelId int) bool {
|
|
key := channelKey(tunnelId, channelId)
|
|
_, ok := tm.channelIndex.Load(key)
|
|
return ok
|
|
}
|
|
|
|
func (tm *TunnelManager) ChannelExistsInTunnel(tunnel *Tunnel, channelId int) bool {
|
|
return tunnel.connections.Contains(strconv.Itoa(channelId))
|
|
}
|
|
|
|
func (tm *TunnelManager) CloseChannel(tunnelId string, channelId int) {
|
|
entry, ok := tm.GetChannel(tunnelId, channelId)
|
|
if !ok {
|
|
return
|
|
}
|
|
tm.closeChannelInternal(entry.Tunnel, entry.Channel)
|
|
}
|
|
|
|
func (tm *TunnelManager) CloseChannelByIdOnly(channelId int, writeOnly bool) {
|
|
entry, ok := tm.GetChannelByIdOnly(channelId)
|
|
if !ok {
|
|
return
|
|
}
|
|
if writeOnly {
|
|
if entry.Channel != nil && entry.Channel.pwTun != nil {
|
|
_ = entry.Channel.pwTun.Close()
|
|
}
|
|
} else {
|
|
tm.closeChannelInternal(entry.Tunnel, entry.Channel)
|
|
}
|
|
}
|
|
|
|
func (tm *TunnelManager) closeChannelInternal(tunnel *Tunnel, channel *TunnelChannel) {
|
|
if channel == nil {
|
|
return
|
|
}
|
|
|
|
if channel.conn != nil {
|
|
_ = channel.conn.Close()
|
|
}
|
|
if channel.wsconn != nil {
|
|
_ = channel.wsconn.Close()
|
|
}
|
|
|
|
if channel.pwTun != nil {
|
|
_ = channel.pwTun.Close()
|
|
}
|
|
if channel.prTun != nil {
|
|
_ = channel.prTun.Close()
|
|
}
|
|
if channel.pwSrv != nil {
|
|
_ = channel.pwSrv.Close()
|
|
}
|
|
if channel.prSrv != nil {
|
|
_ = channel.prSrv.Close()
|
|
}
|
|
|
|
if tunnel != nil {
|
|
tm.UnregisterChannel(tunnel.Data.TunnelId, channel.channelId)
|
|
}
|
|
}
|
|
|
|
func (tm *TunnelManager) CloseAllChannels(tunnel *Tunnel) {
|
|
var channelKeys []string
|
|
tm.channelIndex.Range(func(key, value interface{}) bool {
|
|
entry := value.(*ChannelEntry)
|
|
if entry.TunnelId == tunnel.Data.TunnelId {
|
|
channelKeys = append(channelKeys, key.(string))
|
|
}
|
|
return true
|
|
})
|
|
|
|
for _, k := range channelKeys {
|
|
if value, ok := tm.channelIndex.LoadAndDelete(k); ok {
|
|
entry := value.(*ChannelEntry)
|
|
tm.closeChannelInternal(nil, entry.Channel)
|
|
tm.stats.ActiveChannels.Add(-1)
|
|
}
|
|
}
|
|
|
|
tunnel.connections.CutMap()
|
|
}
|
|
|
|
func (tm *TunnelManager) WriteToChannel(tunnelId string, channelId int, data []byte) bool {
|
|
entry, ok := tm.GetChannel(tunnelId, channelId)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
if entry.Channel != nil && entry.Channel.ingressChan != nil {
|
|
curLen := len(entry.Channel.ingressChan)
|
|
capLen := cap(entry.Channel.ingressChan)
|
|
if capLen > 0 && curLen > (capLen*tunnelIngressHiWM)/100 {
|
|
if entry.Channel.flowPaused.CompareAndSwap(false, true) {
|
|
tm.SendTunnelFlowControl(channelId, true)
|
|
}
|
|
}
|
|
|
|
select {
|
|
case entry.Channel.ingressChan <- data:
|
|
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
if entry.Channel != nil && entry.Channel.pwTun != nil {
|
|
_, err := entry.Channel.pwTun.Write(data)
|
|
if err == nil {
|
|
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (tm *TunnelManager) WriteToChannelByIdOnly(channelId int, data []byte) bool {
|
|
entry, ok := tm.GetChannelByIdOnly(channelId)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
if entry.Channel != nil && entry.Channel.ingressChan != nil {
|
|
curLen := len(entry.Channel.ingressChan)
|
|
capLen := cap(entry.Channel.ingressChan)
|
|
if capLen > 0 && curLen > (capLen*tunnelIngressHiWM)/100 {
|
|
if entry.Channel.flowPaused.CompareAndSwap(false, true) {
|
|
tm.SendTunnelFlowControl(channelId, true)
|
|
}
|
|
}
|
|
|
|
select {
|
|
case entry.Channel.ingressChan <- data:
|
|
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
if entry.Channel != nil && entry.Channel.pwTun != nil {
|
|
_, err := entry.Channel.pwTun.Write(data)
|
|
if err == nil {
|
|
tm.stats.TotalBytesRecv.Add(uint64(len(data)))
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (tm *TunnelManager) GetChannelPipes(tunnelId string, channelId int) (*io.PipeReader, *io.PipeWriter, error) {
|
|
entry, ok := tm.GetChannel(tunnelId, channelId)
|
|
if !ok {
|
|
return nil, nil, ErrChannelNotFound
|
|
}
|
|
return entry.Channel.prSrv, entry.Channel.pwTun, nil
|
|
}
|
|
|
|
func (tm *TunnelManager) GetChannelPipesByIdOnly(channelId int) (*io.PipeReader, *io.PipeWriter, error) {
|
|
entry, ok := tm.GetChannelByIdOnly(channelId)
|
|
if !ok {
|
|
return nil, nil, ErrChannelNotFound
|
|
}
|
|
return entry.Channel.prSrv, entry.Channel.pwTun, nil
|
|
}
|
|
|
|
func (tm *TunnelManager) GetStats() *TunnelStats {
|
|
return &tm.stats
|
|
}
|
|
|
|
func (tm *TunnelManager) PauseChannel(channelId int) {
|
|
entry, ok := tm.GetChannelByIdOnly(channelId)
|
|
if ok && entry.Channel != nil {
|
|
entry.Channel.paused.Store(true)
|
|
}
|
|
}
|
|
|
|
func (tm *TunnelManager) ResumeChannel(channelId int) {
|
|
entry, ok := tm.GetChannelByIdOnly(channelId)
|
|
if ok && entry.Channel != nil {
|
|
entry.Channel.paused.Store(false)
|
|
}
|
|
}
|
|
|
|
func (tm *TunnelManager) ListTunnels() []adaptix.TunnelData {
|
|
var tunnels []adaptix.TunnelData
|
|
tm.tunnels.Range(func(key, value interface{}) bool {
|
|
tunnel := value.(*Tunnel)
|
|
tunnels = append(tunnels, tunnel.Data)
|
|
return true
|
|
})
|
|
return tunnels
|
|
}
|
|
|
|
/// SAFE TUNNEL CHANNEL
|
|
|
|
type SafeTunnelChannel struct {
|
|
*TunnelChannel
|
|
tm *TunnelManager
|
|
mu sync.Mutex
|
|
closed atomic.Bool
|
|
closing atomic.Bool
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func NewSafeTunnelChannel(tm *TunnelManager, channelId int, conn net.Conn, wsconn *websocket.Conn, protocol string) *SafeTunnelChannel {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
stc := &SafeTunnelChannel{
|
|
TunnelChannel: &TunnelChannel{
|
|
channelId: channelId,
|
|
conn: conn,
|
|
wsconn: wsconn,
|
|
protocol: protocol,
|
|
},
|
|
tm: tm,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
stc.prSrv, stc.pwSrv = io.Pipe()
|
|
stc.prTun, stc.pwTun = io.Pipe()
|
|
stc.ingressChan = make(chan []byte, tunnelIngressQueueDepth)
|
|
go stc.ingressPump()
|
|
return stc
|
|
}
|
|
|
|
func (stc *SafeTunnelChannel) ingressPump() {
|
|
defer func() {
|
|
if stc.pwTun != nil {
|
|
_ = stc.pwTun.Close()
|
|
}
|
|
}()
|
|
|
|
for data := range stc.ingressChan {
|
|
// RESUME
|
|
if stc.flowPaused.Load() {
|
|
capLen := cap(stc.ingressChan)
|
|
if capLen > 0 && len(stc.ingressChan) < (capLen*tunnelIngressLoWM)/100 {
|
|
if stc.flowPaused.CompareAndSwap(true, false) {
|
|
if stc.tm != nil {
|
|
stc.tm.SendTunnelFlowControl(stc.channelId, false)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if stc.conn != nil {
|
|
if _, err := stc.conn.Write(data); err != nil {
|
|
return
|
|
}
|
|
} else if stc.wsconn != nil {
|
|
if err := stc.wsconn.WriteMessage(websocket.BinaryMessage, data); err != nil {
|
|
return
|
|
}
|
|
} else if stc.pwTun != nil {
|
|
if _, err := stc.pwTun.Write(data); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (stc *SafeTunnelChannel) Close() bool {
|
|
if stc.closed.Swap(true) {
|
|
return false
|
|
}
|
|
|
|
stc.closing.Store(true)
|
|
if stc.ingressChan != nil {
|
|
close(stc.ingressChan)
|
|
stc.ingressChan = nil
|
|
}
|
|
|
|
stc.cancel()
|
|
|
|
stc.mu.Lock()
|
|
defer stc.mu.Unlock()
|
|
|
|
if stc.conn != nil {
|
|
_ = stc.conn.Close()
|
|
}
|
|
if stc.wsconn != nil {
|
|
_ = stc.wsconn.Close()
|
|
}
|
|
if stc.pwTun != nil {
|
|
_ = stc.pwTun.Close()
|
|
}
|
|
if stc.prTun != nil {
|
|
_ = stc.prTun.Close()
|
|
}
|
|
if stc.pwSrv != nil {
|
|
_ = stc.pwSrv.Close()
|
|
}
|
|
if stc.prSrv != nil {
|
|
_ = stc.prSrv.Close()
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (stc *SafeTunnelChannel) IsClosed() bool {
|
|
return stc.closed.Load()
|
|
}
|
|
|
|
func (stc *SafeTunnelChannel) Context() context.Context {
|
|
return stc.ctx
|
|
}
|
|
|
|
/// UTILS
|
|
|
|
var ErrChannelNotFound = errorString("tunnel channel not found")
|
|
var ErrTunnelNotFound = errorString("tunnel not found")
|
|
var ErrAgentNotFound = errorString("agent not found")
|
|
var ErrTunnelAlreadyActive = errorString("tunnel already active")
|
|
|
|
type errorString string
|
|
|
|
func (e errorString) Error() string {
|
|
return string(e)
|
|
}
|