You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

380 lines
9.1 KiB

package secureshell
import (
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"maze.io/x/secureshell"
"maze.io/gate/pkg/core"
"maze.io/gate/pkg/core/logger"
"maze.io/gate/pkg/util/compact"
)
// DefaultServer runs with defaults.
var DefaultServer = Server{
Server: secureshell.Server{
Addr: ":4222",
},
}
// Server can accept incoming SSH client connections.
type Server struct {
secureshell.Server
// IdentityProvider is used to authenticate and resolve users and groups.
IdentityProvider core.IdentityProvider
// TunnelDialer is used to establish new tunneled connections.
TunnelDialer func(transport core.Transport, addr string) (net.Conn, error)
// Banner is displayed after the key exchange.
Banner string
// Message is displayed after a session has been established.
Message string `hcl:"motd"`
// Prompt is the prompt format string.
Prompt string
log *logger.Logger
keys []ssh.Signer
listener net.Listener
mutex sync.RWMutex
conns map[compact.ID]*transport
}
func (sshd *Server) Component() string {
return "sshd"
}
func (sshd *Server) Setup(provider core.IdentityProvider) error {
sshd.IdentityProvider = provider
return nil
}
func (sshd *Server) Start(log *logger.Logger, errs chan<- error) error {
sshd.log = log.WithField(logger.Component, sshd.Component())
sshd.conns = make(map[compact.ID]*transport)
if len(sshd.KeyFiles) == 0 {
return errors.New("sshd: no keys configured")
}
sshd.Config = &ssh.ServerConfig{
AuthLogCallback: sshd.authLog,
PasswordCallback: sshd.wrapPasswordCallback,
PublicKeyCallback: sshd.wrapPublicKeyCallback,
}
sshd.ConnectCallback = sshd.handleConnect
sshd.DisconnectCallback = sshd.handleDisconnect
sshd.SessionHandler = sshd.handleSession
sshd.TunnelHandler = sshd.handleTunnel
sshd.ReverseTunnelHandler = sshd.handleReverseTunnel
if sshd.Banner != "" {
sshd.Config.BannerCallback = func(metadata ssh.ConnMetadata) string {
// TODO(maze): format strings?
return sshd.Banner
}
}
go sshd.serve(errs)
return nil
}
// wrappedMetadata overrides the user.
type wrappedMetadata struct {
ssh.ConnMetadata
}
func wrapMetadata(metadata ssh.ConnMetadata) ssh.ConnMetadata {
if wrapped, ok := metadata.(wrappedMetadata); ok {
return wrapped
}
return wrappedMetadata{metadata}
}
func (meta wrappedMetadata) User() string {
user := meta.ConnMetadata.User()
if i := strings.IndexByte(user, '@'); i > -1 {
return user[:i]
}
return user
}
func (sshd *Server) wrapPasswordCallback(metadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
metadata = wrapMetadata(metadata)
return sshd.IdentityProvider.PasswordCallback(metadata, password)
}
func (sshd *Server) wrapPublicKeyCallback(metadata ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
metadata = wrapMetadata(metadata)
return sshd.IdentityProvider.PublicKeyCallback(metadata, key)
}
func (sshd *Server) serve(errors chan<- error) {
sshd.log.Info("start server")
defer sshd.log.Warn("server stopped")
if err := sshd.ListenAndServe(); err != nil {
errors <- err
}
}
func (sshd *Server) authLog(metadata ssh.ConnMetadata, method string, err error) {
if err == ssh.ErrNoAuth {
return
}
metadata = wrapMetadata(metadata)
log := sshd.log.WithFields(logger.Fields{
"session": compact.Bytes(metadata.SessionID()),
logger.User: metadata.User(),
"method": method,
}).WithSourceAddr(metadata.RemoteAddr())
if err != nil {
log.WithError(err).Warn("authentication failure")
} else {
log.Info("authenticated")
}
}
func (sshd *Server) addTransport(conn ssh.Conn) (tr *transport, err error) {
metadata := wrapMetadata(conn)
sid := compact.Bytes(metadata.SessionID())
sshd.mutex.Lock()
defer sshd.mutex.Unlock()
var dupe bool
if tr, dupe = sshd.conns[sid]; dupe {
return
}
u, err := sshd.IdentityProvider.LookupUser(metadata.User())
if err != nil {
return nil, err
}
tr = newTransport(conn, u)
tr.setupLogger(sshd.log)
sshd.conns[sid] = tr
return
}
func (sshd *Server) getTransport(conn ssh.Conn) (tr *transport, ok bool) {
return sshd.getTransportByID(compact.Bytes(conn.SessionID()))
}
func (sshd *Server) getTransportByID(id compact.ID) (tr *transport, ok bool) {
sshd.mutex.RLock()
tr, ok = sshd.conns[id]
sshd.mutex.RUnlock()
return
}
func (sshd *Server) deleteTransport(conn ssh.Conn) (tr *transport) {
sid := compact.Bytes(conn.SessionID())
sshd.mutex.Lock()
defer sshd.mutex.Unlock()
var ok bool
if tr, ok = sshd.conns[sid]; !ok {
return
}
delete(sshd.conns, sid)
for _, tun := range tr.tunnels {
_ = tun.Close()
}
return
}
func (sshd *Server) handleConnect(conn ssh.Conn) error {
// TODO(maze): connection has been authenticated here, but we may want to do additional policies here?
tr, err := sshd.addTransport(conn)
if err != nil {
sshd.log.WithError(err).Warn("add transport failed")
return err
}
tr.Info("new connection")
return nil
}
func (sshd *Server) handleDisconnect(conn ssh.Conn, err error) {
tr := sshd.deleteTransport(conn)
if tr == nil {
return
}
if err != nil && err != io.EOF {
tr.WithError(err).Warn("connection lost")
} else {
tr.Info("connection lost")
}
return
}
func (sshd *Server) handleSession(session secureshell.Session) {
metadata := wrapMetadata(session)
log := sshd.log.WithFields(logger.Fields{
"session": compact.Bytes(session.SessionID()),
logger.User: metadata.User(),
})
log.Info("session start")
defer log.Info("session ended")
tr, ok := sshd.getTransportByID(compact.Bytes(session.SessionID()))
if !ok {
log.Error("transport not found for session")
return
}
t := NewSessionTerminal(tr, session)
if sshd.Prompt != "" {
t.SetPrompt(sshd.Prompt)
}
if t.IsInteractive() && sshd.Message != "" {
// TODO(maze): string formatting?
fmt.Fprintln(t, strings.TrimSpace(sshd.Message))
}
if err := t.Run(); err != nil {
if err != io.EOF {
log.WithError(err).Warn("shell terminated")
}
}
}
// handleTunnel is used for:
// * local port forwarding (ssh -L)
// * dynamic application-level port forwarding (ssh -D)
func (sshd *Server) handleTunnel(conn ssh.Conn, address string) (net.Conn, error) {
// TODO(maze): policies for tunnel requests
// TODO(maze): protocol detection for established connection (see pkg/net/sniff)
// TODO(maze): session logging for tunnel connections
tr, ok := sshd.getTransport(conn)
if !ok {
return nil, errors.New("unauthorized tunnel attempt")
}
log := tr.WithFields(logger.Fields{
logger.TargetAddr: address,
})
log.Info("tunnel request")
c, err := net.Dial("tcp", address)
if err != nil {
log.WithError(err).Warn("tunnel connect failed")
return nil, err
}
tr.AddTunnel(c)
return c, nil
}
// handleReverseTunnel is used for:
// * remote port forwarding (ssh -R)
func (sshd *Server) handleReverseTunnel(conn ssh.Conn, _, address string) (net.Listener, error) {
return nil, errors.New("reverse tunnels are not permitted")
}
// Transports returns connected SSH clients.
func (sshd *Server) Transports() []core.Transport {
sshd.mutex.RLock()
ts := make([]core.Transport, 0, len(sshd.conns))
for _, t := range sshd.conns {
ts = append(ts, t)
}
sshd.mutex.RUnlock()
return ts
}
func ptyModes(modes map[byte]uint32) string {
s := make([]string, 0, len(modes))
for k, v := range modes {
switch k {
case ssh.VINTR:
s = append(s, fmt.Sprintf("interrupt=%#02x", byte(v)))
case ssh.VQUIT:
s = append(s, fmt.Sprintf("quit=%#02x", byte(v)))
case ssh.VERASE:
case ssh.VKILL:
case ssh.VEOF:
s = append(s, fmt.Sprintf("eof=%#02x", byte(v)))
case ssh.VEOL:
s = append(s, fmt.Sprintf("eol=%#02x", byte(v)))
case ssh.VEOL2:
s = append(s, fmt.Sprintf("eol2=%#02x", byte(v)))
case ssh.VSTART:
s = append(s, fmt.Sprintf("continue=%#02x", byte(v)))
case ssh.VSTOP:
s = append(s, fmt.Sprintf("pause=%#02x", byte(v)))
case ssh.VSUSP:
s = append(s, fmt.Sprintf("suspend=%#02x", byte(v)))
case ssh.VDSUSP:
s = append(s, fmt.Sprintf("suspend2=%#02x", byte(v)))
case ssh.VREPRINT:
s = append(s, fmt.Sprintf("reprint=%#02x", byte(v)))
case ssh.VWERASE:
s = append(s, fmt.Sprintf("erase word left=%#02x", byte(v)))
case ssh.VLNEXT:
case ssh.VFLUSH:
case ssh.VSWTCH:
case ssh.VSTATUS:
case ssh.VDISCARD:
case ssh.IGNPAR:
case ssh.PARMRK:
case ssh.INPCK:
case ssh.ISTRIP:
case ssh.INLCR:
case ssh.IGNCR:
case ssh.ICRNL:
case ssh.IUCLC:
case ssh.IXON:
case ssh.IXANY:
case ssh.IXOFF:
case ssh.IMAXBEL:
case 42:
s = append(s, "UTF-8")
case ssh.ISIG:
case ssh.ICANON:
case ssh.XCASE:
case ssh.ECHO:
s = append(s, fmt.Sprintf("echo=%t", v != 0))
case ssh.ECHOE:
case ssh.ECHOK:
case ssh.ECHONL:
case ssh.NOFLSH:
case ssh.TOSTOP:
case ssh.IEXTEN:
case ssh.ECHOCTL:
case ssh.ECHOKE:
case ssh.PENDIN:
case ssh.OPOST:
case ssh.OLCUC:
case ssh.ONLCR:
case ssh.OCRNL:
case ssh.ONOCR:
case ssh.ONLRET:
case ssh.CS7:
s = append(s, "7-bit mode")
case ssh.CS8:
s = append(s, "8-bit mode")
case ssh.PARENB:
case ssh.PARODD:
case ssh.TTY_OP_ISPEED:
s = append(s, fmt.Sprintf("input baud=%d", v))
case ssh.TTY_OP_OSPEED:
s = append(s, fmt.Sprintf("output baud=%d", v))
default:
s = append(s, fmt.Sprintf("%d=%d", k, v))
}
}
return strings.Join(s, ";")
}