Files
styx/proxy/mitm/authority.go
2025-09-26 08:49:53 +02:00

232 lines
5.6 KiB
Go

package mitm
import (
"crypto"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"math/big"
"os"
"strings"
"time"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/log"
"github.com/miekg/dns"
)
const DefaultValidity = 24 * time.Hour
type Authority interface {
Certificate() *x509.Certificate
TLSConfig(name string) *tls.Config
}
type authority struct {
pool *x509.CertPool
cert *x509.Certificate
key crypto.PrivateKey
keyID []byte
keyPool chan crypto.PrivateKey
cache Cache
}
func New(config *Config) (Authority, error) {
cache, err := NewCache(config.Cache)
if err != nil {
return nil, err
}
caConfig := config.CA
if caConfig == nil {
caConfig = new(CAConfig)
}
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
if os.IsNotExist(err) {
days := caConfig.Days
if days == 0 {
days = DefaultDays
}
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
return nil, err
}
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
return nil, err
}
}
} else if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(cert)
keyConfig := config.Key
if keyConfig == nil {
keyConfig = &defaultKeyConfig
}
keyPoolSize := defaultKeyConfig.Pool
if keyConfig.Pool > 0 {
keyPoolSize = keyConfig.Pool
}
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
} else {
keyPool <- key
}
go func(pool chan<- crypto.PrivateKey) {
for {
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
if err != nil {
log.Panic().Err(err).Msg("error generating private key")
}
pool <- key
}
}(keyPool)
return &authority{
pool: pool,
cert: cert,
key: key,
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
keyPool: keyPool,
cache: cache,
}, nil
}
func (ca *authority) log() log.Logger {
return log.Console.With().
Str("ca", ca.cert.Subject.String()).
Logger()
}
func (ca *authority) Certificate() *x509.Certificate {
return ca.cert
}
func (ca *authority) TLSConfig(name string) *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log := ca.log()
if hello.ServerName != "" {
name = strings.ToLower(hello.ServerName)
log.Debug().Msg("requesting certificate for server name (SNI)")
} else {
log.Debug().Msg("requesting certificate for hostname")
}
if cert, ok := ca.getCached(name); ok {
log.Debug().
Str("subject", cert.Leaf.Subject.String()).
Str("serial", cert.Leaf.SerialNumber.String()).
Time("valid", cert.Leaf.NotAfter).
Msg("using cached certificate")
return cert, nil
}
return ca.issueFor(name)
},
NextProtos: []string{"http/1.1"},
}
}
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
log := ca.log()
if cert = ca.cache.Certificate(name); cert == nil {
if baseDomain(name) != name {
cert = ca.cache.Certificate(baseDomain(name))
}
}
if cert != nil {
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
DNSName: name,
Roots: ca.pool,
}); err != nil {
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
} else {
ok = true
}
}
return
}
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
var (
log = ca.log().With().Str("name", name).Logger()
key crypto.PrivateKey
)
select {
case key = <-ca.keyPool:
case <-time.After(5 * time.Second):
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
}
if key == nil {
panic("key pool returned nil key")
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
}
if part := dns.SplitDomainName(name); len(part) > 2 {
name = strings.Join(part[1:], ".")
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{name, "*." + name},
BasicConstraintsValid: true,
NotBefore: now.Add(-DefaultValidity),
NotAfter: now.Add(+DefaultValidity),
}
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
if err != nil {
return nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
out := &tls.Certificate{
Certificate: [][]byte{der},
Leaf: cert,
PrivateKey: key,
}
//ca.cache[name] = out
ca.cache.SaveCertificate(name, out)
return out, nil
}
func containsValidCertificate(cert *tls.Certificate) bool {
if cert == nil || len(cert.Certificate) == 0 {
return false
}
if cert.Leaf == nil {
var err error
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
return false
}
}
now := time.Now()
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
}