commit a76650da35800b6bba34da21fda7611d6c5c82c7 Author: maze Date: Fri Sep 26 08:49:53 2025 +0200 Initial import diff --git a/README.md b/README.md new file mode 100644 index 0000000..3920f8b --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# Styx + +![Logo](styx.png "Styx proxy") + +Styx is a filtering HTTP proxy. diff --git a/cmd/styx/main.go b/cmd/styx/main.go new file mode 100644 index 0000000..3103aff --- /dev/null +++ b/cmd/styx/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "flag" + "os" + "os/signal" + "syscall" + + "github.com/hashicorp/hcl/v2/hclsimple" + + "git.maze.io/maze/styx/internal/log" + "git.maze.io/maze/styx/proxy" + "git.maze.io/maze/styx/proxy/cache" + "git.maze.io/maze/styx/proxy/match" + "git.maze.io/maze/styx/proxy/mitm" + "git.maze.io/maze/styx/proxy/resolver" +) + +func main() { + configFlag := flag.String("config", "styx.hcl", "Configuration file") + traceFlag := flag.Bool("T", false, "Enable trace level logging") + debugFlag := flag.Bool("D", false, "Enable debug level logging") + flag.Parse() + + if *traceFlag { + log.SetLevel(log.TraceLevel) + } else if *debugFlag { + log.SetLevel(log.DebugLevel) + } + + config, err := load(*configFlag) + if err != nil { + log.Fatal().Err(err).Msg("") + } + + matchers, err := config.Match.Matchers() + if err != nil { + log.Fatal().Err(err).Msg("") + } else if err = config.Proxy.Policy.Configure(matchers); err != nil { + log.Fatal().Err(err).Msg("") + } + + var ca mitm.Authority + if config.MITM != nil { + if ca, err = mitm.New(config.MITM); err != nil { + log.Fatal().Err(err).Msg("error configuring mitm") + } + } + + server, err := proxy.New(&config.Proxy, ca) + if err != nil { + log.Fatal().Err(err).Msg("") + } + + if err = server.Start(); err != nil { + log.Fatal().Err(err).Msg("") + } + + signalChannel := make(chan os.Signal, 1) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM) + <-signalChannel + + server.Close() +} + +type Config struct { + DNS *resolver.Config `hcl:"dns,block"` + Proxy proxy.Config `hcl:"proxy,block"` + MITM *mitm.Config `hcl:"mitm,block"` + Cache *cache.Config `hcl:"cache,block"` + Match *match.Config `hcl:"match,block"` +} + +func load(name string) (*Config, error) { + config := new(Config) + if err := hclsimple.DecodeFile(name, nil, config); err != nil { + return nil, err + } + + if config.DNS != nil { + config.Proxy.Resolver = resolver.New(*config.DNS) + } + + return config, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..533e20d --- /dev/null +++ b/go.mod @@ -0,0 +1,29 @@ +module git.maze.io/maze/styx + +go 1.25.0 + +require ( + github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/hashicorp/hcl/v2 v2.24.0 + github.com/miekg/dns v1.1.68 + github.com/rs/zerolog v1.34.0 +) + +require ( + github.com/agext/levenshtein v1.2.1 // indirect + github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mitchellh/go-wordwrap v1.0.1 // indirect + github.com/yl2chen/cidranger v1.0.2 // indirect + github.com/zclconf/go-cty v1.16.3 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect + golang.org/x/tools v0.33.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5cf0546 --- /dev/null +++ b/go.sum @@ -0,0 +1,60 @@ +github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= +github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= +github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= +github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/hcl/v2 v2.24.0 h1:2QJdZ454DSsYGoaE6QheQZjtKZSUs9Nh2izTWiwQxvE= +github.com/hashicorp/hcl/v2 v2.24.0/go.mod h1:oGoO1FIQYfn/AgyOhlg9qLC6/nOJPX3qGbkZpYAcqfM= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= +github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= +github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= +github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/yl2chen/cidranger v1.0.2 h1:lbOWZVCG1tCRX4u24kuM1Tb4nHqWkDxwLdoS+SevawU= +github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g= +github.com/zclconf/go-cty v1.16.3 h1:osr++gw2T61A8KVYHoQiFbFd1Lh3JOCXc/jFLJXKTxk= +github.com/zclconf/go-cty v1.16.3/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE= +github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6ZMSMNJFMOjqrGHynW3DIBuR2H9j0ug+Mo= +github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/cryptutil/key.go b/internal/cryptutil/key.go new file mode 100644 index 0000000..2432eea --- /dev/null +++ b/internal/cryptutil/key.go @@ -0,0 +1,114 @@ +package cryptutil + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "os" +) + +type publicKeyer interface { + Public() any +} + +// PublicKey returns the public part of a crypto.PrivateKey. +func PublicKey(key crypto.PrivateKey) any { + switch key := key.(type) { + case ed25519.PublicKey: + return key + case ed25519.PrivateKey: + return key.Public() + case ecdsa.PublicKey: + return &key + case *ecdsa.PublicKey: + return key + case *ecdsa.PrivateKey: + return &key.PublicKey + case rsa.PublicKey: + return &key + case *rsa.PublicKey: + return key + case *rsa.PrivateKey: + return &key.PublicKey + default: + if p, ok := key.(publicKeyer); ok { + return p.Public() + } + panic(fmt.Sprintf("don't know how to extract a public key from %T", key)) + } +} + +// LoadPrivateKey loads a private key from disk. +func LoadPrivateKey(name string) (crypto.PrivateKey, error) { + b, err := os.ReadFile(name) + if err != nil { + return nil, err + } + return decodePEMPrivateKey(b) +} + +func decodePEMPrivateKey(b []byte) (key crypto.PrivateKey, err error) { + var ( + rest = b + block *pem.Block + ) + for { + if block, rest = pem.Decode(rest); block == nil { + return nil, errors.New("mitm: no private key PEM block could be decoded") + } + switch block.Type { + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + + case "PRIVATE KEY": + return x509.ParsePKCS8PrivateKey(block.Bytes) + } + } +} + +// GenerateKeyID generates the PKIX public key ID. +func GenerateKeyID(key crypto.PublicKey) []byte { + b, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + return nil + } + return sha1.New().Sum(b) +} + +func keyType(key any) string { + switch key := key.(type) { + case ed25519.PrivateKey: + return "ed25519" + case *ecdsa.PrivateKey: + return "ecdsa (" + curveType(key.Curve) + ")" + case *rsa.PrivateKey: + return "rsa" + default: + return fmt.Sprintf("%T", key) + } +} + +func curveType(c elliptic.Curve) string { + switch c { + case elliptic.P224(): + return "p224" + case elliptic.P256(): + return "p256" + case elliptic.P384(): + return "p384" + case elliptic.P521(): + return "p521" + default: + return fmt.Sprintf("%T", c) + } +} diff --git a/internal/cryptutil/x509.go b/internal/cryptutil/x509.go new file mode 100644 index 0000000..6192fda --- /dev/null +++ b/internal/cryptutil/x509.go @@ -0,0 +1,242 @@ +package cryptutil + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math/big" + "os" + "path/filepath" + "strings" + "time" + + "git.maze.io/maze/styx/internal/log" +) + +// Supported key types. +const ( + TypeRSA = "rsa" + TypeECDSA = "ecdsa" + TypeED25519 = "ed25519" +) + +// Supported PEM block types. +const ( + pemTypeCert = "CERTIFICATE" + pemTypeRSA = "RSA PRIVATE KEY" + pemTypeECDSA = "EC PRIVATE KEY" + pemTypeAny = "PRIVATE KEY" +) + +// LoadKeyPair loads a certificate and private key, certdata and keydata can be a PEM encoded block or a file. +// +// If [keydata] is empty, then the private key is assumed to be contained in [certdata]. +func LoadKeyPair(certdata, keydata string) (cert *x509.Certificate, key crypto.PrivateKey, err error) { + if keydata == "" { + keydata = certdata + } + if strings.Contains(certdata, "-----BEGIN "+pemTypeCert) { + log.Trace().Msg("parsing X.509 certificate") + if cert, err = decodePEMBCertificate([]byte(certdata)); err != nil { + return + } + } else { + log.Trace().Str("name", certdata).Msg("loading X.509 certificate") + if cert, err = LoadCertificate(certdata); err != nil { + return + } + } + if strings.Contains(keydata, pemTypeAny+"-----") { + log.Trace().Msg("parsing private key") + if key, err = decodePEMPrivateKey([]byte(keydata)); err != nil { + return + } + } else if key, err = LoadPrivateKey(keydata); err != nil { + log.Trace().Str("name", keydata).Msg("loading private key") + return + } + return +} + +// SaveKeyPair saves a certificate and private key in PEM encoding. +// +// If [keyFile] is empty, then the private key is stored in [certFile] alongside the certificate. +// +// Attempts are made to use secure file modes for files that contains private keys. +func SaveKeyPair(cert *x509.Certificate, key crypto.PrivateKey, certFile, keyFile string) (err error) { + var ( + keyDER []byte + keyPEMType = pemTypeAny + ) + switch key := key.(type) { + case *ecdsa.PrivateKey: + if keyDER, err = x509.MarshalECPrivateKey(key); err != nil { + return + } + keyPEMType = pemTypeECDSA + case ed25519.PrivateKey: + if keyDER, err = x509.MarshalPKCS8PrivateKey(key); err != nil { + return + } + case *rsa.PrivateKey: + keyDER = x509.MarshalPKCS1PrivateKey(key) + keyPEMType = pemTypeRSA + default: + return fmt.Errorf("mitm: don't know how to marshal %T", key) + } + + var certf, keyf *os.File + if certf, err = os.OpenFile(certFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644); err != nil { + return + } + defer func() { _ = certf.Close() }() + + if filepath.Clean(certFile) == filepath.Clean(keyFile) || keyFile == "" { + if err = certf.Chmod(0o600); err != nil { + return + } + keyf, keyFile = certf, certFile + } else { + if keyf, err = os.OpenFile(keyFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o600); err != nil { + return + } + defer func() { _ = keyf.Close() }() + } + + log.Debug().Str("file", certFile).Msg("saving X.509 certificate") + if err = pem.Encode(certf, &pem.Block{Type: pemTypeCert, Bytes: cert.Raw}); err != nil { + return + } + + log.Debug().Str("fiile", keyFile).Msg("saving private key") + if err = pem.Encode(keyf, &pem.Block{Type: keyPEMType, Bytes: keyDER}); err != nil { + return + } + return +} + +// GenerateKeyPair generates a private key and self-signed certificate. +func GenerateKeyPair(name pkix.Name, days int, keyType string, keyBits int) (cert *x509.Certificate, key crypto.PrivateKey, err error) { + if key, err = GeneratePrivateKey(keyType, keyBits); err != nil { + return + } + if cert, err = GenerateCertificateAuthority(name, days, key); err != nil { + return + } + return +} + +func GenerateCertificateAuthority(name pkix.Name, days int, key crypto.PrivateKey) (cert *x509.Certificate, err error) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err) + } + + keyUsage := x509.KeyUsageCertSign + if _, ok := key.(*rsa.PrivateKey); ok { + keyUsage |= x509.KeyUsageDigitalSignature + } + + notBefore := roundToDay(time.Now()) + notAfter := notBefore.Add(time.Duration(days) * 24 * time.Hour) + + template := &x509.Certificate{ + Subject: name, + SerialNumber: serialNumber, + KeyUsage: keyUsage, + SubjectKeyId: GenerateKeyID(key), + IsCA: true, + BasicConstraintsValid: true, + NotBefore: notBefore, + NotAfter: notAfter, + } + + log.Info(). + Str("name", name.CommonName). + Int("days", days). + Str("key", keyType(key)). + Str("serial", serialNumber.String()). + Msg("generating self-signed CA certificate") + + var der []byte + if der, err = x509.CreateCertificate(rand.Reader, template, template, PublicKey(key), key); err != nil { + return + } + + return x509.ParseCertificate(der) +} + +func GeneratePrivateKey(kind string, bits int) (key crypto.PrivateKey, err error) { + switch strings.ToLower(kind) { + case TypeRSA, "": + if bits == 0 { + bits = 2048 + } + log.Trace().Int("bits", bits).Str("type", TypeRSA).Msg("generating private key") + return rsa.GenerateKey(rand.Reader, bits) + + case TypeECDSA, "ec", "ecc": + if bits == 0 { + bits = 256 + } + + var curve elliptic.Curve + switch bits { + case 224: + curve = elliptic.P224() + case 256: + curve = elliptic.P256() + case 384: + curve = elliptic.P384() + case 521: + curve = elliptic.P521() + default: + return nil, fmt.Errorf("mitm: elliptic curve %d bits not supported", bits) + } + log.Trace().Int("bits", bits).Str("type", TypeECDSA).Msg("generating private key") + return ecdsa.GenerateKey(curve, rand.Reader) + + case TypeED25519: + log.Trace().Str("type", TypeED25519).Msg("generating ED25519 private key") + _, key, err = ed25519.GenerateKey(rand.Reader) + return + + default: + return nil, fmt.Errorf("mitm: don't know how to generate %s private key", kind) + } +} + +func decodePEMBCertificate(b []byte) (cert *x509.Certificate, err error) { + var ( + rest = b + block *pem.Block + ) + for { + if block, rest = pem.Decode(rest); block == nil { + return nil, errors.New("mitm: no CERTIFICATE PEM block could be decoded") + } else if block.Type == "CERTIFICATE" { + return x509.ParseCertificate(block.Bytes) + } + } +} + +func LoadCertificate(name string) (*x509.Certificate, error) { + b, err := os.ReadFile(name) + if err != nil { + return nil, err + } + return decodePEMBCertificate(b) +} + +func roundToDay(t time.Time) time.Time { + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) +} diff --git a/internal/log/log.go b/internal/log/log.go new file mode 100644 index 0000000..11da6f8 --- /dev/null +++ b/internal/log/log.go @@ -0,0 +1,44 @@ +package log + +import ( + "io" + + "github.com/rs/zerolog" +) + +// Aliases +const ( + TraceLevel = zerolog.TraceLevel + DebugLevel = zerolog.DebugLevel + InfoLevel = zerolog.InfoLevel + WarnLevel = zerolog.WarnLevel + FatalLevel = zerolog.FatalLevel +) + +// Aliases +type ( + Event = zerolog.Event + Logger = zerolog.Logger +) + +// Console logger. +var Console = zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger() + +func SetLevel(level zerolog.Level) { + zerolog.SetGlobalLevel(level) + //Console = Console.Level(level) +} + +func Trace() *Event { return Console.Trace() } +func Debug() *Event { return Console.Debug() } +func Info() *Event { return Console.Info() } +func Warn() *Event { return Console.Warn() } +func Error() *Event { return Console.Error() } +func Fatal() *Event { return Console.Fatal() } +func Panic() *Event { return Console.Panic() } + +func OnCloseError(event *Event, closer io.Closer) { + if err := closer.Close(); err != nil { + event.Err(err).Msg("close failed") + } +} diff --git a/internal/netutil/addr.go b/internal/netutil/addr.go new file mode 100644 index 0000000..44a96e5 --- /dev/null +++ b/internal/netutil/addr.go @@ -0,0 +1,35 @@ +package netutil + +import ( + "net" + "strconv" +) + +// EnsurePort makes sure the address in [host] contains a port. +func EnsurePort(host, port string) string { + if _, _, err := net.SplitHostPort(host); err == nil { + return host + } + return net.JoinHostPort(host, port) +} + +// Host returns the bare host (without port). +func Host(name string) string { + host, _, err := net.SplitHostPort(name) + if err == nil { + return host + } + return name +} + +// Port returns the port number. +func Port(name string) int { + _, port, err := net.SplitHostPort(name) + if err != nil { + return 0 + } + + // TODO: name resolution for ports? + i, _ := strconv.Atoi(port) + return i +} diff --git a/internal/netutil/domain.go b/internal/netutil/domain.go new file mode 100644 index 0000000..f7e7668 --- /dev/null +++ b/internal/netutil/domain.go @@ -0,0 +1,99 @@ +package netutil + +import ( + "strings" + + "github.com/miekg/dns" +) + +type DomainTree struct { + root *domainTreeNode +} + +type domainTreeNode struct { + leaf map[string]*domainTreeNode + isEnd bool +} + +func NewDomainList(domains ...string) *DomainTree { + tree := &DomainTree{ + root: &domainTreeNode{leaf: make(map[string]*domainTreeNode)}, + } + for _, domain := range domains { + tree.Add(domain) + } + return tree +} + +func (tree *DomainTree) Add(domain string) { + domain = normalizeDomain(domain) + if domain == "" { + return + } + + labels := dns.SplitDomainName(domain) + if len(labels) == 0 { + return + } + + node := tree.root + for i := len(labels) - 1; i >= 0; i-- { + label := labels[i] + if label == "" { + continue + } + if node.leaf == nil { + node.leaf = make(map[string]*domainTreeNode) + } + if node.leaf[label] == nil { + node.leaf[label] = &domainTreeNode{} + } + node = node.leaf[label] + } + node.isEnd = true +} + +func (tree *DomainTree) Contains(domain string) bool { + domain = normalizeDomain(domain) + if domain == "" { + return false + } + + labels := dns.SplitDomainName(domain) + if len(labels) == 0 { + return false + } + + node := tree.root + for i := len(labels) - 1; i >= 0; i-- { + if node.isEnd { + return true + } + + if node.leaf == nil { + return false + } + + label := labels[i] + if node = node.leaf[label]; node == nil { + return false + } + } + return node.isEnd +} + +func normalizeDomain(domain string) string { + domain = strings.ToLower(strings.TrimSpace(domain)) + if domain == "" { + return "" + } + + // Remove trailing dot if present, dns.Fqdn will add it back properly + domain = strings.TrimSuffix(domain, ".") + + if domain == "" { + return "" + } + + return dns.Fqdn(domain) +} diff --git a/internal/netutil/domain_test.go b/internal/netutil/domain_test.go new file mode 100644 index 0000000..52582f8 --- /dev/null +++ b/internal/netutil/domain_test.go @@ -0,0 +1,276 @@ +package netutil + +import ( + "testing" +) + +func TestDomainList(t *testing.T) { + tests := []struct { + name string + domains []string + hostname string + expected bool + }{ + // Basic exact matches + { + name: "exact match", + domains: []string{"example.com"}, + hostname: "example.com", + expected: true, + }, + { + name: "exact match with subdomain in list", + domains: []string{"api.example.com"}, + hostname: "api.example.com", + expected: true, + }, + + // Suffix matching - if domain is in list, all subdomains should match + { + name: "subdomain matches parent domain", + domains: []string{"example.com"}, + hostname: "sub.example.com", + expected: true, + }, + { + name: "multiple subdomain levels match", + domains: []string{"example.com"}, + hostname: "deep.nested.sub.example.com", + expected: true, + }, + { + name: "subdomain matches intermediate domain", + domains: []string{"api.example.com", "example.com"}, + hostname: "sub.api.example.com", + expected: true, + }, + + // Multi-level TLDs + { + name: "co.uk domain exact match", + domains: []string{"domain.co.uk"}, + hostname: "domain.co.uk", + expected: true, + }, + { + name: "subdomain of co.uk domain", + domains: []string{"domain.co.uk"}, + hostname: "sub.domain.co.uk", + expected: true, + }, + + // Case sensitivity + { + name: "case insensitive match", + domains: []string{"Example.COM"}, + hostname: "example.com", + expected: true, + }, + { + name: "case insensitive hostname", + domains: []string{"example.com"}, + hostname: "EXAMPLE.COM", + expected: true, + }, + + // Trailing dots + { + name: "domain with trailing dot", + domains: []string{"example.com."}, + hostname: "example.com", + expected: true, + }, + { + name: "hostname with trailing dot", + domains: []string{"example.com"}, + hostname: "example.com.", + expected: true, + }, + + // Non-matches + { + name: "different TLD", + domains: []string{"example.com"}, + hostname: "example.org", + expected: false, + }, + { + name: "different domain", + domains: []string{"example.com"}, + hostname: "test.com", + expected: false, + }, + { + name: "partial match but not suffix", + domains: []string{"example.com"}, + hostname: "com", + expected: false, + }, + { + name: "empty hostname", + domains: []string{"example.com"}, + hostname: "", + expected: false, + }, + + // Multiple domains in list + { + name: "matches first domain in list", + domains: []string{"test.org", "example.com"}, + hostname: "example.com", + expected: true, + }, + { + name: "matches second domain in list", + domains: []string{"test.org", "example.com"}, + hostname: "test.org", + expected: true, + }, + { + name: "subdomain matches any domain in list", + domains: []string{"test.org", "example.com"}, + hostname: "sub.example.com", + expected: true, + }, + + // Edge cases + { + name: "empty domain list", + domains: []string{}, + hostname: "example.com", + expected: false, + }, + { + name: "invalid domain in list", + domains: []string{""}, + hostname: "example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list := NewDomainList(tt.domains...) + result := list.Contains(tt.hostname) + + if result != tt.expected { + t.Errorf("Contains(%q) = %v, expected %v (domains: %v)", + tt.hostname, result, tt.expected, tt.domains) + } + }) + } +} + +func TestDomainList_Performance(t *testing.T) { + // Test with a large number of domains to ensure performance + domains := make([]string, 1000) + for i := 0; i < 1000; i++ { + domains[i] = string(rune('a'+(i%26))) + ".com" + } + domains = append(domains, "example.com") // Add our test domain + + list := NewDomainList(domains...) + + // These should be fast even with many domains + if !list.Contains("example.com") { + t.Error("Should match exact domain") + } + if !list.Contains("sub.example.com") { + t.Error("Should match subdomain") + } + if list.Contains("notfound.com") { + t.Error("Should not match unrelated domain") + } +} + +func TestDomainList_ComplexDomains(t *testing.T) { + domains := []string{ + "very.long.domain.name.with.many.labels.com", + "example.co.uk", + "sub.domain.example.com", + "a.b.c.d.e.f.com", + } + + list := NewDomainList(domains...) + + tests := []struct { + hostname string + expected bool + }{ + {"very.long.domain.name.with.many.labels.com", true}, + {"sub.very.long.domain.name.with.many.labels.com", true}, + {"example.co.uk", true}, + {"www.example.co.uk", true}, + {"sub.domain.example.com", true}, + {"another.sub.domain.example.com", true}, + {"a.b.c.d.e.f.com", true}, + {"x.a.b.c.d.e.f.com", true}, + {"not.matching.com", false}, + {"com", false}, + {"uk", false}, + } + + for _, tt := range tests { + t.Run(tt.hostname, func(t *testing.T) { + result := list.Contains(tt.hostname) + if result != tt.expected { + t.Errorf("Contains(%q) = %v, expected %v", tt.hostname, result, tt.expected) + } + }) + } +} + +func TestDomainList_SpecialCases(t *testing.T) { + t.Run("domain with asterisk treated literally", func(t *testing.T) { + list := NewDomainList("*.example.com") + + // The asterisk should be treated as a literal label, not a wildcard + if !list.Contains("*.example.com") { + t.Error("Asterisk should be treated literally, not as wildcard") + } + if list.Contains("test.example.com") { + t.Error("Should not match subdomain with literal asterisk domain") + } + }) + + t.Run("domains with hyphens and numbers", func(t *testing.T) { + list := NewDomainList("test-123.example.com", "123abc.org") + + if !list.Contains("test-123.example.com") { + t.Error("Should match domain with hyphens and numbers") + } + if !list.Contains("sub.test-123.example.com") { + t.Error("Should match subdomain of hyphenated domain") + } + if !list.Contains("123abc.org") { + t.Error("Should match domain starting with numbers") + } + if !list.Contains("www.123abc.org") { + t.Error("Should match subdomain of numeric domain") + } + }) +} + +func BenchmarkDomainList(b *testing.B) { + // Benchmark with realistic domain list + domains := []string{ + "google.com", + "github.com", + "example.org", + "sub.domain.com", + "api.service.co.uk", + "very.long.domain.name.example.com", + } + + list := NewDomainList(domains...) + + b.ResetTimer() + for b.Loop() { + // Mix of matches and non-matches + list.Contains("sub.example.org") + list.Contains("api.github.com") + list.Contains("nonexistent.com") + list.Contains("deep.nested.sub.domain.com") + list.Contains("service.co.uk") + } +} diff --git a/internal/netutil/network.go b/internal/netutil/network.go new file mode 100644 index 0000000..a6ea455 --- /dev/null +++ b/internal/netutil/network.go @@ -0,0 +1,44 @@ +package netutil + +import ( + "net" + + "github.com/yl2chen/cidranger" +) + +type NetworkTree struct { + ranger cidranger.Ranger +} + +func NewNetworkTree(networks ...string) (*NetworkTree, error) { + tree := &NetworkTree{ + ranger: cidranger.NewPCTrieRanger(), + } + for _, cidr := range networks { + if err := tree.AddCIDR(cidr); err != nil { + return nil, err + } + } + return tree, nil +} + +func (tree *NetworkTree) Add(ipnet *net.IPNet) { + if ipnet == nil { + return + } + tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet)) +} + +func (tree *NetworkTree) AddCIDR(cidr string) error { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet)) + return nil +} + +func (tree *NetworkTree) Contains(ip net.IP) bool { + contains, _ := tree.ranger.Contains(ip) + return contains +} diff --git a/internal/netutil/network_test.go b/internal/netutil/network_test.go new file mode 100644 index 0000000..63d7572 --- /dev/null +++ b/internal/netutil/network_test.go @@ -0,0 +1,410 @@ +package netutil + +import ( + "net" + "testing" +) + +func TestNewNetworkTree(t *testing.T) { + // Test empty creation + nl, err := NewNetworkTree() + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + if nl == nil { + t.Fatal("NewNetworkTree() returned nil") + } + if nl.ranger == nil { + t.Error("NetworkTree ranger should not be nil") + } + + // Test creation with networks + nl, err = NewNetworkTree("192.168.1.0/24", "10.0.0.0/8") + if err != nil { + t.Fatalf("NewNetworkTree() with networks failed: %v", err) + } + if nl == nil { + t.Fatal("NewNetworkTree() with networks returned nil") + } +} + +func TestNewNetworkTree_InvalidNetworks(t *testing.T) { + // Test with invalid network + _, err := NewNetworkTree("invalid-cidr") + if err == nil { + t.Error("NewNetworkTree() with invalid CIDR should have failed") + } + + // Test with mix of valid and invalid networks + _, err = NewNetworkTree("192.168.1.0/24", "invalid-cidr", "10.0.0.0/8") + if err == nil { + t.Error("NewNetworkTree() with mixed valid/invalid CIDRs should have failed") + } +} + +func TestNetworkTree_AddCIDR_Valid(t *testing.T) { + nl, err := NewNetworkTree() + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + tests := []struct { + cidr string + desc string + }{ + {"192.168.1.0/24", "IPv4 CIDR"}, + {"10.0.0.0/8", "IPv4 large range"}, + {"2001:db8::/32", "IPv6 CIDR"}, + {"::1/128", "IPv6 localhost"}, + {"0.0.0.0/0", "IPv4 entire internet"}, + {"::/0", "IPv6 entire internet"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if err := nl.AddCIDR(tt.cidr); err != nil { + t.Errorf("AddCIDR(%q) failed: %v", tt.cidr, err) + } + }) + } +} + +func TestNetworkTree_AddCIDR_Invalid(t *testing.T) { + nl, err := NewNetworkTree() + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + invalidCIDRs := []string{ + "invalid-cidr", + "192.168.1.1", // missing mask + "192.168.1.0/33", // invalid mask for IPv4 + "2001:db8::/129", // invalid mask for IPv6 + "", + "not-an-ip/24", + } + + for _, cidr := range invalidCIDRs { + t.Run(cidr, func(t *testing.T) { + if err := nl.AddCIDR(cidr); err == nil { + t.Errorf("AddCIDR(%q) should have failed but didn't", cidr) + } + }) + } +} + +func TestNetworkTree_Add(t *testing.T) { + nl, err := NewNetworkTree() + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + tests := []struct { + cidr string + desc string + }{ + {"192.168.1.0/24", "IPv4 network"}, + {"2001:db8::/32", "IPv6 network"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + _, ipNet, err := net.ParseCIDR(tt.cidr) + if err != nil { + t.Fatalf("ParseCIDR failed: %v", err) + } + + // Should not panic + nl.Add(ipNet) + }) + } +} + +func TestNetworkTree_Contains_IPv4(t *testing.T) { + nl, err := NewNetworkTree("192.168.1.0/24", "10.0.0.0/8", "172.16.0.0/12") + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + tests := []struct { + ip string + want bool + desc string + }{ + // IPs that should match + {"192.168.1.1", true, "in 192.168.1.0/24"}, + {"192.168.1.255", true, "broadcast in 192.168.1.0/24"}, + {"10.0.0.1", true, "in 10.0.0.0/8"}, + {"10.255.255.255", true, "max in 10.0.0.0/8"}, + {"172.16.0.1", true, "in 172.16.0.0/12"}, + {"172.31.255.255", true, "max in 172.16.0.0/12"}, + + // IPs that should not match + {"192.168.2.1", false, "outside 192.168.1.0/24"}, + {"11.0.0.1", false, "outside 10.0.0.0/8"}, + {"172.32.0.1", false, "outside 172.16.0.0/12"}, + {"8.8.8.8", false, "public DNS"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("ParseIP(%q) returned nil", tt.ip) + } + + got := nl.Contains(ip) + if got != tt.want { + t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want) + } + }) + } +} + +func TestNetworkTree_Contains_IPv6(t *testing.T) { + nl, err := NewNetworkTree("2001:db8::/32", "2001:db8:abcd::/48", "::1/128") + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + tests := []struct { + ip string + want bool + desc string + }{ + // IPs that should match + {"2001:db8::1", true, "in 2001:db8::/32"}, + {"2001:db8:ffff:ffff:ffff:ffff:ffff:ffff", true, "max in 2001:db8::/32"}, + {"2001:db8:abcd::1", true, "in 2001:db8:abcd::/48"}, + {"::1", true, "localhost"}, + + // IPs that should not match + {"2001:db9::1", false, "outside 2001:db8::/32"}, + {"2001:db9:abcd::1", false, "outside 2001:db8:abcd::/48"}, + {"::2", false, "outside ::1/128"}, + {"2001:4860:4860::8888", false, "public DNS"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("ParseIP(%q) returned nil", tt.ip) + } + + got := nl.Contains(ip) + if got != tt.want { + t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want) + } + }) + } +} + +func TestNetworkTree_Contains_EdgeCases(t *testing.T) { + nl, err := NewNetworkTree() + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + // Test with nil IP + if nl.Contains(nil) != false { + t.Error("Contains(nil) should return false") + } + + // Test with empty list + ip := net.ParseIP("192.168.1.1") + if nl.Contains(ip) != false { + t.Error("Contains() on empty list should return false") + } +} + +func TestNetworkTree_Contains_OverlappingRanges(t *testing.T) { + nl, err := NewNetworkTree("192.168.0.0/16", "192.168.1.0/24", "192.168.1.128/25") + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + // All these should match because we have overlapping ranges + tests := []string{ + "192.168.1.1", + "192.168.1.129", + "192.168.2.1", + } + + for _, ipStr := range tests { + t.Run(ipStr, func(t *testing.T) { + ip := net.ParseIP(ipStr) + if !nl.Contains(ip) { + t.Errorf("Contains(%q) should return true for overlapping ranges", ipStr) + } + }) + } +} + +func TestNetworkTree_Contains_EntireInternet(t *testing.T) { + nl, err := NewNetworkTree("0.0.0.0/0", "::/0") + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + tests := []struct { + ip string + desc string + }{ + {"192.168.1.1", "IPv4 private"}, + {"8.8.8.8", "IPv4 public"}, + {"2001:db8::1", "IPv6"}, + {"::1", "IPv6 localhost"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if !nl.Contains(ip) { + t.Errorf("Contains(%q) should return true for entire internet range", tt.ip) + } + }) + } +} + +func TestNetworkTree_MixedIPv4AndIPv6(t *testing.T) { + nl, err := NewNetworkTree("192.168.1.0/24", "2001:db8::/32") + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + // Test IPv4 in IPv6 format (should still work due to normalization) + ipv4InIPv6 := net.ParseIP("::ffff:192.168.1.1") // IPv4-mapped IPv6 + if !nl.Contains(ipv4InIPv6) { + t.Error("Contains() should handle IPv4-mapped IPv6 addresses") + } + + // Regular IPv4 should work + ipv4 := net.ParseIP("192.168.1.1") + if !nl.Contains(ipv4) { + t.Error("Contains() should handle regular IPv4 addresses") + } + + // IPv6 should work + ipv6 := net.ParseIP("2001:db8::1") + if !nl.Contains(ipv6) { + t.Error("Contains() should handle IPv6 addresses") + } +} + +func TestNetworkTree_Add_InvalidIPNet(t *testing.T) { + nl, err := NewNetworkTree() + if err != nil { + t.Fatalf("NewNetworkTree() failed: %v", err) + } + + // Create an invalid IPNet (nil IP) + invalidIPNet := &net.IPNet{ + IP: nil, + Mask: net.CIDRMask(24, 32), + } + + // This should not panic + nl.Add(invalidIPNet) + + // Verify that it doesn't affect Contains results + ip := net.ParseIP("192.168.1.1") + if nl.Contains(ip) { + t.Error("Contains() should return false after adding invalid IPNet") + } +} + +func TestNetworkTree_InitializationWithNetworks(t *testing.T) { + networks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "2001:db8::/32", + } + + nl, err := NewNetworkTree(networks...) + if err != nil { + t.Fatalf("NewNetworkTree() with multiple networks failed: %v", err) + } + + // Test that all networks were added correctly + testCases := []struct { + ip string + want bool + }{ + {"10.1.2.3", true}, + {"172.16.1.1", true}, + {"192.168.1.1", true}, + {"2001:db8::1", true}, + {"8.8.8.8", false}, + } + + for _, tc := range testCases { + ip := net.ParseIP(tc.ip) + if got := nl.Contains(ip); got != tc.want { + t.Errorf("Contains(%q) = %v, want %v", tc.ip, got, tc.want) + } + } +} + +func BenchmarkNetworkTree_Contains(b *testing.B) { + nl, err := NewNetworkTree( + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "2001:db8::/32", + ) + if err != nil { + b.Fatalf("NewNetworkTree() failed: %v", err) + } + + testIPs := []net.IP{ + net.ParseIP("10.1.2.3"), + net.ParseIP("192.168.1.1"), + net.ParseIP("2001:db8::1"), + net.ParseIP("8.8.8.8"), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := testIPs[i%len(testIPs)] + nl.Contains(ip) + } +} + +func BenchmarkNetworkTree_NewNetworkTree(b *testing.B) { + cidrs := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "2001:db8::/32", + } + + b.ResetTimer() + for b.Loop() { + _, err := NewNetworkTree(cidrs...) + if err != nil { + b.Fatalf("NewNetworkTree() failed: %v", err) + } + } +} + +func BenchmarkNetworkTree_AddCIDR(b *testing.B) { + cidrs := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "2001:db8::/32", + } + + b.ResetTimer() + for b.Loop() { + nl, err := NewNetworkTree() + if err != nil { + b.Fatalf("NewNetworkTree() failed: %v", err) + } + for _, cidr := range cidrs { + nl.AddCIDR(cidr) + } + } +} diff --git a/proxy/admin.go b/proxy/admin.go new file mode 100644 index 0000000..b04f652 --- /dev/null +++ b/proxy/admin.go @@ -0,0 +1,145 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "encoding/pem" + "errors" + "net/http" + "os" + "strconv" + "strings" + "time" + + "git.maze.io/maze/styx/internal/log" +) + +type Admin struct { + *Proxy +} + +func NewAdmin(proxy *Proxy) *Admin { + a := &Admin{ + Proxy: proxy, + } + return a +} + +func (a *Admin) handleRequest(ses *Session) error { + var ( + logger = ses.log() + err error + ) + switch ses.request.URL.Path { + case "/ca.crt": + err = a.handleCACert(ses) + case "/api/v1/policy": + err = a.apiPolicy(ses) + case "/api/v1/policy/matcher": + err = a.apiPolicyMatcher(ses) + case "/api/v1/stats/log": + err = a.apiStatsLog(ses) + case "/api/v1/stats/status": + err = a.apiStatsStatus(ses) + default: + if strings.HasPrefix(ses.request.URL.Path, "/api") { + err = errors.New("invalid endpoint") + } else { + err = os.ErrNotExist + } + } + if err != nil { + logger.Warn().Err(err).Msg("admin error") + ses.response = ErrorResponse(ses.request, err) + defer log.OnCloseError(logger.Debug(), ses.response.Body) + ses.response.Close = true + return a.writeResponse(ses) + } + return err +} + +func (a *Admin) handleCACert(ses *Session) error { + b := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: a.authority.Certificate().Raw, + }) + + ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request) + defer log.OnCloseError(log.Debug(), ses.response.Body) + + ses.response.Close = true + ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert") + ses.response.ContentLength = int64(len(b)) + return a.writeResponse(ses) +} + +func (a *Admin) apiPolicy(ses *Session) error { + var ( + b = new(bytes.Buffer) + e = json.NewEncoder(b) + ) + e.SetIndent("", " ") + if err := e.Encode(a.config.Policy); err != nil { + return err + } + + ses.response = NewJSONResponse(http.StatusOK, b, ses.request) + defer log.OnCloseError(log.Debug(), ses.response.Body) + ses.response.Close = true + return a.writeResponse(ses) +} + +func (a *Admin) apiPolicyMatcher(ses *Session) error { + var ( + b = new(bytes.Buffer) + e = json.NewEncoder(b) + ) + e.SetIndent("", " ") + if err := e.Encode(a.config.Policy.Matchers); err != nil { + return err + } + + ses.response = NewJSONResponse(http.StatusOK, b, ses.request) + defer log.OnCloseError(log.Debug(), ses.response.Body) + ses.response.Close = true + return a.writeResponse(ses) +} + +func (a *Admin) apiResponse(ses *Session, v any, err error) error { + if err != nil { + return err + } + var ( + b = new(bytes.Buffer) + e = json.NewEncoder(b) + ) + e.SetIndent("", " ") + if err := e.Encode(v); err != nil { + return err + } + + ses.response = NewJSONResponse(http.StatusOK, b, ses.request) + defer log.OnCloseError(log.Debug(), ses.response.Body) + ses.response.Close = true + return a.writeResponse(ses) + +} + +func (a *Admin) apiStatsLog(ses *Session) error { + var ( + query = ses.request.URL.Query() + offset, _ = strconv.Atoi(query.Get("offset")) + limit, _ = strconv.Atoi(query.Get("limit")) + ) + if limit > 100 { + limit = 100 + } + + s, err := a.stats.QueryLog(offset, limit) + return a.apiResponse(ses, s, err) +} + +func (a *Admin) apiStatsStatus(ses *Session) error { + s, err := a.stats.QueryStatus(time.Time{}) + return a.apiResponse(ses, s, err) +} diff --git a/proxy/cache/config.go b/proxy/cache/config.go new file mode 100644 index 0000000..cb46c72 --- /dev/null +++ b/proxy/cache/config.go @@ -0,0 +1,8 @@ +package cache + +import "github.com/hashicorp/hcl/v2" + +type Config struct { + Type string `hcl:"type"` + Body hcl.Body `hcl:",remain"` +} diff --git a/proxy/config.go b/proxy/config.go new file mode 100644 index 0000000..4e5c62c --- /dev/null +++ b/proxy/config.go @@ -0,0 +1,88 @@ +package proxy + +import ( + "net" + "net/http" + "time" + + "git.maze.io/maze/styx/proxy/policy" + "git.maze.io/maze/styx/proxy/resolver" +) + +type ConnectHandler interface { + HandleConnect(session *Session, network, address string) net.Conn +} + +// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request. +type ConnectHandlerFunc func(session *Session, network, address string) net.Conn + +func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn { + return f(session, network, address) +} + +type RequestHandler interface { + HandleRequest(session *Session) (*http.Request, *http.Response) +} + +// RequestHandlerFunc is called when the proxy receives a new request. +type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response) + +func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) { + return f(session) +} + +type ResponseHandler interface { + HandleResponse(session *Session) *http.Response +} + +// ResponseHandler is called when the proxy receives a response. +type ResponseHandlerFunc func(session *Session) *http.Response + +func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response { + return f(session) +} + +type ErrorHandler interface { + HandleError(session *Session, err error) +} + +type ErrorHandlerFunc func(session *Session, err error) + +func (f ErrorHandlerFunc) HandleError(session *Session, err error) { + f(session, err) +} + +type Config struct { + // Listen address. + Listen string `hcl:"listen,optional"` + + // Bind address for outgoing connections. + Bind string `hcl:"bind,optional"` + + // Interface for outgoing connections. + Interface string `hcl:"interface,optional"` + + // Upstream proxy servers. + Upstream []string `hcl:"upstream,optional"` + + // DialTimeout for establishing new connections. + DialTimeout time.Duration `hcl:"dial_timeout,optional"` + + // Policy for the proxy. + Policy *policy.Policy `hcl:"policy,block"` + + // Resolver for the proxy. + Resolver resolver.Resolver + + ConnectHandler ConnectHandler + RequestHandler RequestHandler + ResponseHandler ResponseHandler + ErrorHandler ErrorHandler +} + +var ( + _ ConnectHandler = (ConnectHandlerFunc)(nil) + _ RequestHandler = (RequestHandlerFunc)(nil) + _ ResponseHandler = (ResponseHandlerFunc)(nil) + _ ErrorHandler = (ErrorHandlerFunc)(nil) +) diff --git a/proxy/match/config.go b/proxy/match/config.go new file mode 100644 index 0000000..832a207 --- /dev/null +++ b/proxy/match/config.go @@ -0,0 +1,324 @@ +package match + +import ( + "fmt" + "net" + "net/http" + "os" + "regexp" + "slices" + "strconv" + "strings" + "time" + + "git.maze.io/maze/styx/internal/log" + "git.maze.io/maze/styx/internal/netutil" + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/gohcl" +) + +type Config struct { + Path string `hcl:"path,optional"` + Refresh time.Duration `hcl:"refresh,optional"` + Domain []*Domain `hcl:"domain,block"` + Network []*Network `hcl:"network,block"` + Content []*Content `hcl:"content,block"` +} + +func (config Config) Matchers() (Matchers, error) { + all := make(Matchers) + if config.Domain != nil { + all["domain"] = make(map[string]Matcher) + for _, domain := range config.Domain { + m, err := domain.Matcher() + if err != nil { + return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err) + } + all["domain"][domain.Name] = m + } + } + if config.Network != nil { + all["network"] = make(map[string]Matcher) + for _, network := range config.Network { + m, err := network.Matcher(true) + if err != nil { + return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err) + } + all["network"][network.Name] = m + } + } + return all, nil +} + +type Content struct { + Name string `hcl:"name,label"` + Type string `hcl:"type"` + Body hcl.Body `hcl:",remain"` +} + +type contentHeader struct { + Key string `hcl:"name"` + Value string `hcl:"value,optional"` + List []string `hcl:"list,optional"` + name string + keyRe *regexp.Regexp + valueRe *regexp.Regexp +} + +func (m contentHeader) Name() string { return m.name } +func (m contentHeader) MatchesResponse(r *http.Response) bool { + for k, vv := range r.Header { + if m.keyRe.MatchString(k) { + for _, v := range vv { + if slices.Contains(m.List, v) { + return true + } + if m.valueRe != nil && m.valueRe.MatchString(v) { + return true + } + } + } + } + return false +} + +type contentType struct { + List []string `hcl:"list"` + name string +} + +func (m contentType) Name() string { return m.name } +func (m contentType) MatchesResponse(r *http.Response) bool { + return slices.Contains(m.List, r.Header.Get("Content-Type")) +} + +type contentSizeLargerThan struct { + Size int64 `hcl:"size"` + name string +} + +func (m contentSizeLargerThan) Name() string { return m.name } +func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool { + size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) + if err != nil { + return false + } + return size >= m.Size +} + +type contentStatus struct { + Code []int `hcl:"code"` + name string +} + +func (m contentStatus) Name() string { return m.name } +func (m contentStatus) MatchesResponse(r *http.Response) bool { + return slices.Contains(m.Code, r.StatusCode) +} + +func (config Content) Matcher() (Response, error) { + switch strings.ToLower(config.Type) { + case "content", "contenttype", "content-type", "type": + var matcher = contentType{name: config.Name} + if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { + return nil, err + } + return matcher, nil + + case "header": + var ( + matcher = contentHeader{name: config.Name} + err error + ) + if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { + return nil, err + } + if matcher.Value == "" && len(matcher.List) == 0 { + return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name) + } + if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil { + return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err) + } + if matcher.Value != "" { + if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil { + return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err) + } + } + return matcher, nil + + case "size": + var matcher = contentSizeLargerThan{name: config.Name} + if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { + return nil, err + } + return matcher, nil + + case "status": + var matcher = contentStatus{name: config.Name} + if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { + return nil, err + } + return matcher, nil + + default: + return nil, fmt.Errorf("unknown content matcher type %q", config.Type) + } +} + +type Domain struct { + Name string `hcl:"name,label"` + Type string `hcl:"type"` + Body hcl.Body `hcl:",remain"` +} + +func (config Domain) Matcher() (Request, error) { + switch config.Type { + case "list": + var matcher = domainList{Title: config.Name} + if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { + return nil, err + } + matcher.list = netutil.NewDomainList(matcher.List...) + return matcher, nil + + case "adblock", "dnsmasq", "hosts", "detect", "domains": + var matcher = DomainFile{ + Title: config.Name, + Type: config.Type, + } + if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil { + return nil, err + } + if matcher.Path == "" && matcher.From == "" { + return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name) + } + if err := matcher.Update(); err != nil { + return nil, err + } + return matcher, nil + + default: + return nil, fmt.Errorf("unknown domain matcher type %q", config.Type) + } + +} + +type domainList struct { + Title string `json:"title"` + List []string `hcl:"list" json:"list"` + list *netutil.DomainTree +} + +func (m domainList) Name() string { + return m.Title +} + +func (m domainList) MatchesRequest(r *http.Request) bool { + host := netutil.Host(r.URL.Host) + log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List)) + return m.list.Contains(host) +} + +type DomainFile struct { + Title string `json:"name"` + Type string `json:"type"` + Path string `hcl:"path,optional" json:"path,omitempty"` + From string `hcl:"from,optional" json:"from,omitempty"` + Refresh time.Duration `hcl:"refresh,optional" json:"refresh"` +} + +func (m DomainFile) Name() string { + return m.Title +} + +func (m DomainFile) MatchesRequest(_ *http.Request) bool { + return false +} + +func (m *DomainFile) Update() (err error) { + var data []byte + if m.Path != "" { + if data, err = os.ReadFile(m.Path); err != nil { + return + } + } else { + /* + var response *http.Response + if response, err = http.DefaultClient.Get(m.From); err != nil { + return + } + defer func() { _ = response.Body.Close() }() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status) + } + if data, err = io.ReadAll(response.Body); err != nil { + return + } + */ + } + + switch m.Type { + case "hosts": + } + + _ = data + return nil +} + +type Network struct { + Name string `hcl:"name,label"` + Type string `hcl:"type"` + Body hcl.Body `hcl:",remain"` +} + +func (config *Network) Matcher(target bool) (Matcher, error) { + switch config.Type { + case "list": + var ( + matcher = networkList{Title: config.Name} + err error + ) + if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() { + return nil, diag + } + if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil { + return nil, err + } + return &matcher, nil + + default: + return nil, fmt.Errorf("unknown network matcher type %q", config.Type) + } +} + +type networkList struct { + Title string `json:"name"` + List []string `hcl:"list" json:"list"` + tree *netutil.NetworkTree + target bool +} + +func (m *networkList) Name() string { + return m.Title +} + +func (m *networkList) MatchesIP(ip net.IP) bool { + return m.tree.Contains(ip) +} + +func (m *networkList) MatchesRequest(r *http.Request) bool { + var ( + host string + err error + ) + if m.target { + host, _, err = net.SplitHostPort(r.URL.Host) + } else { + host, _, err = net.SplitHostPort(r.RemoteAddr) + } + if err != nil { + return false + } + ip := net.ParseIP(host) + return m.MatchesIP(ip) +} diff --git a/proxy/match/match.go b/proxy/match/match.go new file mode 100644 index 0000000..98a5e17 --- /dev/null +++ b/proxy/match/match.go @@ -0,0 +1,45 @@ +package match + +import ( + "fmt" + "net" + "net/http" +) + +type Matchers map[string]map[string]Matcher + +func (all Matchers) Get(kind, name string) (m Matcher, err error) { + if typeMatchers, ok := all[kind]; ok { + if m, ok = typeMatchers[name]; ok { + return + } + return nil, fmt.Errorf("no %s matcher named %q found", kind, name) + } + return nil, fmt.Errorf("no %s matcher found", kind) +} + +type Matcher interface { + Name() string +} + +type Updater interface { + Update() error +} + +type IP interface { + Matcher + + MatchesIP(net.IP) bool +} + +type Request interface { + Matcher + + MatchesRequest(*http.Request) bool +} + +type Response interface { + Matcher + + MatchesResponse(*http.Response) bool +} diff --git a/proxy/match/util.go b/proxy/match/util.go new file mode 100644 index 0000000..210b1e5 --- /dev/null +++ b/proxy/match/util.go @@ -0,0 +1,11 @@ +package match + +import "net" + +func onlyHost(name string) string { + host, _, err := net.SplitHostPort(name) + if err != nil { + return name + } + return host +} diff --git a/proxy/mitm/authority.go b/proxy/mitm/authority.go new file mode 100644 index 0000000..37b55ca --- /dev/null +++ b/proxy/mitm/authority.go @@ -0,0 +1,231 @@ +package mitm + +import ( + "crypto" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "fmt" + "math/big" + "os" + "strings" + "time" + + "git.maze.io/maze/styx/internal/cryptutil" + "git.maze.io/maze/styx/internal/log" + "github.com/miekg/dns" +) + +const DefaultValidity = 24 * time.Hour + +type Authority interface { + Certificate() *x509.Certificate + TLSConfig(name string) *tls.Config +} + +type authority struct { + pool *x509.CertPool + cert *x509.Certificate + key crypto.PrivateKey + keyID []byte + keyPool chan crypto.PrivateKey + cache Cache +} + +func New(config *Config) (Authority, error) { + cache, err := NewCache(config.Cache) + if err != nil { + return nil, err + } + + caConfig := config.CA + if caConfig == nil { + caConfig = new(CAConfig) + } + + cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key) + if os.IsNotExist(err) { + days := caConfig.Days + if days == 0 { + days = DefaultDays + } + if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil { + return nil, err + } + if strings.ContainsRune(caConfig.Cert, os.PathSeparator) { + if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil { + return nil, err + } + } + } else if err != nil { + return nil, err + } + + pool := x509.NewCertPool() + pool.AddCert(cert) + + keyConfig := config.Key + if keyConfig == nil { + keyConfig = &defaultKeyConfig + } + + keyPoolSize := defaultKeyConfig.Pool + if keyConfig.Pool > 0 { + keyPoolSize = keyConfig.Pool + } + keyPool := make(chan crypto.PrivateKey, keyPoolSize) + if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil { + return nil, fmt.Errorf("mitm: invalid key configuration: %w", err) + } else { + keyPool <- key + } + + go func(pool chan<- crypto.PrivateKey) { + for { + key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits) + if err != nil { + log.Panic().Err(err).Msg("error generating private key") + } + pool <- key + } + }(keyPool) + + return &authority{ + pool: pool, + cert: cert, + key: key, + keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)), + keyPool: keyPool, + cache: cache, + }, nil +} + +func (ca *authority) log() log.Logger { + return log.Console.With(). + Str("ca", ca.cert.Subject.String()). + Logger() +} + +func (ca *authority) Certificate() *x509.Certificate { + return ca.cert +} + +func (ca *authority) TLSConfig(name string) *tls.Config { + return &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + log := ca.log() + if hello.ServerName != "" { + name = strings.ToLower(hello.ServerName) + log.Debug().Msg("requesting certificate for server name (SNI)") + } else { + log.Debug().Msg("requesting certificate for hostname") + } + if cert, ok := ca.getCached(name); ok { + log.Debug(). + Str("subject", cert.Leaf.Subject.String()). + Str("serial", cert.Leaf.SerialNumber.String()). + Time("valid", cert.Leaf.NotAfter). + Msg("using cached certificate") + return cert, nil + } + return ca.issueFor(name) + }, + NextProtos: []string{"http/1.1"}, + } +} + +func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) { + log := ca.log() + + if cert = ca.cache.Certificate(name); cert == nil { + if baseDomain(name) != name { + cert = ca.cache.Certificate(baseDomain(name)) + } + } + if cert != nil { + if _, err := cert.Leaf.Verify(x509.VerifyOptions{ + DNSName: name, + Roots: ca.pool, + }); err != nil { + log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache") + } else { + ok = true + } + } + return +} + +func (ca *authority) issueFor(name string) (*tls.Certificate, error) { + var ( + log = ca.log().With().Str("name", name).Logger() + key crypto.PrivateKey + ) + select { + case key = <-ca.keyPool: + case <-time.After(5 * time.Second): + return nil, errors.New("mitm: timeout waiting for private key generator to catch up") + } + if key == nil { + panic("key pool returned nil key") + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err) + } + + if part := dns.SplitDomainName(name); len(part) > 2 { + name = strings.Join(part[1:], ".") + log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name) + } + + now := time.Now() + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: name}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{name, "*." + name}, + BasicConstraintsValid: true, + NotBefore: now.Add(-DefaultValidity), + NotAfter: now.Add(+DefaultValidity), + } + der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate") + out := &tls.Certificate{ + Certificate: [][]byte{der}, + Leaf: cert, + PrivateKey: key, + } + //ca.cache[name] = out + ca.cache.SaveCertificate(name, out) + return out, nil +} + +func containsValidCertificate(cert *tls.Certificate) bool { + if cert == nil || len(cert.Certificate) == 0 { + return false + } + + if cert.Leaf == nil { + var err error + if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil { + return false + } + } + + now := time.Now() + + return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now)) +} diff --git a/proxy/mitm/cache.go b/proxy/mitm/cache.go new file mode 100644 index 0000000..e0221ad --- /dev/null +++ b/proxy/mitm/cache.go @@ -0,0 +1,233 @@ +package mitm + +import ( + "crypto/tls" + "fmt" + "io/fs" + "os" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/hashicorp/hcl/v2/gohcl" + "github.com/miekg/dns" + + "git.maze.io/maze/styx/internal/cryptutil" + "git.maze.io/maze/styx/internal/log" +) + +type Cache interface { + Certificate(name string) *tls.Certificate + SaveCertificate(name string, cert *tls.Certificate) error + RemoveCertificate(name string) +} + +func NewCache(config *CacheConfig) (Cache, error) { + if config == nil { + return NewCache(&CacheConfig{Type: "memory"}) + } + switch config.Type { + case "memory": + var cacheConfig = new(MemoryCacheConfig) + if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil { + return nil, err + } + return NewMemoryCache(cacheConfig.Size), nil + case "disk": + var cacheConfig = new(DiskCacheConfig) + if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil { + return nil, err + } + return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second))) + default: + return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type) + } +} + +type memoryCache struct { + cache *expirable.LRU[string, *tls.Certificate] +} + +func NewMemoryCache(size int) Cache { + return memoryCache{ + cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) { + log.Debug().Str("name", key).Msg("certificate evicted from cache") + }, time.Hour*24), + } +} + +func (c memoryCache) Certificate(name string) (cert *tls.Certificate) { + var ok bool + if cert, ok = c.cache.Get(name); !ok { + cert, _ = c.cache.Get(baseDomain(name)) + } + return +} + +func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error { + c.cache.Add(name, cert) + log.Debug().Str("name", name).Msg("certificate added to cache") + return nil +} + +func (c memoryCache) RemoveCertificate(name string) { + c.cache.Remove(name) +} + +type diskCache string + +func NewDiskCache(dir string, expire time.Duration) (Cache, error) { + if !filepath.IsAbs(dir) { + var err error + if dir, err = filepath.Abs(dir); err != nil { + return nil, err + } + } + if err := os.MkdirAll(dir, 0o750); err != nil { + return nil, err + } + info, err := os.Stat(dir) + if err != nil { + return nil, err + } + if info.Mode()&os.ModePerm|0o057 != 0 { + if err := os.Chmod(dir, 0o750); err != nil { + return nil, err + } + } + + if expire > 0 { + go expireDiskCache(dir, expire) + } + + return diskCache(dir), nil +} + +func expireDiskCache(root string, expire time.Duration) { + log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting") + ticker := time.NewTicker(expire) + defer ticker.Stop() + for { + now := <-ticker.C + log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache") + filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + // Remove the directory; this will fail if it's not empty, which is fine. + _ = os.Remove(path) + return nil + } + + cert, err := cryptutil.LoadCertificate(path) + if err != nil { + log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file") + _ = os.Remove(path) + return nil + } else if cert.NotAfter.Before(now) { + log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate") + _ = os.Remove(path) + return nil + } + return nil + }) + } +} + +func (c diskCache) path(name string) string { + part := dns.SplitDomainName(strings.ToLower(name)) + // x,com -> com,x + // www,maze,io -> io,maze,www + slices.Reverse(part) + // com,x -> com,x,x.com + // io,maze,www -> io,m,ma,maze,www.maze.io + if len(part) > 2 { + if len(part[1]) > 1 { + part = []string{ + part[0], + part[1][:1], + part[1][:2], + part[1], + name, + } + } else { + part = []string{ + part[0], + part[1][:1], + part[1], + name, + } + } + } else if len(part) > 1 { + if len(part[1]) > 1 { + part = []string{ + part[0], + part[1][:1], + part[1][:2], + name, + } + } else { + part = []string{ + part[0], + part[1][:1], + name, + } + } + } + part[len(part)-1] += ".crt" + return filepath.Join(append([]string{string(c)}, part...)...) +} + +func (c diskCache) Certificate(name string) (cert *tls.Certificate) { + if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil { + return &tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + Leaf: cert, + PrivateKey: key, + } + } + if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil { + return &tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + Leaf: cert, + PrivateKey: key, + } + } + log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss") + return nil +} + +func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error { + dir, name := filepath.Split(c.path(name)) + if err := os.MkdirAll(dir, 0o750); err != nil { + return err + } + if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil { + return err + } + log.Debug().Str("name", name).Msg("certificate added to cache") + return nil +} + +func (c diskCache) RemoveCertificate(name string) { + path := c.path(name) + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + return + } + log.Error().Err(err).Msg("certificate remove from cache failed") + } + _ = os.Remove(filepath.Dir(path)) + log.Debug().Str("name", name).Msg("certificate removed from cache") +} + +func baseDomain(name string) string { + name = strings.ToLower(name) + if part := dns.SplitDomainName(name); len(part) > 2 { + return strings.Join(part[1:], ".") + } + return name +} diff --git a/proxy/mitm/cache_test.go b/proxy/mitm/cache_test.go new file mode 100644 index 0000000..6f6d364 --- /dev/null +++ b/proxy/mitm/cache_test.go @@ -0,0 +1,25 @@ +package mitm + +import "testing" + +func TestDiskCachePath(t *testing.T) { + cache := diskCache("testdata") + tests := []struct { + test string + want string + }{ + {"x.com", "testdata/com/x/x.com.crt"}, + {"feed.x.com", "testdata/com/x/x/feed.x.com.crt"}, + {"nu.nl", "testdata/nl/n/nu/nu.nl.crt"}, + {"maze.io", "testdata/io/m/ma/maze.io.crt"}, + {"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"}, + {"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"}, + } + for _, test := range tests { + t.Run(test.test, func(it *testing.T) { + if v := cache.path(test.test); v != test.want { + it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v) + } + }) + } +} diff --git a/proxy/mitm/config.go b/proxy/mitm/config.go new file mode 100644 index 0000000..0c67650 --- /dev/null +++ b/proxy/mitm/config.go @@ -0,0 +1,89 @@ +package mitm + +import ( + "crypto/x509/pkix" + + "github.com/hashicorp/hcl/v2" +) + +const ( + DefaultCommonName = "Styx Certificate Authority" + DefaultDays = 3 +) + +type Config struct { + CA *CAConfig `hcl:"ca,block"` + Key *KeyConfig `hcl:"key,block"` + Cache *CacheConfig `hcl:"cache,block"` +} + +type CAConfig struct { + Cert string `hcl:"cert"` + Key string `hcl:"key,optional"` + Days int `hcl:"days,optional"` + KeyType string `hcl:"key_type,optional"` + Bits int `hcl:"bits,optional"` + Name string `hcl:"name,optional"` + Country string `hcl:"country,optional"` + Organization string `hcl:"organization,optional"` + Unit string `hcl:"unit,optional"` + Locality string `hcl:"locality,optional"` + Province string `hcl:"province,optional"` + Address []string `hcl:"address,optional"` + PostalCode string `hcl:"postal_code,optional"` +} + +func (config CAConfig) DN() pkix.Name { + var name = pkix.Name{ + CommonName: config.Name, + StreetAddress: config.Address, + } + if config.Name == "" { + name.CommonName = DefaultCommonName + } + if config.Country != "" { + name.Country = append(name.Country, config.Country) + } + if config.Organization != "" { + name.Organization = append(name.Organization, config.Organization) + } + if config.Unit != "" { + name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit) + } + if config.Locality != "" { + name.Locality = append(name.Locality, config.Locality) + } + if config.Province != "" { + name.Province = append(name.Province, config.Province) + } + if config.PostalCode != "" { + name.PostalCode = append(name.PostalCode, config.PostalCode) + } + return name +} + +type KeyConfig struct { + Type string `hcl:"type,optional"` + Bits int `hcl:"bits,optional"` + Pool int `hcl:"pool,optional"` +} + +var defaultKeyConfig = KeyConfig{ + Type: "rsa", + Bits: 2048, + Pool: 5, +} + +type CacheConfig struct { + Type string `hcl:"type"` + Body hcl.Body `hcl:",remain"` +} + +type MemoryCacheConfig struct { + Size int `hcl:"size,optional"` +} + +type DiskCacheConfig struct { + Path string `hcl:"path"` + Expire float64 `hcl:"expire,optional"` +} diff --git a/proxy/policy/policy.go b/proxy/policy/policy.go new file mode 100644 index 0000000..539b47e --- /dev/null +++ b/proxy/policy/policy.go @@ -0,0 +1,53 @@ +package policy + +import ( + "net/http" + + "git.maze.io/maze/styx/proxy/match" +) + +// Policy contains rules that make up the policy. +// +// Some policy rules contain nested policies. +type Policy struct { + Rules []*rawRule `hcl:"on,block" json:"rules"` + Permit *bool `hcl:"permit" json:"permit"` + Matchers match.Matchers `json:"matchers"` // Matchers for the policy + +} + +func (p *Policy) Configure(matchers match.Matchers) (err error) { + for _, r := range p.Rules { + if err = r.Configure(matchers); err != nil { + return + } + } + p.Matchers = matchers + return +} + +func (p *Policy) PermitIntercept(r *http.Request) *bool { + if p != nil { + for _, rule := range p.Rules { + if rule, ok := rule.Rule.(InterceptRule); ok { + if permit := rule.PermitIntercept(r); permit != nil { + return permit + } + } + } + } + return p.Permit +} + +func (p *Policy) PermitRequest(r *http.Request) *bool { + if p != nil { + for _, rule := range p.Rules { + if rule, ok := rule.Rule.(RequestRule); ok { + if permit := rule.PermitRequest(r); permit != nil { + return permit + } + } + } + } + return p.Permit +} diff --git a/proxy/policy/policy_test.go b/proxy/policy/policy_test.go new file mode 100644 index 0000000..307bab0 --- /dev/null +++ b/proxy/policy/policy_test.go @@ -0,0 +1,139 @@ +package policy + +import ( + "net" + "net/http" + "net/url" + "testing" + + "git.maze.io/maze/styx/internal/netutil" + "git.maze.io/maze/styx/proxy/match" + "github.com/miekg/dns" +) + +type testInDomainList struct { + t *testing.T + list []string +} + +func (testInDomainList) Name() string { return "testInDomainList" } +func (l testInDomainList) MatchesRequest(r *http.Request) bool { + for _, domain := range l.list { + if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) { + l.t.Logf("domain %s contains %s", domain, r.URL.Host) + return true + } + l.t.Logf("domain %s does not contain %s", domain, r.URL.Host) + } + return false +} + +func testInDomain(t *testing.T, domains ...string) match.Matcher { + return &testInDomainList{t: t, list: domains} +} + +type testInNetworkList struct { + t *testing.T + list []*net.IPNet +} + +func (testInNetworkList) Name() string { return "testInNetworkList" } +func (l testInNetworkList) MatchesIP(ip net.IP) bool { + for _, ipnet := range l.list { + if ipnet.Contains(ip) { + l.t.Logf("network %s contains %s", ipnet, ip) + return true + } + l.t.Logf("network %s does not contain %s", ipnet, ip) + } + return false +} + +func testInNetwork(t *testing.T, cidr string) match.Matcher { + t.Helper() + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + return testInNetworkList{t: t, list: []*net.IPNet{ipnet}} +} + +func TestPolicy(t *testing.T) { + var ( + yes = true + nope = false + ) + p := &Policy{ + Rules: []*rawRule{ + { + Rule: &requestRule{ + domainOrNetworkRule: domainOrNetworkRule{ + matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")}, + isSource: []bool{true}, + }, + }, + }, + { + Rule: &requestRule{ + domainOrNetworkRule: domainOrNetworkRule{ + matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")}, + isSource: []bool{false}, + }, + Permit: &yes, + }, + }, + { + Rule: &requestRule{ + domainOrNetworkRule: domainOrNetworkRule{ + matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")}, + }, + Permit: &yes, + }, + }, + { + Rule: &requestRule{ + domainOrNetworkRule: domainOrNetworkRule{ + matchers: []match.Matcher{testInDomain(t, "google.com")}, + }, + Permit: &nope, + }, + }, + }, + } + + r := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "golang.org:80"}, + RemoteAddr: "127.0.0.1:1234", + } + if v := p.PermitRequest(r); v != nil { + t.Errorf("expected request to return no verdict, got %t", *v) + } + + p.Rules[0].Rule.(*requestRule).Permit = &yes + if v := p.PermitRequest(r); v == nil || *v != yes { + t.Errorf("expected request to return %t, %v", yes, v) + } + + r.RemoteAddr = "192.168.1.2:3456" + if v := p.PermitRequest(r); v != nil { + t.Errorf("expected request to return no verdict, got %t", *v) + } + if v := p.PermitIntercept(r); v != nil { + t.Errorf("expected request to return no verdict, got %t", *v) + } + + r.URL.Host = "maze.io" + if v := p.PermitRequest(r); v == nil || *v != yes { + t.Errorf("expected request to return %t, %v", yes, v) + } + + r.URL.Host = "google.com" + if v := p.PermitRequest(r); v == nil || *v != nope { + t.Errorf("expected request to return %t, %v", nope, v) + } + + r.URL.Host = "localhost:80" + if v := p.PermitRequest(r); v == nil || *v != yes { + t.Errorf("expected request to return %t, %v", yes, v) + } +} diff --git a/proxy/policy/rule.go b/proxy/policy/rule.go new file mode 100644 index 0000000..7cbe559 --- /dev/null +++ b/proxy/policy/rule.go @@ -0,0 +1,368 @@ +package policy + +import ( + "fmt" + "net" + "net/http" + "strings" + "time" + + "git.maze.io/maze/styx/internal/netutil" + "git.maze.io/maze/styx/proxy/match" + "github.com/google/uuid" + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/gohcl" +) + +// Rule is a policy rule. +type Rule interface { + Configure(match.Matchers) error +} + +// InterceptRule can make policy rule decisions on intercept requests. +type InterceptRule interface { + PermitIntercept(r *http.Request) *bool +} + +// RequestRule can make policy rule decisions on HTTP CONNECT requests. +type RequestRule interface { + PermitRequest(r *http.Request) *bool +} + +type rawRule struct { + Type string `hcl:"type,label" json:"type"` + Body hcl.Body `hcl:",remain" json:"-"` + Rule `json:"rule"` +} + +func (r *rawRule) Configure(matchers match.Matchers) (err error) { + switch r.Type { + case "intercept": + r.Rule = new(interceptRule) + case "request": + r.Rule = new(requestRule) + case "days": + r.Rule = new(daysRule) + case "time": + r.Rule = new(timeRule) + case "all": + r.Rule = new(allRule) + default: + return fmt.Errorf("policy: invalid event type %q", r.Type) + } + + if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() { + return err + } + + return r.Rule.Configure(matchers) +} + +type allRule struct { + Rules []*rawRule `hcl:"on,block"` + Permit *bool `hcl:"permit"` +} + +func (r *allRule) Configure(matchers match.Matchers) (err error) { + return +} + +type domainOrNetworkRule struct { + matchers []match.Matcher + isSource []bool +} + +func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) { + var m match.Matcher + for _, domain := range domains { + if m, err = matchers.Get("domain", domain); err != nil { + return fmt.Errorf("%s: unknown domain %q", kind, domain) + } + r.matchers = append(r.matchers, m) + r.isSource = append(r.isSource, false) + } + for _, network := range sources { + if m, err = matchers.Get("network", network); err != nil { + return fmt.Errorf("%s: unknown source network %q", kind, network) + } + r.matchers = append(r.matchers, m) + r.isSource = append(r.isSource, true) + } + for _, network := range targets { + if m, err = matchers.Get("network", network); err != nil { + return fmt.Errorf("%s: unknown target network %q", kind, network) + } + r.matchers = append(r.matchers, m) + r.isSource = append(r.isSource, false) + } + if len(r.matchers) == 0 { + return fmt.Errorf("%s: missing any of domain, source, target", kind) + } + if id != nil { + *id = uuid.NewString() + } + return +} + +func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool { + for i, m := range r.matchers { + if m, ok := m.(match.Request); ok { + if m.MatchesRequest(q) { + return true + } + } + if m, ok := m.(match.IP); ok { + if r.isSource[i] { + if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) { + return true + } + } else { + var ( + host = netutil.Host(q.URL.Host) + ips []net.IP + ) + if ip := net.ParseIP(host); ip != nil { + ips = append(ips, ip) + } else { + ips, _ = net.LookupIP(host) + } + for _, ip := range ips { + if m.MatchesIP(ip) { + return true + } + } + } + } + } + return false +} + +type interceptRule struct { + ID string `json:"id,omitempty"` + Domain []string `hcl:"domain,optional" json:"domain,omitempty"` + Source []string `hcl:"source,optional" json:"source,omitempty"` + Target []string `hcl:"target,optional" json:"target,omitempty"` + Permit *bool `hcl:"permit" json:"permit"` + domainOrNetworkRule `json:"-"` +} + +func (r *interceptRule) Configure(matchers match.Matchers) (err error) { + return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID) +} + +func (r *interceptRule) PermitIntercept(q *http.Request) *bool { + if r.matchesRequest(q) { + return r.Permit + } + return nil +} + +type requestRule struct { + ID string `json:"id,omitempty"` + Domain []string `hcl:"domain,optional" json:"domain,omitempty"` + Source []string `hcl:"source,optional" json:"source,omitempty"` + Target []string `hcl:"target,optional" json:"target,omitempty"` + Permit *bool `hcl:"permit" json:"permit"` + domainOrNetworkRule `json:"-"` +} + +func (r *requestRule) Configure(matchers match.Matchers) (err error) { + return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID) +} + +func (r *requestRule) PermitRequest(q *http.Request) *bool { + if r.matchesRequest(q) { + return r.Permit + } + return nil +} + +type timeRule struct { + ID string `json:"id,omitempty"` + Time []string `hcl:"time" json:"time"` + Permit *bool `hcl:"permit" json:"permit"` + Body hcl.Body `hcl:",remain" json:"-"` + Rules *Policy `json:"rules"` + Start Time `json:"start"` + End Time `json:"end"` +} + +func (r *timeRule) isActive() bool { + if r == nil { + return true + } + + now := Now() + if r.Start.After(r.End) { // ie: 18:00-06:00 + return now.After(r.Start) || now.Before(r.End) + } + return now.After(r.Start) && now.Before(r.End) +} + +func (r *timeRule) Configure(matchers match.Matchers) (err error) { + if len(r.Time) != 2 { + return fmt.Errorf("invalid time %s, need [start, stop]", r.Time) + } + if r.Start, err = ParseTime(r.Time[0]); err != nil { + return fmt.Errorf("invalid start %q: %w", r.Time[0], err) + } + if r.End, err = ParseTime(r.Time[1]); err != nil { + return fmt.Errorf("invalid end %q: %w", r.Time[1], err) + } + + r.Rules = new(Policy) + if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() { + return diag + } + + if err = r.Rules.Configure(matchers); err != nil { + return + } + r.Rules.Matchers = nil + + if r.ID == "" { + r.ID = uuid.NewString() + } + + return +} + +func (r *timeRule) PermitIntercept(q *http.Request) *bool { + if !r.isActive() { + return nil + } + return r.Rules.PermitIntercept(q) +} + +func (r *timeRule) PermitRequest(q *http.Request) *bool { + if !r.isActive() { + return nil + } + return r.Rules.PermitRequest(q) +} + +type daysRule struct { + ID string `json:"id,omitempty"` + Days string `hcl:"days" json:"days"` + Permit *bool `hcl:"permit" json:"permit"` + Body hcl.Body `hcl:",remain" json:"-"` + Rules *Policy `json:"rules"` + cond []onCond +} + +func (r *daysRule) isActive() bool { + if r == nil || len(r.cond) == 0 { + return true + } + + now := time.Now() + for _, cond := range r.cond { + if cond(now) { + return true + } + } + return false +} + +func (r *daysRule) Configure(matchers match.Matchers) (err error) { + if r.cond, err = parseOnCond(r.Days); err != nil { + return + } + + r.Rules = new(Policy) + if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() { + return diag + } + if err = r.Rules.Configure(matchers); err != nil { + return + } + r.Rules.Matchers = nil + + if r.ID == "" { + r.ID = uuid.NewString() + } + + return +} + +func (r *daysRule) PermitIntercept(q *http.Request) *bool { + if !r.isActive() { + return nil + } + return r.Rules.PermitIntercept(q) +} + +func (r *daysRule) PermitRequest(q *http.Request) *bool { + if !r.isActive() { + return nil + } + return r.Rules.PermitRequest(q) +} + +type onCond func(time.Time) bool + +var weekdays = map[string]time.Weekday{ + "sun": time.Sunday, + "mon": time.Monday, + "tue": time.Tuesday, + "wed": time.Wednesday, + "thu": time.Thursday, + "fri": time.Friday, + "sat": time.Saturday, +} + +func parseOnCond(when string) (conds []onCond, err error) { + for _, spec := range strings.Split(when, ",") { + spec = strings.ToLower(strings.TrimSpace(spec)) + if d, ok := weekdays[spec]; ok { + conds = append(conds, onWeekday(d)) + } else if spec == "weekend" || spec == "weekends" { + conds = append(conds, onWeekend) + } else if spec == "workday" || spec == "workdays" { + conds = append(conds, onWorkday) + } else if strings.ContainsRune(spec, '-') { + var ( + part = strings.SplitN(spec, "-", 2) + from, upto time.Weekday + ok bool + ) + if from, ok = weekdays[part[0]]; !ok { + return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0]) + } + if upto, ok = weekdays[part[1]]; !ok { + return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1]) + } + if from < upto { + for d := from; d < upto; d++ { + conds = append(conds, onWeekday(d)) + } + } else { + for d := time.Sunday; d < from; d++ { + conds = append(conds, onWeekday(d)) + } + for d := upto; d <= time.Saturday; d++ { + conds = append(conds, onWeekday(d)) + } + } + } else { + return nil, fmt.Errorf("on %q: invalid condition", spec) + } + } + return +} + +func onWeekday(weekday time.Weekday) onCond { + return func(t time.Time) bool { + return t.Weekday() == weekday + } +} + +func onWeekend(t time.Time) bool { + d := t.Weekday() + return d == time.Saturday || d == time.Sunday +} + +func onWorkday(t time.Time) bool { + d := t.Weekday() + return !(d == time.Saturday || d == time.Sunday) +} diff --git a/proxy/policy/time.go b/proxy/policy/time.go new file mode 100644 index 0000000..957092b --- /dev/null +++ b/proxy/policy/time.go @@ -0,0 +1,53 @@ +package policy + +import ( + "fmt" + "time" +) + +type Time struct { + Hour int + Minute int + Second int +} + +func (t Time) Eq(other Time) bool { + return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second +} + +func (t Time) After(other Time) bool { + return t.Seconds() > other.Seconds() +} + +func (t Time) Before(other Time) bool { + return t.Seconds() < other.Seconds() +} + +func (t Time) Seconds() int { + return t.Hour*3600 + t.Minute*60 + t.Second +} + +func (t Time) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil +} + +var timeFormats = []string{ + time.TimeOnly, + "15:04", + time.Kitchen, +} + +func Now() Time { + now := time.Now() + return Time{now.Hour(), now.Minute(), now.Second()} +} + +func ParseTime(s string) (t Time, err error) { + var tt time.Time + for _, layout := range timeFormats { + if tt, err = time.Parse(layout, s); err == nil { + return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil + } + } + return Time{}, fmt.Errorf("time: invalid time %q", s) +} diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 0000000..91cdc2d --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,616 @@ +package proxy + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "syscall" + "time" + + "git.maze.io/maze/styx/internal/log" + "git.maze.io/maze/styx/internal/netutil" + "git.maze.io/maze/styx/proxy/mitm" + "git.maze.io/maze/styx/proxy/policy" + "git.maze.io/maze/styx/proxy/resolver" + "git.maze.io/maze/styx/proxy/stats" +) + +const ( + DefaultListenAddr = ":3128" + DefaultBindAddr = "" + DefaultDialTimeout = 30 * time.Second + DefaultKeepAlivePeriod = 1 * time.Minute +) + +const ( + HeaderAcceptEncoding = "Accept-Encoding" + HeaderConnection = "Connection" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderUpgrade = "Upgrade" +) + +var ( + ErrClosed = errors.New("proxy: shutdown") + ErrClientCert = errors.New("tls: client certificate requested") +) + +type Proxy struct { + addr *net.TCPAddr + bind *net.TCPAddr + resolver resolver.Resolver + transport *http.Transport + dial func(network, address string) (net.Conn, error) + config *Config + authority mitm.Authority + policy *policy.Policy + admin *Admin + stats *stats.Stats + closed chan struct{} + onConnect ConnectHandler + onRequest RequestHandler + onResponse ResponseHandler + onError ErrorHandler +} + +func New(config *Config, ca mitm.Authority) (*Proxy, error) { + if config == nil { + return nil, errors.New("proxy: config can't be nil") + } + + p := &Proxy{ + transport: newTransport(), + config: config, + resolver: resolver.Default, + authority: ca, + policy: config.Policy, + closed: make(chan struct{}), + onConnect: config.ConnectHandler, + onRequest: config.RequestHandler, + onResponse: config.ResponseHandler, + onError: config.ErrorHandler, + } + + var err error + if config.Listen == "" { + p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr) + } else { + p.addr, err = net.ResolveTCPAddr("tcp", config.Listen) + } + if err != nil { + return nil, fmt.Errorf("proxy: invalid listen addres: %w", err) + } + if config.Bind != "" { + if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil { + return nil, fmt.Errorf("proxy: invalid bind address: %w", err) + } + } else if config.Interface != "" { + if err = resolveInterfaceAddr(config.Interface); err != nil { + return nil, err + } + } + if p.bind != nil { + /* FIXME + var c *net.TCPConn + if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) { + return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL) + } else if c != nil { + _ = c.Close() + } + */ + } + if config.Resolver != nil { + p.resolver = config.Resolver + } + + dialTimeout := DefaultDialTimeout + if config.DialTimeout > 0 { + dialTimeout = config.DialTimeout + } + p.dial = (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: dialTimeout, + LocalAddr: p.bind, + }).Dial + + p.admin = NewAdmin(p) + + if p.stats, err = stats.New(); err != nil { + return nil, err + } + + return p, nil +} + +func newTransport() *http.Transport { + return &http.Transport{ + TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), + Proxy: http.ProxyFromEnvironment, + TLSHandshakeTimeout: 15 * time.Second, + ExpectContinueTimeout: 5 * time.Second, + } +} + +func (p *Proxy) Close() error { + select { + case <-p.closed: + return ErrClosed + default: + close(p.closed) + return nil + } +} + +func (p *Proxy) Start() error { + l, err := net.ListenTCP("tcp", p.addr) + if err != nil { + return err + } + + go p.Serve(l) + return nil +} + +func (p *Proxy) Serve(listener net.Listener) error { + defer func() { _ = listener.Close() }() + + log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening") + for { + select { + case <-p.closed: + return nil + default: + } + + c, err := listener.Accept() + if err != nil { + return err + } + + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + ctx := newContext(c, rw, nil) + + if c, ok := c.(*net.TCPConn); ok { + _ = c.SetKeepAlive(true) + _ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod) + } + + go p.handle(ctx) + } +} + +func (p *Proxy) handle(ctx *Context) { + logger := ctx.log() + defer log.OnCloseError(logger.Debug(), ctx.conn) + logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection") + + last := int64(0) + for { + select { + case <-p.closed: + return + + default: + ses, err := p.handleRequest(ctx) + if ses != nil { + log := ses.log() + log.Info(). + Str("method", ses.request.Method). + Str("url", ses.request.URL.String()). + Str("status", ses.response.Status). + Int64("size", ctx.conn.bytes-last). + Msg("handled request") + + p.stats.AddLog(&stats.Log{ + ClientIP: netutil.Host(ses.request.RemoteAddr), + Request: stats.FromRequest(ses.request), + Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last), + }) + + last = ctx.conn.bytes + } + if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) { + event := logger.Debug() + if ctx.conn.bytes > 0 { + event = event.Int64("size", ctx.conn.bytes) + } + event.Msg("closing client connection") + return + } + } + } +} + +func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) { + logger := ctx.log() + + var request *http.Request + if request, err = p.readRequest(ctx); err != nil { + return + } + + ses = newSession(ctx, request) + p.cleanRequest(ses, request) + + logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request") + + if p.onRequest != nil { + newRequest, newResponse := p.onRequest.HandleRequest(ses) + if newRequest != nil { + logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override") + ses.request = newRequest + } + if newResponse != nil { + logger.Debug().Str("status", newResponse.Status).Msg("response override") + ses.response = newResponse + } + } + + if ses.response == nil { + // WebSocket request + if ses.request.Header.Get(HeaderUpgrade) == "websocket" { + return ses, p.handleTunnel(ses) + } + + cleanHopByHopHeaders(ses.request.Header) + + // Proxy CONNECT request + if ses.request.Method == http.MethodConnect { + return p.handleConnect(ses) + } + + if netutil.Port(ses.request.URL.Host) == p.addr.Port { + // Plain API request + ses.request.URL.Host = ses.request.Host + return ses, p.admin.handleRequest(ses) + + } else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil { + // Plain HTTP request + if p.config.ErrorHandler != nil { + p.config.ErrorHandler.HandleError(ses, err) + } + ses.response = ErrorResponse(ses.request, err) + } + + logger.Debug().Str("status", ses.response.Status).Msg("received response") + cleanHopByHopHeaders(ses.response.Header) + } + + ses.response.Close = true + defer log.OnCloseError(logger.Debug(), ses.response.Body) + return ses, p.writeResponse(ses) +} + +func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) { + next = ses + + logger := ses.log() + logger.Debug().Msgf("connecting to %s", ses.request.URL.Host) + + var c net.Conn + if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil { + logger.Error().Err(err).Msg("connect failed") + if p.onError != nil { + p.onError.HandleError(ses, err) + } + + ses.response = ErrorResponse(ses.request, err) + defer log.OnCloseError(logger.Debug(), ses.response.Body) + _ = p.writeResponse(ses) + + return + } + + defer func() { + if err := c.Close(); err != nil { + if p.onError != nil { + p.onError.HandleError(ses, err) + } + } + }() + + if p.canIntercept(ses.request) { + logger.Debug().Msg("intercepting connection") + ses.response = NewResponse(http.StatusOK, nil, ses.request) + err = p.writeResponse(ses) + log.OnCloseError(logger.Debug(), ses.response.Body) + if err != nil { + return + } + + // Peek first byte + b := make([]byte, 1) + if _, err = io.ReadFull(ses.ctx.rw, b); err != nil { + logger.Error().Err(err).Msg("error peeking CONNECT byte") + return + } + + // Drain buffered bytes + b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...) + ses.ctx.rw.Reader.Read(b[1:]) + + r := &connReader{ + Conn: ses.ctx.conn, + Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn), + } + if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1 + secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host)) + if err = secure.Handshake(); err != nil { + logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed") + return + } + + rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure)) + ctx := newContext(secure, rw, ses) + return p.handleRequest(ctx) + } + + rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r)) + ctx := newContext(r, rw, ses) + return p.handleRequest(ctx) + } + + ses.response = NewResponse(http.StatusOK, nil, ses.request) + defer log.OnCloseError(logger.Debug(), ses.response.Body) + ses.response.ContentLength = -1 + if err = p.writeResponse(ses); err != nil { + return + } + + logger.Debug().Msg("established CONNECT tunnel, proxying traffic") + var wait sync.WaitGroup + wait.Go(func() { copyStream(ses, c, ses.ctx.conn) }) + wait.Go(func() { copyStream(ses, ses.ctx.conn, c) }) + wait.Wait() + logger.Debug().Msg("closed CONNECT tunnel") + return +} + +func (p *Proxy) handleTunnel(ses *Session) (err error) { + logger := ses.log() + logger.Debug().Msgf("connecting to %s", ses.request.URL.Host) + + var c net.Conn + if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil { + logger.Error().Err(err).Msg("connect failed") + if p.onError != nil { + p.onError.HandleError(ses, err) + } + + ses.response = ErrorResponse(ses.request, err) + defer log.OnCloseError(logger.Debug(), ses.response.Body) + _ = p.writeResponse(ses) + + return + } + + defer log.OnCloseError(logger.Debug(), c) + + if ses.ctx.IsTLS() { + // Open a TLS client connection + secure := tls.Client(c, &tls.Config{ + ServerName: ses.request.URL.Host, + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return nil, ErrClientCert + }, + }) + if err = secure.Handshake(); err != nil { + logger.Error().Err(err).Msg("TLS handshake failed") + return + } + c = secure + } + + if err = ses.request.Write(c); err != nil { + logger.Error().Err(err).Msg("failed to write request") + return + } + + logger.Debug().Msg("established tunnel, proxying traffic") + var wait sync.WaitGroup + wait.Go(func() { copyStream(ses, c, ses.ctx.conn) }) + wait.Go(func() { copyStream(ses, ses.ctx.conn, c) }) + wait.Wait() + logger.Debug().Msg("closed tunnel") + return +} + +func (p *Proxy) canIntercept(request *http.Request) bool { + if permit := p.policy.PermitIntercept(request); permit != nil { + return *permit + } + return true +} + +/* +func (p *Proxy) handleAPIRequest(ses *Session) error { + if ses.request.URL.Path == "/ca.crt" && p.authority != nil { + b := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: p.authority.Certificate().Raw, + }) + + ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request) + defer log.OnCloseError(logger.Debug(), ses.response.Body) + + ses.response.Close = true + ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert") + ses.response.ContentLength = int64(len(b)) + return p.writeResponse(ses) + } + + ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint")) + defer log.OnCloseError(logger.Debug(), ses.response.Body) + ses.response.Close = true + return p.writeResponse(ses) +} +*/ + +func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) { + var ( + done = make(chan *http.Request, 1) + errs = make(chan error, 1) + ) + + go func() { + r, err := http.ReadRequest(ctx.rw.Reader) + if err != nil { + errs <- err + } else { + done <- r + } + }() + + select { + case <-p.closed: + return nil, ErrClosed + case request = <-done: + return + case err = <-errs: + return + } +} + +func (p *Proxy) cleanRequest(ses *Session, request *http.Request) { + if request.URL.Host == "" { + request.URL.Host = request.Host + } + + // Ensure proper URL scheme + if !strings.HasPrefix(request.URL.Scheme, "http") { + request.URL.Scheme = "http" + } + if ses.ctx.IsTLS() { + state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState() + request.TLS = &state + request.URL.Scheme = "https" + } + + // Ensure proper RemoteAddr + request.RemoteAddr = ses.ctx.RemoteAddr().String() + + // Ensure proper encoding + if request.Header.Get(HeaderAcceptEncoding) != "" { + // We only support gzip + request.Header.Set(HeaderAcceptEncoding, "gzip") + } +} + +func (p *Proxy) writeResponse(ses *Session) (err error) { + log := ses.log() + + if p.onResponse != nil { + response := p.onResponse.HandleResponse(ses) + if response != nil { + log.Debug().Str("status", response.Status).Msg("response override") + ses.response = response + } + } + + if err = ses.response.Write(ses.ctx); err != nil { + log.Error().Err(err).Msg("error writing response back to client") + } else if err = ses.ctx.Flush(); err != nil { + log.Error().Err(err).Msg("error flushing response back to client") + } + + return +} + +func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) { + log := ses.log() + log.Debug().Msgf("connect to %s://%s", network, address) + + if p.onConnect != nil { + if c = p.onConnect.HandleConnect(ses, network, address); c != nil { + log.Debug().Msg("connect override") + return + } + } + + var host, port string + if host, port, err = net.SplitHostPort(address); err != nil { + return + } + + var hosts []string + if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil { + log.Warn().Err(err).Msg("connect failed: DNS lookup error") + return + } + + log.Debug().Str("address", hosts[0]).Msg("connect resolved address") + return p.dial(network, net.JoinHostPort(hosts[0], port)) +} + +var hopByHopHeaders = []string{ + HeaderConnection, + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Proxy-Connection", // Non-standard, but required for HTTP/2. + "Te", + "Trailer", + "Transfer-Encoding", + HeaderUpgrade, +} + +func cleanHopByHopHeaders(header http.Header) { + // Additional hop-by-hop headers may be specified in `Connection` headers. + // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1 + for _, values := range header[HeaderConnection] { + for _, key := range strings.Split(values, ",") { + header.Del(key) + } + } + for _, key := range hopByHopHeaders { + header.Del(key) + } +} + +// copyStream copies data from reader to writer +func copyStream(ses *Session, w io.Writer, r io.Reader) { + log := ses.log() + if _, err := io.Copy(w, r); err != nil && !isClosing(err) { + log.Error().Err(err).Msg("failed CONNECT tunnel") + } else { + log.Debug().Msg("finished copying CONNECT tunnel") + } +} + +func isClosing(err error) bool { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed { + return true + } + if err, ok := err.(net.Error); ok && err.Timeout() { + return true + } + // log.Debug().Msgf("not a closing error %T: %#+v", err, err) + return false +} + +func resolveInterfaceAddr(name string) (err error) { + var iface *net.Interface + if iface, err = net.InterfaceByName(name); err != nil { + return + } + + var addrs []net.Addr + if addrs, err = iface.Addrs(); err != nil { + return + } + + for _, addr := range addrs { + if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() { + log.Warn().Msgf("addr %T: %s", addr, addr) + } + } + return errors.New("nope; TODO") +} diff --git a/proxy/resolver/resolver.go b/proxy/resolver/resolver.go new file mode 100644 index 0000000..9d43c30 --- /dev/null +++ b/proxy/resolver/resolver.go @@ -0,0 +1,148 @@ +// Package resolver implements a caching DNS resolver +package resolver + +import ( + "context" + "math/rand/v2" + "net" + "strings" + "time" + + "git.maze.io/maze/styx/internal/netutil" + "github.com/hashicorp/golang-lru/v2/expirable" +) + +const ( + DefaultSize = 1024 + DefaultTTL = 5 * time.Minute + DefaultTimeout = 10 * time.Second +) + +var ( + // DefaultConfig are the defaults for the Default resolver. + DefaultConfig = Config{ + Size: DefaultSize, + TTL: DefaultTTL.Seconds(), + Timeout: DefaultTimeout.Seconds(), + } + + // Default resolver. + Default = New(DefaultConfig) +) + +type Resolver interface { + // Lookup returns resolved IPs for given hostname/ips. + Lookup(context.Context, string) ([]string, error) +} + +type netResolver struct { + resolver *net.Resolver + timeout time.Duration + noIPv6 bool + cache *expirable.LRU[string, []string] +} + +type Config struct { + // Size is our cache size in number of entries. + Size int `hcl:"size,optional"` + + // TTL is the cache time to live in seconds. + TTL float64 `hcl:"ttl,optional"` + + // Timeout is the cache timeout in seconds. + Timeout float64 `hcl:"timeout,optional"` + + // Server are alternative DNS servers. + Server []string `hcl:"server,optional"` + + // NoIPv6 disables IPv6 DNS resolution. + NoIPv6 bool `hcl:"noipv6,optional"` +} + +func New(config Config) Resolver { + var ( + size = config.Size + ttl = time.Duration(float64(time.Second) * config.TTL) + timeout = time.Duration(float64(time.Second) * config.Timeout) + ) + if size <= 0 { + size = DefaultSize + } + if ttl <= 0 { + ttl = DefaultTTL + } + if timeout <= 0 { + timeout = 0 + } + + var resolver = new(net.Resolver) + if len(config.Server) > 0 { + var dialer net.Dialer + resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53") + return dialer.DialContext(ctx, network, server) + } + } + + return &netResolver{ + resolver: resolver, + timeout: timeout, + noIPv6: config.NoIPv6, + cache: expirable.NewLRU[string, []string](size, nil, ttl), + } +} + +func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) { + host = strings.ToLower(strings.TrimSpace(host)) + if hosts, ok := r.cache.Get(host); ok { + rand.Shuffle(len(hosts), func(i, j int) { + hosts[i], hosts[j] = hosts[j], hosts[i] + }) + return hosts, nil + } + + hosts, err := r.lookup(ctx, host) + if err != nil { + return nil, err + } + r.cache.Add(host, hosts) + return hosts, nil +} + +func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) { + if r.timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, r.timeout) + defer cancel() + } + + if net.ParseIP(host) == nil { + addrs, err := r.resolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + if r.noIPv6 { + var addrs4 []string + for _, addr := range addrs { + if net.ParseIP(addr).To4() != nil { + addrs4 = append(addrs4, addr) + } + } + return addrs4, nil + } + return addrs, nil + } + + addrs, err := r.resolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + hosts := make([]string, len(addrs)) + for i, addr := range addrs { + if !r.noIPv6 || addr.IP.To4() != nil { + hosts[i] = addr.IP.String() + } + } + return hosts, nil +} diff --git a/proxy/response.go b/proxy/response.go new file mode 100644 index 0000000..a4168d0 --- /dev/null +++ b/proxy/response.go @@ -0,0 +1,78 @@ +package proxy + +import ( + "bytes" + "fmt" + "io" + "net/http" + "os" + "strconv" + + "git.maze.io/maze/styx/internal/log" +) + +func NewResponse(code int, body io.Reader, request *http.Request) *http.Response { + if body == nil { + body = new(bytes.Buffer) + } + + rc, ok := body.(io.ReadCloser) + if !ok { + rc = io.NopCloser(body) + } + + response := &http.Response{ + Status: strconv.Itoa(code) + " " + http.StatusText(code), + StatusCode: code, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Body: rc, + Request: request, + } + + if request != nil { + response.Close = request.Close + response.Proto = request.Proto + response.ProtoMajor = request.ProtoMajor + response.ProtoMinor = request.ProtoMinor + } + + return response +} + +type withLen interface { + Len() int +} + +type withSize interface { + Size() int64 +} + +func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response { + response := NewResponse(code, body, request) + response.Header.Set(HeaderContentType, "application/json") + if s, ok := body.(withLen); ok { + response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len())) + } else if s, ok := body.(withSize); ok { + response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10)) + } else { + log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size") + } + response.Close = true + return response +} + +func ErrorResponse(request *http.Request, err error) *http.Response { + response := NewResponse(http.StatusBadGateway, nil, request) + switch { + case os.IsNotExist(err): + response.StatusCode = http.StatusNotFound + case os.IsPermission(err): + response.StatusCode = http.StatusForbidden + } + response.Status = http.StatusText(response.StatusCode) + response.Close = true + return response +} diff --git a/proxy/session.go b/proxy/session.go new file mode 100644 index 0000000..de7ab6d --- /dev/null +++ b/proxy/session.go @@ -0,0 +1,151 @@ +package proxy + +import ( + "bufio" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "math/rand" + "net" + "net/http" + "sync/atomic" + "time" + + "git.maze.io/maze/styx/internal/log" +) + +var seed = rand.NewSource(time.Now().UnixNano()) + +type Context struct { + id int64 + conn *wrappedConn + rw *bufio.ReadWriter + parent *Session + data map[string]any +} + +func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context { + if wrapped, ok := conn.(*wrappedConn); ok { + conn = wrapped.Conn + } + + ctx := &Context{ + id: seed.Int63(), + conn: &wrappedConn{Conn: conn}, + rw: rw, + parent: parent, + data: make(map[string]any), + } + + return ctx +} + +func (ctx *Context) log() log.Logger { + return log.Console.With(). + Str("context", ctx.ID()). + Str("addr", ctx.RemoteAddr().String()). + Logger() +} + +func (ctx *Context) ID() string { + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(ctx.id)) + if ctx.parent != nil { + return ctx.parent.ID() + "-" + hex.EncodeToString(b[:]) + } + return hex.EncodeToString(b[:]) +} + +func (ctx *Context) IsTLS() bool { + _, ok := ctx.conn.Conn.(*tls.Conn) + return ok && ctx.parent != nil +} + +func (ctx *Context) RemoteAddr() net.Addr { + if ctx.parent != nil { + return ctx.parent.ctx.RemoteAddr() + } + return ctx.conn.RemoteAddr() +} + +func (ctx *Context) SetDeadline(t time.Time) error { + if ctx.parent != nil { + return ctx.parent.ctx.SetDeadline(t) + } + return ctx.conn.SetDeadline(t) +} + +func (ctx *Context) Set(key string, value any) { + ctx.data[key] = value +} + +func (ctx *Context) Get(key string) (value any, ok bool) { + value, ok = ctx.data[key] + return +} + +func (ctx *Context) Flush() error { + return ctx.rw.Flush() +} + +func (ctx *Context) Write(p []byte) (n int, err error) { + if n, err = ctx.rw.Write(p); n > 0 { + atomic.AddInt64(&ctx.conn.bytes, int64(n)) + } + return +} + +type Session struct { + id int64 + ctx *Context + request *http.Request + response *http.Response + data map[string]any +} + +func newSession(ctx *Context, request *http.Request) *Session { + return &Session{ + id: seed.Int63(), + ctx: ctx, + request: request, + data: make(map[string]any), + } +} + +func (ses *Session) log() log.Logger { + return log.Console.With(). + Str("context", ses.ctx.ID()). + Str("session", ses.ID()). + Str("addr", ses.ctx.RemoteAddr().String()). + Logger() +} + +func (ses *Session) ID() string { + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(ses.id)) + return hex.EncodeToString(b[:]) +} + +func (ses *Session) Context() *Context { + return ses.ctx +} + +func (ses *Session) Request() *http.Request { + return ses.request +} + +func (ses *Session) Response() *http.Response { + return ses.response +} + +type wrappedConn struct { + net.Conn + bytes int64 +} + +func (c *wrappedConn) Write(p []byte) (n int, err error) { + if n, err = c.Conn.Write(p); n > 0 { + atomic.AddInt64(&c.bytes, int64(n)) + } + return +} diff --git a/proxy/stats/stats.go b/proxy/stats/stats.go new file mode 100644 index 0000000..183a898 --- /dev/null +++ b/proxy/stats/stats.go @@ -0,0 +1,225 @@ +package stats + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "net/http" + "os" + "os/user" + "path/filepath" + "time" + + "git.maze.io/maze/styx/internal/log" + _ "github.com/mattn/go-sqlite3" +) + +type Stats struct { + db *sql.DB +} + +func New() (*Stats, error) { + u, err := user.Current() + if err != nil { + return nil, err + } + + path := filepath.Join(u.HomeDir, ".styx", "stats.db") + if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil { + return nil, err + } + + db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL") + if err != nil { + return nil, err + } + + for _, table := range []string{ + createLog, + createDomainStat, + createStatusStat, + } { + if _, err = db.Exec(table); err != nil { + return nil, err + } + } + + return &Stats{db: db}, nil +} + +func (s *Stats) AddLog(entry *Log) error { + var ( + request []byte + response []byte + err error + ) + if request, err = json.Marshal(entry.Request); err != nil { + return err + } + if response, err = json.Marshal(entry.Response); err != nil { + return err + } + + tx, err := s.db.Begin() + if err != nil { + return err + } + stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)") + if err != nil { + return err + } + defer stmt.Close() + if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil { + return err + } + return tx.Commit() +} + +func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) { + if limit == 0 { + limit = 50 + } + + rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var logs []*Log + for rows.Next() { + var entry = new(Log) + if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil { + return nil, err + } + logs = append(logs, entry) + } + + return logs, nil +} + +type Status struct { + Code int `json:"code"` + Count int `json:"count"` +} + +var timeZero time.Time + +func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) { + if since.Equal(timeZero) { + since = time.Now().Add(-24 * time.Hour) + } + + rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since) + if err != nil { + return nil, err + } + + var stats []*Status + for rows.Next() { + var entry = new(Status) + if err = rows.Scan(&entry.Code, &entry.Count); err != nil { + return nil, err + } + stats = append(stats, entry) + } + return stats, nil +} + +const createLog = `CREATE TABLE IF NOT EXISTS styx_log ( + id INT PRIMARY KEY, + dt DATETIME DEFAULT CURRENT_TIMESTAMP, + client_ip TEXT NOT NULL, + request JSONB NOT NULL, + response JSONB NOT NULL +);` + +type Log struct { + Time time.Time `json:"time"` + ClientIP string `json:"client_ip"` + Request *Request `json:"request"` + Response *Response `json:"response"` +} + +type Request struct { + URL string `json:"url"` + Host string `json:"host"` + Method string `json:"method"` + Proto string `json:"proto"` + Header http.Header `json:"header"` +} + +func (r *Request) Scan(value any) error { + switch v := value.(type) { + case string: + return json.Unmarshal([]byte(v), r) + case []byte: + return json.Unmarshal(v, r) + default: + log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type") + return nil + } +} + +func (r *Request) Value() (driver.Value, error) { + b, err := json.Marshal(r) + return string(b), err +} + +func FromRequest(r *http.Request) *Request { + return &Request{ + URL: r.URL.String(), + Host: r.Host, + Method: r.Method, + Proto: r.Proto, + Header: r.Header, + } +} + +type Response struct { + Status int `json:"status"` + Size int64 `json:"size"` + Header http.Header `json:"header"` +} + +func (r *Response) Scan(value any) error { + switch v := value.(type) { + case string: + return json.Unmarshal([]byte(v), r) + case []byte: + return json.Unmarshal(v, r) + default: + log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type") + return nil + } +} + +func (r *Response) Value() (driver.Value, error) { + b, err := json.Marshal(r) + return string(b), err +} + +func (r *Response) SetSize(size int64) *Response { + r.Size = size + return r +} + +func FromResponse(r *http.Response) *Response { + return &Response{ + Status: r.StatusCode, + Header: r.Header, + } +} + +const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status ( + id INT PRIMARY KEY, + dt DATETIME DEFAULT CURRENT_TIMESTAMP, + status INT NOT NULL +);` + +const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain ( + id INT PRIMARY KEY, + dt DATETIME DEFAULT CURRENT_TIMESTAMP, + domain TEXT NOT NULL +);` diff --git a/proxy/util.go b/proxy/util.go new file mode 100644 index 0000000..8f1a4ae --- /dev/null +++ b/proxy/util.go @@ -0,0 +1,16 @@ +package proxy + +import ( + "io" + "net" +) + +// connReader is a net.Conn with a separate reader. +type connReader struct { + net.Conn + io.Reader +} + +func (c connReader) Read(p []byte) (int, error) { + return c.Reader.Read(p) +} diff --git a/styx.hcl b/styx.hcl new file mode 100644 index 0000000..4dbd57d --- /dev/null +++ b/styx.hcl @@ -0,0 +1,154 @@ + +proxy { + # TCP listen address + listen = ":3128" + + # TCP bind address for outgoing connections + #bind = "10.42.42.215" + # Interface for outgoign connections + #interface = "en1" + + # Upstream proxies + upstream = [] + + + policy { + on intercept { + domain = ["sensitive"] + permit = false + } + + on request { + source = ["kids"] + domain = ["nsfw"] + permit = false + } + + on request { + source = ["kids"] + domain = ["nsfw"] + permit = false + } + + on days { + days = "mon-thu,sun" + on time { + time = ["22:00", "06:00"] + on request { + source = ["kids"] + domain = ["social"] + permit = false + } + } + } + } +} + +dns { + # Set the cache size + #size = 1024 + + # Set the time to live for positive responses (in seconds) + #ttl = 300 + + # Set the resolve timeout (in seconds) + #timeout = 10 + + # Set the DNS servers + #servers = ["1.1.1.1", "8.8.8.8"] + + # Disable IPv6 + noipv6 = true +} + +mitm { + ca { + cert = "testdata/ca.crt" + key = "testdata/ca.key" + key_type = "ecc" + days = 1825 + organization = "maze.io" + } + + key { + type = "rsa" + bits = 2048 + } + + cache { + #type = "memory" + type = "disk" + path = "testdata/mitm" + expire = 10 + } +} + +cache { + type = "memory" + size = 10485760 +} + +match { + path = "testdata/match" + + network "internal" { + type = "list" + list = [ + "0.0.0.0/32", + "127.0.0.0/8", + "169.254.0.0/16", + "fe80::/10", + ] + } + + network "kids" { + type = "list" + list = ["10.42.66.0/24"] + } + + domain "sensitive" { + type = "list" + list = [ + # Banking + "abnamro.nl", + "knab.nl", + "rabobank.nl", + + # Government + "belastingdienst.nl", + "digid.nl", + + # Messaging + "signal.org", + "telegram.org", + "whatsapp.net", + "whatsapp.com", + ] + } + + domain "social" { + type = "list" + list = [ + "pinterest.com", + "reddit.com", + "x.com", + # YouTube + "googlevideo.com", + "youtube.com", + "youtu.be", + "ytimg.com", + ] + } + + domain "nsfw" { + type = "domains" + from = "https://energized.pro/nsfw/domains.txt" + refresh = 43200 # 12h + } + + domain "ads" { + type = "detect" + from = "https://small.oisd.nl/dnsmasq" + refresh = 12 + } +} diff --git a/testdata/ca.crt b/testdata/ca.crt new file mode 100644 index 0000000..f308dae --- /dev/null +++ b/testdata/ca.crt @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBrjCCAVSgAwIBAgIQYEfQcIZJ90sXXLyE1F0gpzAKBggqhkjOPQQDAjA3MRAw +DgYDVQQKEwdtYXplLmlvMSMwIQYDVQQDExpTdHl4IENlcnRpZmljYXRlIEF1dGhv +cml0eTAeFw0yNTA5MjQwMDAwMDBaFw0zMDA5MjMwMDAwMDBaMDcxEDAOBgNVBAoT +B21hemUuaW8xIzAhBgNVBAMTGlN0eXggQ2VydGlmaWNhdGUgQXV0aG9yaXR5MFkw +EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMS3tcysM9OjDLrZNTp2Nw5jqsPcrfaGW +jBsPACynhhNx8oKYrRjabbbZsqXQiBbEeFw75U+CS82WGS+c7DpttaNCMEAwDgYD +VR0PAQH/BAQDAgIEMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFKHEYd+Lckg0 +ywh26MypID6hLse2MAoGCCqGSM49BAMCA0gAMEUCIQCwNrBAa0W9lHIQ9xy0+402 +QH/xlaz1xDDFwMINQ54r0AIgDp7E2jmbwa45zC1DJVXVJuHS+8XGcgP+LdvzhPV2 +J70= +-----END CERTIFICATE----- diff --git a/testdata/ca.key b/testdata/ca.key new file mode 100644 index 0000000..2d880cb --- /dev/null +++ b/testdata/ca.key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIL/DOgsInoOhgVZ24VIf7dfHSyyuj57KQw8vPl1Gs2imoAoGCCqGSM49 +AwEHoUQDQgAEMS3tcysM9OjDLrZNTp2Nw5jqsPcrfaGWjBsPACynhhNx8oKYrRja +bbbZsqXQiBbEeFw75U+CS82WGS+c7DpttQ== +-----END EC PRIVATE KEY-----