234 lines
5.5 KiB
Go
234 lines
5.5 KiB
Go
package mitm
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io/fs"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
"github.com/hashicorp/hcl/v2/gohcl"
|
|
"github.com/miekg/dns"
|
|
|
|
"git.maze.io/maze/styx/internal/cryptutil"
|
|
"git.maze.io/maze/styx/internal/log"
|
|
)
|
|
|
|
type Cache interface {
|
|
Certificate(name string) *tls.Certificate
|
|
SaveCertificate(name string, cert *tls.Certificate) error
|
|
RemoveCertificate(name string)
|
|
}
|
|
|
|
func NewCache(config *CacheConfig) (Cache, error) {
|
|
if config == nil {
|
|
return NewCache(&CacheConfig{Type: "memory"})
|
|
}
|
|
switch config.Type {
|
|
case "memory":
|
|
var cacheConfig = new(MemoryCacheConfig)
|
|
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
return NewMemoryCache(cacheConfig.Size), nil
|
|
case "disk":
|
|
var cacheConfig = new(DiskCacheConfig)
|
|
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
|
|
default:
|
|
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
|
|
}
|
|
}
|
|
|
|
type memoryCache struct {
|
|
cache *expirable.LRU[string, *tls.Certificate]
|
|
}
|
|
|
|
func NewMemoryCache(size int) Cache {
|
|
return memoryCache{
|
|
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
|
|
log.Debug().Str("name", key).Msg("certificate evicted from cache")
|
|
}, time.Hour*24),
|
|
}
|
|
}
|
|
|
|
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
|
|
var ok bool
|
|
if cert, ok = c.cache.Get(name); !ok {
|
|
cert, _ = c.cache.Get(baseDomain(name))
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
|
c.cache.Add(name, cert)
|
|
log.Debug().Str("name", name).Msg("certificate added to cache")
|
|
return nil
|
|
}
|
|
|
|
func (c memoryCache) RemoveCertificate(name string) {
|
|
c.cache.Remove(name)
|
|
}
|
|
|
|
type diskCache string
|
|
|
|
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
|
|
if !filepath.IsAbs(dir) {
|
|
var err error
|
|
if dir, err = filepath.Abs(dir); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := os.MkdirAll(dir, 0o750); err != nil {
|
|
return nil, err
|
|
}
|
|
info, err := os.Stat(dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if info.Mode()&os.ModePerm|0o057 != 0 {
|
|
if err := os.Chmod(dir, 0o750); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if expire > 0 {
|
|
go expireDiskCache(dir, expire)
|
|
}
|
|
|
|
return diskCache(dir), nil
|
|
}
|
|
|
|
func expireDiskCache(root string, expire time.Duration) {
|
|
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
|
|
ticker := time.NewTicker(expire)
|
|
defer ticker.Stop()
|
|
for {
|
|
now := <-ticker.C
|
|
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
|
|
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if d.IsDir() {
|
|
// Remove the directory; this will fail if it's not empty, which is fine.
|
|
_ = os.Remove(path)
|
|
return nil
|
|
}
|
|
|
|
cert, err := cryptutil.LoadCertificate(path)
|
|
if err != nil {
|
|
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
|
|
_ = os.Remove(path)
|
|
return nil
|
|
} else if cert.NotAfter.Before(now) {
|
|
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
|
|
_ = os.Remove(path)
|
|
return nil
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c diskCache) path(name string) string {
|
|
part := dns.SplitDomainName(strings.ToLower(name))
|
|
// x,com -> com,x
|
|
// www,maze,io -> io,maze,www
|
|
slices.Reverse(part)
|
|
// com,x -> com,x,x.com
|
|
// io,maze,www -> io,m,ma,maze,www.maze.io
|
|
if len(part) > 2 {
|
|
if len(part[1]) > 1 {
|
|
part = []string{
|
|
part[0],
|
|
part[1][:1],
|
|
part[1][:2],
|
|
part[1],
|
|
name,
|
|
}
|
|
} else {
|
|
part = []string{
|
|
part[0],
|
|
part[1][:1],
|
|
part[1],
|
|
name,
|
|
}
|
|
}
|
|
} else if len(part) > 1 {
|
|
if len(part[1]) > 1 {
|
|
part = []string{
|
|
part[0],
|
|
part[1][:1],
|
|
part[1][:2],
|
|
name,
|
|
}
|
|
} else {
|
|
part = []string{
|
|
part[0],
|
|
part[1][:1],
|
|
name,
|
|
}
|
|
}
|
|
}
|
|
part[len(part)-1] += ".crt"
|
|
return filepath.Join(append([]string{string(c)}, part...)...)
|
|
}
|
|
|
|
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
|
|
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{cert.Raw},
|
|
Leaf: cert,
|
|
PrivateKey: key,
|
|
}
|
|
}
|
|
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{cert.Raw},
|
|
Leaf: cert,
|
|
PrivateKey: key,
|
|
}
|
|
}
|
|
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
|
|
return nil
|
|
}
|
|
|
|
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
|
dir, name := filepath.Split(c.path(name))
|
|
if err := os.MkdirAll(dir, 0o750); err != nil {
|
|
return err
|
|
}
|
|
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
|
|
return err
|
|
}
|
|
log.Debug().Str("name", name).Msg("certificate added to cache")
|
|
return nil
|
|
}
|
|
|
|
func (c diskCache) RemoveCertificate(name string) {
|
|
path := c.path(name)
|
|
if err := os.Remove(path); err != nil {
|
|
if os.IsNotExist(err) {
|
|
return
|
|
}
|
|
log.Error().Err(err).Msg("certificate remove from cache failed")
|
|
}
|
|
_ = os.Remove(filepath.Dir(path))
|
|
log.Debug().Str("name", name).Msg("certificate removed from cache")
|
|
}
|
|
|
|
func baseDomain(name string) string {
|
|
name = strings.ToLower(name)
|
|
if part := dns.SplitDomainName(name); len(part) > 2 {
|
|
return strings.Join(part[1:], ".")
|
|
}
|
|
return name
|
|
}
|