232 lines
5.6 KiB
Go
232 lines
5.6 KiB
Go
package mitm
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.maze.io/maze/styx/internal/cryptutil"
|
|
"git.maze.io/maze/styx/internal/log"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const DefaultValidity = 24 * time.Hour
|
|
|
|
type Authority interface {
|
|
Certificate() *x509.Certificate
|
|
TLSConfig(name string) *tls.Config
|
|
}
|
|
|
|
type authority struct {
|
|
pool *x509.CertPool
|
|
cert *x509.Certificate
|
|
key crypto.PrivateKey
|
|
keyID []byte
|
|
keyPool chan crypto.PrivateKey
|
|
cache Cache
|
|
}
|
|
|
|
func New(config *Config) (Authority, error) {
|
|
cache, err := NewCache(config.Cache)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
caConfig := config.CA
|
|
if caConfig == nil {
|
|
caConfig = new(CAConfig)
|
|
}
|
|
|
|
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
|
|
if os.IsNotExist(err) {
|
|
days := caConfig.Days
|
|
if days == 0 {
|
|
days = DefaultDays
|
|
}
|
|
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
|
|
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pool := x509.NewCertPool()
|
|
pool.AddCert(cert)
|
|
|
|
keyConfig := config.Key
|
|
if keyConfig == nil {
|
|
keyConfig = &defaultKeyConfig
|
|
}
|
|
|
|
keyPoolSize := defaultKeyConfig.Pool
|
|
if keyConfig.Pool > 0 {
|
|
keyPoolSize = keyConfig.Pool
|
|
}
|
|
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
|
|
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
|
|
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
|
|
} else {
|
|
keyPool <- key
|
|
}
|
|
|
|
go func(pool chan<- crypto.PrivateKey) {
|
|
for {
|
|
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
|
|
if err != nil {
|
|
log.Panic().Err(err).Msg("error generating private key")
|
|
}
|
|
pool <- key
|
|
}
|
|
}(keyPool)
|
|
|
|
return &authority{
|
|
pool: pool,
|
|
cert: cert,
|
|
key: key,
|
|
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
|
|
keyPool: keyPool,
|
|
cache: cache,
|
|
}, nil
|
|
}
|
|
|
|
func (ca *authority) log() log.Logger {
|
|
return log.Console.With().
|
|
Str("ca", ca.cert.Subject.String()).
|
|
Logger()
|
|
}
|
|
|
|
func (ca *authority) Certificate() *x509.Certificate {
|
|
return ca.cert
|
|
}
|
|
|
|
func (ca *authority) TLSConfig(name string) *tls.Config {
|
|
return &tls.Config{
|
|
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
log := ca.log()
|
|
if hello.ServerName != "" {
|
|
name = strings.ToLower(hello.ServerName)
|
|
log.Debug().Msg("requesting certificate for server name (SNI)")
|
|
} else {
|
|
log.Debug().Msg("requesting certificate for hostname")
|
|
}
|
|
if cert, ok := ca.getCached(name); ok {
|
|
log.Debug().
|
|
Str("subject", cert.Leaf.Subject.String()).
|
|
Str("serial", cert.Leaf.SerialNumber.String()).
|
|
Time("valid", cert.Leaf.NotAfter).
|
|
Msg("using cached certificate")
|
|
return cert, nil
|
|
}
|
|
return ca.issueFor(name)
|
|
},
|
|
NextProtos: []string{"http/1.1"},
|
|
}
|
|
}
|
|
|
|
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
|
|
log := ca.log()
|
|
|
|
if cert = ca.cache.Certificate(name); cert == nil {
|
|
if baseDomain(name) != name {
|
|
cert = ca.cache.Certificate(baseDomain(name))
|
|
}
|
|
}
|
|
if cert != nil {
|
|
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
|
|
DNSName: name,
|
|
Roots: ca.pool,
|
|
}); err != nil {
|
|
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
|
|
} else {
|
|
ok = true
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
|
|
var (
|
|
log = ca.log().With().Str("name", name).Logger()
|
|
key crypto.PrivateKey
|
|
)
|
|
select {
|
|
case key = <-ca.keyPool:
|
|
case <-time.After(5 * time.Second):
|
|
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
|
|
}
|
|
if key == nil {
|
|
panic("key pool returned nil key")
|
|
}
|
|
|
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
|
}
|
|
|
|
if part := dns.SplitDomainName(name); len(part) > 2 {
|
|
name = strings.Join(part[1:], ".")
|
|
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
|
|
}
|
|
|
|
now := time.Now()
|
|
template := &x509.Certificate{
|
|
SerialNumber: serialNumber,
|
|
Subject: pkix.Name{CommonName: name},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
DNSNames: []string{name, "*." + name},
|
|
BasicConstraintsValid: true,
|
|
NotBefore: now.Add(-DefaultValidity),
|
|
NotAfter: now.Add(+DefaultValidity),
|
|
}
|
|
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cert, err := x509.ParseCertificate(der)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
|
|
out := &tls.Certificate{
|
|
Certificate: [][]byte{der},
|
|
Leaf: cert,
|
|
PrivateKey: key,
|
|
}
|
|
//ca.cache[name] = out
|
|
ca.cache.SaveCertificate(name, out)
|
|
return out, nil
|
|
}
|
|
|
|
func containsValidCertificate(cert *tls.Certificate) bool {
|
|
if cert == nil || len(cert.Certificate) == 0 {
|
|
return false
|
|
}
|
|
|
|
if cert.Leaf == nil {
|
|
var err error
|
|
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
|
return false
|
|
}
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
|
|
}
|