Transparent SSH jump host with auditing.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

319 lines
8.1 KiB

package stronghold
import (
"encoding/hex"
"io"
"net"
"path/filepath"
"time"
"maze.io/x/ttyrec"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// channelConn allows you to use an ssh channel as net.Conn
type channelConn struct {
ssh.Channel
localAddr net.Addr
remoteAddr net.Addr
}
func (c channelConn) LocalAddr() net.Addr { return c.localAddr }
func (c channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (channelConn) SetDeadline(_ time.Time) error { return nil }
func (channelConn) SetReadDeadline(_ time.Time) error { return nil }
func (channelConn) SetWriteDeadline(_ time.Time) error { return nil }
type sshClient struct {
ssh.Conn
NewChannels <-chan ssh.NewChannel
Requests <-chan *ssh.Request
netConn net.Conn
config *ssh.ClientConfig
}
func newSSHClient(user string) *sshClient {
return &sshClient{
config: &ssh.ClientConfig{
User: user,
Auth: nil,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
ClientVersion: version,
Timeout: connectTimeout,
},
}
}
func (client *sshClient) Close() error {
if client.Conn == nil {
return nil
}
err := client.Conn.Close()
_ = client.netConn.Close()
client.netConn = nil
client.Conn = nil
return err
}
func (client *sshClient) ConnectAgent(agent agent.Agent) error {
signers, err := agent.Signers()
if err != nil {
return err
}
client.config.Auth = append(client.config.Auth, ssh.PublicKeys(signers...))
return nil
}
func (client *sshClient) DialTimeout(address string, timeout time.Duration) error {
log := logrus.WithFields(logrus.Fields{
"dst": address,
"timeout": timeout,
})
log.Info("connecting to SSH server")
var err error
if client.netConn, err = net.DialTimeout("tcp", address, timeout); err != nil {
return err
}
log.Debug("doing SSH client handshake")
if client.Conn, client.NewChannels, client.Requests, err = ssh.NewClientConn(client.netConn, address, client.config); err != nil {
log.WithError(err).Error("connection failed")
_ = client.netConn.Close()
client.netConn = nil
return err
}
return nil
}
type sshInterceptingServer struct {
*ssh.ServerConn
NewChannels <-chan ssh.NewChannel
Requests <-chan *ssh.Request
config *ssh.ServerConfig
netConn net.Conn
root string
recorder io.Writer
}
func newSSHInterceptingServer(root string) *sshInterceptingServer {
return &sshInterceptingServer{
config: &ssh.ServerConfig{
Config: ssh.Config{},
NoClientAuth: true,
MaxAuthTries: 128,
PasswordCallback: PermitAllPasswords,
PublicKeyCallback: PermitAllPublicKeys,
ServerVersion: version,
},
root: root,
}
}
func (server *sshInterceptingServer) AddHostKeys(signers ...ssh.Signer) {
for _, signer := range signers {
server.config.AddHostKey(signer)
}
}
func (server *sshInterceptingServer) Close() error {
if server.netConn != nil {
_ = server.netConn.Close()
server.netConn = nil
}
if server.ServerConn != nil {
return server.ServerConn.Close()
}
return nil
}
func (server *sshInterceptingServer) Handshake(conn net.Conn) error {
var err error
log := logrus.WithFields(logrus.Fields{
"dst": conn.RemoteAddr(),
})
log.Info("starting SSH server connection")
if server.ServerConn, server.NewChannels, server.Requests, err = ssh.NewServerConn(conn, server.config); err != nil {
return err
}
server.netConn = conn
return nil
}
func (server *sshInterceptingServer) ProxyClient(client *sshClient) {
defer client.Close()
defer server.Close()
log := logrus.WithFields(logrus.Fields{
"src": server.RemoteAddr(),
"dst": client.RemoteAddr(),
"tag": "proxy",
})
transcript := newSafeFileWriter(filepath.Join(server.root, hex.EncodeToString(client.SessionID()))+".ttyrec", 0600)
server.recorder = ttyrec.NewEncoder(transcript)
defer func() {
_ = transcript.Close()
}()
for {
select {
case newChannel := <-server.NewChannels:
if newChannel == nil {
return
}
log := log.WithFields(logrus.Fields{
"channel": newChannel.ChannelType(),
"flow": "server->client",
})
server.multiplexNewChannel(log, client, newChannel)
case newChannel := <-client.NewChannels:
if newChannel == nil {
return
}
log := log.WithFields(logrus.Fields{
"channel": newChannel.ChannelType(),
"flow": "client->server",
})
server.multiplexNewChannel(log, server, newChannel)
case request := <-server.Requests:
if request == nil {
return
}
log := log.WithFields(logrus.Fields{
"request": request.Type,
"flow": "server->client",
})
server.relayRequest(log, client, request)
case request := <-client.Requests:
if request == nil {
return
}
log := log.WithFields(logrus.Fields{
"channel": request.Type,
"flow": "client->server",
})
server.relayRequest(log, server, request)
}
}
}
func (server *sshInterceptingServer) multiplexNewChannel(log *logrus.Entry, conn ssh.Conn, newChannel ssh.NewChannel) {
log.Debug("new channel")
var (
clientChannel, serverChannel ssh.Channel
clientRequests, serverRequests <-chan *ssh.Request
err error
)
if clientChannel, clientRequests, err = conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()); err != nil {
log.WithError(err).Warn("failed to open channel")
if err, ok := err.(*ssh.OpenChannelError); ok {
_ = newChannel.Reject(err.Reason, err.Message)
} else {
_ = newChannel.Reject(ssh.Prohibited, err.Error())
}
return
}
if serverChannel, serverRequests, err = newChannel.Accept(); err != nil {
log.WithError(err).Warn("failed to accept channel")
go ssh.DiscardRequests(clientRequests)
_ = clientChannel.Close()
return
}
go server.multiplex(log, newChannel.ChannelType(), clientChannel, serverChannel, clientRequests, serverRequests)
}
func (server *sshInterceptingServer) multiplex(log *logrus.Entry, channelType string, clientChannel, serverChannel ssh.Channel, clientRequests, serverRequests <-chan *ssh.Request) {
var serverWriters io.Writer
switch channelType {
case "session":
serverWriters = io.MultiWriter(serverChannel, server.recorder)
default:
serverWriters = serverChannel
}
go func() {
_, _ = io.Copy(clientChannel, serverChannel)
}()
go func() {
_, _ = io.Copy(serverWriters, clientChannel)
}()
defer func() {
_ = clientChannel.Close()
_ = serverChannel.Close()
}()
for {
select {
case request := <-clientRequests:
if request == nil {
return
}
log := log.WithFields(logrus.Fields{
"request": request.Type,
"flow": "client->server",
})
log.WithField("request", request.Type).Debug("new inbound channel request")
server.relayChannelRequest(log, serverChannel, request)
case request := <-serverRequests:
if request == nil {
return
}
log := log.WithFields(logrus.Fields{
"request": request.Type,
"flow": "server->client",
})
server.relayChannelRequest(log, clientChannel, request)
}
}
}
func (server *sshInterceptingServer) relayRequest(log *logrus.Entry, conn ssh.Conn, request *ssh.Request) {
log.Debug("request")
ok, payload, err := conn.SendRequest(request.Type, request.WantReply, request.Payload)
if err != nil {
log.WithError(err).Warn("failed to relay request")
return
}
log.WithFields(logrus.Fields{
"ok": ok,
"want_reply": request.WantReply,
}).Debug("request response")
if request.WantReply {
if err = request.Reply(ok, payload); err != nil {
log.WithError(err).Warn("failed to reply to relayed request")
return
}
}
}
func (server *sshInterceptingServer) relayChannelRequest(log *logrus.Entry, channel ssh.Channel, request *ssh.Request) {
log.Debug("channel request")
ok, err := channel.SendRequest(request.Type, request.WantReply, request.Payload)
if err != nil {
log.WithError(err).Warn("failed to relay channel request")
return
}
log.WithFields(logrus.Fields{
"ok": ok,
"want_reply": request.WantReply,
}).Debug("channel request response")
if request.WantReply {
if err = request.Reply(ok, nil); err != nil {
log.WithError(err).Warn("failed to reply to relayed channel request")
}
}
}