package ssh import ( "crypto" "crypto/rand" "crypto/rsa" "encoding/hex" "errors" "fmt" "io" "net" "sort" "sync" "sync/atomic" "time" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "git.maze.io/maze/conduit/internal/netutil" "git.maze.io/maze/conduit/internal/stringutil" "git.maze.io/maze/conduit/logger" "git.maze.io/maze/conduit/ssh/sshutil" ) type ConnectHandler interface { HandleConnect(net.Conn) (net.Conn, error) } type ConnectHandlerFunc func(net.Conn) (net.Conn, error) func (f ConnectHandlerFunc) HandleConnect(c net.Conn) (net.Conn, error) { return f(c) } type AcceptHandler interface { HandleAccept(net.Conn, *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) } // MaxConnections limits the maximum number of connections. func MaxConnections(max int) ConnectHandler { var connections atomic.Int64 return ConnectHandlerFunc(func(c net.Conn) (net.Conn, error) { if max <= 0 { return nil, nil } else if connections.Load() >= int64(max) { return nil, errors.New("server: maximum number of connections reached") } var ( once sync.Once cc = &netutil.ConnCloser{ Conn: c, Closer: func() error { once.Do(func() { connections.Add(-1) }) return c.Close() }, } ) connections.Add(1) return cc, nil }) } type ChannelHandler interface { HandleChannel(Context, ssh.Channel, <-chan *ssh.Request, []byte) error } type ChannelHandlerFunc func(Context, ssh.Channel, <-chan *ssh.Request, []byte) error func (f ChannelHandlerFunc) HandleChannel(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, extra []byte) error { return f(ctx, channel, requests, extra) } type debugSessionInfo struct { start time.Time duration time.Duration method string agent agent.Agent agentRequest bool agentChannel ssh.Channel agentError error env map[string]string pty *sshutil.PTYRequest windowChange *sshutil.WindowChangeRequest unsupported []string } func newDebugSessionInfo() *debugSessionInfo { return &debugSessionInfo{ start: time.Now(), env: make(map[string]string), } } // DebugSession is a session channel handler that will print debug information to the client, which // may aid with troubleshooting SSH connectivity or client issues. func DebugSession() ChannelHandler { debugPrivateKey, _ := rsa.GenerateKey(rand.Reader, 1024) return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error { defer channel.Close() var ( log = ctx.(*sshContext).log done = make(chan struct{}, 1) info = newDebugSessionInfo() conn = ctx.Conn() err error ) go func() { var ( reply bool reesponse []byte ) for request := range requests { log.Values(logger.Values{ "request": request.Type, }).Trace("New session channel request") switch request.Type { case RequestTypeAgent: info.agentRequest = true agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil) if err != nil { info.agentError = err } else { go ssh.DiscardRequests(agentRequests) info.agent = agent.NewClient(agentChannel) info.agentChannel = agentChannel } reply = true case RequestTypeEnv: var payload *sshutil.EnvRequest if payload, err = sshutil.ParseEnvRequest(request.Payload); err != nil { log.Err(err).Debug("Corrupted env request payload, discarding") } else { log.Values(logger.Values{ "key": payload.Key, "value": payload.Value, }).Trace("Client requested env variable") info.env[payload.Key] = payload.Value } reply = true case RequestTypePTY: if info.pty, err = sshutil.ParsePTYRequest(request.Payload); err != nil { log.Err(err).Debug("Corrupted pty request payload, discarding") } else { log.Values(logger.Values{ "term": info.pty.Term, "size": fmt.Sprintf("%dx%d", info.pty.Columns, info.pty.Rows), }).Trace("Client requested PTY") } reply = true case RequestTypeExec, RequestTypeShell: info.method = request.Type select { case <-done: default: close(done) } //return case RequestTypeWindowChange: if info.windowChange, err = sshutil.ParseWindowChangeRequest(request.Payload); err != nil { log.Err(err).Debug("Corrupted window change request payload, discarding") } else { log.Values(logger.Values{ "size": fmt.Sprintf("%dx%d", info.windowChange.Columns, info.windowChange.Rows), }).Trace("Client requested window change") } reply = true default: log.Values(logger.Values{ "request": request.Type, "payload": hex.EncodeToString(request.Payload), }).Trace("Client requested something we don't understand, ignored") info.unsupported = append(info.unsupported, request.Type) } if request.WantReply { if err := request.Reply(reply, reesponse); err != nil { log.Err(err).Debug("Error sending session channel request reply: terminating") _ = channel.Close() select { case <-done: default: close(done) } return } } } select { case <-done: default: close(done) } }() select { case <-done: case <-time.After(10 * time.Second): io.WriteString(channel, "Timeout waiting for your client to send either an exec or shell request\r\n") } // Any requests that follow are ignored from hereon forward. go ssh.DiscardRequests(requests) // Attempt to request agent forwarding if info.agent == nil && !info.agentRequest { agentChannel, agentRequests, err := conn.OpenChannel(ChannelTypeAgent, nil) if err != nil { info.agentError = err } else { go ssh.DiscardRequests(agentRequests) info.agent = agent.NewClient(agentChannel) info.agentChannel = agentChannel } } info.duration = time.Since(info.start) printSessionInfo(ctx, channel, info, debugPrivateKey) return nil }) } func printSessionInfo(ctx Context, channel ssh.Channel, info *debugSessionInfo, key crypto.PrivateKey) { if info.agentChannel != nil { defer info.agentChannel.Close() } conn := ctx.Conn().(*ssh.ServerConn) fmt.Fprintf(channel, "It took your client %s to request %s\r\n\r\n", info.duration, info.method) fmt.Fprintf(channel, "SSH connection information:\r\n") fmt.Fprintf(channel, " Server:\r\n") fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ServerVersion()) fmt.Fprintf(channel, " Client:\r\n") fmt.Fprintf(channel, " Version: \x1b[1m%s\x1b[0m\r\n", conn.ClientVersion()) fmt.Fprintf(channel, " Username: \x1b[1m%s\x1b[0m\r\n", conn.User()) if conn.Permissions != nil { if conn.Permissions.CriticalOptions != nil { fmt.Fprint(channel, " Options:\r\n") for k := range stringutil.MapKeys(conn.Permissions.CriticalOptions) { fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.CriticalOptions[k]) } } if conn.Permissions.Extensions != nil { fmt.Fprint(channel, " Extensions:\r\n") for k := range stringutil.MapKeys(conn.Permissions.Extensions) { fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, conn.Permissions.Extensions[k]) } } } fmt.Fprintf(channel, " Environment: (%d variables sent)\r\n", len(info.env)) for k := range stringutil.MapKeys(info.env) { fmt.Fprintf(channel, " ✅ %s=%s\r\n", k, info.env[k]) } if info.pty == nil { fmt.Fprint(channel, " PTY: not requested\r\n") } else { fmt.Fprintf(channel, " PTY: %d cols, %d rows, %dx%d, %q\r\n", info.pty.Columns, info.pty.Rows, info.pty.Width, info.pty.Height, info.pty.Term) } if info.windowChange == nil { fmt.Fprint(channel, " Window: not requested\r\n") } else { fmt.Fprintf(channel, " Window: %d cols, %d rows, %dx%d", info.windowChange.Columns, info.windowChange.Rows, info.windowChange.Width, info.windowChange.Height) } if len(info.unsupported) > 0 { fmt.Fprintf(channel, " Requests: (%d unsupported):\r\n", len(info.unsupported)) sort.Strings(info.unsupported) for _, v := range info.unsupported { fmt.Fprintf(channel, " ❌ %s\r\n", v) } } else { fmt.Fprint(channel, " Requests: no unsupported requests\r\n") } if info.agentRequest { if info.agentError != nil { fmt.Fprint(channel, " Agent: requested but unavailable:\r\n") fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError) } } else { if info.agentError == nil { fmt.Fprint(channel, " Agent: not requested by client, but accepted, upgrade your client!\r\n") } else { fmt.Fprint(channel, " Agent: not requested by client, and client refused our attempt (good!):\r\n") fmt.Fprintf(channel, " ❌ %v\r\n", info.agentError) } } if info.agent != nil { fmt.Fprint(channel, " Agent: available, checking keys:\r\n") keys, err := info.agent.List() if err != nil { fmt.Fprintf(channel, " ❌ %v\r\n", err) } else { for _, key := range keys { blob := make([]byte, 8) rand.Reader.Read(blob) if _, err := info.agent.Sign(key, blob); err != nil { fmt.Fprint(channel, " ❌ ") } else { fmt.Fprint(channel, " ✅ ") } fmt.Fprintf(channel, "%-4d %s %s (%s)\r\n", sshutil.KeyBits(key), sshutil.KeyFingerprint(key), key.Comment, sshutil.KeyType(key)) } } fmt.Fprint(channel, " Agent: available, checking add/remove:\r\n") if err = info.agent.Add(agent.AddedKey{ PrivateKey: key, LifetimeSecs: 5, Comment: "Test by conduit", }); err != nil { fmt.Fprintf(channel, " ❌ Add: %v\r\n", err) } else { fmt.Fprint(channel, " ✅ Add\r\n") } pk, _ := ssh.NewPublicKey(key.(*rsa.PrivateKey).Public()) if err = info.agent.Remove(pk); err != nil { fmt.Fprintf(channel, " ❌ Remove: %v\r\n", err) } else { fmt.Fprint(channel, " ✅ Remove\r\n") } } } func StaticSession(message string) ChannelHandler { return ChannelHandlerFunc(func(ctx Context, channel ssh.Channel, requests <-chan *ssh.Request, _ []byte) error { go ssh.DiscardRequests(requests) io.WriteString(channel, message+"\r\n") return channel.Close() }) }