@ -0,0 +1,195 @@ | |||
package authority | |||
import ( | |||
"crypto/rand" | |||
"io/ioutil" | |||
"net" | |||
"os" | |||
"path/filepath" | |||
"strings" | |||
"time" | |||
"golang.org/x/crypto/ssh" | |||
) | |||
const ( | |||
defaultCertificateLifetime = 5 * time.Minute | |||
defaultLifetimeSplay = 5 * time.Second | |||
defaultKeyPoolSize = 4 | |||
) | |||
// Issuer for SSH keys. | |||
type Issuer interface { | |||
HostIssuer | |||
UserIssuer | |||
} | |||
// HostIssuer can issue host SSH keys. | |||
type HostIssuer interface { | |||
// IssueHostKey returns the host key for the given host. | |||
IssueHostKey(name string) (ssh.Signer, error) | |||
} | |||
// UserIssuer can issue user SSH keys. | |||
type UserIssuer interface { | |||
// IssueUserKey returns the user key for the given user. | |||
IssueUserKey(name string) (ssh.Signer, error) | |||
} | |||
type ephemeralKeyIssuer struct { | |||
pool *keyPool | |||
} | |||
// EphemeralKeyIssuer can issue ephemeral 2048-bit RSA keys. | |||
func EphemeralKeyIssuer() Issuer { | |||
return ephemeralKeyIssuer{pool: newHostKeyPool(defaultKeyPoolSize)} | |||
} | |||
func (issuer ephemeralKeyIssuer) IssueHostKey(_ string) (ssh.Signer, error) { | |||
return issuer.pool.Signer(), nil | |||
} | |||
func (issuer ephemeralKeyIssuer) IssueUserKey(_ string) (ssh.Signer, error) { | |||
return issuer.pool.Signer(), nil | |||
} | |||
type cachedKeyIssuer struct { | |||
root string | |||
pool *keyPool | |||
} | |||
// CachedKeyIssuer can issue keys from cache, or create a new 2048-bit RSA key. | |||
func CachedKeyIssuer(root string) (Issuer, error) { | |||
for _, dir := range []string{ | |||
root, | |||
filepath.Join(root, "host"), | |||
filepath.Join(root, "user"), | |||
} { | |||
if err := os.MkdirAll(dir, 0700); err != nil && !os.IsExist(err) { | |||
return nil, err | |||
} | |||
} | |||
return &cachedKeyIssuer{ | |||
root: root, | |||
pool: newHostKeyPool(defaultKeyPoolSize), | |||
}, nil | |||
} | |||
func (issuer cachedKeyIssuer) IssueHostKey(name string) (ssh.Signer, error) { | |||
return issuer.getOrCreateKey(filepath.Join(issuer.root, "host", strings.ToLower(name)+".key")) | |||
} | |||
func (issuer cachedKeyIssuer) IssueUserKey(name string) (ssh.Signer, error) { | |||
return issuer.getOrCreateKey(filepath.Join(issuer.root, "user", strings.ToLower(name)+".key")) | |||
} | |||
func (issuer cachedKeyIssuer) getOrCreateKey(name string) (ssh.Signer, error) { | |||
if signer, err := loadSigner(name); err != nil && !os.IsNotExist(err) { | |||
return nil, err | |||
} else if err == nil { | |||
return signer, nil | |||
} | |||
privateKey := issuer.pool.Key() | |||
if err := savePrivateKey(name, privateKey); err != nil { | |||
return nil, err | |||
} | |||
return ssh.NewSignerFromKey(privateKey) | |||
} | |||
// CertificateAuthority authority can issue host and user certificates. | |||
type CertificateAuthority struct { | |||
signer ssh.Signer | |||
issuer Issuer | |||
lifetime time.Duration | |||
} | |||
func NewCertificateAuthority(caKeyFile string, issuer Issuer) (*CertificateAuthority, error) { | |||
if issuer == nil { | |||
issuer = EphemeralKeyIssuer() | |||
} | |||
pemBytes, err := ioutil.ReadFile(caKeyFile) | |||
if err != nil { | |||
return nil, err | |||
} | |||
var signer ssh.Signer | |||
if signer, err = ssh.ParsePrivateKey(pemBytes); err != nil { | |||
return nil, err | |||
} | |||
return &CertificateAuthority{ | |||
signer: signer, | |||
lifetime: defaultCertificateLifetime, | |||
issuer: issuer, | |||
}, nil | |||
} | |||
func (ca *CertificateAuthority) SignCertificate(cert *ssh.Certificate) error { | |||
return cert.SignCert(rand.Reader, ca.signer) | |||
} | |||
func (ca *CertificateAuthority) SignUserCertificate(cert *ssh.Certificate) error { | |||
cert.CertType = ssh.UserCert | |||
return ca.SignCertificate(cert) | |||
} | |||
func (ca *CertificateAuthority) SignHostCertificate(cert *ssh.Certificate) error { | |||
cert.CertType = ssh.HostCert | |||
return ca.SignCertificate(cert) | |||
} | |||
func (ca *CertificateAuthority) HostCertificate(name string) ([]ssh.Signer, error) { | |||
ips, err := net.LookupIP(name) | |||
if err != nil { | |||
return nil, err | |||
} | |||
principals := []string{name} | |||
for _, ip := range ips { | |||
principals = append(principals, ip.String()) | |||
} | |||
var signer ssh.Signer | |||
if signer, err = ca.issuer.IssueHostKey(name); err != nil { | |||
return nil, err | |||
} | |||
var ( | |||
now = time.Now() | |||
validAfter = now.Add(-defaultLifetimeSplay) | |||
validBefore = now.Add(defaultLifetimeSplay + ca.lifetime) | |||
cert = &ssh.Certificate{ | |||
Nonce: nil, | |||
Key: signer.PublicKey(), | |||
Serial: uint64(now.UnixNano()), | |||
CertType: ssh.HostCert, | |||
KeyId: name, | |||
ValidPrincipals: uniqueStrings(principals), | |||
ValidAfter: uint64(validAfter.Unix()), | |||
ValidBefore: uint64(validBefore.Unix()), | |||
} | |||
certSigner ssh.Signer | |||
) | |||
if err = cert.SignCert(rand.Reader, ca.signer); err != nil { | |||
return nil, err | |||
} | |||
if certSigner, err = ssh.NewCertSigner(cert, signer); err != nil { | |||
return nil, err | |||
} | |||
return []ssh.Signer{certSigner, signer}, nil | |||
} | |||
func uniqueStrings(values []string) []string { | |||
var ( | |||
keys = make(map[string]bool) | |||
uniq []string | |||
) | |||
for _, key := range values { | |||
if !keys[key] { | |||
keys[key] = true | |||
uniq = append(uniq, key) | |||
} | |||
} | |||
return uniq | |||
} |
@ -0,0 +1,124 @@ | |||
package authority | |||
import ( | |||
"crypto" | |||
"crypto/dsa" | |||
"crypto/ecdsa" | |||
"crypto/rand" | |||
"crypto/rsa" | |||
"crypto/x509" | |||
"encoding/asn1" | |||
"encoding/pem" | |||
"fmt" | |||
"io/ioutil" | |||
"math/big" | |||
"github.com/sirupsen/logrus" | |||
"golang.org/x/crypto/ssh" | |||
) | |||
type keyPool struct { | |||
keys chan *rsa.PrivateKey | |||
} | |||
func newHostKeyPool(size int) *keyPool { | |||
pool := &keyPool{keys: make(chan *rsa.PrivateKey, size)} | |||
go pool.fill() | |||
return pool | |||
} | |||
func (pool *keyPool) fill() { | |||
log := logrus.WithField("tag", "key_pool") | |||
for { | |||
log.Debug("generating RSA key") | |||
k, err := rsa.GenerateKey(rand.Reader, 2048) | |||
if err == nil { | |||
pool.keys <- k | |||
} | |||
} | |||
} | |||
func (pool *keyPool) Key() *rsa.PrivateKey { | |||
return <-pool.keys | |||
} | |||
func (pool *keyPool) Signer() ssh.Signer { | |||
signer, _ := ssh.NewSignerFromKey(<-pool.keys) | |||
return signer | |||
} | |||
func loadPrivateKey(name string) (key crypto.PrivateKey, err error) { | |||
var pemBytes []byte | |||
if pemBytes, err = ioutil.ReadFile(name); err != nil { | |||
return | |||
} | |||
var block *pem.Block | |||
for { | |||
if block, pemBytes = pem.Decode(pemBytes); block == nil { | |||
return nil, fmt.Errorf("authority: no private key found in %s", name) | |||
} | |||
switch block.Type { | |||
case "DSA PRIVATE KEY": | |||
return ssh.ParseDSAPrivateKey(block.Bytes) | |||
case "RSA PRIVATE KEY": | |||
return x509.ParsePKCS1PrivateKey(block.Bytes) | |||
case "EC PRIVATE KEY": | |||
return x509.ParseECPrivateKey(block.Bytes) | |||
case "PRIVATE KEY": | |||
return x509.ParsePKCS8PrivateKey(block.Bytes) | |||
} | |||
} | |||
} | |||
func loadSigner(name string) (signer ssh.Signer, err error) { | |||
var privateKey crypto.PrivateKey | |||
if privateKey, err = loadPrivateKey(name); err != nil { | |||
return | |||
} | |||
return ssh.NewSignerFromKey(privateKey) | |||
} | |||
func savePrivateKey(name string, key crypto.PrivateKey) (err error) { | |||
var pemBytes []byte | |||
switch key := key.(type) { | |||
case *dsa.PrivateKey: | |||
var k struct { | |||
Version int | |||
P *big.Int | |||
Q *big.Int | |||
G *big.Int | |||
Pub *big.Int | |||
Priv *big.Int | |||
} | |||
k.Version = 1 | |||
k.P = key.P | |||
k.Q = key.Q | |||
k.G = key.G | |||
k.Pub = key.PublicKey.Y | |||
k.Priv = key.X | |||
var derBytes []byte | |||
if derBytes, err = asn1.Marshal(&k); err != nil { | |||
return | |||
} | |||
pemBytes = pem.EncodeToMemory(&pem.Block{ | |||
Type: "DSA PRIVATE KEY", | |||
Bytes: derBytes, | |||
}) | |||
case *ecdsa.PrivateKey: | |||
var derBytes []byte | |||
if derBytes, err = x509.MarshalECPrivateKey(key); err != nil { | |||
return | |||
} | |||
pemBytes = pem.EncodeToMemory(&pem.Block{ | |||
Type: "EC PRIVATE KEY", | |||
Bytes: derBytes, | |||
}) | |||
case *rsa.PrivateKey: | |||
pemBytes = pem.EncodeToMemory(&pem.Block{ | |||
Type: "RSA PRIVATE KEY", | |||
Bytes: x509.MarshalPKCS1PrivateKey(key), | |||
}) | |||
} | |||
return ioutil.WriteFile(name, pemBytes, 0600) | |||
} |
@ -1,4 +1,4 @@ | |||
package server | |||
package stronghold | |||
import ( | |||
"io/ioutil" |
@ -1,4 +1,4 @@ | |||
package server | |||
package stronghold | |||
import ( | |||
"io" |
@ -1,4 +1,4 @@ | |||
package server | |||
package stronghold | |||
import ( | |||
"encoding/hex" |
@ -1,123 +0,0 @@ | |||
package hostkey | |||
import ( | |||
"crypto/rand" | |||
"crypto/rsa" | |||
"io" | |||
"io/ioutil" | |||
"net" | |||
"time" | |||
"github.com/sirupsen/logrus" | |||
"golang.org/x/crypto/ssh" | |||
) | |||
const ( | |||
defaultCertificateLifetime = 5 * time.Minute | |||
defaultLifetimeSplay = 5 * time.Second | |||
defaultHostKeyPoolSize = 4 | |||
) | |||
// Signer for SSH host keys. | |||
type Signer interface { | |||
// HostKeys for the given address. | |||
HostKeys(address string) ([]ssh.Signer, error) | |||
} | |||
type hostKeyPool struct { | |||
keys chan ssh.Signer | |||
} | |||
func newHostKeyPool(size int) *hostKeyPool { | |||
pool := &hostKeyPool{keys: make(chan ssh.Signer, size)} | |||
go pool.fill() | |||
return pool | |||
} | |||
func (pool *hostKeyPool) fill() { | |||
log := logrus.WithField("tag", "hostkey_pool") | |||
for { | |||
log.Debug("generating RSA key") | |||
k, err := rsa.GenerateKey(rand.Reader, 2048) | |||
if err == nil { | |||
if s, err := ssh.NewSignerFromKey(k); err == nil { | |||
pool.keys <- s | |||
} | |||
} | |||
} | |||
} | |||
func (pool *hostKeyPool) Key() ssh.Signer { | |||
return <-pool.keys | |||
} | |||
type CertificateSigner struct { | |||
ca ssh.Signer | |||
lifetime time.Duration | |||
pool *hostKeyPool | |||
} | |||
func NewCerticateSigner(caKeyFile string) (*CertificateSigner, error) { | |||
b, err := ioutil.ReadFile(caKeyFile) | |||
if err != nil { | |||
return nil, err | |||
} | |||
ca, err := ssh.ParsePrivateKey(b) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &CertificateSigner{ | |||
ca: ca, | |||
lifetime: defaultCertificateLifetime, | |||
pool: newHostKeyPool(defaultHostKeyPoolSize), | |||
}, nil | |||
} | |||
func (signer *CertificateSigner) HostKeys(address string) ([]ssh.Signer, error) { | |||
host, _, err := net.SplitHostPort(address) | |||
if err != nil { | |||
return nil, err | |||
} | |||
ips, err := net.LookupIP(host) | |||
if err != nil { | |||
return nil, err | |||
} | |||
principals := []string{host} | |||
for _, ip := range ips { | |||
principals = append(principals, ip.String()) | |||
} | |||
var ( | |||
key = signer.pool.Key() | |||
now = time.Now() | |||
validAfter = now.Add(-defaultLifetimeSplay) | |||
validBefore = now.Add(defaultLifetimeSplay + signer.lifetime) | |||
cert = &ssh.Certificate{ | |||
Nonce: nil, | |||
Key: key.PublicKey(), | |||
Serial: 0, | |||
CertType: ssh.HostCert, | |||
KeyId: host, | |||
ValidPrincipals: principals, | |||
ValidAfter: uint64(validAfter.Unix()), | |||
ValidBefore: uint64(validBefore.Unix()), | |||
} | |||
) | |||
cert.Nonce = make([]byte, 16) | |||
_, _ = io.ReadFull(rand.Reader, cert.Nonce) | |||
if err = cert.SignCert(rand.Reader, signer.ca); err != nil { | |||
return nil, err | |||
} | |||
logrus.Printf("cert: %#+v", cert) | |||
certSigner, err := ssh.NewCertSigner(cert, key) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return []ssh.Signer{certSigner, key}, nil | |||
} |