package mitm import ( "crypto" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "errors" "fmt" "math/big" "os" "strings" "time" "git.maze.io/maze/styx/internal/cryptutil" "git.maze.io/maze/styx/internal/log" "github.com/miekg/dns" ) const DefaultValidity = 24 * time.Hour type Authority interface { Certificate() *x509.Certificate TLSConfig(name string) *tls.Config } type authority struct { pool *x509.CertPool cert *x509.Certificate key crypto.PrivateKey keyID []byte keyPool chan crypto.PrivateKey cache Cache } func New(config *Config) (Authority, error) { cache, err := NewCache(config.Cache) if err != nil { return nil, err } caConfig := config.CA if caConfig == nil { caConfig = new(CAConfig) } cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key) if os.IsNotExist(err) { days := caConfig.Days if days == 0 { days = DefaultDays } if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil { return nil, err } if strings.ContainsRune(caConfig.Cert, os.PathSeparator) { if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil { return nil, err } } } else if err != nil { return nil, err } pool := x509.NewCertPool() pool.AddCert(cert) keyConfig := config.Key if keyConfig == nil { keyConfig = &defaultKeyConfig } keyPoolSize := defaultKeyConfig.Pool if keyConfig.Pool > 0 { keyPoolSize = keyConfig.Pool } keyPool := make(chan crypto.PrivateKey, keyPoolSize) if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil { return nil, fmt.Errorf("mitm: invalid key configuration: %w", err) } else { keyPool <- key } go func(pool chan<- crypto.PrivateKey) { for { key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits) if err != nil { log.Panic().Err(err).Msg("error generating private key") } pool <- key } }(keyPool) return &authority{ pool: pool, cert: cert, key: key, keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)), keyPool: keyPool, cache: cache, }, nil } func (ca *authority) log() log.Logger { return log.Console.With(). Str("ca", ca.cert.Subject.String()). Logger() } func (ca *authority) Certificate() *x509.Certificate { return ca.cert } func (ca *authority) TLSConfig(name string) *tls.Config { return &tls.Config{ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { log := ca.log() if hello.ServerName != "" { name = strings.ToLower(hello.ServerName) log.Debug().Msg("requesting certificate for server name (SNI)") } else { log.Debug().Msg("requesting certificate for hostname") } if cert, ok := ca.getCached(name); ok { log.Debug(). Str("subject", cert.Leaf.Subject.String()). Str("serial", cert.Leaf.SerialNumber.String()). Time("valid", cert.Leaf.NotAfter). Msg("using cached certificate") return cert, nil } return ca.issueFor(name) }, NextProtos: []string{"http/1.1"}, } } func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) { log := ca.log() if cert = ca.cache.Certificate(name); cert == nil { if baseDomain(name) != name { cert = ca.cache.Certificate(baseDomain(name)) } } if cert != nil { if _, err := cert.Leaf.Verify(x509.VerifyOptions{ DNSName: name, Roots: ca.pool, }); err != nil { log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache") } else { ok = true } } return } func (ca *authority) issueFor(name string) (*tls.Certificate, error) { var ( log = ca.log().With().Str("name", name).Logger() key crypto.PrivateKey ) select { case key = <-ca.keyPool: case <-time.After(5 * time.Second): return nil, errors.New("mitm: timeout waiting for private key generator to catch up") } if key == nil { panic("key pool returned nil key") } serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err) } if part := dns.SplitDomainName(name); len(part) > 2 { name = strings.Join(part[1:], ".") log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name) } now := time.Now() template := &x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{CommonName: name}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, DNSNames: []string{name, "*." + name}, BasicConstraintsValid: true, NotBefore: now.Add(-DefaultValidity), NotAfter: now.Add(+DefaultValidity), } der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key) if err != nil { return nil, err } cert, err := x509.ParseCertificate(der) if err != nil { return nil, err } log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate") out := &tls.Certificate{ Certificate: [][]byte{der}, Leaf: cert, PrivateKey: key, } //ca.cache[name] = out ca.cache.SaveCertificate(name, out) return out, nil } func containsValidCertificate(cert *tls.Certificate) bool { if cert == nil || len(cert.Certificate) == 0 { return false } if cert.Leaf == nil { var err error if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil { return false } } now := time.Now() return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now)) }