Initial import
This commit is contained in:
231
proxy/mitm/authority.go
Normal file
231
proxy/mitm/authority.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultValidity = 24 * time.Hour
|
||||
|
||||
type Authority interface {
|
||||
Certificate() *x509.Certificate
|
||||
TLSConfig(name string) *tls.Config
|
||||
}
|
||||
|
||||
type authority struct {
|
||||
pool *x509.CertPool
|
||||
cert *x509.Certificate
|
||||
key crypto.PrivateKey
|
||||
keyID []byte
|
||||
keyPool chan crypto.PrivateKey
|
||||
cache Cache
|
||||
}
|
||||
|
||||
func New(config *Config) (Authority, error) {
|
||||
cache, err := NewCache(config.Cache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caConfig := config.CA
|
||||
if caConfig == nil {
|
||||
caConfig = new(CAConfig)
|
||||
}
|
||||
|
||||
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
|
||||
if os.IsNotExist(err) {
|
||||
days := caConfig.Days
|
||||
if days == 0 {
|
||||
days = DefaultDays
|
||||
}
|
||||
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
|
||||
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
keyConfig := config.Key
|
||||
if keyConfig == nil {
|
||||
keyConfig = &defaultKeyConfig
|
||||
}
|
||||
|
||||
keyPoolSize := defaultKeyConfig.Pool
|
||||
if keyConfig.Pool > 0 {
|
||||
keyPoolSize = keyConfig.Pool
|
||||
}
|
||||
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
|
||||
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
|
||||
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
|
||||
} else {
|
||||
keyPool <- key
|
||||
}
|
||||
|
||||
go func(pool chan<- crypto.PrivateKey) {
|
||||
for {
|
||||
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msg("error generating private key")
|
||||
}
|
||||
pool <- key
|
||||
}
|
||||
}(keyPool)
|
||||
|
||||
return &authority{
|
||||
pool: pool,
|
||||
cert: cert,
|
||||
key: key,
|
||||
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
|
||||
keyPool: keyPool,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ca *authority) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("ca", ca.cert.Subject.String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ca *authority) Certificate() *x509.Certificate {
|
||||
return ca.cert
|
||||
}
|
||||
|
||||
func (ca *authority) TLSConfig(name string) *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
log := ca.log()
|
||||
if hello.ServerName != "" {
|
||||
name = strings.ToLower(hello.ServerName)
|
||||
log.Debug().Msg("requesting certificate for server name (SNI)")
|
||||
} else {
|
||||
log.Debug().Msg("requesting certificate for hostname")
|
||||
}
|
||||
if cert, ok := ca.getCached(name); ok {
|
||||
log.Debug().
|
||||
Str("subject", cert.Leaf.Subject.String()).
|
||||
Str("serial", cert.Leaf.SerialNumber.String()).
|
||||
Time("valid", cert.Leaf.NotAfter).
|
||||
Msg("using cached certificate")
|
||||
return cert, nil
|
||||
}
|
||||
return ca.issueFor(name)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
|
||||
log := ca.log()
|
||||
|
||||
if cert = ca.cache.Certificate(name); cert == nil {
|
||||
if baseDomain(name) != name {
|
||||
cert = ca.cache.Certificate(baseDomain(name))
|
||||
}
|
||||
}
|
||||
if cert != nil {
|
||||
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
|
||||
DNSName: name,
|
||||
Roots: ca.pool,
|
||||
}); err != nil {
|
||||
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
|
||||
} else {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
|
||||
var (
|
||||
log = ca.log().With().Str("name", name).Logger()
|
||||
key crypto.PrivateKey
|
||||
)
|
||||
select {
|
||||
case key = <-ca.keyPool:
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
|
||||
}
|
||||
if key == nil {
|
||||
panic("key pool returned nil key")
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
name = strings.Join(part[1:], ".")
|
||||
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{name, "*." + name},
|
||||
BasicConstraintsValid: true,
|
||||
NotBefore: now.Add(-DefaultValidity),
|
||||
NotAfter: now.Add(+DefaultValidity),
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
|
||||
out := &tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
//ca.cache[name] = out
|
||||
ca.cache.SaveCertificate(name, out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func containsValidCertificate(cert *tls.Certificate) bool {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if cert.Leaf == nil {
|
||||
var err error
|
||||
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
|
||||
}
|
Reference in New Issue
Block a user