Initial import
This commit is contained in:
204
ssh/context.go
Normal file
204
ssh/context.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
)
|
||||
|
||||
var seed = rand.NewSource(time.Now().UnixNano())
|
||||
|
||||
type Context interface {
|
||||
ssh.ConnMetadata
|
||||
|
||||
// ID is the unique identifier.
|
||||
ID() string
|
||||
|
||||
// Conn is the [ssh.Conn].
|
||||
Conn() ssh.Conn
|
||||
|
||||
// NetConn is the underlying [net.Conn].
|
||||
NetConn() net.Conn
|
||||
|
||||
// Close the client connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type sshContext struct {
|
||||
id uint64
|
||||
server *Server
|
||||
netConn net.Conn
|
||||
conn *ssh.ServerConn
|
||||
log logger.Structured
|
||||
}
|
||||
|
||||
func newSSHContext(server *Server, netConn net.Conn, conn *ssh.ServerConn, log logger.Structured) *sshContext {
|
||||
ctx := &sshContext{
|
||||
id: uint64(seed.Int63()),
|
||||
server: server,
|
||||
netConn: netConn,
|
||||
conn: conn,
|
||||
}
|
||||
ctx.log = log.Value("context", ctx.ID())
|
||||
return ctx
|
||||
}
|
||||
|
||||
// User returns the user ID for this connection.
|
||||
func (ctx *sshContext) User() string {
|
||||
return ctx.conn.User()
|
||||
}
|
||||
|
||||
// SessionID returns the session hash, also denoted by H.
|
||||
func (ctx *sshContext) SessionID() []byte {
|
||||
return ctx.conn.SessionID()
|
||||
}
|
||||
|
||||
// ClientVersion returns the client's version string as hashed
|
||||
// into the session ID.
|
||||
func (ctx *sshContext) ClientVersion() []byte {
|
||||
return ctx.conn.ClientVersion()
|
||||
}
|
||||
|
||||
// ServerVersion returns the server's version string as hashed
|
||||
// into the session ID.
|
||||
func (ctx *sshContext) ServerVersion() []byte {
|
||||
return ctx.conn.ServerVersion()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
func (ctx *sshContext) RemoteAddr() net.Addr {
|
||||
return ctx.netConn.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
func (ctx *sshContext) LocalAddr() net.Addr {
|
||||
return ctx.netConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (ctx *sshContext) handleChannels(channels <-chan ssh.NewChannel) (err error) {
|
||||
for newChan := range channels {
|
||||
var (
|
||||
kind = newChan.ChannelType()
|
||||
log = ctx.log.Value("channel", kind)
|
||||
)
|
||||
log.Trace("Client requested new channel")
|
||||
|
||||
handler, ok := ctx.server.ChannelHandler[kind]
|
||||
if !ok {
|
||||
handler = ctx.server.ChannelHandler[ChannelTypeDefault]
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
var (
|
||||
channel ssh.Channel
|
||||
requests <-chan *ssh.Request
|
||||
)
|
||||
if channel, requests, err = newChan.Accept(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = handler.HandleChannel(ctx, channel, requests, newChan.ExtraData()); err != nil {
|
||||
return
|
||||
}
|
||||
} else if kind == ChannelTypeDirectTCPIP && ctx.server.PortForwardHandler != nil {
|
||||
if err = ctx.handleDirectTCPIP(newChan); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ctx.log.Debug("Rejecting unsupported channel type")
|
||||
if err = newChan.Reject(ssh.Prohibited, ""); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Our client hang up.
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (ctx *sshContext) handleDirectTCPIP(newChan ssh.NewChannel) (err error) {
|
||||
var payload struct {
|
||||
Host string
|
||||
Port uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
if err = ssh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||
_ = newChan.Reject(ssh.Prohibited, "")
|
||||
return
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
if ip = net.ParseIP(payload.Host); ip == nil {
|
||||
// Not an IP
|
||||
var ips []net.IP
|
||||
if ips, err = net.LookupIP(payload.Host); err != nil {
|
||||
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
} else if len(ips) == 0 {
|
||||
_ = newChan.Reject(ssh.ConnectionFailed, "")
|
||||
return
|
||||
}
|
||||
ip = ips[0]
|
||||
}
|
||||
|
||||
var (
|
||||
raddr = &net.TCPAddr{
|
||||
IP: ip,
|
||||
Port: int(payload.Port),
|
||||
}
|
||||
laddr = &net.TCPAddr{
|
||||
IP: net.ParseIP(payload.OriginAddr),
|
||||
Port: int(payload.OriginPort),
|
||||
}
|
||||
)
|
||||
if payload.OriginAddr == "" && payload.OriginPort == 0 {
|
||||
laddr = nil
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
if conn, err = ctx.server.PortForwardHandler.HandlePortForwardRequest(ctx, raddr, laddr); err != nil {
|
||||
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
var (
|
||||
channel ssh.Channel
|
||||
requests <-chan *ssh.Request
|
||||
)
|
||||
if channel, requests, err = newChan.Accept(); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = channel.Close() }()
|
||||
go ssh.DiscardRequests(requests)
|
||||
|
||||
go io.Copy(channel, conn)
|
||||
_, err = io.Copy(conn, channel)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Conn() ssh.Conn {
|
||||
return ctx.conn
|
||||
}
|
||||
|
||||
func (ctx *sshContext) NetConn() net.Conn {
|
||||
return ctx.netConn
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Close() error {
|
||||
return ctx.conn.Close()
|
||||
}
|
||||
|
||||
func (ctx *sshContext) ID() string {
|
||||
var b [8]byte
|
||||
binary.BigEndian.PutUint64(b[:], ctx.id)
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
Reference in New Issue
Block a user