Initial import
This commit is contained in:
352
ssh/handler.go
Normal file
352
ssh/handler.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
|
||||
"git.maze.io/maze/conduit/internal/netutil"
|
||||
"git.maze.io/maze/conduit/internal/stringutil"
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
"git.maze.io/maze/conduit/ssh/sshutil"
|
||||
)
|
||||
|
||||
type ConnectHandler interface {
|
||||
HandleConnect(net.Conn) (net.Conn, error)
|
||||
}
|
||||
|
||||
type ConnectHandlerFunc func(net.Conn) (net.Conn, error)
|
||||
|
||||
func (f ConnectHandlerFunc) HandleConnect(c net.Conn) (net.Conn, error) {
|
||||
return f(c)
|
||||
}
|
||||
|
||||
type AcceptHandler interface {
|
||||
HandleAccept(net.Conn, *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request, error)
|
||||
}
|
||||
|
||||
// MaxConnections limits the maximum number of connections.
|
||||
func MaxConnections(max int) ConnectHandler {
|
||||
var connections atomic.Int64
|
||||
return ConnectHandlerFunc(func(c net.Conn) (net.Conn, error) {
|
||||
if max <= 0 {
|
||||
return nil, nil
|
||||
} else if connections.Load() >= int64(max) {
|
||||
return nil, errors.New("server: maximum number of connections reached")
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
cc = &netutil.ConnCloser{
|
||||
Conn: c,
|
||||
Closer: func() error {
|
||||
once.Do(func() { connections.Add(-1) })
|
||||
return c.Close()
|
||||
},
|
||||
}
|
||||
)
|
||||
connections.Add(1)
|
||||
return cc, nil
|
||||
})
|
||||
}
|
||||
|
||||
type ChannelHandler interface {
|
||||
HandleChannel(Context, ssh.Channel, <-chan *ssh.Request, []byte) error
|
||||
}
|
||||
|
||||
type ChannelHandlerFunc func(Context, ssh.Channel, <-chan *ssh.Request, []byte) error
|
||||
|
||||
func (f ChannelHandlerFunc) HandleChannel(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, extra []byte) error {
|
||||
return f(ctx, channel, requests, extra)
|
||||
}
|
||||
|
||||
type debugSessionInfo struct {
|
||||
start time.Time
|
||||
duration time.Duration
|
||||
method string
|
||||
agent agent.Agent
|
||||
agentRequest bool
|
||||
agentChannel ssh.Channel
|
||||
agentError error
|
||||
env map[string]string
|
||||
pty *sshutil.PTYRequest
|
||||
windowChange *sshutil.WindowChangeRequest
|
||||
unsupported []string
|
||||
}
|
||||
|
||||
func newDebugSessionInfo() *debugSessionInfo {
|
||||
return &debugSessionInfo{
|
||||
start: time.Now(),
|
||||
env: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// DebugSession is a session channel handler that will print debug information to the client, which
|
||||
// may aid with troubleshooting SSH connectivity or client issues.
|
||||
func DebugSession() ChannelHandler {
|
||||
debugPrivateKey, _ := rsa.GenerateKey(rand.Reader, 1024)
|
||||
|
||||
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
||||
defer channel.Close()
|
||||
|
||||
var (
|
||||
log = ctx.(*sshContext).log
|
||||
done = make(chan struct{}, 1)
|
||||
info = newDebugSessionInfo()
|
||||
conn = ctx.Conn()
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
var (
|
||||
reply bool
|
||||
reesponse []byte
|
||||
)
|
||||
for request := range requests {
|
||||
log.Values(logger.Values{
|
||||
"request": request.Type,
|
||||
}).Trace("New session channel request")
|
||||
switch request.Type {
|
||||
case RequestTypeAgent:
|
||||
info.agentRequest = true
|
||||
agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil)
|
||||
if err != nil {
|
||||
info.agentError = err
|
||||
} else {
|
||||
go ssh.DiscardRequests(agentRequests)
|
||||
info.agent = agent.NewClient(agentChannel)
|
||||
info.agentChannel = agentChannel
|
||||
}
|
||||
reply = true
|
||||
|
||||
case RequestTypeEnv:
|
||||
var payload *sshutil.EnvRequest
|
||||
if payload, err = sshutil.ParseEnvRequest(request.Payload); err != nil {
|
||||
log.Err(err).Debug("Corrupted env request payload, discarding")
|
||||
} else {
|
||||
log.Values(logger.Values{
|
||||
"key": payload.Key,
|
||||
"value": payload.Value,
|
||||
}).Trace("Client requested env variable")
|
||||
info.env[payload.Key] = payload.Value
|
||||
}
|
||||
reply = true
|
||||
|
||||
case RequestTypePTY:
|
||||
if info.pty, err = sshutil.ParsePTYRequest(request.Payload); err != nil {
|
||||
log.Err(err).Debug("Corrupted pty request payload, discarding")
|
||||
} else {
|
||||
log.Values(logger.Values{
|
||||
"term": info.pty.Term,
|
||||
"size": fmt.Sprintf("%dx%d", info.pty.Columns, info.pty.Rows),
|
||||
}).Trace("Client requested PTY")
|
||||
}
|
||||
reply = true
|
||||
|
||||
case RequestTypeExec, RequestTypeShell:
|
||||
info.method = request.Type
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
//return
|
||||
|
||||
case RequestTypeWindowChange:
|
||||
if info.windowChange, err = sshutil.ParseWindowChangeRequest(request.Payload); err != nil {
|
||||
log.Err(err).Debug("Corrupted window change request payload, discarding")
|
||||
} else {
|
||||
log.Values(logger.Values{
|
||||
"size": fmt.Sprintf("%dx%d", info.windowChange.Columns, info.windowChange.Rows),
|
||||
}).Trace("Client requested window change")
|
||||
}
|
||||
reply = true
|
||||
|
||||
default:
|
||||
log.Values(logger.Values{
|
||||
"request": request.Type,
|
||||
"payload": hex.EncodeToString(request.Payload),
|
||||
}).Trace("Client requested something we don't understand, ignored")
|
||||
info.unsupported = append(info.unsupported, request.Type)
|
||||
}
|
||||
|
||||
if request.WantReply {
|
||||
if err := request.Reply(reply, reesponse); err != nil {
|
||||
log.Err(err).Debug("Error sending session channel request reply: terminating")
|
||||
_ = channel.Close()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(10 * time.Second):
|
||||
io.WriteString(channel, "Timeout waiting for your client to send either an exec or shell request\r\n")
|
||||
}
|
||||
|
||||
// Any requests that follow are ignored from hereon forward.
|
||||
go ssh.DiscardRequests(requests)
|
||||
|
||||
// Attempt to request agent forwarding
|
||||
if info.agent == nil && !info.agentRequest {
|
||||
agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil)
|
||||
if err != nil {
|
||||
info.agentError = err
|
||||
} else {
|
||||
go ssh.DiscardRequests(agentRequests)
|
||||
info.agent = agent.NewClient(agentChannel)
|
||||
info.agentChannel = agentChannel
|
||||
}
|
||||
}
|
||||
|
||||
info.duration = time.Since(info.start)
|
||||
printSessionInfo(ctx, channel, info, debugPrivateKey)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func printSessionInfo(ctx Context, channel ssh.Channel, info *debugSessionInfo, key crypto.PrivateKey) {
|
||||
if info.agentChannel != nil {
|
||||
defer info.agentChannel.Close()
|
||||
}
|
||||
|
||||
conn := ctx.Conn().(*ssh.ServerConn)
|
||||
|
||||
fmt.Fprintf(channel, "It took your client %s to request %s\r\n\r\n", info.duration, info.method)
|
||||
fmt.Fprintf(channel, "SSH connection information:\r\n")
|
||||
fmt.Fprintf(channel, " Server:\r\n")
|
||||
fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ServerVersion())
|
||||
fmt.Fprintf(channel, " Client:\r\n")
|
||||
fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ClientVersion())
|
||||
fmt.Fprintf(channel, " Username: \x1b[1m%s\x1b[0m\r\n", conn.User())
|
||||
|
||||
if conn.Permissions != nil {
|
||||
if conn.Permissions.CriticalOptions != nil {
|
||||
fmt.Fprint(channel, " Options:\r\n")
|
||||
for k := range stringutil.MapKeys(conn.Permissions.CriticalOptions) {
|
||||
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.CriticalOptions[k])
|
||||
}
|
||||
}
|
||||
if conn.Permissions.Extensions != nil {
|
||||
fmt.Fprint(channel, " Extensions:\r\n")
|
||||
for k := range stringutil.MapKeys(conn.Permissions.Extensions) {
|
||||
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.Extensions[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(channel, " Environment: (%d variables sent)\r\n", len(info.env))
|
||||
for k := range stringutil.MapKeys(info.env) {
|
||||
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, info.env[k])
|
||||
}
|
||||
|
||||
if info.pty == nil {
|
||||
fmt.Fprint(channel, " PTY: not requested\r\n")
|
||||
} else {
|
||||
fmt.Fprintf(channel, " PTY: %d cols, %d rows, %dx%d, %q\r\n",
|
||||
info.pty.Columns, info.pty.Rows, info.pty.Width, info.pty.Height, info.pty.Term)
|
||||
}
|
||||
|
||||
if info.windowChange == nil {
|
||||
fmt.Fprint(channel, " Window: not requested\r\n")
|
||||
} else {
|
||||
fmt.Fprintf(channel, " Window: %d cols, %d rows, %dx%d",
|
||||
info.windowChange.Columns, info.windowChange.Rows,
|
||||
info.windowChange.Width, info.windowChange.Height)
|
||||
}
|
||||
|
||||
if len(info.unsupported) > 0 {
|
||||
fmt.Fprintf(channel, " Requests: (%d unsupported):\r\n", len(info.unsupported))
|
||||
sort.Strings(info.unsupported)
|
||||
for _, v := range info.unsupported {
|
||||
fmt.Fprintf(channel, " ❌ %s\r\n", v)
|
||||
}
|
||||
} else {
|
||||
fmt.Fprint(channel, " Requests: no unsupported requests\r\n")
|
||||
}
|
||||
|
||||
if info.agentRequest {
|
||||
if info.agentError != nil {
|
||||
fmt.Fprint(channel, " Agent: requested but unavailable:\r\n")
|
||||
fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError)
|
||||
}
|
||||
} else {
|
||||
if info.agentError == nil {
|
||||
fmt.Fprint(channel, " Agent: not requested by client, but accepted, upgrade your client!\r\n")
|
||||
} else {
|
||||
fmt.Fprint(channel, " Agent: not requested by client, and client refused our attempt (good!):\r\n")
|
||||
fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError)
|
||||
}
|
||||
}
|
||||
|
||||
if info.agent != nil {
|
||||
fmt.Fprint(channel, " Agent: available, checking keys:\r\n")
|
||||
keys, err := info.agent.List()
|
||||
if err != nil {
|
||||
fmt.Fprintf(channel, " ❌ %v\r\n", err)
|
||||
} else {
|
||||
for _, key := range keys {
|
||||
blob := make([]byte, 8)
|
||||
rand.Reader.Read(blob)
|
||||
if _, err := info.agent.Sign(key, blob); err != nil {
|
||||
fmt.Fprint(channel, " ❌ ")
|
||||
} else {
|
||||
fmt.Fprint(channel, " ✅ ")
|
||||
}
|
||||
fmt.Fprintf(channel, "%-4d %s %s (%s)\r\n",
|
||||
sshutil.KeyBits(key), sshutil.KeyFingerprint(key),
|
||||
key.Comment, sshutil.KeyType(key))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprint(channel, " Agent: available, checking add/remove:\r\n")
|
||||
if err = info.agent.Add(agent.AddedKey{
|
||||
PrivateKey: key,
|
||||
LifetimeSecs: 5,
|
||||
Comment: "Test by conduit",
|
||||
}); err != nil {
|
||||
fmt.Fprintf(channel, " ❌ Add: %v\r\n", err)
|
||||
} else {
|
||||
fmt.Fprint(channel, " ✅ Add\r\n")
|
||||
}
|
||||
pk, _ := ssh.NewPublicKey(key.(*rsa.PrivateKey).Public())
|
||||
if err = info.agent.Remove(pk); err != nil {
|
||||
fmt.Fprintf(channel, " ❌ Remove: %v\r\n", err)
|
||||
} else {
|
||||
fmt.Fprint(channel, " ✅ Remove\r\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StaticSession(message string) ChannelHandler {
|
||||
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
||||
go ssh.DiscardRequests(requests)
|
||||
io.WriteString(channel, message+"\r\n")
|
||||
return channel.Close()
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user