Files
styx/proxy/mitm/cache.go
2025-09-26 08:49:53 +02:00

234 lines
5.5 KiB
Go

package mitm
import (
"crypto/tls"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/miekg/dns"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/log"
)
type Cache interface {
Certificate(name string) *tls.Certificate
SaveCertificate(name string, cert *tls.Certificate) error
RemoveCertificate(name string)
}
func NewCache(config *CacheConfig) (Cache, error) {
if config == nil {
return NewCache(&CacheConfig{Type: "memory"})
}
switch config.Type {
case "memory":
var cacheConfig = new(MemoryCacheConfig)
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
return nil, err
}
return NewMemoryCache(cacheConfig.Size), nil
case "disk":
var cacheConfig = new(DiskCacheConfig)
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
return nil, err
}
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
default:
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
}
}
type memoryCache struct {
cache *expirable.LRU[string, *tls.Certificate]
}
func NewMemoryCache(size int) Cache {
return memoryCache{
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
log.Debug().Str("name", key).Msg("certificate evicted from cache")
}, time.Hour*24),
}
}
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
var ok bool
if cert, ok = c.cache.Get(name); !ok {
cert, _ = c.cache.Get(baseDomain(name))
}
return
}
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
c.cache.Add(name, cert)
log.Debug().Str("name", name).Msg("certificate added to cache")
return nil
}
func (c memoryCache) RemoveCertificate(name string) {
c.cache.Remove(name)
}
type diskCache string
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
if !filepath.IsAbs(dir) {
var err error
if dir, err = filepath.Abs(dir); err != nil {
return nil, err
}
}
if err := os.MkdirAll(dir, 0o750); err != nil {
return nil, err
}
info, err := os.Stat(dir)
if err != nil {
return nil, err
}
if info.Mode()&os.ModePerm|0o057 != 0 {
if err := os.Chmod(dir, 0o750); err != nil {
return nil, err
}
}
if expire > 0 {
go expireDiskCache(dir, expire)
}
return diskCache(dir), nil
}
func expireDiskCache(root string, expire time.Duration) {
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
ticker := time.NewTicker(expire)
defer ticker.Stop()
for {
now := <-ticker.C
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
// Remove the directory; this will fail if it's not empty, which is fine.
_ = os.Remove(path)
return nil
}
cert, err := cryptutil.LoadCertificate(path)
if err != nil {
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
_ = os.Remove(path)
return nil
} else if cert.NotAfter.Before(now) {
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
_ = os.Remove(path)
return nil
}
return nil
})
}
}
func (c diskCache) path(name string) string {
part := dns.SplitDomainName(strings.ToLower(name))
// x,com -> com,x
// www,maze,io -> io,maze,www
slices.Reverse(part)
// com,x -> com,x,x.com
// io,maze,www -> io,m,ma,maze,www.maze.io
if len(part) > 2 {
if len(part[1]) > 1 {
part = []string{
part[0],
part[1][:1],
part[1][:2],
part[1],
name,
}
} else {
part = []string{
part[0],
part[1][:1],
part[1],
name,
}
}
} else if len(part) > 1 {
if len(part[1]) > 1 {
part = []string{
part[0],
part[1][:1],
part[1][:2],
name,
}
} else {
part = []string{
part[0],
part[1][:1],
name,
}
}
}
part[len(part)-1] += ".crt"
return filepath.Join(append([]string{string(c)}, part...)...)
}
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
return &tls.Certificate{
Certificate: [][]byte{cert.Raw},
Leaf: cert,
PrivateKey: key,
}
}
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
return &tls.Certificate{
Certificate: [][]byte{cert.Raw},
Leaf: cert,
PrivateKey: key,
}
}
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
return nil
}
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
dir, name := filepath.Split(c.path(name))
if err := os.MkdirAll(dir, 0o750); err != nil {
return err
}
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
return err
}
log.Debug().Str("name", name).Msg("certificate added to cache")
return nil
}
func (c diskCache) RemoveCertificate(name string) {
path := c.path(name)
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
return
}
log.Error().Err(err).Msg("certificate remove from cache failed")
}
_ = os.Remove(filepath.Dir(path))
log.Debug().Str("name", name).Msg("certificate removed from cache")
}
func baseDomain(name string) string {
name = strings.ToLower(name)
if part := dns.SplitDomainName(name); len(part) > 2 {
return strings.Join(part[1:], ".")
}
return name
}