package ssh import ( "encoding/binary" "encoding/hex" "io" "math/rand" "net" "time" "golang.org/x/crypto/ssh" "git.maze.io/maze/conduit/logger" ) var seed = rand.NewSource(time.Now().UnixNano()) type Context interface { ssh.ConnMetadata // ID is the unique identifier. ID() string // Conn is the [ssh.Conn]. Conn() ssh.Conn // NetConn is the underlying [net.Conn]. NetConn() net.Conn // Close the client connection. Close() error } type sshContext struct { id uint64 server *Server netConn net.Conn conn *ssh.ServerConn log logger.Structured } func newSSHContext(server *Server, netConn net.Conn, conn *ssh.ServerConn, log logger.Structured) *sshContext { ctx := &sshContext{ id: uint64(seed.Int63()), server: server, netConn: netConn, conn: conn, } ctx.log = log.Value("context", ctx.ID()) return ctx } // User returns the user ID for this connection. func (ctx *sshContext) User() string { return ctx.conn.User() } // SessionID returns the session hash, also denoted by H. func (ctx *sshContext) SessionID() []byte { return ctx.conn.SessionID() } // ClientVersion returns the client's version string as hashed // into the session ID. func (ctx *sshContext) ClientVersion() []byte { return ctx.conn.ClientVersion() } // ServerVersion returns the server's version string as hashed // into the session ID. func (ctx *sshContext) ServerVersion() []byte { return ctx.conn.ServerVersion() } // RemoteAddr returns the remote address for this connection. func (ctx *sshContext) RemoteAddr() net.Addr { return ctx.netConn.RemoteAddr() } // LocalAddr returns the local address for this connection. func (ctx *sshContext) LocalAddr() net.Addr { return ctx.netConn.LocalAddr() } func (ctx *sshContext) handleChannels(channels <-chan ssh.NewChannel) (err error) { for newChan := range channels { var ( kind = newChan.ChannelType() log = ctx.log.Value("channel", kind) ) log.Trace("Client requested new channel") handler, ok := ctx.server.ChannelHandler[kind] if !ok { handler = ctx.server.ChannelHandler[ChannelTypeDefault] } if handler != nil { var ( channel ssh.Channel requests <-chan *ssh.Request ) if channel, requests, err = newChan.Accept(); err != nil { return } if err = handler.HandleChannel(ctx, channel, requests, newChan.ExtraData()); err != nil { return } } else if kind == ChannelTypeDirectTCPIP && ctx.server.PortForwardHandler != nil { if err = ctx.handleDirectTCPIP(newChan); err != nil { return } } else { ctx.log.Debug("Rejecting unsupported channel type") if err = newChan.Reject(ssh.Prohibited, ""); err != nil { return } } } // Our client hang up. return io.EOF } func (ctx *sshContext) handleDirectTCPIP(newChan ssh.NewChannel) (err error) { var payload struct { Host string Port uint32 OriginAddr string OriginPort uint32 } if err = ssh.Unmarshal(newChan.ExtraData(), &payload); err != nil { _ = newChan.Reject(ssh.Prohibited, "") return } var ip net.IP if ip = net.ParseIP(payload.Host); ip == nil { // Not an IP var ips []net.IP if ips, err = net.LookupIP(payload.Host); err != nil { _ = newChan.Reject(ssh.ConnectionFailed, err.Error()) return } else if len(ips) == 0 { _ = newChan.Reject(ssh.ConnectionFailed, "") return } ip = ips[0] } var ( raddr = &net.TCPAddr{ IP: ip, Port: int(payload.Port), } laddr = &net.TCPAddr{ IP: net.ParseIP(payload.OriginAddr), Port: int(payload.OriginPort), } ) if payload.OriginAddr == "" && payload.OriginPort == 0 { laddr = nil } var conn net.Conn if conn, err = ctx.server.PortForwardHandler.HandlePortForwardRequest(ctx, raddr, laddr); err != nil { _ = newChan.Reject(ssh.ConnectionFailed, err.Error()) return } defer func() { _ = conn.Close() }() var ( channel ssh.Channel requests <-chan *ssh.Request ) if channel, requests, err = newChan.Accept(); err != nil { return } defer func() { _ = channel.Close() }() go ssh.DiscardRequests(requests) go io.Copy(channel, conn) _, err = io.Copy(conn, channel) return } func (ctx *sshContext) Conn() ssh.Conn { return ctx.conn } func (ctx *sshContext) NetConn() net.Conn { return ctx.netConn } func (ctx *sshContext) Close() error { return ctx.conn.Close() } func (ctx *sshContext) ID() string { var b [8]byte binary.BigEndian.PutUint64(b[:], ctx.id) return hex.EncodeToString(b[:]) }