Initial import
This commit is contained in:
233
proxy/mitm/cache.go
Normal file
233
proxy/mitm/cache.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Certificate(name string) *tls.Certificate
|
||||
SaveCertificate(name string, cert *tls.Certificate) error
|
||||
RemoveCertificate(name string)
|
||||
}
|
||||
|
||||
func NewCache(config *CacheConfig) (Cache, error) {
|
||||
if config == nil {
|
||||
return NewCache(&CacheConfig{Type: "memory"})
|
||||
}
|
||||
switch config.Type {
|
||||
case "memory":
|
||||
var cacheConfig = new(MemoryCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMemoryCache(cacheConfig.Size), nil
|
||||
case "disk":
|
||||
var cacheConfig = new(DiskCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
|
||||
default:
|
||||
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type memoryCache struct {
|
||||
cache *expirable.LRU[string, *tls.Certificate]
|
||||
}
|
||||
|
||||
func NewMemoryCache(size int) Cache {
|
||||
return memoryCache{
|
||||
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
|
||||
log.Debug().Str("name", key).Msg("certificate evicted from cache")
|
||||
}, time.Hour*24),
|
||||
}
|
||||
}
|
||||
|
||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
var ok bool
|
||||
if cert, ok = c.cache.Get(name); !ok {
|
||||
cert, _ = c.cache.Get(baseDomain(name))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
c.cache.Add(name, cert)
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c memoryCache) RemoveCertificate(name string) {
|
||||
c.cache.Remove(name)
|
||||
}
|
||||
|
||||
type diskCache string
|
||||
|
||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
|
||||
if !filepath.IsAbs(dir) {
|
||||
var err error
|
||||
if dir, err = filepath.Abs(dir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.Mode()&os.ModePerm|0o057 != 0 {
|
||||
if err := os.Chmod(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if expire > 0 {
|
||||
go expireDiskCache(dir, expire)
|
||||
}
|
||||
|
||||
return diskCache(dir), nil
|
||||
}
|
||||
|
||||
func expireDiskCache(root string, expire time.Duration) {
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
|
||||
ticker := time.NewTicker(expire)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
now := <-ticker.C
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
|
||||
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
// Remove the directory; this will fail if it's not empty, which is fine.
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.LoadCertificate(path)
|
||||
if err != nil {
|
||||
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
} else if cert.NotAfter.Before(now) {
|
||||
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c diskCache) path(name string) string {
|
||||
part := dns.SplitDomainName(strings.ToLower(name))
|
||||
// x,com -> com,x
|
||||
// www,maze,io -> io,maze,www
|
||||
slices.Reverse(part)
|
||||
// com,x -> com,x,x.com
|
||||
// io,maze,www -> io,m,ma,maze,www.maze.io
|
||||
if len(part) > 2 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
} else if len(part) > 1 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
}
|
||||
part[len(part)-1] += ".crt"
|
||||
return filepath.Join(append([]string{string(c)}, part...)...)
|
||||
}
|
||||
|
||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
dir, name := filepath.Split(c.path(name))
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) RemoveCertificate(name string) {
|
||||
path := c.path(name)
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("certificate remove from cache failed")
|
||||
}
|
||||
_ = os.Remove(filepath.Dir(path))
|
||||
log.Debug().Str("name", name).Msg("certificate removed from cache")
|
||||
}
|
||||
|
||||
func baseDomain(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
return strings.Join(part[1:], ".")
|
||||
}
|
||||
return name
|
||||
}
|
Reference in New Issue
Block a user