package stronghold
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"io"
|
|
"net"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"maze.io/x/ttyrec"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
)
|
|
|
|
// channelConn allows you to use an ssh channel as net.Conn
|
|
type channelConn struct {
|
|
ssh.Channel
|
|
localAddr net.Addr
|
|
remoteAddr net.Addr
|
|
}
|
|
|
|
func (c channelConn) LocalAddr() net.Addr { return c.localAddr }
|
|
func (c channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
|
func (channelConn) SetDeadline(_ time.Time) error { return nil }
|
|
func (channelConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
func (channelConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
|
|
type sshClient struct {
|
|
ssh.Conn
|
|
NewChannels <-chan ssh.NewChannel
|
|
Requests <-chan *ssh.Request
|
|
netConn net.Conn
|
|
config *ssh.ClientConfig
|
|
}
|
|
|
|
func newSSHClient(user string) *sshClient {
|
|
return &sshClient{
|
|
config: &ssh.ClientConfig{
|
|
User: user,
|
|
Auth: nil,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
ClientVersion: version,
|
|
Timeout: connectTimeout,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (client *sshClient) Close() error {
|
|
if client.Conn == nil {
|
|
return nil
|
|
}
|
|
|
|
err := client.Conn.Close()
|
|
_ = client.netConn.Close()
|
|
client.netConn = nil
|
|
client.Conn = nil
|
|
return err
|
|
}
|
|
|
|
func (client *sshClient) ConnectAgent(agent agent.Agent) error {
|
|
signers, err := agent.Signers()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
client.config.Auth = append(client.config.Auth, ssh.PublicKeys(signers...))
|
|
return nil
|
|
}
|
|
|
|
func (client *sshClient) DialTimeout(address string, timeout time.Duration) error {
|
|
log := logrus.WithFields(logrus.Fields{
|
|
"dst": address,
|
|
"timeout": timeout,
|
|
})
|
|
log.Info("connecting to SSH server")
|
|
|
|
var err error
|
|
if client.netConn, err = net.DialTimeout("tcp", address, timeout); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Debug("doing SSH client handshake")
|
|
if client.Conn, client.NewChannels, client.Requests, err = ssh.NewClientConn(client.netConn, address, client.config); err != nil {
|
|
log.WithError(err).Error("connection failed")
|
|
_ = client.netConn.Close()
|
|
client.netConn = nil
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type sshInterceptingServer struct {
|
|
*ssh.ServerConn
|
|
NewChannels <-chan ssh.NewChannel
|
|
Requests <-chan *ssh.Request
|
|
config *ssh.ServerConfig
|
|
netConn net.Conn
|
|
root string
|
|
recorder io.Writer
|
|
}
|
|
|
|
func newSSHInterceptingServer(root string) *sshInterceptingServer {
|
|
return &sshInterceptingServer{
|
|
config: &ssh.ServerConfig{
|
|
Config: ssh.Config{},
|
|
NoClientAuth: true,
|
|
MaxAuthTries: 128,
|
|
PasswordCallback: PermitAllPasswords,
|
|
PublicKeyCallback: PermitAllPublicKeys,
|
|
ServerVersion: version,
|
|
},
|
|
root: root,
|
|
}
|
|
}
|
|
|
|
func (server *sshInterceptingServer) AddHostKeys(signers ...ssh.Signer) {
|
|
for _, signer := range signers {
|
|
server.config.AddHostKey(signer)
|
|
}
|
|
}
|
|
|
|
func (server *sshInterceptingServer) Close() error {
|
|
if server.netConn != nil {
|
|
_ = server.netConn.Close()
|
|
server.netConn = nil
|
|
}
|
|
if server.ServerConn != nil {
|
|
return server.ServerConn.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (server *sshInterceptingServer) Handshake(conn net.Conn) error {
|
|
var err error
|
|
log := logrus.WithFields(logrus.Fields{
|
|
"dst": conn.RemoteAddr(),
|
|
})
|
|
log.Info("starting SSH server connection")
|
|
if server.ServerConn, server.NewChannels, server.Requests, err = ssh.NewServerConn(conn, server.config); err != nil {
|
|
return err
|
|
}
|
|
server.netConn = conn
|
|
return nil
|
|
}
|
|
|
|
func (server *sshInterceptingServer) ProxyClient(client *sshClient) {
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
log := logrus.WithFields(logrus.Fields{
|
|
"src": server.RemoteAddr(),
|
|
"dst": client.RemoteAddr(),
|
|
"tag": "proxy",
|
|
})
|
|
|
|
transcript := newSafeFileWriter(filepath.Join(server.root, hex.EncodeToString(client.SessionID()))+".ttyrec", 0600)
|
|
server.recorder = ttyrec.NewEncoder(transcript)
|
|
defer func() {
|
|
_ = transcript.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case newChannel := <-server.NewChannels:
|
|
if newChannel == nil {
|
|
return
|
|
}
|
|
log := log.WithFields(logrus.Fields{
|
|
"channel": newChannel.ChannelType(),
|
|
"flow": "server->client",
|
|
})
|
|
server.multiplexNewChannel(log, client, newChannel)
|
|
|
|
case newChannel := <-client.NewChannels:
|
|
if newChannel == nil {
|
|
return
|
|
}
|
|
log := log.WithFields(logrus.Fields{
|
|
"channel": newChannel.ChannelType(),
|
|
"flow": "client->server",
|
|
})
|
|
server.multiplexNewChannel(log, server, newChannel)
|
|
|
|
case request := <-server.Requests:
|
|
if request == nil {
|
|
return
|
|
}
|
|
log := log.WithFields(logrus.Fields{
|
|
"request": request.Type,
|
|
"flow": "server->client",
|
|
})
|
|
server.relayRequest(log, client, request)
|
|
|
|
case request := <-client.Requests:
|
|
if request == nil {
|
|
return
|
|
}
|
|
log := log.WithFields(logrus.Fields{
|
|
"channel": request.Type,
|
|
"flow": "client->server",
|
|
})
|
|
server.relayRequest(log, server, request)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (server *sshInterceptingServer) multiplexNewChannel(log *logrus.Entry, conn ssh.Conn, newChannel ssh.NewChannel) {
|
|
log.Debug("new channel")
|
|
|
|
var (
|
|
clientChannel, serverChannel ssh.Channel
|
|
clientRequests, serverRequests <-chan *ssh.Request
|
|
err error
|
|
)
|
|
if clientChannel, clientRequests, err = conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()); err != nil {
|
|
log.WithError(err).Warn("failed to open channel")
|
|
if err, ok := err.(*ssh.OpenChannelError); ok {
|
|
_ = newChannel.Reject(err.Reason, err.Message)
|
|
} else {
|
|
_ = newChannel.Reject(ssh.Prohibited, err.Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
if serverChannel, serverRequests, err = newChannel.Accept(); err != nil {
|
|
log.WithError(err).Warn("failed to accept channel")
|
|
go ssh.DiscardRequests(clientRequests)
|
|
_ = clientChannel.Close()
|
|
return
|
|
}
|
|
|
|
go server.multiplex(log, newChannel.ChannelType(), clientChannel, serverChannel, clientRequests, serverRequests)
|
|
}
|
|
|
|
func (server *sshInterceptingServer) multiplex(log *logrus.Entry, channelType string, clientChannel, serverChannel ssh.Channel, clientRequests, serverRequests <-chan *ssh.Request) {
|
|
var serverWriters io.Writer
|
|
switch channelType {
|
|
case "session":
|
|
serverWriters = io.MultiWriter(serverChannel, server.recorder)
|
|
default:
|
|
serverWriters = serverChannel
|
|
}
|
|
|
|
go func() {
|
|
_, _ = io.Copy(clientChannel, serverChannel)
|
|
}()
|
|
go func() {
|
|
_, _ = io.Copy(serverWriters, clientChannel)
|
|
}()
|
|
|
|
defer func() {
|
|
_ = clientChannel.Close()
|
|
_ = serverChannel.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case request := <-clientRequests:
|
|
if request == nil {
|
|
return
|
|
}
|
|
log := log.WithFields(logrus.Fields{
|
|
"request": request.Type,
|
|
"flow": "client->server",
|
|
})
|
|
log.WithField("request", request.Type).Debug("new inbound channel request")
|
|
server.relayChannelRequest(log, serverChannel, request)
|
|
|
|
case request := <-serverRequests:
|
|
if request == nil {
|
|
return
|
|
}
|
|
log := log.WithFields(logrus.Fields{
|
|
"request": request.Type,
|
|
"flow": "server->client",
|
|
})
|
|
server.relayChannelRequest(log, clientChannel, request)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (server *sshInterceptingServer) relayRequest(log *logrus.Entry, conn ssh.Conn, request *ssh.Request) {
|
|
log.Debug("request")
|
|
ok, payload, err := conn.SendRequest(request.Type, request.WantReply, request.Payload)
|
|
if err != nil {
|
|
log.WithError(err).Warn("failed to relay request")
|
|
return
|
|
}
|
|
log.WithFields(logrus.Fields{
|
|
"ok": ok,
|
|
"want_reply": request.WantReply,
|
|
}).Debug("request response")
|
|
if request.WantReply {
|
|
if err = request.Reply(ok, payload); err != nil {
|
|
log.WithError(err).Warn("failed to reply to relayed request")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (server *sshInterceptingServer) relayChannelRequest(log *logrus.Entry, channel ssh.Channel, request *ssh.Request) {
|
|
log.Debug("channel request")
|
|
ok, err := channel.SendRequest(request.Type, request.WantReply, request.Payload)
|
|
if err != nil {
|
|
log.WithError(err).Warn("failed to relay channel request")
|
|
return
|
|
}
|
|
log.WithFields(logrus.Fields{
|
|
"ok": ok,
|
|
"want_reply": request.WantReply,
|
|
}).Debug("channel request response")
|
|
if request.WantReply {
|
|
if err = request.Reply(ok, nil); err != nil {
|
|
log.WithError(err).Warn("failed to reply to relayed channel request")
|
|
}
|
|
}
|
|
}
|