package stronghold
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/terminal"
|
|
|
|
"git.maze.io/maze/stronghold/authority"
|
|
)
|
|
|
|
const (
|
|
requestTypeAgent = "auth-agent-req@openssh.com"
|
|
requestTypeEnv = "env"
|
|
requestTypeKeepAlive = "keepalive@openssh.com"
|
|
requestTypePTY = "pty-req"
|
|
requestTypeSession = "session"
|
|
requestTypeShell = "shell"
|
|
requestTypeTCPIPForward = "tcpip-forward"
|
|
requestTypeWindowChange = "window-change"
|
|
requestTypeX11 = "x11-req"
|
|
channelTypeAgent = "auth-agent@openssh.com"
|
|
channelTypeDirectTCPIP = "direct-tcpip"
|
|
channelTypeForwardedTCPIP = "forwarded-tcpip"
|
|
channelTypeSession = "session"
|
|
keepAliveInterval = 30 * time.Second
|
|
connectTimeout = 30 * time.Second
|
|
)
|
|
|
|
// clientConn connected to the server.
|
|
type clientConn struct {
|
|
// CertificateAuthority is a callback for proxied SSH connections.
|
|
CertificateAuthority *authority.CertificateAuthority
|
|
|
|
// SessionHandler is a callback for SSH shell sessions.
|
|
SessionHandler func(conn ssh.ConnMetadata, channel ssh.Channel, pty *PTY, windowChanges <-chan *WindowChange) error
|
|
|
|
netConn net.Conn
|
|
sshConn ssh.Conn
|
|
id string
|
|
closed chan struct{}
|
|
recordDir string
|
|
|
|
requestMutex sync.RWMutex
|
|
ptyRequest *PTY
|
|
sessionRequest *sessionRequest
|
|
envRequest map[string]string
|
|
agent agent.Agent
|
|
}
|
|
|
|
func newClient(netConn net.Conn, sshConn ssh.Conn, recordDir string) *clientConn {
|
|
return &clientConn{
|
|
netConn: netConn,
|
|
sshConn: sshConn,
|
|
id: sessionID(sshConn),
|
|
closed: make(chan struct{}),
|
|
recordDir: recordDir,
|
|
envRequest: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
func (client *clientConn) log() *logrus.Entry {
|
|
return logrus.WithFields(logrus.Fields{
|
|
"tag": "client",
|
|
"session": client.id,
|
|
"src": client.sshConn.RemoteAddr().String(),
|
|
"user": client.sshConn.User(),
|
|
})
|
|
}
|
|
|
|
func (client *clientConn) handleRequests(requests <-chan *ssh.Request) {
|
|
log := client.log()
|
|
for {
|
|
select {
|
|
case <-client.closed:
|
|
return
|
|
|
|
case request := <-requests:
|
|
if request == nil {
|
|
log.Debug("client request channel closed")
|
|
return
|
|
}
|
|
|
|
var (
|
|
log = log.WithField("request", request.Type)
|
|
ok bool
|
|
err error
|
|
response []byte
|
|
)
|
|
log.Debug("new OOB request")
|
|
switch request.Type {
|
|
case requestTypeAgent:
|
|
ok, err = client.handleAgentRequest(request)
|
|
|
|
case requestTypeEnv:
|
|
ok, err = client.handleEnvRequest(request)
|
|
|
|
case requestTypeKeepAlive:
|
|
ok = true
|
|
|
|
case requestTypeTCPIPForward:
|
|
ok, err = client.handleTCPIPForwardRequest(request)
|
|
|
|
case requestTypePTY:
|
|
ok, err = client.handlePTYRequest(request)
|
|
|
|
case requestTypeSession:
|
|
ok, err = client.handleSessionRequest(request)
|
|
|
|
case requestTypeX11:
|
|
// Request type is ignored (not supported)
|
|
|
|
default:
|
|
log.Debug("unknown/unhandled request type")
|
|
}
|
|
|
|
if err != nil {
|
|
log.WithError(err).Warn("failed to handle request")
|
|
ok = false
|
|
}
|
|
|
|
if request.WantReply {
|
|
_ = request.Reply(ok, response)
|
|
}
|
|
|
|
case <-time.After(keepAliveInterval):
|
|
log.Debug("sending keepalive request")
|
|
if _, _, err := client.sshConn.SendRequest(requestTypeKeepAlive, false, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *clientConn) handleAgentRequest(request *ssh.Request) (bool, error) {
|
|
return true, nil
|
|
}
|
|
|
|
// RFC 4254 Section 6.4.
|
|
type envRequest struct {
|
|
Name string
|
|
Value string
|
|
}
|
|
|
|
func (client *clientConn) handleEnvRequest(request *ssh.Request) (ok bool, err error) {
|
|
var payload envRequest
|
|
if err = ssh.Unmarshal(request.Payload, &payload); err != nil {
|
|
return
|
|
}
|
|
client.requestMutex.Lock()
|
|
client.envRequest[payload.Name] = payload.Value
|
|
client.requestMutex.Unlock()
|
|
return true, nil
|
|
}
|
|
|
|
// RFC 4254 Section 6.2.
|
|
type PTY struct {
|
|
Term string
|
|
Columns uint32
|
|
Rows uint32
|
|
Width uint32
|
|
Height uint32
|
|
ModeList string
|
|
}
|
|
|
|
// RFC 4254 Section 6.7.
|
|
type WindowChange struct {
|
|
Width uint32
|
|
Height uint32
|
|
PixelWidth uint32
|
|
PixelHeight uint32
|
|
}
|
|
|
|
func (client *clientConn) handlePTYRequest(request *ssh.Request) (ok bool, err error) {
|
|
client.requestMutex.Lock()
|
|
defer client.requestMutex.Unlock()
|
|
|
|
if client.ptyRequest != nil {
|
|
err = errors.New("duplicate pty request")
|
|
return
|
|
}
|
|
|
|
client.ptyRequest = new(PTY)
|
|
if err = ssh.Unmarshal(request.Payload, client.ptyRequest); err != nil {
|
|
return
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// RFC 4254 Section 6.2.
|
|
type sessionRequest struct {
|
|
SenderChannel uint32
|
|
WindowSize uint32
|
|
MaxPacketSize uint32
|
|
}
|
|
|
|
func (client *clientConn) handleSessionRequest(request *ssh.Request) (ok bool, err error) {
|
|
client.requestMutex.Lock()
|
|
defer client.requestMutex.Unlock()
|
|
|
|
if client.sessionRequest != nil {
|
|
// Duplicate session request, only one is allowed.
|
|
err = errors.New("duplicate session request")
|
|
return
|
|
}
|
|
|
|
client.sessionRequest = new(sessionRequest)
|
|
if err = ssh.Unmarshal(request.Payload, client.sessionRequest); err != nil {
|
|
return
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// RFC 4254 Section 7.1.
|
|
type forwardTCPIPRequest struct {
|
|
TargetHost string
|
|
TargetPort uint32
|
|
}
|
|
|
|
type forwardedTCPIPResponse struct {
|
|
TargetHost string
|
|
TargetPort uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
func (client *clientConn) handleTCPIPForwardRequest(request *ssh.Request) (ok bool, err error) {
|
|
var payload forwardTCPIPRequest
|
|
if err = ssh.Unmarshal(request.Payload, &payload); err != nil {
|
|
return
|
|
}
|
|
|
|
var (
|
|
remoteAddr = net.JoinHostPort(payload.TargetHost, strconv.FormatUint(uint64(payload.TargetPort), 10))
|
|
log = client.log().WithField("dst", remoteAddr)
|
|
)
|
|
log.Info("TCP/IP forward request")
|
|
|
|
var (
|
|
response = forwardedTCPIPResponse{
|
|
TargetHost: payload.TargetHost,
|
|
TargetPort: payload.TargetPort,
|
|
OriginAddr: "",
|
|
OriginPort: 0,
|
|
}
|
|
channel ssh.Channel
|
|
requests <-chan *ssh.Request
|
|
)
|
|
if channel, requests, err = client.sshConn.OpenChannel(channelTypeForwardedTCPIP, ssh.Marshal(&response)); err != nil {
|
|
return
|
|
}
|
|
_ = channel
|
|
go ssh.DiscardRequests(requests)
|
|
return true, nil
|
|
}
|
|
|
|
func (client *clientConn) handleNewChannels(channels <-chan ssh.NewChannel) error {
|
|
log := client.log()
|
|
for {
|
|
select {
|
|
case <-client.closed:
|
|
return io.EOF
|
|
|
|
case newChannel := <-channels:
|
|
if newChannel == nil {
|
|
close(client.closed)
|
|
return io.EOF
|
|
}
|
|
|
|
var (
|
|
channelType = newChannel.ChannelType()
|
|
log = log.WithField("channel", channelType)
|
|
)
|
|
log.Debug("new channel requested by client")
|
|
|
|
var channelErr error
|
|
switch channelType {
|
|
case channelTypeDirectTCPIP:
|
|
channelErr = client.handleDirectTCPIPChannel(newChannel)
|
|
|
|
case channelTypeSession:
|
|
channelErr = client.handleSessionChannelRequest(newChannel)
|
|
|
|
default:
|
|
log.Debug("unknown/unhandled channel type")
|
|
if err := newChannel.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if channelErr == io.EOF {
|
|
log.Debug("end client channel")
|
|
} else if channelErr != nil {
|
|
log.WithError(channelErr).Warn("failed to handle new channel")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// RFC 4254 Section 7.2.
|
|
type directTCPIP struct {
|
|
TargetHost string
|
|
TargetPort uint32
|
|
SourceIP string
|
|
SourcePort uint32
|
|
}
|
|
|
|
func (client *clientConn) handleDirectTCPIPChannel(newChannel ssh.NewChannel) (err error) {
|
|
var (
|
|
log = client.log()
|
|
payload directTCPIP
|
|
request = newChannel.ExtraData()
|
|
)
|
|
if err = ssh.Unmarshal(request, &payload); err != nil {
|
|
log.WithError(err).Warn("invalid direct TCP/IP payload")
|
|
_ = newChannel.Reject(ssh.ConnectionFailed, "invalid direct TCP/IP payload")
|
|
return
|
|
}
|
|
|
|
if client.agent == nil {
|
|
var (
|
|
agentChannel ssh.Channel
|
|
agentRequests <-chan *ssh.Request
|
|
)
|
|
log.Debug("connecting to ssh-agent")
|
|
if agentChannel, agentRequests, err = client.sshConn.OpenChannel(channelTypeAgent, nil); err != nil {
|
|
log.WithError(err).Warn("error opening channel to ssh-agent")
|
|
return newChannel.Reject(ssh.ConnectionFailed, err.Error())
|
|
}
|
|
go ssh.DiscardRequests(agentRequests)
|
|
|
|
log.Debug("connected to ssh-agent")
|
|
client.agent = agent.NewClient(agentChannel)
|
|
}
|
|
|
|
// Connect as an SSH client to the requested endpoint. We do this first so
|
|
// that if the connecting fails, we can supply a reason in rejecting the
|
|
// new channel.
|
|
// Generate a host key for the target.
|
|
var serverKeys []ssh.Signer
|
|
if serverKeys, err = client.CertificateAuthority.HostCertificate(payload.TargetHost); err != nil {
|
|
return newChannel.Reject(ssh.ConnectionFailed, err.Error())
|
|
}
|
|
|
|
remote := newSSHClient(client.sshConn.User())
|
|
if err = remote.ConnectAgent(client.agent); err != nil {
|
|
return newChannel.Reject(ssh.Prohibited, err.Error())
|
|
}
|
|
addr := net.JoinHostPort(payload.TargetHost, strconv.FormatUint(uint64(payload.TargetPort), 10))
|
|
if err = remote.DialTimeout(addr, connectTimeout); err != nil {
|
|
return newChannel.Reject(ssh.ConnectionFailed, err.Error())
|
|
}
|
|
|
|
// Now that we are connected as an SSH client, accept the new channel.
|
|
var (
|
|
channel ssh.Channel
|
|
channelRequests <-chan *ssh.Request
|
|
)
|
|
if channel, channelRequests, err = newChannel.Accept(); err != nil {
|
|
return err
|
|
}
|
|
go ssh.DiscardRequests(channelRequests)
|
|
|
|
// Present a (fake) SSH server on the newly accepted channel.
|
|
server := newSSHInterceptingServer(client.recordDir)
|
|
server.AddHostKeys(serverKeys...)
|
|
|
|
// Negotiate SSH with the client (over the channel).
|
|
if err = server.Handshake(channelConn{Channel: channel, remoteAddr: client.netConn.RemoteAddr()}); err != nil {
|
|
_ = remote.Close()
|
|
_ = channel.Close()
|
|
return err
|
|
}
|
|
|
|
// Serve the proxied SSH traffic.
|
|
go server.ProxyClient(remote)
|
|
|
|
return
|
|
}
|
|
|
|
func (client *clientConn) handleSessionChannelRequest(newChannel ssh.NewChannel) error {
|
|
channel, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go client.handleSession(channel, requests)
|
|
return nil
|
|
}
|
|
|
|
func (client *clientConn) handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
|
|
defer channel.Close()
|
|
|
|
var (
|
|
log = client.log()
|
|
winch = make(chan *WindowChange)
|
|
waiting sync.Mutex
|
|
)
|
|
waiting.Lock()
|
|
|
|
go func(requests <-chan *ssh.Request, winch chan<- *WindowChange, mutex *sync.Mutex) {
|
|
defer close(winch)
|
|
timeout := time.AfterFunc(connectTimeout, func() { mutex.Unlock() })
|
|
for {
|
|
select {
|
|
case request := <-requests:
|
|
if request == nil {
|
|
log.Debug("session request channel closed")
|
|
return
|
|
}
|
|
|
|
log.WithField("request", request.Type).Debug("new session OOB request")
|
|
var ok bool
|
|
switch request.Type {
|
|
case requestTypeShell:
|
|
ok = true
|
|
if timeout.Stop() {
|
|
mutex.Unlock()
|
|
}
|
|
|
|
case requestTypePTY:
|
|
ok, _ = client.handlePTYRequest(request)
|
|
if timeout.Stop() {
|
|
mutex.Unlock()
|
|
}
|
|
|
|
case requestTypeWindowChange:
|
|
var payload = new(WindowChange)
|
|
if err := ssh.Unmarshal(request.Payload, payload); err == nil {
|
|
ok = true
|
|
select {
|
|
case winch <- payload:
|
|
default:
|
|
}
|
|
}
|
|
|
|
default:
|
|
log.Debug("unknown session OOB request")
|
|
}
|
|
|
|
if request.WantReply {
|
|
_ = request.Reply(ok, nil)
|
|
}
|
|
|
|
case <-timeout.C:
|
|
log.Debug("timeout waiting for a pty request")
|
|
return
|
|
}
|
|
}
|
|
}(requests, winch, &waiting)
|
|
|
|
waiting.Lock()
|
|
|
|
if client.SessionHandler != nil {
|
|
log.Debug("starting session handler")
|
|
if err := client.SessionHandler(client.sshConn, channel, client.ptyRequest, winch); err != nil && err != io.EOF {
|
|
log.WithError(err).Error("session terminated")
|
|
}
|
|
} else {
|
|
log.Debug("no session handler, sending default shell")
|
|
term := terminal.NewTerminal(channel, "")
|
|
_, _ = fmt.Fprintln(term, "This SSH server does not support interactive sessions.")
|
|
}
|
|
}
|