205 lines
4.4 KiB
Go
205 lines
4.4 KiB
Go
package ssh
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"git.maze.io/maze/conduit/logger"
|
|
)
|
|
|
|
var seed = rand.NewSource(time.Now().UnixNano())
|
|
|
|
type Context interface {
|
|
ssh.ConnMetadata
|
|
|
|
// ID is the unique identifier.
|
|
ID() string
|
|
|
|
// Conn is the [ssh.Conn].
|
|
Conn() ssh.Conn
|
|
|
|
// NetConn is the underlying [net.Conn].
|
|
NetConn() net.Conn
|
|
|
|
// Close the client connection.
|
|
Close() error
|
|
}
|
|
|
|
type sshContext struct {
|
|
id uint64
|
|
server *Server
|
|
netConn net.Conn
|
|
conn *ssh.ServerConn
|
|
log logger.Structured
|
|
}
|
|
|
|
func newSSHContext(server *Server, netConn net.Conn, conn *ssh.ServerConn, log logger.Structured) *sshContext {
|
|
ctx := &sshContext{
|
|
id: uint64(seed.Int63()),
|
|
server: server,
|
|
netConn: netConn,
|
|
conn: conn,
|
|
}
|
|
ctx.log = log.Value("context", ctx.ID())
|
|
return ctx
|
|
}
|
|
|
|
// User returns the user ID for this connection.
|
|
func (ctx *sshContext) User() string {
|
|
return ctx.conn.User()
|
|
}
|
|
|
|
// SessionID returns the session hash, also denoted by H.
|
|
func (ctx *sshContext) SessionID() []byte {
|
|
return ctx.conn.SessionID()
|
|
}
|
|
|
|
// ClientVersion returns the client's version string as hashed
|
|
// into the session ID.
|
|
func (ctx *sshContext) ClientVersion() []byte {
|
|
return ctx.conn.ClientVersion()
|
|
}
|
|
|
|
// ServerVersion returns the server's version string as hashed
|
|
// into the session ID.
|
|
func (ctx *sshContext) ServerVersion() []byte {
|
|
return ctx.conn.ServerVersion()
|
|
}
|
|
|
|
// RemoteAddr returns the remote address for this connection.
|
|
func (ctx *sshContext) RemoteAddr() net.Addr {
|
|
return ctx.netConn.RemoteAddr()
|
|
}
|
|
|
|
// LocalAddr returns the local address for this connection.
|
|
func (ctx *sshContext) LocalAddr() net.Addr {
|
|
return ctx.netConn.LocalAddr()
|
|
}
|
|
|
|
func (ctx *sshContext) handleChannels(channels <-chan ssh.NewChannel) (err error) {
|
|
for newChan := range channels {
|
|
var (
|
|
kind = newChan.ChannelType()
|
|
log = ctx.log.Value("channel", kind)
|
|
)
|
|
log.Trace("Client requested new channel")
|
|
|
|
handler, ok := ctx.server.ChannelHandler[kind]
|
|
if !ok {
|
|
handler = ctx.server.ChannelHandler[ChannelTypeDefault]
|
|
}
|
|
|
|
if handler != nil {
|
|
var (
|
|
channel ssh.Channel
|
|
requests <-chan *ssh.Request
|
|
)
|
|
if channel, requests, err = newChan.Accept(); err != nil {
|
|
return
|
|
}
|
|
if err = handler.HandleChannel(ctx, channel, requests, newChan.ExtraData()); err != nil {
|
|
return
|
|
}
|
|
} else if kind == ChannelTypeDirectTCPIP && ctx.server.PortForwardHandler != nil {
|
|
if err = ctx.handleDirectTCPIP(newChan); err != nil {
|
|
return
|
|
}
|
|
} else {
|
|
ctx.log.Debug("Rejecting unsupported channel type")
|
|
if err = newChan.Reject(ssh.Prohibited, ""); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Our client hang up.
|
|
return io.EOF
|
|
}
|
|
|
|
func (ctx *sshContext) handleDirectTCPIP(newChan ssh.NewChannel) (err error) {
|
|
var payload struct {
|
|
Host string
|
|
Port uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
if err = ssh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
|
_ = newChan.Reject(ssh.Prohibited, "")
|
|
return
|
|
}
|
|
|
|
var ip net.IP
|
|
if ip = net.ParseIP(payload.Host); ip == nil {
|
|
// Not an IP
|
|
var ips []net.IP
|
|
if ips, err = net.LookupIP(payload.Host); err != nil {
|
|
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
|
|
return
|
|
} else if len(ips) == 0 {
|
|
_ = newChan.Reject(ssh.ConnectionFailed, "")
|
|
return
|
|
}
|
|
ip = ips[0]
|
|
}
|
|
|
|
var (
|
|
raddr = &net.TCPAddr{
|
|
IP: ip,
|
|
Port: int(payload.Port),
|
|
}
|
|
laddr = &net.TCPAddr{
|
|
IP: net.ParseIP(payload.OriginAddr),
|
|
Port: int(payload.OriginPort),
|
|
}
|
|
)
|
|
if payload.OriginAddr == "" && payload.OriginPort == 0 {
|
|
laddr = nil
|
|
}
|
|
|
|
var conn net.Conn
|
|
if conn, err = ctx.server.PortForwardHandler.HandlePortForwardRequest(ctx, raddr, laddr); err != nil {
|
|
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
|
|
return
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
var (
|
|
channel ssh.Channel
|
|
requests <-chan *ssh.Request
|
|
)
|
|
if channel, requests, err = newChan.Accept(); err != nil {
|
|
return
|
|
}
|
|
defer func() { _ = channel.Close() }()
|
|
go ssh.DiscardRequests(requests)
|
|
|
|
go io.Copy(channel, conn)
|
|
_, err = io.Copy(conn, channel)
|
|
|
|
return
|
|
}
|
|
|
|
func (ctx *sshContext) Conn() ssh.Conn {
|
|
return ctx.conn
|
|
}
|
|
|
|
func (ctx *sshContext) NetConn() net.Conn {
|
|
return ctx.netConn
|
|
}
|
|
|
|
func (ctx *sshContext) Close() error {
|
|
return ctx.conn.Close()
|
|
}
|
|
|
|
func (ctx *sshContext) ID() string {
|
|
var b [8]byte
|
|
binary.BigEndian.PutUint64(b[:], ctx.id)
|
|
return hex.EncodeToString(b[:])
|
|
}
|