Initial import
This commit is contained in:
231
proxy/mitm/authority.go
Normal file
231
proxy/mitm/authority.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultValidity = 24 * time.Hour
|
||||
|
||||
type Authority interface {
|
||||
Certificate() *x509.Certificate
|
||||
TLSConfig(name string) *tls.Config
|
||||
}
|
||||
|
||||
type authority struct {
|
||||
pool *x509.CertPool
|
||||
cert *x509.Certificate
|
||||
key crypto.PrivateKey
|
||||
keyID []byte
|
||||
keyPool chan crypto.PrivateKey
|
||||
cache Cache
|
||||
}
|
||||
|
||||
func New(config *Config) (Authority, error) {
|
||||
cache, err := NewCache(config.Cache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caConfig := config.CA
|
||||
if caConfig == nil {
|
||||
caConfig = new(CAConfig)
|
||||
}
|
||||
|
||||
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
|
||||
if os.IsNotExist(err) {
|
||||
days := caConfig.Days
|
||||
if days == 0 {
|
||||
days = DefaultDays
|
||||
}
|
||||
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
|
||||
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
keyConfig := config.Key
|
||||
if keyConfig == nil {
|
||||
keyConfig = &defaultKeyConfig
|
||||
}
|
||||
|
||||
keyPoolSize := defaultKeyConfig.Pool
|
||||
if keyConfig.Pool > 0 {
|
||||
keyPoolSize = keyConfig.Pool
|
||||
}
|
||||
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
|
||||
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
|
||||
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
|
||||
} else {
|
||||
keyPool <- key
|
||||
}
|
||||
|
||||
go func(pool chan<- crypto.PrivateKey) {
|
||||
for {
|
||||
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msg("error generating private key")
|
||||
}
|
||||
pool <- key
|
||||
}
|
||||
}(keyPool)
|
||||
|
||||
return &authority{
|
||||
pool: pool,
|
||||
cert: cert,
|
||||
key: key,
|
||||
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
|
||||
keyPool: keyPool,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ca *authority) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("ca", ca.cert.Subject.String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ca *authority) Certificate() *x509.Certificate {
|
||||
return ca.cert
|
||||
}
|
||||
|
||||
func (ca *authority) TLSConfig(name string) *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
log := ca.log()
|
||||
if hello.ServerName != "" {
|
||||
name = strings.ToLower(hello.ServerName)
|
||||
log.Debug().Msg("requesting certificate for server name (SNI)")
|
||||
} else {
|
||||
log.Debug().Msg("requesting certificate for hostname")
|
||||
}
|
||||
if cert, ok := ca.getCached(name); ok {
|
||||
log.Debug().
|
||||
Str("subject", cert.Leaf.Subject.String()).
|
||||
Str("serial", cert.Leaf.SerialNumber.String()).
|
||||
Time("valid", cert.Leaf.NotAfter).
|
||||
Msg("using cached certificate")
|
||||
return cert, nil
|
||||
}
|
||||
return ca.issueFor(name)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
|
||||
log := ca.log()
|
||||
|
||||
if cert = ca.cache.Certificate(name); cert == nil {
|
||||
if baseDomain(name) != name {
|
||||
cert = ca.cache.Certificate(baseDomain(name))
|
||||
}
|
||||
}
|
||||
if cert != nil {
|
||||
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
|
||||
DNSName: name,
|
||||
Roots: ca.pool,
|
||||
}); err != nil {
|
||||
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
|
||||
} else {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
|
||||
var (
|
||||
log = ca.log().With().Str("name", name).Logger()
|
||||
key crypto.PrivateKey
|
||||
)
|
||||
select {
|
||||
case key = <-ca.keyPool:
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
|
||||
}
|
||||
if key == nil {
|
||||
panic("key pool returned nil key")
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
name = strings.Join(part[1:], ".")
|
||||
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{name, "*." + name},
|
||||
BasicConstraintsValid: true,
|
||||
NotBefore: now.Add(-DefaultValidity),
|
||||
NotAfter: now.Add(+DefaultValidity),
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
|
||||
out := &tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
//ca.cache[name] = out
|
||||
ca.cache.SaveCertificate(name, out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func containsValidCertificate(cert *tls.Certificate) bool {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if cert.Leaf == nil {
|
||||
var err error
|
||||
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
|
||||
}
|
233
proxy/mitm/cache.go
Normal file
233
proxy/mitm/cache.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Certificate(name string) *tls.Certificate
|
||||
SaveCertificate(name string, cert *tls.Certificate) error
|
||||
RemoveCertificate(name string)
|
||||
}
|
||||
|
||||
func NewCache(config *CacheConfig) (Cache, error) {
|
||||
if config == nil {
|
||||
return NewCache(&CacheConfig{Type: "memory"})
|
||||
}
|
||||
switch config.Type {
|
||||
case "memory":
|
||||
var cacheConfig = new(MemoryCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMemoryCache(cacheConfig.Size), nil
|
||||
case "disk":
|
||||
var cacheConfig = new(DiskCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
|
||||
default:
|
||||
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type memoryCache struct {
|
||||
cache *expirable.LRU[string, *tls.Certificate]
|
||||
}
|
||||
|
||||
func NewMemoryCache(size int) Cache {
|
||||
return memoryCache{
|
||||
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
|
||||
log.Debug().Str("name", key).Msg("certificate evicted from cache")
|
||||
}, time.Hour*24),
|
||||
}
|
||||
}
|
||||
|
||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
var ok bool
|
||||
if cert, ok = c.cache.Get(name); !ok {
|
||||
cert, _ = c.cache.Get(baseDomain(name))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
c.cache.Add(name, cert)
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c memoryCache) RemoveCertificate(name string) {
|
||||
c.cache.Remove(name)
|
||||
}
|
||||
|
||||
type diskCache string
|
||||
|
||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
|
||||
if !filepath.IsAbs(dir) {
|
||||
var err error
|
||||
if dir, err = filepath.Abs(dir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.Mode()&os.ModePerm|0o057 != 0 {
|
||||
if err := os.Chmod(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if expire > 0 {
|
||||
go expireDiskCache(dir, expire)
|
||||
}
|
||||
|
||||
return diskCache(dir), nil
|
||||
}
|
||||
|
||||
func expireDiskCache(root string, expire time.Duration) {
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
|
||||
ticker := time.NewTicker(expire)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
now := <-ticker.C
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
|
||||
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
// Remove the directory; this will fail if it's not empty, which is fine.
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.LoadCertificate(path)
|
||||
if err != nil {
|
||||
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
} else if cert.NotAfter.Before(now) {
|
||||
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c diskCache) path(name string) string {
|
||||
part := dns.SplitDomainName(strings.ToLower(name))
|
||||
// x,com -> com,x
|
||||
// www,maze,io -> io,maze,www
|
||||
slices.Reverse(part)
|
||||
// com,x -> com,x,x.com
|
||||
// io,maze,www -> io,m,ma,maze,www.maze.io
|
||||
if len(part) > 2 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
} else if len(part) > 1 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
}
|
||||
part[len(part)-1] += ".crt"
|
||||
return filepath.Join(append([]string{string(c)}, part...)...)
|
||||
}
|
||||
|
||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
dir, name := filepath.Split(c.path(name))
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) RemoveCertificate(name string) {
|
||||
path := c.path(name)
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("certificate remove from cache failed")
|
||||
}
|
||||
_ = os.Remove(filepath.Dir(path))
|
||||
log.Debug().Str("name", name).Msg("certificate removed from cache")
|
||||
}
|
||||
|
||||
func baseDomain(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
return strings.Join(part[1:], ".")
|
||||
}
|
||||
return name
|
||||
}
|
25
proxy/mitm/cache_test.go
Normal file
25
proxy/mitm/cache_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package mitm
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDiskCachePath(t *testing.T) {
|
||||
cache := diskCache("testdata")
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{"x.com", "testdata/com/x/x.com.crt"},
|
||||
{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
|
||||
{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
|
||||
{"maze.io", "testdata/io/m/ma/maze.io.crt"},
|
||||
{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
|
||||
{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.test, func(it *testing.T) {
|
||||
if v := cache.path(test.test); v != test.want {
|
||||
it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
89
proxy/mitm/config.go
Normal file
89
proxy/mitm/config.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCommonName = "Styx Certificate Authority"
|
||||
DefaultDays = 3
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
CA *CAConfig `hcl:"ca,block"`
|
||||
Key *KeyConfig `hcl:"key,block"`
|
||||
Cache *CacheConfig `hcl:"cache,block"`
|
||||
}
|
||||
|
||||
type CAConfig struct {
|
||||
Cert string `hcl:"cert"`
|
||||
Key string `hcl:"key,optional"`
|
||||
Days int `hcl:"days,optional"`
|
||||
KeyType string `hcl:"key_type,optional"`
|
||||
Bits int `hcl:"bits,optional"`
|
||||
Name string `hcl:"name,optional"`
|
||||
Country string `hcl:"country,optional"`
|
||||
Organization string `hcl:"organization,optional"`
|
||||
Unit string `hcl:"unit,optional"`
|
||||
Locality string `hcl:"locality,optional"`
|
||||
Province string `hcl:"province,optional"`
|
||||
Address []string `hcl:"address,optional"`
|
||||
PostalCode string `hcl:"postal_code,optional"`
|
||||
}
|
||||
|
||||
func (config CAConfig) DN() pkix.Name {
|
||||
var name = pkix.Name{
|
||||
CommonName: config.Name,
|
||||
StreetAddress: config.Address,
|
||||
}
|
||||
if config.Name == "" {
|
||||
name.CommonName = DefaultCommonName
|
||||
}
|
||||
if config.Country != "" {
|
||||
name.Country = append(name.Country, config.Country)
|
||||
}
|
||||
if config.Organization != "" {
|
||||
name.Organization = append(name.Organization, config.Organization)
|
||||
}
|
||||
if config.Unit != "" {
|
||||
name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
|
||||
}
|
||||
if config.Locality != "" {
|
||||
name.Locality = append(name.Locality, config.Locality)
|
||||
}
|
||||
if config.Province != "" {
|
||||
name.Province = append(name.Province, config.Province)
|
||||
}
|
||||
if config.PostalCode != "" {
|
||||
name.PostalCode = append(name.PostalCode, config.PostalCode)
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
type KeyConfig struct {
|
||||
Type string `hcl:"type,optional"`
|
||||
Bits int `hcl:"bits,optional"`
|
||||
Pool int `hcl:"pool,optional"`
|
||||
}
|
||||
|
||||
var defaultKeyConfig = KeyConfig{
|
||||
Type: "rsa",
|
||||
Bits: 2048,
|
||||
Pool: 5,
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
type MemoryCacheConfig struct {
|
||||
Size int `hcl:"size,optional"`
|
||||
}
|
||||
|
||||
type DiskCacheConfig struct {
|
||||
Path string `hcl:"path"`
|
||||
Expire float64 `hcl:"expire,optional"`
|
||||
}
|
Reference in New Issue
Block a user