Browse Source

Cleanup

master
parent
commit
407f0800af
4 changed files with 47 additions and 49 deletions
  1. +13
    -10
      client.go
  2. +4
    -4
      cmd/stronghold/main.go
  3. +16
    -16
      proxy.go
  4. +14
    -19
      server.go

+ 13
- 10
client.go View File

@ -43,11 +43,11 @@ type clientConn struct {
// SessionHandler is a callback for SSH shell sessions.
SessionHandler func(conn ssh.ConnMetadata, channel ssh.Channel, pty *PTY, windowChanges <-chan *WindowChange) error
netConn net.Conn
sshConn ssh.Conn
id string
closed chan struct{}
root string
netConn net.Conn
sshConn ssh.Conn
id string
closed chan struct{}
recordDir string
requestMutex sync.RWMutex
ptyRequest *PTY
@ -56,13 +56,13 @@ type clientConn struct {
agent agent.Agent
}
func newClient(netConn net.Conn, sshConn ssh.Conn, root string) *clientConn {
func newClient(netConn net.Conn, sshConn ssh.Conn, recordDir string) *clientConn {
return &clientConn{
netConn: netConn,
sshConn: sshConn,
id: sessionID(sshConn),
closed: make(chan struct{}),
root: root,
recordDir: recordDir,
envRequest: make(map[string]string),
}
}
@ -362,14 +362,17 @@ func (client *clientConn) handleDirectTCPIPChannel(newChannel ssh.NewChannel) (e
}
// Now that we are connected as an SSH client, accept the new channel.
channel, channelRequests, err := newChannel.Accept()
if err != nil {
var (
channel ssh.Channel
channelRequests <-chan *ssh.Request
)
if channel, channelRequests, err = newChannel.Accept(); err != nil {
return err
}
go ssh.DiscardRequests(channelRequests)
// Present a (fake) SSH server on the newly accepted channel.
server := newSSHInterceptingServer(client.root)
server := newSSHInterceptingServer(client.recordDir)
server.AddHostKeys(serverKeys...)
// Negotiate SSH with the client (over the channel).


+ 4
- 4
cmd/stronghold/main.go View File

@ -15,7 +15,7 @@ import (
var (
defaultHostKeyFile = "testdata/ssh_host_rsa_key"
defaultCAKeyFile = "testdata/ssh_host_rsa_key"
defaultRootPath = "testdata/session"
defaultRecordDir = "testdata/session"
defaultKeyCacheDir = ""
)
@ -25,7 +25,7 @@ func main() {
hostKey = flag.String("key", defaultHostKeyFile, "server host key")
caKey = flag.String("ca-key", defaultCAKeyFile, "server certificate authority key")
keyCacheDir = flag.String("key-cache", defaultKeyCacheDir, "issuer key cache directory")
root = flag.String("root", defaultRootPath, "server recording root")
recordDir = flag.String("record", defaultRecordDir, "server recording recordDir")
debug = flag.Bool("debug", false, "enable debug messages")
)
flag.Parse()
@ -42,7 +42,7 @@ func main() {
issuer = authority.EphemeralKeyIssuer()
} else {
if issuer, err = authority.CachedKeyIssuer(*keyCacheDir); err != nil {
logrus.WithField("root", *keyCacheDir).Fatalln(err)
logrus.WithField("recordDir", *keyCacheDir).Fatalln(err)
}
}
@ -52,7 +52,7 @@ func main() {
}
var server *stronghold.Server
if server, err = stronghold.New(*root, *hostKey); err != nil {
if server, err = stronghold.New(*recordDir, *hostKey); err != nil {
logrus.Fatalln(err)
}


+ 16
- 16
proxy.go View File

@ -98,7 +98,7 @@ type sshInterceptingServer struct {
config *ssh.ServerConfig
netConn net.Conn
root string
recorder io.WriteCloser
recorder io.Writer
}
func newSSHInterceptingServer(root string) *sshInterceptingServer {
@ -156,10 +156,10 @@ func (server *sshInterceptingServer) ProxyClient(client *sshClient) {
})
transcript := newSafeFileWriter(filepath.Join(server.root, hex.EncodeToString(client.SessionID()))+".ttyrec", 0600)
server.recorder = writeCloser{
Writer: ttyrec.NewEncoder(transcript),
Closer: transcript,
}
server.recorder = ttyrec.NewEncoder(transcript)
defer func() {
_ = transcript.Close()
}()
for {
select {
@ -209,8 +209,12 @@ func (server *sshInterceptingServer) ProxyClient(client *sshClient) {
func (server *sshInterceptingServer) multiplexNewChannel(log *logrus.Entry, conn ssh.Conn, newChannel ssh.NewChannel) {
log.Debug("new channel")
clientChannel, clientRequests, err := conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData())
if err != nil {
var (
clientChannel, serverChannel ssh.Channel
clientRequests, serverRequests <-chan *ssh.Request
err error
)
if clientChannel, clientRequests, err = conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()); err != nil {
log.WithError(err).Warn("failed to open channel")
if err, ok := err.(*ssh.OpenChannelError); ok {
_ = newChannel.Reject(err.Reason, err.Message)
@ -220,8 +224,7 @@ func (server *sshInterceptingServer) multiplexNewChannel(log *logrus.Entry, conn
return
}
serverChannel, serverRequests, err := newChannel.Accept()
if err != nil {
if serverChannel, serverRequests, err = newChannel.Accept(); err != nil {
log.WithError(err).Warn("failed to accept channel")
go ssh.DiscardRequests(clientRequests)
_ = clientChannel.Close()
@ -232,27 +235,24 @@ func (server *sshInterceptingServer) multiplexNewChannel(log *logrus.Entry, conn
}
func (server *sshInterceptingServer) multiplex(log *logrus.Entry, channelType string, clientChannel, serverChannel ssh.Channel, clientRequests, serverRequests <-chan *ssh.Request) {
var serverWriter io.Writer
var serverWriters io.Writer
switch channelType {
case "session":
serverWriter = io.MultiWriter(serverChannel, server.recorder)
serverWriters = io.MultiWriter(serverChannel, server.recorder)
default:
serverWriter = serverChannel
serverWriters = serverChannel
}
go func() {
_, _ = io.Copy(clientChannel, serverChannel)
}()
go func() {
_, _ = io.Copy(serverWriter, clientChannel)
_, _ = io.Copy(serverWriters, clientChannel)
}()
defer func() {
_ = clientChannel.Close()
_ = serverChannel.Close()
if server.recorder != nil {
_ = server.recorder.Close()
}
}()
for {


+ 14
- 19
server.go View File

@ -43,31 +43,25 @@ type Server struct {
// SessionHandler is a callback for SSH shell sessions.
SessionHandler func(conn ssh.ConnMetadata, channel ssh.Channel, pty *PTY, windowChanges <-chan *WindowChange) error
root string
config *ssh.ServerConfig
closed chan struct{}
recordDir string
config *ssh.ServerConfig
closed chan struct{}
}
func New(root string, hostKeyFiles ...string) (*Server, error) {
// New server
func New(recordDir string, hostKeyFiles ...string) (*Server, error) {
if len(hostKeyFiles) == 0 {
return nil, errors.New("server: at least one key file must be specified")
}
if !filepath.IsAbs(root) {
if !filepath.IsAbs(recordDir) {
var err error
if root, err = filepath.Abs(root); err != nil {
if recordDir, err = filepath.Abs(recordDir); err != nil {
return nil, err
}
}
if info, err := os.Stat(root); err != nil {
if !os.IsNotExist(err) {
return nil, err
}
if err = os.MkdirAll(root, 0700); err != nil {
return nil, err
}
} else if !info.IsDir() {
return nil, fmt.Errorf("server: root %s exists but is not a directory", root)
if err := os.MkdirAll(recordDir, 0700); err != nil && !os.IsExist(err) {
return nil, err
}
server := &Server{
@ -77,7 +71,6 @@ func New(root string, hostKeyFiles ...string) (*Server, error) {
ServerVersion: version,
},
closed: make(chan struct{}),
root: root,
}
for _, keyFile := range hostKeyFiles {
@ -91,6 +84,7 @@ func New(root string, hostKeyFiles ...string) (*Server, error) {
return server, nil
}
// ListenAndServe starts the SSH server on the given address.
func (server *Server) ListenAndServe(addr string) error {
l, err := net.Listen("tcp", addr)
if err != nil {
@ -111,8 +105,8 @@ func (server *Server) ListenAndServe(addr string) error {
}
for {
conn, err := l.Accept()
if err != nil {
var conn net.Conn
if conn, err = l.Accept(); err != nil {
logrus.WithError(err).Error("error accepting client")
continue
}
@ -136,7 +130,7 @@ func (server *Server) handleConn(netConn net.Conn) {
return
}
client := newClient(netConn, sshConn, server.root)
client := newClient(netConn, sshConn, server.recordDir)
client.CertificateAuthority = server.CertificateAuthority
client.SessionHandler = server.SessionHandler
go client.handleRequests(requests)
@ -185,6 +179,7 @@ func PermitAllPublicKeys(metadata ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Per
return new(ssh.Permissions), nil
}
// PermitAllHostKeys logs and accepts all server keys.
func PermitAllHostKeys(hostname string, remote net.Addr, key ssh.PublicKey) error {
log := logrus.WithFields(logrus.Fields{
"dst": remote.String(),


Loading…
Cancel
Save