2026-04-06 00:20:51 -05:00

258 lines
5.3 KiB
Go

package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"
"github.com/gorilla/websocket"
)
var (
ErrTerminalNotFound = errors.New("terminal not found")
ErrInvalidTerminalType = errors.New("invalid terminal type")
)
type TerminalChannelData struct {
AgentId string `json:"agent_id"`
TerminalId string `json:"terminal_id"`
Program string `json:"program"`
SizeH int `json:"size_h"`
SizeW int `json:"size_w"`
OemCP int `json:"oem_cp"`
}
func (ts *Teamserver) TsAgentTerminalCreateChannel(terminalData string, wsconn *websocket.Conn) error {
var td TerminalChannelData
if err := json.Unmarshal([]byte(terminalData), &td); err != nil {
return errors.New("invalid terminal data")
}
termId, err := strconv.ParseInt(td.TerminalId, 16, 64)
if err != nil {
return errors.New("TerminalId not supported")
}
agent, err := ts.getAgent(td.AgentId)
if err != nil {
return err
}
if !agent.Active {
return fmt.Errorf("agent '%v' not active", td.AgentId)
}
terminal := &Terminal{
TerminalId: int(termId),
agent: agent,
wsconn: wsconn,
CodePage: td.OemCP,
}
terminal.prSrv, terminal.pwSrv = io.Pipe()
terminal.prTun, terminal.pwTun = io.Pipe()
terminal.Callbacks = agent.TerminalCallbacks()
taskData := terminal.Callbacks.Start(terminal.TerminalId, td.Program, td.SizeH, td.SizeW, td.OemCP)
tunnelManageTask(agent, taskData)
ts.terminals.Put(td.TerminalId, terminal)
return nil
}
func (ts *Teamserver) TsAgentTerminalCloseChannel(terminalId string, status string) error {
_ = ts.TsTerminalConnClose(terminalId, status)
return nil
}
///
func (ts *Teamserver) TsTerminalConnExists(terminalId string) bool {
return ts.terminals.Contains(terminalId)
}
func (ts *Teamserver) TsTerminalGetPipe(AgentId string, terminalId string) (*io.PipeReader, *io.PipeWriter, error) {
value, ok := ts.terminals.Get(terminalId)
if !ok {
return nil, nil, ErrTerminalNotFound
}
terminal, ok := value.(*Terminal)
if !ok {
return nil, nil, ErrInvalidTerminalType
}
return terminal.prSrv, terminal.pwTun, nil
}
func (ts *Teamserver) TsTerminalConnResume(agentId string, terminalId string, ioDirect bool) {
agent, err := ts.getAgent(agentId)
if err != nil {
return
}
value, ok := ts.terminals.Get(terminalId)
if !ok {
return
}
terminal, ok := value.(*Terminal)
if !ok {
return
}
relayWebsocketToTerminal(ts, agent, terminal, terminalId, ioDirect)
}
func (ts *Teamserver) TsTerminalConnData(terminalId string, data []byte) {
value, ok := ts.terminals.Get(terminalId)
if !ok {
return
}
terminal, ok := value.(*Terminal)
if !ok {
return
}
terminal.mu.Lock()
defer terminal.mu.Unlock()
if terminal.closed || terminal.pwTun == nil {
return
}
_, _ = terminal.pwTun.Write(data)
}
func (ts *Teamserver) TsTerminalConnClose(terminalId string, status string) error {
value, ok := ts.terminals.GetDelete(terminalId)
if !ok {
return ErrTerminalNotFound
}
terminal, ok := value.(*Terminal)
if !ok {
return ErrInvalidTerminalType
}
closeTerminalResources(terminal)
return nil
}
func closeTerminalResources(terminal *Terminal) {
if terminal == nil {
return
}
terminal.mu.Lock()
if terminal.closed {
terminal.mu.Unlock()
return
}
terminal.closed = true
terminal.mu.Unlock()
if terminal.wsconn != nil {
_ = terminal.wsconn.Close()
}
if terminal.pwTun != nil {
_ = terminal.pwTun.Close()
}
if terminal.prTun != nil {
_ = terminal.prTun.Close()
}
if terminal.pwSrv != nil {
_ = terminal.pwSrv.Close()
}
if terminal.prSrv != nil {
_ = terminal.prSrv.Close()
}
}
///
func relayWebsocketToTerminal(ts *Teamserver, agent *Agent, terminal *Terminal, terminalId string, direct bool) {
ctx, cancel := context.WithCancel(context.Background())
var once sync.Once
var wsWriteMu sync.Mutex
finish := func() {
once.Do(func() {
cancel()
ts.terminals.Delete(terminalId)
closeTerminalResources(terminal)
taskData := terminal.Callbacks.Close(terminal.TerminalId)
tunnelManageTask(agent, taskData)
})
}
go func() {
defer finish()
if terminal.wsconn == nil || terminal.pwSrv == nil {
return
}
for {
terminal.mu.Lock()
closed := terminal.closed
terminal.mu.Unlock()
if closed {
break
}
_, msg, err := terminal.wsconn.ReadMessage()
if err != nil {
break
}
if _, err := terminal.pwSrv.Write(msg); err != nil {
break
}
}
_ = terminal.pwSrv.Close()
}()
go func() {
defer finish()
if terminal.wsconn == nil || terminal.prTun == nil {
return
}
buf := ts.TunnelManager.GetBuffer()
defer ts.TunnelManager.PutBuffer(buf)
for {
n, err := terminal.prTun.Read(buf)
if n > 0 {
wsWriteMu.Lock()
writeErr := terminal.wsconn.WriteMessage(websocket.BinaryMessage, buf[:n])
wsWriteMu.Unlock()
if writeErr != nil {
break
}
}
if err != nil {
break
}
}
}()
if !direct {
go func() {
buf := ts.TunnelManager.GetBuffer()
defer ts.TunnelManager.PutBuffer(buf)
for {
select {
case <-ctx.Done():
return
default:
n, err := terminal.prSrv.Read(buf)
if n > 0 {
taskData := terminal.Callbacks.Write(terminal.TerminalId, terminal.CodePage, buf[:n])
relayPipeToTaskData(agent, terminal.TerminalId, taskData)
}
if err != nil {
return
}
}
}
}()
}
}