Files
conduit/ssh/context.go
2025-10-10 10:05:28 +02:00

205 lines
4.4 KiB
Go

package ssh
import (
"encoding/binary"
"encoding/hex"
"io"
"math/rand"
"net"
"time"
"golang.org/x/crypto/ssh"
"git.maze.io/maze/conduit/logger"
)
var seed = rand.NewSource(time.Now().UnixNano())
type Context interface {
ssh.ConnMetadata
// ID is the unique identifier.
ID() string
// Conn is the [ssh.Conn].
Conn() ssh.Conn
// NetConn is the underlying [net.Conn].
NetConn() net.Conn
// Close the client connection.
Close() error
}
type sshContext struct {
id uint64
server *Server
netConn net.Conn
conn *ssh.ServerConn
log logger.Structured
}
func newSSHContext(server *Server, netConn net.Conn, conn *ssh.ServerConn, log logger.Structured) *sshContext {
ctx := &sshContext{
id: uint64(seed.Int63()),
server: server,
netConn: netConn,
conn: conn,
}
ctx.log = log.Value("context", ctx.ID())
return ctx
}
// User returns the user ID for this connection.
func (ctx *sshContext) User() string {
return ctx.conn.User()
}
// SessionID returns the session hash, also denoted by H.
func (ctx *sshContext) SessionID() []byte {
return ctx.conn.SessionID()
}
// ClientVersion returns the client's version string as hashed
// into the session ID.
func (ctx *sshContext) ClientVersion() []byte {
return ctx.conn.ClientVersion()
}
// ServerVersion returns the server's version string as hashed
// into the session ID.
func (ctx *sshContext) ServerVersion() []byte {
return ctx.conn.ServerVersion()
}
// RemoteAddr returns the remote address for this connection.
func (ctx *sshContext) RemoteAddr() net.Addr {
return ctx.netConn.RemoteAddr()
}
// LocalAddr returns the local address for this connection.
func (ctx *sshContext) LocalAddr() net.Addr {
return ctx.netConn.LocalAddr()
}
func (ctx *sshContext) handleChannels(channels <-chan ssh.NewChannel) (err error) {
for newChan := range channels {
var (
kind = newChan.ChannelType()
log = ctx.log.Value("channel", kind)
)
log.Trace("Client requested new channel")
handler, ok := ctx.server.ChannelHandler[kind]
if !ok {
handler = ctx.server.ChannelHandler[ChannelTypeDefault]
}
if handler != nil {
var (
channel ssh.Channel
requests <-chan *ssh.Request
)
if channel, requests, err = newChan.Accept(); err != nil {
return
}
if err = handler.HandleChannel(ctx, channel, requests, newChan.ExtraData()); err != nil {
return
}
} else if kind == ChannelTypeDirectTCPIP && ctx.server.PortForwardHandler != nil {
if err = ctx.handleDirectTCPIP(newChan); err != nil {
return
}
} else {
ctx.log.Debug("Rejecting unsupported channel type")
if err = newChan.Reject(ssh.Prohibited, ""); err != nil {
return
}
}
}
// Our client hang up.
return io.EOF
}
func (ctx *sshContext) handleDirectTCPIP(newChan ssh.NewChannel) (err error) {
var payload struct {
Host string
Port uint32
OriginAddr string
OriginPort uint32
}
if err = ssh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
_ = newChan.Reject(ssh.Prohibited, "")
return
}
var ip net.IP
if ip = net.ParseIP(payload.Host); ip == nil {
// Not an IP
var ips []net.IP
if ips, err = net.LookupIP(payload.Host); err != nil {
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
return
} else if len(ips) == 0 {
_ = newChan.Reject(ssh.ConnectionFailed, "")
return
}
ip = ips[0]
}
var (
raddr = &net.TCPAddr{
IP: ip,
Port: int(payload.Port),
}
laddr = &net.TCPAddr{
IP: net.ParseIP(payload.OriginAddr),
Port: int(payload.OriginPort),
}
)
if payload.OriginAddr == "" && payload.OriginPort == 0 {
laddr = nil
}
var conn net.Conn
if conn, err = ctx.server.PortForwardHandler.HandlePortForwardRequest(ctx, raddr, laddr); err != nil {
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
return
}
defer func() { _ = conn.Close() }()
var (
channel ssh.Channel
requests <-chan *ssh.Request
)
if channel, requests, err = newChan.Accept(); err != nil {
return
}
defer func() { _ = channel.Close() }()
go ssh.DiscardRequests(requests)
go io.Copy(channel, conn)
_, err = io.Copy(conn, channel)
return
}
func (ctx *sshContext) Conn() ssh.Conn {
return ctx.conn
}
func (ctx *sshContext) NetConn() net.Conn {
return ctx.netConn
}
func (ctx *sshContext) Close() error {
return ctx.conn.Close()
}
func (ctx *sshContext) ID() string {
var b [8]byte
binary.BigEndian.PutUint64(b[:], ctx.id)
return hex.EncodeToString(b[:])
}