353 lines
10 KiB
Go
353 lines
10 KiB
Go
package ssh
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sort"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
|
|
"git.maze.io/maze/conduit/internal/netutil"
|
|
"git.maze.io/maze/conduit/internal/stringutil"
|
|
"git.maze.io/maze/conduit/logger"
|
|
"git.maze.io/maze/conduit/ssh/sshutil"
|
|
)
|
|
|
|
type ConnectHandler interface {
|
|
HandleConnect(net.Conn) (net.Conn, error)
|
|
}
|
|
|
|
type ConnectHandlerFunc func(net.Conn) (net.Conn, error)
|
|
|
|
func (f ConnectHandlerFunc) HandleConnect(c net.Conn) (net.Conn, error) {
|
|
return f(c)
|
|
}
|
|
|
|
type AcceptHandler interface {
|
|
HandleAccept(net.Conn, *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request, error)
|
|
}
|
|
|
|
// MaxConnections limits the maximum number of connections.
|
|
func MaxConnections(max int) ConnectHandler {
|
|
var connections atomic.Int64
|
|
return ConnectHandlerFunc(func(c net.Conn) (net.Conn, error) {
|
|
if max <= 0 {
|
|
return nil, nil
|
|
} else if connections.Load() >= int64(max) {
|
|
return nil, errors.New("server: maximum number of connections reached")
|
|
}
|
|
|
|
var (
|
|
once sync.Once
|
|
cc = &netutil.ConnCloser{
|
|
Conn: c,
|
|
Closer: func() error {
|
|
once.Do(func() { connections.Add(-1) })
|
|
return c.Close()
|
|
},
|
|
}
|
|
)
|
|
connections.Add(1)
|
|
return cc, nil
|
|
})
|
|
}
|
|
|
|
type ChannelHandler interface {
|
|
HandleChannel(Context, ssh.Channel, <-chan *ssh.Request, []byte) error
|
|
}
|
|
|
|
type ChannelHandlerFunc func(Context, ssh.Channel, <-chan *ssh.Request, []byte) error
|
|
|
|
func (f ChannelHandlerFunc) HandleChannel(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, extra []byte) error {
|
|
return f(ctx, channel, requests, extra)
|
|
}
|
|
|
|
type debugSessionInfo struct {
|
|
start time.Time
|
|
duration time.Duration
|
|
method string
|
|
agent agent.Agent
|
|
agentRequest bool
|
|
agentChannel ssh.Channel
|
|
agentError error
|
|
env map[string]string
|
|
pty *sshutil.PTYRequest
|
|
windowChange *sshutil.WindowChangeRequest
|
|
unsupported []string
|
|
}
|
|
|
|
func newDebugSessionInfo() *debugSessionInfo {
|
|
return &debugSessionInfo{
|
|
start: time.Now(),
|
|
env: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// DebugSession is a session channel handler that will print debug information to the client, which
|
|
// may aid with troubleshooting SSH connectivity or client issues.
|
|
func DebugSession() ChannelHandler {
|
|
debugPrivateKey, _ := rsa.GenerateKey(rand.Reader, 1024)
|
|
|
|
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
|
defer channel.Close()
|
|
|
|
var (
|
|
log = ctx.(*sshContext).log
|
|
done = make(chan struct{}, 1)
|
|
info = newDebugSessionInfo()
|
|
conn = ctx.Conn()
|
|
err error
|
|
)
|
|
|
|
go func() {
|
|
var (
|
|
reply bool
|
|
reesponse []byte
|
|
)
|
|
for request := range requests {
|
|
log.Values(logger.Values{
|
|
"request": request.Type,
|
|
}).Trace("New session channel request")
|
|
switch request.Type {
|
|
case RequestTypeAgent:
|
|
info.agentRequest = true
|
|
agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil)
|
|
if err != nil {
|
|
info.agentError = err
|
|
} else {
|
|
go ssh.DiscardRequests(agentRequests)
|
|
info.agent = agent.NewClient(agentChannel)
|
|
info.agentChannel = agentChannel
|
|
}
|
|
reply = true
|
|
|
|
case RequestTypeEnv:
|
|
var payload *sshutil.EnvRequest
|
|
if payload, err = sshutil.ParseEnvRequest(request.Payload); err != nil {
|
|
log.Err(err).Debug("Corrupted env request payload, discarding")
|
|
} else {
|
|
log.Values(logger.Values{
|
|
"key": payload.Key,
|
|
"value": payload.Value,
|
|
}).Trace("Client requested env variable")
|
|
info.env[payload.Key] = payload.Value
|
|
}
|
|
reply = true
|
|
|
|
case RequestTypePTY:
|
|
if info.pty, err = sshutil.ParsePTYRequest(request.Payload); err != nil {
|
|
log.Err(err).Debug("Corrupted pty request payload, discarding")
|
|
} else {
|
|
log.Values(logger.Values{
|
|
"term": info.pty.Term,
|
|
"size": fmt.Sprintf("%dx%d", info.pty.Columns, info.pty.Rows),
|
|
}).Trace("Client requested PTY")
|
|
}
|
|
reply = true
|
|
|
|
case RequestTypeExec, RequestTypeShell:
|
|
info.method = request.Type
|
|
select {
|
|
case <-done:
|
|
default:
|
|
close(done)
|
|
}
|
|
//return
|
|
|
|
case RequestTypeWindowChange:
|
|
if info.windowChange, err = sshutil.ParseWindowChangeRequest(request.Payload); err != nil {
|
|
log.Err(err).Debug("Corrupted window change request payload, discarding")
|
|
} else {
|
|
log.Values(logger.Values{
|
|
"size": fmt.Sprintf("%dx%d", info.windowChange.Columns, info.windowChange.Rows),
|
|
}).Trace("Client requested window change")
|
|
}
|
|
reply = true
|
|
|
|
default:
|
|
log.Values(logger.Values{
|
|
"request": request.Type,
|
|
"payload": hex.EncodeToString(request.Payload),
|
|
}).Trace("Client requested something we don't understand, ignored")
|
|
info.unsupported = append(info.unsupported, request.Type)
|
|
}
|
|
|
|
if request.WantReply {
|
|
if err := request.Reply(reply, reesponse); err != nil {
|
|
log.Err(err).Debug("Error sending session channel request reply: terminating")
|
|
_ = channel.Close()
|
|
select {
|
|
case <-done:
|
|
default:
|
|
close(done)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
select {
|
|
case <-done:
|
|
default:
|
|
close(done)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(10 * time.Second):
|
|
io.WriteString(channel, "Timeout waiting for your client to send either an exec or shell request\r\n")
|
|
}
|
|
|
|
// Any requests that follow are ignored from hereon forward.
|
|
go ssh.DiscardRequests(requests)
|
|
|
|
// Attempt to request agent forwarding
|
|
if info.agent == nil && !info.agentRequest {
|
|
agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil)
|
|
if err != nil {
|
|
info.agentError = err
|
|
} else {
|
|
go ssh.DiscardRequests(agentRequests)
|
|
info.agent = agent.NewClient(agentChannel)
|
|
info.agentChannel = agentChannel
|
|
}
|
|
}
|
|
|
|
info.duration = time.Since(info.start)
|
|
printSessionInfo(ctx, channel, info, debugPrivateKey)
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func printSessionInfo(ctx Context, channel ssh.Channel, info *debugSessionInfo, key crypto.PrivateKey) {
|
|
if info.agentChannel != nil {
|
|
defer info.agentChannel.Close()
|
|
}
|
|
|
|
conn := ctx.Conn().(*ssh.ServerConn)
|
|
|
|
fmt.Fprintf(channel, "It took your client %s to request %s\r\n\r\n", info.duration, info.method)
|
|
fmt.Fprintf(channel, "SSH connection information:\r\n")
|
|
fmt.Fprintf(channel, " Server:\r\n")
|
|
fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ServerVersion())
|
|
fmt.Fprintf(channel, " Client:\r\n")
|
|
fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ClientVersion())
|
|
fmt.Fprintf(channel, " Username: \x1b[1m%s\x1b[0m\r\n", conn.User())
|
|
|
|
if conn.Permissions != nil {
|
|
if conn.Permissions.CriticalOptions != nil {
|
|
fmt.Fprint(channel, " Options:\r\n")
|
|
for k := range stringutil.MapKeys(conn.Permissions.CriticalOptions) {
|
|
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.CriticalOptions[k])
|
|
}
|
|
}
|
|
if conn.Permissions.Extensions != nil {
|
|
fmt.Fprint(channel, " Extensions:\r\n")
|
|
for k := range stringutil.MapKeys(conn.Permissions.Extensions) {
|
|
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.Extensions[k])
|
|
}
|
|
}
|
|
}
|
|
|
|
fmt.Fprintf(channel, " Environment: (%d variables sent)\r\n", len(info.env))
|
|
for k := range stringutil.MapKeys(info.env) {
|
|
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, info.env[k])
|
|
}
|
|
|
|
if info.pty == nil {
|
|
fmt.Fprint(channel, " PTY: not requested\r\n")
|
|
} else {
|
|
fmt.Fprintf(channel, " PTY: %d cols, %d rows, %dx%d, %q\r\n",
|
|
info.pty.Columns, info.pty.Rows, info.pty.Width, info.pty.Height, info.pty.Term)
|
|
}
|
|
|
|
if info.windowChange == nil {
|
|
fmt.Fprint(channel, " Window: not requested\r\n")
|
|
} else {
|
|
fmt.Fprintf(channel, " Window: %d cols, %d rows, %dx%d",
|
|
info.windowChange.Columns, info.windowChange.Rows,
|
|
info.windowChange.Width, info.windowChange.Height)
|
|
}
|
|
|
|
if len(info.unsupported) > 0 {
|
|
fmt.Fprintf(channel, " Requests: (%d unsupported):\r\n", len(info.unsupported))
|
|
sort.Strings(info.unsupported)
|
|
for _, v := range info.unsupported {
|
|
fmt.Fprintf(channel, " ❌ %s\r\n", v)
|
|
}
|
|
} else {
|
|
fmt.Fprint(channel, " Requests: no unsupported requests\r\n")
|
|
}
|
|
|
|
if info.agentRequest {
|
|
if info.agentError != nil {
|
|
fmt.Fprint(channel, " Agent: requested but unavailable:\r\n")
|
|
fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError)
|
|
}
|
|
} else {
|
|
if info.agentError == nil {
|
|
fmt.Fprint(channel, " Agent: not requested by client, but accepted, upgrade your client!\r\n")
|
|
} else {
|
|
fmt.Fprint(channel, " Agent: not requested by client, and client refused our attempt (good!):\r\n")
|
|
fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError)
|
|
}
|
|
}
|
|
|
|
if info.agent != nil {
|
|
fmt.Fprint(channel, " Agent: available, checking keys:\r\n")
|
|
keys, err := info.agent.List()
|
|
if err != nil {
|
|
fmt.Fprintf(channel, " ❌ %v\r\n", err)
|
|
} else {
|
|
for _, key := range keys {
|
|
blob := make([]byte, 8)
|
|
rand.Reader.Read(blob)
|
|
if _, err := info.agent.Sign(key, blob); err != nil {
|
|
fmt.Fprint(channel, " ❌ ")
|
|
} else {
|
|
fmt.Fprint(channel, " ✅ ")
|
|
}
|
|
fmt.Fprintf(channel, "%-4d %s %s (%s)\r\n",
|
|
sshutil.KeyBits(key), sshutil.KeyFingerprint(key),
|
|
key.Comment, sshutil.KeyType(key))
|
|
}
|
|
}
|
|
|
|
fmt.Fprint(channel, " Agent: available, checking add/remove:\r\n")
|
|
if err = info.agent.Add(agent.AddedKey{
|
|
PrivateKey: key,
|
|
LifetimeSecs: 5,
|
|
Comment: "Test by conduit",
|
|
}); err != nil {
|
|
fmt.Fprintf(channel, " ❌ Add: %v\r\n", err)
|
|
} else {
|
|
fmt.Fprint(channel, " ✅ Add\r\n")
|
|
}
|
|
pk, _ := ssh.NewPublicKey(key.(*rsa.PrivateKey).Public())
|
|
if err = info.agent.Remove(pk); err != nil {
|
|
fmt.Fprintf(channel, " ❌ Remove: %v\r\n", err)
|
|
} else {
|
|
fmt.Fprint(channel, " ✅ Remove\r\n")
|
|
}
|
|
}
|
|
}
|
|
|
|
func StaticSession(message string) ChannelHandler {
|
|
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
|
go ssh.DiscardRequests(requests)
|
|
io.WriteString(channel, message+"\r\n")
|
|
return channel.Close()
|
|
})
|
|
}
|