Initial import

This commit is contained in:
2025-09-26 08:49:53 +02:00
commit a76650da35
35 changed files with 4660 additions and 0 deletions

231
proxy/mitm/authority.go Normal file
View File

@@ -0,0 +1,231 @@
package mitm
import (
"crypto"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"math/big"
"os"
"strings"
"time"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/log"
"github.com/miekg/dns"
)
const DefaultValidity = 24 * time.Hour
type Authority interface {
Certificate() *x509.Certificate
TLSConfig(name string) *tls.Config
}
type authority struct {
pool *x509.CertPool
cert *x509.Certificate
key crypto.PrivateKey
keyID []byte
keyPool chan crypto.PrivateKey
cache Cache
}
func New(config *Config) (Authority, error) {
cache, err := NewCache(config.Cache)
if err != nil {
return nil, err
}
caConfig := config.CA
if caConfig == nil {
caConfig = new(CAConfig)
}
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
if os.IsNotExist(err) {
days := caConfig.Days
if days == 0 {
days = DefaultDays
}
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
return nil, err
}
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
return nil, err
}
}
} else if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(cert)
keyConfig := config.Key
if keyConfig == nil {
keyConfig = &defaultKeyConfig
}
keyPoolSize := defaultKeyConfig.Pool
if keyConfig.Pool > 0 {
keyPoolSize = keyConfig.Pool
}
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
} else {
keyPool <- key
}
go func(pool chan<- crypto.PrivateKey) {
for {
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
if err != nil {
log.Panic().Err(err).Msg("error generating private key")
}
pool <- key
}
}(keyPool)
return &authority{
pool: pool,
cert: cert,
key: key,
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
keyPool: keyPool,
cache: cache,
}, nil
}
func (ca *authority) log() log.Logger {
return log.Console.With().
Str("ca", ca.cert.Subject.String()).
Logger()
}
func (ca *authority) Certificate() *x509.Certificate {
return ca.cert
}
func (ca *authority) TLSConfig(name string) *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log := ca.log()
if hello.ServerName != "" {
name = strings.ToLower(hello.ServerName)
log.Debug().Msg("requesting certificate for server name (SNI)")
} else {
log.Debug().Msg("requesting certificate for hostname")
}
if cert, ok := ca.getCached(name); ok {
log.Debug().
Str("subject", cert.Leaf.Subject.String()).
Str("serial", cert.Leaf.SerialNumber.String()).
Time("valid", cert.Leaf.NotAfter).
Msg("using cached certificate")
return cert, nil
}
return ca.issueFor(name)
},
NextProtos: []string{"http/1.1"},
}
}
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
log := ca.log()
if cert = ca.cache.Certificate(name); cert == nil {
if baseDomain(name) != name {
cert = ca.cache.Certificate(baseDomain(name))
}
}
if cert != nil {
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
DNSName: name,
Roots: ca.pool,
}); err != nil {
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
} else {
ok = true
}
}
return
}
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
var (
log = ca.log().With().Str("name", name).Logger()
key crypto.PrivateKey
)
select {
case key = <-ca.keyPool:
case <-time.After(5 * time.Second):
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
}
if key == nil {
panic("key pool returned nil key")
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
}
if part := dns.SplitDomainName(name); len(part) > 2 {
name = strings.Join(part[1:], ".")
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{name, "*." + name},
BasicConstraintsValid: true,
NotBefore: now.Add(-DefaultValidity),
NotAfter: now.Add(+DefaultValidity),
}
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
if err != nil {
return nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
out := &tls.Certificate{
Certificate: [][]byte{der},
Leaf: cert,
PrivateKey: key,
}
//ca.cache[name] = out
ca.cache.SaveCertificate(name, out)
return out, nil
}
func containsValidCertificate(cert *tls.Certificate) bool {
if cert == nil || len(cert.Certificate) == 0 {
return false
}
if cert.Leaf == nil {
var err error
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
return false
}
}
now := time.Now()
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
}

233
proxy/mitm/cache.go Normal file
View File

