Transparent SSH jump host with auditing.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

474 lines
12 KiB

package stronghold
import (
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
"git.maze.io/maze/stronghold/authority"
)
const (
requestTypeAgent = "auth-agent-req@openssh.com"
requestTypeEnv = "env"
requestTypeKeepAlive = "keepalive@openssh.com"
requestTypePTY = "pty-req"
requestTypeSession = "session"
requestTypeShell = "shell"
requestTypeTCPIPForward = "tcpip-forward"
requestTypeWindowChange = "window-change"
requestTypeX11 = "x11-req"
channelTypeAgent = "auth-agent@openssh.com"
channelTypeDirectTCPIP = "direct-tcpip"
channelTypeForwardedTCPIP = "forwarded-tcpip"
channelTypeSession = "session"
keepAliveInterval = 30 * time.Second
connectTimeout = 30 * time.Second
)
// clientConn connected to the server.
type clientConn struct {
// CertificateAuthority is a callback for proxied SSH connections.
CertificateAuthority *authority.CertificateAuthority
// SessionHandler is a callback for SSH shell sessions.
SessionHandler func(conn ssh.ConnMetadata, channel ssh.Channel, pty *PTY, windowChanges <-chan *WindowChange) error
netConn net.Conn
sshConn ssh.Conn
id string
closed chan struct{}
recordDir string
requestMutex sync.RWMutex
ptyRequest *PTY
sessionRequest *sessionRequest
envRequest map[string]string
agent agent.Agent
}
func newClient(netConn net.Conn, sshConn ssh.Conn, recordDir string) *clientConn {
return &clientConn{
netConn: netConn,
sshConn: sshConn,
id: sessionID(sshConn),
closed: make(chan struct{}),
recordDir: recordDir,
envRequest: make(map[string]string),
}
}
func (client *clientConn) log() *logrus.Entry {
return logrus.WithFields(logrus.Fields{
"tag": "client",
"session": client.id,
"src": client.sshConn.RemoteAddr().String(),
"user": client.sshConn.User(),
})
}
func (client *clientConn) handleRequests(requests <-chan *ssh.Request) {
log := client.log()
for {
select {
case <-client.closed:
return
case request := <-requests:
if request == nil {
log.Debug("client request channel closed")
return
}
var (
log = log.WithField("request", request.Type)
ok bool
err error
response []byte
)
log.Debug("new OOB request")
switch request.Type {
case requestTypeAgent:
ok, err = client.handleAgentRequest(request)
case requestTypeEnv:
ok, err = client.handleEnvRequest(request)
case requestTypeKeepAlive:
ok = true
case requestTypeTCPIPForward:
ok, err = client.handleTCPIPForwardRequest(request)
case requestTypePTY:
ok, err = client.handlePTYRequest(request)
case requestTypeSession:
ok, err = client.handleSessionRequest(request)
case requestTypeX11:
// Request type is ignored (not supported)
default:
log.Debug("unknown/unhandled request type")
}
if err != nil {
log.WithError(err).Warn("failed to handle request")
ok = false
}
if request.WantReply {
_ = request.Reply(ok, response)
}
case <-time.After(keepAliveInterval):
log.Debug("sending keepalive request")
if _, _, err := client.sshConn.SendRequest(requestTypeKeepAlive, false, nil); err != nil {
return
}
}
}
}
func (client *clientConn) handleAgentRequest(request *ssh.Request) (bool, error) {
return true, nil
}
// RFC 4254 Section 6.4.
type envRequest struct {
Name string
Value string
}
func (client *clientConn) handleEnvRequest(request *ssh.Request) (ok bool, err error) {
var payload envRequest
if err = ssh.Unmarshal(request.Payload, &payload); err != nil {
return
}
client.requestMutex.Lock()
client.envRequest[payload.Name] = payload.Value
client.requestMutex.Unlock()
return true, nil
}
// RFC 4254 Section 6.2.
type PTY struct {
Term string
Columns uint32
Rows uint32
Width uint32
Height uint32
ModeList string
}
// RFC 4254 Section 6.7.
type WindowChange struct {
Width uint32
Height uint32
PixelWidth uint32
PixelHeight uint32
}
func (client *clientConn) handlePTYRequest(request *ssh.Request) (ok bool, err error) {
client.requestMutex.Lock()
defer client.requestMutex.Unlock()
if client.ptyRequest != nil {
err = errors.New("duplicate pty request")
return
}
client.ptyRequest = new(PTY)
if err = ssh.Unmarshal(request.Payload, client.ptyRequest); err != nil {
return
}
return true, nil
}
// RFC 4254 Section 6.2.
type sessionRequest struct {
SenderChannel uint32
WindowSize uint32
MaxPacketSize uint32
}
func (client *clientConn) handleSessionRequest(request *ssh.Request) (ok bool, err error) {
client.requestMutex.Lock()
defer client.requestMutex.Unlock()
if client.sessionRequest != nil {
// Duplicate session request, only one is allowed.
err = errors.New("duplicate session request")
return
}
client.sessionRequest = new(sessionRequest)
if err = ssh.Unmarshal(request.Payload, client.sessionRequest); err != nil {
return
}
return true, nil
}
// RFC 4254 Section 7.1.
type forwardTCPIPRequest struct {
TargetHost string
TargetPort uint32
}
type forwardedTCPIPResponse struct {
TargetHost string
TargetPort uint32
OriginAddr string
OriginPort uint32
}
func (client *clientConn) handleTCPIPForwardRequest(request *ssh.Request) (ok bool, err error) {
var payload forwardTCPIPRequest
if err = ssh.Unmarshal(request.Payload, &payload); err != nil {
return
}
var (
remoteAddr = net.JoinHostPort(payload.TargetHost, strconv.FormatUint(uint64(payload.TargetPort), 10))
log = client.log().WithField("dst", remoteAddr)
)
log.Info("TCP/IP forward request")
var (
response = forwardedTCPIPResponse{
TargetHost: payload.TargetHost,
TargetPort: payload.TargetPort,
OriginAddr: "",
OriginPort: 0,
}
channel ssh.Channel
requests <-chan *ssh.Request
)
if channel, requests, err = client.sshConn.OpenChannel(channelTypeForwardedTCPIP, ssh.Marshal(&response)); err != nil {
return
}
_ = channel
go ssh.DiscardRequests(requests)
return true, nil
}
func (client *clientConn) handleNewChannels(channels <-chan ssh.NewChannel) error {
log := client.log()
for {
select {
case <-client.closed:
return io.EOF
case newChannel := <-channels:
if newChannel == nil {
close(client.closed)
return io.EOF
}
var (
channelType = newChannel.ChannelType()
log = log.WithField("channel", channelType)
)
log.Debug("new channel requested by client")
var channelErr error
switch channelType {
case channelTypeDirectTCPIP:
channelErr = client.handleDirectTCPIPChannel(newChannel)
case channelTypeSession:
channelErr = client.handleSessionChannelRequest(newChannel)
default:
log.Debug("unknown/unhandled channel type")
if err := newChannel.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil {
return err
}
}
if channelErr == io.EOF {
log.Debug("end client channel")
} else if channelErr != nil {
log.WithError(channelErr).Warn("failed to handle new channel")
}
}
}
}
// RFC 4254 Section 7.2.
type directTCPIP struct {
TargetHost string
TargetPort uint32
SourceIP string
SourcePort uint32
}
func (client *clientConn) handleDirectTCPIPChannel(newChannel ssh.NewChannel) (err error) {
var (
log = client.log()
payload directTCPIP
request = newChannel.ExtraData()
)
if err = ssh.Unmarshal(request, &payload); err != nil {
log.WithError(err).Warn("invalid direct TCP/IP payload")
_ = newChannel.Reject(ssh.ConnectionFailed, "invalid direct TCP/IP payload")
return
}
if client.agent == nil {
var (
agentChannel ssh.Channel
agentRequests <-chan *ssh.Request
)
log.Debug("connecting to ssh-agent")
if agentChannel, agentRequests, err = client.sshConn.OpenChannel(channelTypeAgent, nil); err != nil {
log.WithError(err).Warn("error opening channel to ssh-agent")
return newChannel.Reject(ssh.ConnectionFailed, err.Error())
}
go ssh.DiscardRequests(agentRequests)
log.Debug("connected to ssh-agent")
client.agent = agent.NewClient(agentChannel)
}
// Connect as an SSH client to the requested endpoint. We do this first so
// that if the connecting fails, we can supply a reason in rejecting the
// new channel.
// Generate a host key for the target.
var serverKeys []ssh.Signer
if serverKeys, err = client.CertificateAuthority.HostCertificate(payload.TargetHost); err != nil {
return newChannel.Reject(ssh.ConnectionFailed, err.Error())
}
remote := newSSHClient(client.sshConn.User())
if err = remote.ConnectAgent(client.agent); err != nil {
return newChannel.Reject(ssh.Prohibited, err.Error())
}
addr := net.JoinHostPort(payload.TargetHost, strconv.FormatUint(uint64(payload.TargetPort), 10))
if err = remote.DialTimeout(addr, connectTimeout); err != nil {
return newChannel.Reject(ssh.ConnectionFailed, err.Error())
}
// Now that we are connected as an SSH client, accept the new channel.
var (
channel ssh.Channel
channelRequests <-chan *ssh.Request
)
if channel, channelRequests, err = newChannel.Accept(); err != nil {
return err
}
go ssh.DiscardRequests(channelRequests)
// Present a (fake) SSH server on the newly accepted channel.
server := newSSHInterceptingServer(client.recordDir)
server.AddHostKeys(serverKeys...)
// Negotiate SSH with the client (over the channel).
if err = server.Handshake(channelConn{Channel: channel, remoteAddr: client.netConn.RemoteAddr()}); err != nil {
_ = remote.Close()
_ = channel.Close()
return err
}
// Serve the proxied SSH traffic.
go server.ProxyClient(remote)
return
}
func (client *clientConn) handleSessionChannelRequest(newChannel ssh.NewChannel) error {
channel, requests, err := newChannel.Accept()
if err != nil {
return err
}
go client.handleSession(channel, requests)
return nil
}
func (client *clientConn) handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
defer channel.Close()
var (
log = client.log()
winch = make(chan *WindowChange)
waiting sync.Mutex
)
waiting.Lock()
go func(requests <-chan *ssh.Request, winch chan<- *WindowChange, mutex *sync.Mutex) {
defer close(winch)
timeout := time.AfterFunc(connectTimeout, func() { mutex.Unlock() })
for {
select {
case request := <-requests:
if request == nil {
log.Debug("session request channel closed")
return
}
log.WithField("request", request.Type).Debug("new session OOB request")
var ok bool
switch request.Type {
case requestTypeShell:
ok = true
if timeout.Stop() {
mutex.Unlock()
}
case requestTypePTY:
ok, _ = client.handlePTYRequest(request)
if timeout.Stop() {
mutex.Unlock()
}
case requestTypeWindowChange:
var payload = new(WindowChange)
if err := ssh.Unmarshal(request.Payload, payload); err == nil {
ok = true
select {
case winch <- payload:
default:
}
}
default:
log.Debug("unknown session OOB request")
}
if request.WantReply {
_ = request.Reply(ok, nil)
}
case <-timeout.C:
log.Debug("timeout waiting for a pty request")
return
}
}
}(requests, winch, &waiting)
waiting.Lock()
if client.SessionHandler != nil {
log.Debug("starting session handler")
if err := client.SessionHandler(client.sshConn, channel, client.ptyRequest, winch); err != nil && err != io.EOF {
log.WithError(err).Error("session terminated")
}
} else {
log.Debug("no session handler, sending default shell")
term := terminal.NewTerminal(channel, "")
_, _ = fmt.Fprintln(term, "This SSH server does not support interactive sessions.")
}
}