Initial import
This commit is contained in:
52
ssh/channel.go
Normal file
52
ssh/channel.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type DupeChannel struct {
|
||||
ssh.Channel
|
||||
|
||||
// Reads writes read actions.
|
||||
Reads io.WriteCloser
|
||||
|
||||
// Writer writes write actions.
|
||||
Writes io.WriteCloser
|
||||
}
|
||||
|
||||
func (c DupeChannel) Close() error {
|
||||
var errs []error
|
||||
for _, closer := range []io.Closer{
|
||||
c.Channel,
|
||||
c.Reads,
|
||||
c.Writes,
|
||||
} {
|
||||
if closer == nil {
|
||||
continue
|
||||
}
|
||||
if cerr := closer.Close(); cerr != nil {
|
||||
errs = append(errs, cerr)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (c DupeChannel) Read(p []byte) (n int, err error) {
|
||||
if c.Reads == nil {
|
||||
return c.Channel.Read(p)
|
||||
}
|
||||
return io.TeeReader(c.Channel, c.Reads).Read(p)
|
||||
}
|
||||
|
||||
func (c DupeChannel) Write(p []byte) (n int, err error) {
|
||||
if c.Writes == nil {
|
||||
return c.Channel.Write(p)
|
||||
}
|
||||
if n, err = c.Channel.Write(p); n > 0 {
|
||||
_, _ = c.Writes.Write(p[:])
|
||||
}
|
||||
return
|
||||
}
|
28
ssh/client.go
Normal file
28
ssh/client.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
netConn net.Conn
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
ssh.ClientConfig
|
||||
}
|
||||
|
||||
func NewClient(conn net.Conn, config *ClientConfig) (*Client, error) {
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(conn, conn.RemoteAddr().String(), &config.ClientConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
netConn: conn,
|
||||
client: ssh.NewClient(sshConn, channels, requests),
|
||||
}, nil
|
||||
}
|
46
ssh/compat.go
Normal file
46
ssh/compat.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeDefault = ""
|
||||
ChannelTypeAgent = "auth-agent@openssh.com"
|
||||
ChannelTypeDirectTCPIP = "direct-tcpip"
|
||||
ChannelTypeSession = "session"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestTypeAgent = "auth-agent-req@openssh.com"
|
||||
RequestTypeEnv = "env"
|
||||
RequestTypeExec = "exec"
|
||||
RequestTypePTY = "pty-req"
|
||||
RequestTypeShell = "shell"
|
||||
RequestTypeWindowChange = "window-change"
|
||||
)
|
||||
|
||||
// Type aliases for convenience.
|
||||
type (
|
||||
CertChecker = ssh.CertChecker
|
||||
Certificate = ssh.Certificate
|
||||
Conn = ssh.Conn
|
||||
ConnMetadata = ssh.ConnMetadata
|
||||
Permissions = ssh.Permissions
|
||||
PublicKey = ssh.PublicKey
|
||||
ServerConfig = ssh.ServerConfig
|
||||
Signer = ssh.Signer
|
||||
)
|
||||
|
||||
func MarshalAuthorizedKey(in PublicKey) []byte {
|
||||
return ssh.MarshalAuthorizedKey(in)
|
||||
}
|
||||
|
||||
func ParseAuthorizedKey(in []byte) (out PublicKey, options []string, err error) {
|
||||
out, _, options, _, err = ssh.ParseAuthorizedKey(in)
|
||||
return
|
||||
}
|
||||
|
||||
func ParsePublicKey(in []byte) (out PublicKey, err error) {
|
||||
return ssh.ParsePublicKey(in)
|
||||
}
|
204
ssh/context.go
Normal file
204
ssh/context.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
)
|
||||
|
||||
var seed = rand.NewSource(time.Now().UnixNano())
|
||||
|
||||
type Context interface {
|
||||
ssh.ConnMetadata
|
||||
|
||||
// ID is the unique identifier.
|
||||
ID() string
|
||||
|
||||
// Conn is the [ssh.Conn].
|
||||
Conn() ssh.Conn
|
||||
|
||||
// NetConn is the underlying [net.Conn].
|
||||
NetConn() net.Conn
|
||||
|
||||
// Close the client connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type sshContext struct {
|
||||
id uint64
|
||||
server *Server
|
||||
netConn net.Conn
|
||||
conn *ssh.ServerConn
|
||||
log logger.Structured
|
||||
}
|
||||
|
||||
func newSSHContext(server *Server, netConn net.Conn, conn *ssh.ServerConn, log logger.Structured) *sshContext {
|
||||
ctx := &sshContext{
|
||||
id: uint64(seed.Int63()),
|
||||
server: server,
|
||||
netConn: netConn,
|
||||
conn: conn,
|
||||
}
|
||||
ctx.log = log.Value("context", ctx.ID())
|
||||
return ctx
|
||||
}
|
||||
|
||||
// User returns the user ID for this connection.
|
||||
func (ctx *sshContext) User() string {
|
||||
return ctx.conn.User()
|
||||
}
|
||||
|
||||
// SessionID returns the session hash, also denoted by H.
|
||||
func (ctx *sshContext) SessionID() []byte {
|
||||
return ctx.conn.SessionID()
|
||||
}
|
||||
|
||||
// ClientVersion returns the client's version string as hashed
|
||||
// into the session ID.
|
||||
func (ctx *sshContext) ClientVersion() []byte {
|
||||
return ctx.conn.ClientVersion()
|
||||
}
|
||||
|
||||
// ServerVersion returns the server's version string as hashed
|
||||
// into the session ID.
|
||||
func (ctx *sshContext) ServerVersion() []byte {
|
||||
return ctx.conn.ServerVersion()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
func (ctx *sshContext) RemoteAddr() net.Addr {
|
||||
return ctx.netConn.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
func (ctx *sshContext) LocalAddr() net.Addr {
|
||||
return ctx.netConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (ctx *sshContext) handleChannels(channels <-chan ssh.NewChannel) (err error) {
|
||||
for newChan := range channels {
|
||||
var (
|
||||
kind = newChan.ChannelType()
|
||||
log = ctx.log.Value("channel", kind)
|
||||
)
|
||||
log.Trace("Client requested new channel")
|
||||
|
||||
handler, ok := ctx.server.ChannelHandler[kind]
|
||||
if !ok {
|
||||
handler = ctx.server.ChannelHandler[ChannelTypeDefault]
|
||||
}
|
||||
|
||||
if handler != nil {
|
||||
var (
|
||||
channel ssh.Channel
|
||||
requests <-chan *ssh.Request
|
||||
)
|
||||
if channel, requests, err = newChan.Accept(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = handler.HandleChannel(ctx, channel, requests, newChan.ExtraData()); err != nil {
|
||||
return
|
||||
}
|
||||
} else if kind == ChannelTypeDirectTCPIP && ctx.server.PortForwardHandler != nil {
|
||||
if err = ctx.handleDirectTCPIP(newChan); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ctx.log.Debug("Rejecting unsupported channel type")
|
||||
if err = newChan.Reject(ssh.Prohibited, ""); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Our client hang up.
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (ctx *sshContext) handleDirectTCPIP(newChan ssh.NewChannel) (err error) {
|
||||
var payload struct {
|
||||
Host string
|
||||
Port uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
if err = ssh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||
_ = newChan.Reject(ssh.Prohibited, "")
|
||||
return
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
if ip = net.ParseIP(payload.Host); ip == nil {
|
||||
// Not an IP
|
||||
var ips []net.IP
|
||||
if ips, err = net.LookupIP(payload.Host); err != nil {
|
||||
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
} else if len(ips) == 0 {
|
||||
_ = newChan.Reject(ssh.ConnectionFailed, "")
|
||||
return
|
||||
}
|
||||
ip = ips[0]
|
||||
}
|
||||
|
||||
var (
|
||||
raddr = &net.TCPAddr{
|
||||
IP: ip,
|
||||
Port: int(payload.Port),
|
||||
}
|
||||
laddr = &net.TCPAddr{
|
||||
IP: net.ParseIP(payload.OriginAddr),
|
||||
Port: int(payload.OriginPort),
|
||||
}
|
||||
)
|
||||
if payload.OriginAddr == "" && payload.OriginPort == 0 {
|
||||
laddr = nil
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
if conn, err = ctx.server.PortForwardHandler.HandlePortForwardRequest(ctx, raddr, laddr); err != nil {
|
||||
_ = newChan.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
var (
|
||||
channel ssh.Channel
|
||||
requests <-chan *ssh.Request
|
||||
)
|
||||
if channel, requests, err = newChan.Accept(); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = channel.Close() }()
|
||||
go ssh.DiscardRequests(requests)
|
||||
|
||||
go io.Copy(channel, conn)
|
||||
_, err = io.Copy(conn, channel)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Conn() ssh.Conn {
|
||||
return ctx.conn
|
||||
}
|
||||
|
||||
func (ctx *sshContext) NetConn() net.Conn {
|
||||
return ctx.netConn
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Close() error {
|
||||
return ctx.conn.Close()
|
||||
}
|
||||
|
||||
func (ctx *sshContext) ID() string {
|
||||
var b [8]byte
|
||||
binary.BigEndian.PutUint64(b[:], ctx.id)
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
352
ssh/handler.go
Normal file
352
ssh/handler.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
|
||||
"git.maze.io/maze/conduit/internal/netutil"
|
||||
"git.maze.io/maze/conduit/internal/stringutil"
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
"git.maze.io/maze/conduit/ssh/sshutil"
|
||||
)
|
||||
|
||||
type ConnectHandler interface {
|
||||
HandleConnect(net.Conn) (net.Conn, error)
|
||||
}
|
||||
|
||||
type ConnectHandlerFunc func(net.Conn) (net.Conn, error)
|
||||
|
||||
func (f ConnectHandlerFunc) HandleConnect(c net.Conn) (net.Conn, error) {
|
||||
return f(c)
|
||||
}
|
||||
|
||||
type AcceptHandler interface {
|
||||
HandleAccept(net.Conn, *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request, error)
|
||||
}
|
||||
|
||||
// MaxConnections limits the maximum number of connections.
|
||||
func MaxConnections(max int) ConnectHandler {
|
||||
var connections atomic.Int64
|
||||
return ConnectHandlerFunc(func(c net.Conn) (net.Conn, error) {
|
||||
if max <= 0 {
|
||||
return nil, nil
|
||||
} else if connections.Load() >= int64(max) {
|
||||
return nil, errors.New("server: maximum number of connections reached")
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
cc = &netutil.ConnCloser{
|
||||
Conn: c,
|
||||
Closer: func() error {
|
||||
once.Do(func() { connections.Add(-1) })
|
||||
return c.Close()
|
||||
},
|
||||
}
|
||||
)
|
||||
connections.Add(1)
|
||||
return cc, nil
|
||||
})
|
||||
}
|
||||
|
||||
type ChannelHandler interface {
|
||||
HandleChannel(Context, ssh.Channel, <-chan *ssh.Request, []byte) error
|
||||
}
|
||||
|
||||
type ChannelHandlerFunc func(Context, ssh.Channel, <-chan *ssh.Request, []byte) error
|
||||
|
||||
func (f ChannelHandlerFunc) HandleChannel(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, extra []byte) error {
|
||||
return f(ctx, channel, requests, extra)
|
||||
}
|
||||
|
||||
type debugSessionInfo struct {
|
||||
start time.Time
|
||||
duration time.Duration
|
||||
method string
|
||||
agent agent.Agent
|
||||
agentRequest bool
|
||||
agentChannel ssh.Channel
|
||||
agentError error
|
||||
env map[string]string
|
||||
pty *sshutil.PTYRequest
|
||||
windowChange *sshutil.WindowChangeRequest
|
||||
unsupported []string
|
||||
}
|
||||
|
||||
func newDebugSessionInfo() *debugSessionInfo {
|
||||
return &debugSessionInfo{
|
||||
start: time.Now(),
|
||||
env: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// DebugSession is a session channel handler that will print debug information to the client, which
|
||||
// may aid with troubleshooting SSH connectivity or client issues.
|
||||
func DebugSession() ChannelHandler {
|
||||
debugPrivateKey, _ := rsa.GenerateKey(rand.Reader, 1024)
|
||||
|
||||
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
||||
defer channel.Close()
|
||||
|
||||
var (
|
||||
log = ctx.(*sshContext).log
|
||||
done = make(chan struct{}, 1)
|
||||
info = newDebugSessionInfo()
|
||||
conn = ctx.Conn()
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
var (
|
||||
reply bool
|
||||
reesponse []byte
|
||||
)
|
||||
for request := range requests {
|
||||
log.Values(logger.Values{
|
||||
"request": request.Type,
|
||||
}).Trace("New session channel request")
|
||||
switch request.Type {
|
||||
case RequestTypeAgent:
|
||||
info.agentRequest = true
|
||||
agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil)
|
||||
if err != nil {
|
||||
info.agentError = err
|
||||
} else {
|
||||
go ssh.DiscardRequests(agentRequests)
|
||||
info.agent = agent.NewClient(agentChannel)
|
||||
info.agentChannel = agentChannel
|
||||
}
|
||||
reply = true
|
||||
|
||||
case RequestTypeEnv:
|
||||
var payload *sshutil.EnvRequest
|
||||
if payload, err = sshutil.ParseEnvRequest(request.Payload); err != nil {
|
||||
log.Err(err).Debug("Corrupted env request payload, discarding")
|
||||
} else {
|
||||
log.Values(logger.Values{
|
||||
"key": payload.Key,
|
||||
"value": payload.Value,
|
||||
}).Trace("Client requested env variable")
|
||||
info.env[payload.Key] = payload.Value
|
||||
}
|
||||
reply = true
|
||||
|
||||
case RequestTypePTY:
|
||||
if info.pty, err = sshutil.ParsePTYRequest(request.Payload); err != nil {
|
||||
log.Err(err).Debug("Corrupted pty request payload, discarding")
|
||||
} else {
|
||||
log.Values(logger.Values{
|
||||
"term": info.pty.Term,
|
||||
"size": fmt.Sprintf("%dx%d", info.pty.Columns, info.pty.Rows),
|
||||
}).Trace("Client requested PTY")
|
||||
}
|
||||
reply = true
|
||||
|
||||
case RequestTypeExec, RequestTypeShell:
|
||||
info.method = request.Type
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
//return
|
||||
|
||||
case RequestTypeWindowChange:
|
||||
if info.windowChange, err = sshutil.ParseWindowChangeRequest(request.Payload); err != nil {
|
||||
log.Err(err).Debug("Corrupted window change request payload, discarding")
|
||||
} else {
|
||||
log.Values(logger.Values{
|
||||
"size": fmt.Sprintf("%dx%d", info.windowChange.Columns, info.windowChange.Rows),
|
||||
}).Trace("Client requested window change")
|
||||
}
|
||||
reply = true
|
||||
|
||||
default:
|
||||
log.Values(logger.Values{
|
||||
"request": request.Type,
|
||||
"payload": hex.EncodeToString(request.Payload),
|
||||
}).Trace("Client requested something we don't understand, ignored")
|
||||
info.unsupported = append(info.unsupported, request.Type)
|
||||
}
|
||||
|
||||
if request.WantReply {
|
||||
if err := request.Reply(reply, reesponse); err != nil {
|
||||
log.Err(err).Debug("Error sending session channel request reply: terminating")
|
||||
_ = channel.Close()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(10 * time.Second):
|
||||
io.WriteString(channel, "Timeout waiting for your client to send either an exec or shell request\r\n")
|
||||
}
|
||||
|
||||
// Any requests that follow are ignored from hereon forward.
|
||||
go ssh.DiscardRequests(requests)
|
||||
|
||||
// Attempt to request agent forwarding
|
||||
if info.agent == nil && !info.agentRequest {
|
||||
agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil)
|
||||
if err != nil {
|
||||
info.agentError = err
|
||||
} else {
|
||||
go ssh.DiscardRequests(agentRequests)
|
||||
info.agent = agent.NewClient(agentChannel)
|
||||
info.agentChannel = agentChannel
|
||||
}
|
||||
}
|
||||
|
||||
info.duration = time.Since(info.start)
|
||||
printSessionInfo(ctx, channel, info, debugPrivateKey)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func printSessionInfo(ctx Context, channel ssh.Channel, info *debugSessionInfo, key crypto.PrivateKey) {
|
||||
if info.agentChannel != nil {
|
||||
defer info.agentChannel.Close()
|
||||
}
|
||||
|
||||
conn := ctx.Conn().(*ssh.ServerConn)
|
||||
|
||||
fmt.Fprintf(channel, "It took your client %s to request %s\r\n\r\n", info.duration, info.method)
|
||||
fmt.Fprintf(channel, "SSH connection information:\r\n")
|
||||
fmt.Fprintf(channel, " Server:\r\n")
|
||||
fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ServerVersion())
|
||||
fmt.Fprintf(channel, " Client:\r\n")
|
||||
fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ClientVersion())
|
||||
fmt.Fprintf(channel, " Username: \x1b[1m%s\x1b[0m\r\n", conn.User())
|
||||
|
||||
if conn.Permissions != nil {
|
||||
if conn.Permissions.CriticalOptions != nil {
|
||||
fmt.Fprint(channel, " Options:\r\n")
|
||||
for k := range stringutil.MapKeys(conn.Permissions.CriticalOptions) {
|
||||
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.CriticalOptions[k])
|
||||
}
|
||||
}
|
||||
if conn.Permissions.Extensions != nil {
|
||||
fmt.Fprint(channel, " Extensions:\r\n")
|
||||
for k := range stringutil.MapKeys(conn.Permissions.Extensions) {
|
||||
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.Extensions[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(channel, " Environment: (%d variables sent)\r\n", len(info.env))
|
||||
for k := range stringutil.MapKeys(info.env) {
|
||||
fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, info.env[k])
|
||||
}
|
||||
|
||||
if info.pty == nil {
|
||||
fmt.Fprint(channel, " PTY: not requested\r\n")
|
||||
} else {
|
||||
fmt.Fprintf(channel, " PTY: %d cols, %d rows, %dx%d, %q\r\n",
|
||||
info.pty.Columns, info.pty.Rows, info.pty.Width, info.pty.Height, info.pty.Term)
|
||||
}
|
||||
|
||||
if info.windowChange == nil {
|
||||
fmt.Fprint(channel, " Window: not requested\r\n")
|
||||
} else {
|
||||
fmt.Fprintf(channel, " Window: %d cols, %d rows, %dx%d",
|
||||
info.windowChange.Columns, info.windowChange.Rows,
|
||||
info.windowChange.Width, info.windowChange.Height)
|
||||
}
|
||||
|
||||
if len(info.unsupported) > 0 {
|
||||
fmt.Fprintf(channel, " Requests: (%d unsupported):\r\n", len(info.unsupported))
|
||||
sort.Strings(info.unsupported)
|
||||
for _, v := range info.unsupported {
|
||||
fmt.Fprintf(channel, " ❌ %s\r\n", v)
|
||||
}
|
||||
} else {
|
||||
fmt.Fprint(channel, " Requests: no unsupported requests\r\n")
|
||||
}
|
||||
|
||||
if info.agentRequest {
|
||||
if info.agentError != nil {
|
||||
fmt.Fprint(channel, " Agent: requested but unavailable:\r\n")
|
||||
fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError)
|
||||
}
|
||||
} else {
|
||||
if info.agentError == nil {
|
||||
fmt.Fprint(channel, " Agent: not requested by client, but accepted, upgrade your client!\r\n")
|
||||
} else {
|
||||
fmt.Fprint(channel, " Agent: not requested by client, and client refused our attempt (good!):\r\n")
|
||||
fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError)
|
||||
}
|
||||
}
|
||||
|
||||
if info.agent != nil {
|
||||
fmt.Fprint(channel, " Agent: available, checking keys:\r\n")
|
||||
keys, err := info.agent.List()
|
||||
if err != nil {
|
||||
fmt.Fprintf(channel, " ❌ %v\r\n", err)
|
||||
} else {
|
||||
for _, key := range keys {
|
||||
blob := make([]byte, 8)
|
||||
rand.Reader.Read(blob)
|
||||
if _, err := info.agent.Sign(key, blob); err != nil {
|
||||
fmt.Fprint(channel, " ❌ ")
|
||||
} else {
|
||||
fmt.Fprint(channel, " ✅ ")
|
||||
}
|
||||
fmt.Fprintf(channel, "%-4d %s %s (%s)\r\n",
|
||||
sshutil.KeyBits(key), sshutil.KeyFingerprint(key),
|
||||
key.Comment, sshutil.KeyType(key))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprint(channel, " Agent: available, checking add/remove:\r\n")
|
||||
if err = info.agent.Add(agent.AddedKey{
|
||||
PrivateKey: key,
|
||||
LifetimeSecs: 5,
|
||||
Comment: "Test by conduit",
|
||||
}); err != nil {
|
||||
fmt.Fprintf(channel, " ❌ Add: %v\r\n", err)
|
||||
} else {
|
||||
fmt.Fprint(channel, " ✅ Add\r\n")
|
||||
}
|
||||
pk, _ := ssh.NewPublicKey(key.(*rsa.PrivateKey).Public())
|
||||
if err = info.agent.Remove(pk); err != nil {
|
||||
fmt.Fprintf(channel, " ❌ Remove: %v\r\n", err)
|
||||
} else {
|
||||
fmt.Fprint(channel, " ✅ Remove\r\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StaticSession(message string) ChannelHandler {
|
||||
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
||||
go ssh.DiscardRequests(requests)
|
||||
io.WriteString(channel, message+"\r\n")
|
||||
return channel.Close()
|
||||
})
|
||||
}
|
47
ssh/handler_tunnel.go
Normal file
47
ssh/handler_tunnel.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func ForwardTunnel(dialer Dialer) ChannelHandler {
|
||||
if dialer == nil {
|
||||
dialer = new(net.Dialer)
|
||||
}
|
||||
return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error {
|
||||
return errors.New("byez!")
|
||||
})
|
||||
}
|
||||
|
||||
type PortForwardRequestHandler interface {
|
||||
HandlePortForwardRequest(ctx Context, raddr, laddr net.Addr) (net.Conn, error)
|
||||
}
|
||||
|
||||
type PortForwardRequestHandlerFunc func(Context, net.Addr, net.Addr) (net.Conn, error)
|
||||
|
||||
func (f PortForwardRequestHandlerFunc) HandlePortForwardRequest(ctx Context, raddr, laddr net.Addr) (net.Conn, error) {
|
||||
return f(ctx, raddr, laddr)
|
||||
}
|
||||
|
||||
func PortForwardDialer(dialer Dialer) PortForwardRequestHandler {
|
||||
if dialer == nil {
|
||||
dialer = new(net.Dialer)
|
||||
}
|
||||
return PortForwardRequestHandlerFunc(func(ctx Context, raddr, laddr net.Addr) (net.Conn, error) {
|
||||
log := ctx.(*sshContext).log.Values(logger.Values{
|
||||
"laddr": laddr.String(),
|
||||
"raddr": raddr.String(),
|
||||
})
|
||||
log.Debug("Dialing port forwarding request")
|
||||
return dialer.DialContext(context.Background(), raddr.Network(), raddr.String())
|
||||
})
|
||||
}
|
37
ssh/keys.go
Normal file
37
ssh/keys.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
)
|
||||
|
||||
func LoadPrivateKey(name string) (ssh.Signer, error) {
|
||||
if strings.Contains(name, "-----BEGIN") && strings.Contains(name, "PRIVATE KEY-----") {
|
||||
logger.StandardLog.Debug("Loading private key from string")
|
||||
return ssh.ParsePrivateKey([]byte(name))
|
||||
}
|
||||
logger.StandardLog.Value("path", name).Debug("Loading private key")
|
||||
b, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.ParsePrivateKey(b)
|
||||
}
|
||||
|
||||
func LoadPrivateKeyWithPassphrase(name string, passphrase []byte) (ssh.Signer, error) {
|
||||
if strings.Contains(name, "-----BEGIN") && strings.Contains(name, "PRIVATE KEY-----") {
|
||||
logger.StandardLog.Debug("Loading private key from string (with passphrase)")
|
||||
return ssh.ParsePrivateKeyWithPassphrase([]byte(name), passphrase)
|
||||
}
|
||||
|
||||
logger.StandardLog.Value("path", name).Debug("Loading private key (with passphrase)")
|
||||
b, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.ParsePrivateKeyWithPassphrase(b, passphrase)
|
||||
}
|
115
ssh/server.go
Normal file
115
ssh/server.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.maze.io/maze/conduit/auth"
|
||||
"git.maze.io/maze/conduit/internal/netutil"
|
||||
"git.maze.io/maze/conduit/logger"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
// ConnectHandler gets called before accepting new connections.
|
||||
ConnectHandler []ConnectHandler
|
||||
|
||||
// AcceptHandler accepts new SSH connections, it takes care of authenticating, etc.
|
||||
AcceptHandler AcceptHandler
|
||||
|
||||
// Handler per channel type.
|
||||
ChannelHandler map[string]ChannelHandler
|
||||
|
||||
// PortForwardHandler
|
||||
PortForwardHandler PortForwardRequestHandler
|
||||
|
||||
// FIPSMode enables FIPS 140-2 compatible ciphers, key exchanges, etc.
|
||||
FIPSMode bool // TODO(maze): implement
|
||||
|
||||
// Logger for our server.
|
||||
Logger logger.Structured
|
||||
|
||||
// serverConfig is our SSH server configuration.
|
||||
serverConfig *ssh.ServerConfig
|
||||
}
|
||||
|
||||
func NewServer(keys []ssh.Signer) *Server {
|
||||
config := new(ssh.ServerConfig)
|
||||
config.SetDefaults()
|
||||
config.ServerVersion = "SSH-2.0-conduit"
|
||||
|
||||
for _, key := range keys {
|
||||
config.AddHostKey(key)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
ChannelHandler: make(map[string]ChannelHandler),
|
||||
Logger: logger.StandardLog,
|
||||
serverConfig: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handle(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handle(c net.Conn) {
|
||||
log := s.Logger.Values(logger.Values{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
})
|
||||
log.Debug("New client connection")
|
||||
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil && !netutil.IsClosing(err) {
|
||||
log = log.Err(err)
|
||||
}
|
||||
log.Debug("Closing client connection")
|
||||
}()
|
||||
|
||||
for _, h := range s.ConnectHandler {
|
||||
n, err := h.HandleConnect(c)
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error from connect handler, closing client connection")
|
||||
return
|
||||
} else if n != nil {
|
||||
log.Debugf("Replacing client connection with %T", n)
|
||||
c = n
|
||||
}
|
||||
}
|
||||
|
||||
// Configure our SSH server.
|
||||
handler := s.AcceptHandler
|
||||
if handler == nil {
|
||||
log.Warn("No accept handler configured, using NO AUTHENTICATION")
|
||||
handler = auth.None{}
|
||||
}
|
||||
|
||||
// We made it, now let's talk some SSH.
|
||||
sshConn, channels, requests, err := handler.HandleAccept(c, s.serverConfig)
|
||||
if err != nil {
|
||||
log.Err(err).Warn("Error establishing SSH session with client")
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(requests)
|
||||
|
||||
log = log.Value("user", sshConn.User())
|
||||
ctx := newSSHContext(s, c, sshConn, log)
|
||||
log = log.Value("context", ctx.ID())
|
||||
log.Value("version", string(sshConn.ClientVersion())).Info("New SSH client")
|
||||
|
||||
if err = ctx.handleChannels(channels); err != nil {
|
||||
if netutil.IsClosing(err) {
|
||||
log.Err(err).Debug("Client handler terminated")
|
||||
} else {
|
||||
log.Err(err).Warn("Error handling channel requests")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
62
ssh/sshutil/key.go
Normal file
62
ssh/sshutil/key.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package sshutil
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"math/big"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func KeyBits(key ssh.PublicKey) int {
|
||||
if key == nil {
|
||||
return 0
|
||||
}
|
||||
switch key.Type() {
|
||||
case ssh.KeyAlgoECDSA256:
|
||||
return 256
|
||||
case ssh.KeyAlgoSKECDSA256:
|
||||
return 256
|
||||
case ssh.KeyAlgoECDSA384:
|
||||
return 384
|
||||
case ssh.KeyAlgoECDSA521:
|
||||
return 521
|
||||
case ssh.KeyAlgoED25519:
|
||||
return 256
|
||||
case ssh.KeyAlgoSKED25519:
|
||||
return 256
|
||||
case ssh.KeyAlgoRSA:
|
||||
var w struct {
|
||||
Name string
|
||||
E *big.Int
|
||||
N *big.Int
|
||||
Rest []byte `ssh:"rest"`
|
||||
}
|
||||
_ = ssh.Unmarshal(key.Marshal(), &w)
|
||||
return w.N.BitLen()
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func KeyType(key ssh.PublicKey) string {
|
||||
if key == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
switch key.Type() {
|
||||
case ssh.KeyAlgoECDSA256, ssh.KeyAlgoSKECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
|
||||
return "ECDSA"
|
||||
case ssh.KeyAlgoED25519, ssh.KeyAlgoSKED25519:
|
||||
return "ED25519"
|
||||
case ssh.KeyAlgoRSA:
|
||||
return "RSA"
|
||||
default:
|
||||
return key.Type()
|
||||
}
|
||||
}
|
||||
|
||||
func KeyFingerprint(key ssh.PublicKey) string {
|
||||
h := sha256.New()
|
||||
h.Write(key.Marshal())
|
||||
return "SHA256:" + base64.RawStdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
47
ssh/sshutil/request.go
Normal file
47
ssh/sshutil/request.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package sshutil
|
||||
|
||||
import "golang.org/x/crypto/ssh"
|
||||
|
||||
type EnvRequest struct {
|
||||
Key, Value string
|
||||
}
|
||||
|
||||
func ParseEnvRequest(data []byte) (*EnvRequest, error) {
|
||||
r := new(EnvRequest)
|
||||
if err := ssh.Unmarshal(data, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type PTYRequest struct {
|
||||
Term string
|
||||
Columns uint32
|
||||
Rows uint32
|
||||
Width uint32
|
||||
Height uint32
|
||||
ModeList []byte
|
||||
}
|
||||
|
||||
func ParsePTYRequest(data []byte) (*PTYRequest, error) {
|
||||
r := new(PTYRequest)
|
||||
if err := ssh.Unmarshal(data, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type WindowChangeRequest struct {
|
||||
Columns uint32
|
||||
Rows uint32
|
||||
Width uint32
|
||||
Height uint32
|
||||
}
|
||||
|
||||
func ParseWindowChangeRequest(data []byte) (*WindowChangeRequest, error) {
|
||||
r := new(WindowChangeRequest)
|
||||
if err := ssh.Unmarshal(data, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
Reference in New Issue
Block a user