@@ -0,0 +1,233 @@
package mitm
import (
"crypto/tls"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/miekg/dns"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/log"
)
type Cache interface {
Certificate(name string) *tls.Certificate
SaveCertificate(name string, cert *tls.Certificate) error
RemoveCertificate(name string)
}
func NewCache(config *CacheConfig) (Cache, error) {
if config == nil {
return NewCache(&CacheConfig{Type: "memory"})
}
switch config.Type {
case "memory":
var cacheConfig = new(MemoryCacheConfig)
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
return nil, err
}
return NewMemoryCache(cacheConfig.Size), nil
case "disk":
var cacheConfig = new(DiskCacheConfig)
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
return nil, err
}
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
default:
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
}
}
type memoryCache struct {
cache *expirable.LRU[string, *tls.Certificate]
}
func NewMemoryCache(size int) Cache {
return memoryCache{
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
log.Debug().Str("name", key).Msg("certificate evicted from cache")
}, time.Hour*24),
}
}
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
var ok bool
if cert, ok = c.cache.Get(name); !ok {
cert, _ = c.cache.Get(baseDomain(name))
}
return
}
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
c.cache.Add(name, cert)
log.Debug().Str("name", name).Msg("certificate added to cache")
return nil
}
func (c memoryCache) RemoveCertificate(name string) {
c.cache.Remove(name)
}
type diskCache string
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
if !filepath.IsAbs(dir) {
var err error
if dir, err = filepath.Abs(dir); err != nil {
return nil, err
}
}
if err := os.MkdirAll(dir, 0o750); err != nil {
return nil, err
}
info, err := os.Stat(dir)
if err != nil {
return nil, err
}
if info.Mode()&os.ModePerm|0o057 != 0 {
if err := os.Chmod(dir, 0o750); err != nil {
return nil, err
}
}
if expire > 0 {
go expireDiskCache(dir, expire)
}
return diskCache(dir), nil
}
func expireDiskCache(root string, expire time.Duration) {
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
ticker := time.NewTicker(expire)
defer ticker.Stop()
for {
now := <-ticker.C
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
// Remove the directory; this will fail if it's not empty, which is fine.
_ = os.Remove(path)
return nil
}
cert, err := cryptutil.LoadCertificate(path)
if err != nil {
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
_ = os.Remove(path)
return nil
} else if cert.NotAfter.Before(now) {
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
_ = os.Remove(path)
return nil
}
return nil
})
}
}
func (c diskCache) path(name string) string {
part := dns.SplitDomainName(strings.ToLower(name))
// x,com -> com,x
// www,maze,io -> io,maze,www
slices.Reverse(part)
// com,x -> com,x,x.com
// io,maze,www -> io,m,ma,maze,www.maze.io
if len(part) > 2 {
if len(part[1]) > 1 {
part = []string{
part[0],
part[1][:1],
part[1][:2],
part[1],
name,
}
} else {
part = []string{
part[0],
part[1][:1],
part[1],
name,
}
}
} else if len(part) > 1 {
if len(part[1]) > 1 {
part = []string{
part[0],
part[1][:1],
part[1][:2],
name,
}
} else {
part = []string{
part[0],
part[1][:1],
name,
}
}
}
part[len(part)-1] += ".crt"
return filepath.Join(append([]string{string(c)}, part...)...)
}
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
return &tls.Certificate{
Certificate: [][]byte{cert.Raw},
Leaf: cert,
PrivateKey: key,
}
}
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
return &tls.Certificate{
Certificate: [][]byte{cert.Raw},
Leaf: cert,
PrivateKey: key,
}
}
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
return nil
}
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
dir, name := filepath.Split(c.path(name))
if err := os.MkdirAll(dir, 0o750); err != nil {
return err
}
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
return err
}
log.Debug().Str("name", name).Msg("certificate added to cache")
return nil
}
func (c diskCache) RemoveCertificate(name string) {
path := c.path(name)
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
return
}
log.Error().Err(err).Msg("certificate remove from cache failed")
}
_ = os.Remove(filepath.Dir(path))
log.Debug().Str("name", name).Msg("certificate removed from cache")
}
func baseDomain(name string) string {
name = strings.ToLower(name)
if part := dns.SplitDomainName(name); len(part) > 2 {
return strings.Join(part[1:], ".")
}
return name
}

25
proxy/mitm/cache_test.go Normal file
View File

@@ -0,0 +1,25 @@
package mitm
import "testing"
func TestDiskCachePath(t *testing.T) {
cache := diskCache("testdata")
tests := []struct {
test string
want string
}{
{"x.com", "testdata/com/x/x.com.crt"},
{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
{"maze.io", "testdata/io/m/ma/maze.io.crt"},
{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
}
for _, test := range tests {
t.Run(test.test, func(it *testing.T) {
if v := cache.path(test.test); v != test.want {
it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
}
})
}
}

89
proxy/mitm/config.go Normal file
View File

@@ -0,0 +1,89 @@
package mitm
import (
"crypto/x509/pkix"
"github.com/hashicorp/hcl/v2"
)
const (
DefaultCommonName = "Styx Certificate Authority"
DefaultDays = 3
)
type Config struct {
CA *CAConfig `hcl:"ca,block"`
Key *KeyConfig `hcl:"key,block"`
Cache *CacheConfig `hcl:"cache,block"`
}
type CAConfig struct {
Cert string `hcl:"cert"`
Key string `hcl:"key,optional"`
Days int `hcl:"days,optional"`
KeyType string `hcl:"key_type,optional"`
Bits int `hcl:"bits,optional"`
Name string `hcl:"name,optional"`
Country string `hcl:"country,optional"`
Organization string `hcl:"organization,optional"`
Unit string `hcl:"unit,optional"`
Locality string `hcl:"locality,optional"`
Province string `hcl:"province,optional"`
Address []string `hcl:"address,optional"`
PostalCode string `hcl:"postal_code,optional"`
}
func (config CAConfig) DN() pkix.Name {
var name = pkix.Name{
CommonName: config.Name,
StreetAddress: config.Address,
}
if config.Name == "" {
name.CommonName = DefaultCommonName
}
if config.Country != "" {
name.Country = append(name.Country, config.Country)
}
if config.Organization != "" {
name.Organization = append(name.Organization, config.Organization)
}
if config.Unit != "" {
name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
}
if config.Locality != "" {
name.Locality = append(name.Locality, config.Locality)
}
if config.Province != "" {
name.Province = append(name.Province, config.Province)
}
if config.PostalCode != "" {
name.PostalCode = append(name.PostalCode, config.PostalCode)
}
return name
}
type KeyConfig struct {
Type string `hcl:"type,optional"`
Bits int `hcl:"bits,optional"`
Pool int `hcl:"pool,optional"`
}
var defaultKeyConfig = KeyConfig{
Type: "rsa",
Bits: 2048,
Pool: 5,
}
type CacheConfig struct {
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
type MemoryCacheConfig struct {
Size int `hcl:"size,optional"`
}
type DiskCacheConfig struct {
Path string `hcl:"path"`
Expire float64 `hcl:"expire,optional"`
}