package mitm import ( "crypto/tls" "fmt" "io/fs" "os" "path/filepath" "slices" "strings" "time" "github.com/hashicorp/golang-lru/v2/expirable" "github.com/hashicorp/hcl/v2/gohcl" "github.com/miekg/dns" "git.maze.io/maze/styx/internal/cryptutil" "git.maze.io/maze/styx/internal/log" ) type Cache interface { Certificate(name string) *tls.Certificate SaveCertificate(name string, cert *tls.Certificate) error RemoveCertificate(name string) } func NewCache(config *CacheConfig) (Cache, error) { if config == nil { return NewCache(&CacheConfig{Type: "memory"}) } switch config.Type { case "memory": var cacheConfig = new(MemoryCacheConfig) if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil { return nil, err } return NewMemoryCache(cacheConfig.Size), nil case "disk": var cacheConfig = new(DiskCacheConfig) if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil { return nil, err } return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second))) default: return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type) } } type memoryCache struct { cache *expirable.LRU[string, *tls.Certificate] } func NewMemoryCache(size int) Cache { return memoryCache{ cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) { log.Debug().Str("name", key).Msg("certificate evicted from cache") }, time.Hour*24), } } func (c memoryCache) Certificate(name string) (cert *tls.Certificate) { var ok bool if cert, ok = c.cache.Get(name); !ok { cert, _ = c.cache.Get(baseDomain(name)) } return } func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error { c.cache.Add(name, cert) log.Debug().Str("name", name).Msg("certificate added to cache") return nil } func (c memoryCache) RemoveCertificate(name string) { c.cache.Remove(name) } type diskCache string func NewDiskCache(dir string, expire time.Duration) (Cache, error) { if !filepath.IsAbs(dir) { var err error if dir, err = filepath.Abs(dir); err != nil { return nil, err } } if err := os.MkdirAll(dir, 0o750); err != nil { return nil, err } info, err := os.Stat(dir) if err != nil { return nil, err } if info.Mode()&os.ModePerm|0o057 != 0 { if err := os.Chmod(dir, 0o750); err != nil { return nil, err } } if expire > 0 { go expireDiskCache(dir, expire) } return diskCache(dir), nil } func expireDiskCache(root string, expire time.Duration) { log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting") ticker := time.NewTicker(expire) defer ticker.Stop() for { now := <-ticker.C log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache") filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if d.IsDir() { // Remove the directory; this will fail if it's not empty, which is fine. _ = os.Remove(path) return nil } cert, err := cryptutil.LoadCertificate(path) if err != nil { log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file") _ = os.Remove(path) return nil } else if cert.NotAfter.Before(now) { log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate") _ = os.Remove(path) return nil } return nil }) } } func (c diskCache) path(name string) string { part := dns.SplitDomainName(strings.ToLower(name)) // x,com -> com,x // www,maze,io -> io,maze,www slices.Reverse(part) // com,x -> com,x,x.com // io,maze,www -> io,m,ma,maze,www.maze.io if len(part) > 2 { if len(part[1]) > 1 { part = []string{ part[0], part[1][:1], part[1][:2], part[1], name, } } else { part = []string{ part[0], part[1][:1], part[1], name, } } } else if len(part) > 1 { if len(part[1]) > 1 { part = []string{ part[0], part[1][:1], part[1][:2], name, } } else { part = []string{ part[0], part[1][:1], name, } } } part[len(part)-1] += ".crt" return filepath.Join(append([]string{string(c)}, part...)...) } func (c diskCache) Certificate(name string) (cert *tls.Certificate) { if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil { return &tls.Certificate{ Certificate: [][]byte{cert.Raw}, Leaf: cert, PrivateKey: key, } } if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil { return &tls.Certificate{ Certificate: [][]byte{cert.Raw}, Leaf: cert, PrivateKey: key, } } log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss") return nil } func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error { dir, name := filepath.Split(c.path(name)) if err := os.MkdirAll(dir, 0o750); err != nil { return err } if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil { return err } log.Debug().Str("name", name).Msg("certificate added to cache") return nil } func (c diskCache) RemoveCertificate(name string) { path := c.path(name) if err := os.Remove(path); err != nil { if os.IsNotExist(err) { return } log.Error().Err(err).Msg("certificate remove from cache failed") } _ = os.Remove(filepath.Dir(path)) log.Debug().Str("name", name).Msg("certificate removed from cache") } func baseDomain(name string) string { name = strings.ToLower(name) if part := dns.SplitDomainName(name); len(part) > 2 { return strings.Join(part[1:], ".") } return name }