Initial import
This commit is contained in:
		
							
								
								
									
										5
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
# Styx
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
Styx is a filtering HTTP proxy.
 | 
			
		||||
							
								
								
									
										85
									
								
								cmd/styx/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								cmd/styx/main.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,85 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"flag"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"syscall"
 | 
			
		||||
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/hclsimple"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/cache"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/mitm"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/resolver"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	configFlag := flag.String("config", "styx.hcl", "Configuration file")
 | 
			
		||||
	traceFlag := flag.Bool("T", false, "Enable trace level logging")
 | 
			
		||||
	debugFlag := flag.Bool("D", false, "Enable debug level logging")
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
 | 
			
		||||
	if *traceFlag {
 | 
			
		||||
		log.SetLevel(log.TraceLevel)
 | 
			
		||||
	} else if *debugFlag {
 | 
			
		||||
		log.SetLevel(log.DebugLevel)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	config, err := load(*configFlag)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal().Err(err).Msg("")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	matchers, err := config.Match.Matchers()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal().Err(err).Msg("")
 | 
			
		||||
	} else if err = config.Proxy.Policy.Configure(matchers); err != nil {
 | 
			
		||||
		log.Fatal().Err(err).Msg("")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ca mitm.Authority
 | 
			
		||||
	if config.MITM != nil {
 | 
			
		||||
		if ca, err = mitm.New(config.MITM); err != nil {
 | 
			
		||||
			log.Fatal().Err(err).Msg("error configuring mitm")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	server, err := proxy.New(&config.Proxy, ca)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal().Err(err).Msg("")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = server.Start(); err != nil {
 | 
			
		||||
		log.Fatal().Err(err).Msg("")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	signalChannel := make(chan os.Signal, 1)
 | 
			
		||||
	signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
 | 
			
		||||
	<-signalChannel
 | 
			
		||||
 | 
			
		||||
	server.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	DNS   *resolver.Config `hcl:"dns,block"`
 | 
			
		||||
	Proxy proxy.Config     `hcl:"proxy,block"`
 | 
			
		||||
	MITM  *mitm.Config     `hcl:"mitm,block"`
 | 
			
		||||
	Cache *cache.Config    `hcl:"cache,block"`
 | 
			
		||||
	Match *match.Config    `hcl:"match,block"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func load(name string) (*Config, error) {
 | 
			
		||||
	config := new(Config)
 | 
			
		||||
	if err := hclsimple.DecodeFile(name, nil, config); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.DNS != nil {
 | 
			
		||||
		config.Proxy.Resolver = resolver.New(*config.DNS)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return config, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										29
									
								
								go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								go.mod
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
module git.maze.io/maze/styx
 | 
			
		||||
 | 
			
		||||
go 1.25.0
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/hashicorp/golang-lru/v2 v2.0.7
 | 
			
		||||
	github.com/hashicorp/hcl/v2 v2.24.0
 | 
			
		||||
	github.com/miekg/dns v1.1.68
 | 
			
		||||
	github.com/rs/zerolog v1.34.0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/agext/levenshtein v1.2.1 // indirect
 | 
			
		||||
	github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
 | 
			
		||||
	github.com/google/go-cmp v0.6.0 // indirect
 | 
			
		||||
	github.com/google/uuid v1.6.0 // indirect
 | 
			
		||||
	github.com/mattn/go-colorable v0.1.13 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.19 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.32 // indirect
 | 
			
		||||
	github.com/mitchellh/go-wordwrap v1.0.1 // indirect
 | 
			
		||||
	github.com/yl2chen/cidranger v1.0.2 // indirect
 | 
			
		||||
	github.com/zclconf/go-cty v1.16.3 // indirect
 | 
			
		||||
	golang.org/x/mod v0.24.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.40.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.14.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.33.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.25.0 // indirect
 | 
			
		||||
	golang.org/x/tools v0.33.0 // indirect
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										60
									
								
								go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								go.sum
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,60 @@
 | 
			
		||||
github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8=
 | 
			
		||||
github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
 | 
			
		||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
 | 
			
		||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
 | 
			
		||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68=
 | 
			
		||||
github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
 | 
			
		||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
 | 
			
		||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
 | 
			
		||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 | 
			
		||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 | 
			
		||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
 | 
			
		||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
 | 
			
		||||
github.com/hashicorp/hcl/v2 v2.24.0 h1:2QJdZ454DSsYGoaE6QheQZjtKZSUs9Nh2izTWiwQxvE=
 | 
			
		||||
github.com/hashicorp/hcl/v2 v2.24.0/go.mod h1:oGoO1FIQYfn/AgyOhlg9qLC6/nOJPX3qGbkZpYAcqfM=
 | 
			
		||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
 | 
			
		||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 | 
			
		||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
 | 
			
		||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
 | 
			
		||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
 | 
			
		||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
 | 
			
		||||
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
 | 
			
		||||
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
 | 
			
		||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 | 
			
		||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
 | 
			
		||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
 | 
			
		||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
 | 
			
		||||
github.com/yl2chen/cidranger v1.0.2 h1:lbOWZVCG1tCRX4u24kuM1Tb4nHqWkDxwLdoS+SevawU=
 | 
			
		||||
github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g=
 | 
			
		||||
github.com/zclconf/go-cty v1.16.3 h1:osr++gw2T61A8KVYHoQiFbFd1Lh3JOCXc/jFLJXKTxk=
 | 
			
		||||
github.com/zclconf/go-cty v1.16.3/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE=
 | 
			
		||||
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6ZMSMNJFMOjqrGHynW3DIBuR2H9j0ug+Mo=
 | 
			
		||||
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM=
 | 
			
		||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
 | 
			
		||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
 | 
			
		||||
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
 | 
			
		||||
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
 | 
			
		||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
 | 
			
		||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 | 
			
		||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
 | 
			
		||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
 | 
			
		||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
 | 
			
		||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
 | 
			
		||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
 | 
			
		||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
 | 
			
		||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 | 
			
		||||
							
								
								
									
										114
									
								
								internal/cryptutil/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								internal/cryptutil/key.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,114 @@
 | 
			
		||||
package cryptutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto"
 | 
			
		||||
	"crypto/ecdsa"
 | 
			
		||||
	"crypto/ed25519"
 | 
			
		||||
	"crypto/elliptic"
 | 
			
		||||
	"crypto/rsa"
 | 
			
		||||
	"crypto/sha1"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"encoding/pem"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type publicKeyer interface {
 | 
			
		||||
	Public() any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PublicKey returns the public part of a crypto.PrivateKey.
 | 
			
		||||
func PublicKey(key crypto.PrivateKey) any {
 | 
			
		||||
	switch key := key.(type) {
 | 
			
		||||
	case ed25519.PublicKey:
 | 
			
		||||
		return key
 | 
			
		||||
	case ed25519.PrivateKey:
 | 
			
		||||
		return key.Public()
 | 
			
		||||
	case ecdsa.PublicKey:
 | 
			
		||||
		return &key
 | 
			
		||||
	case *ecdsa.PublicKey:
 | 
			
		||||
		return key
 | 
			
		||||
	case *ecdsa.PrivateKey:
 | 
			
		||||
		return &key.PublicKey
 | 
			
		||||
	case rsa.PublicKey:
 | 
			
		||||
		return &key
 | 
			
		||||
	case *rsa.PublicKey:
 | 
			
		||||
		return key
 | 
			
		||||
	case *rsa.PrivateKey:
 | 
			
		||||
		return &key.PublicKey
 | 
			
		||||
	default:
 | 
			
		||||
		if p, ok := key.(publicKeyer); ok {
 | 
			
		||||
			return p.Public()
 | 
			
		||||
		}
 | 
			
		||||
		panic(fmt.Sprintf("don't know how to extract a public key from %T", key))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoadPrivateKey loads a private key from disk.
 | 
			
		||||
func LoadPrivateKey(name string) (crypto.PrivateKey, error) {
 | 
			
		||||
	b, err := os.ReadFile(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return decodePEMPrivateKey(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decodePEMPrivateKey(b []byte) (key crypto.PrivateKey, err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		rest  = b
 | 
			
		||||
		block *pem.Block
 | 
			
		||||
	)
 | 
			
		||||
	for {
 | 
			
		||||
		if block, rest = pem.Decode(rest); block == nil {
 | 
			
		||||
			return nil, errors.New("mitm: no private key PEM block could be decoded")
 | 
			
		||||
		}
 | 
			
		||||
		switch block.Type {
 | 
			
		||||
		case "EC PRIVATE KEY":
 | 
			
		||||
			return x509.ParseECPrivateKey(block.Bytes)
 | 
			
		||||
 | 
			
		||||
		case "RSA PRIVATE KEY":
 | 
			
		||||
			return x509.ParsePKCS1PrivateKey(block.Bytes)
 | 
			
		||||
 | 
			
		||||
		case "PRIVATE KEY":
 | 
			
		||||
			return x509.ParsePKCS8PrivateKey(block.Bytes)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateKeyID generates the PKIX public key ID.
 | 
			
		||||
func GenerateKeyID(key crypto.PublicKey) []byte {
 | 
			
		||||
	b, err := x509.MarshalPKIXPublicKey(key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return sha1.New().Sum(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func keyType(key any) string {
 | 
			
		||||
	switch key := key.(type) {
 | 
			
		||||
	case ed25519.PrivateKey:
 | 
			
		||||
		return "ed25519"
 | 
			
		||||
	case *ecdsa.PrivateKey:
 | 
			
		||||
		return "ecdsa (" + curveType(key.Curve) + ")"
 | 
			
		||||
	case *rsa.PrivateKey:
 | 
			
		||||
		return "rsa"
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Sprintf("%T", key)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func curveType(c elliptic.Curve) string {
 | 
			
		||||
	switch c {
 | 
			
		||||
	case elliptic.P224():
 | 
			
		||||
		return "p224"
 | 
			
		||||
	case elliptic.P256():
 | 
			
		||||
		return "p256"
 | 
			
		||||
	case elliptic.P384():
 | 
			
		||||
		return "p384"
 | 
			
		||||
	case elliptic.P521():
 | 
			
		||||
		return "p521"
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Sprintf("%T", c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										242
									
								
								internal/cryptutil/x509.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								internal/cryptutil/x509.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,242 @@
 | 
			
		||||
package cryptutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto"
 | 
			
		||||
	"crypto/ecdsa"
 | 
			
		||||
	"crypto/ed25519"
 | 
			
		||||
	"crypto/elliptic"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/rsa"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
	"encoding/pem"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Supported key types.
 | 
			
		||||
const (
 | 
			
		||||
	TypeRSA     = "rsa"
 | 
			
		||||
	TypeECDSA   = "ecdsa"
 | 
			
		||||
	TypeED25519 = "ed25519"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Supported PEM block types.
 | 
			
		||||
const (
 | 
			
		||||
	pemTypeCert  = "CERTIFICATE"
 | 
			
		||||
	pemTypeRSA   = "RSA PRIVATE KEY"
 | 
			
		||||
	pemTypeECDSA = "EC PRIVATE KEY"
 | 
			
		||||
	pemTypeAny   = "PRIVATE KEY"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// LoadKeyPair loads a certificate and private key, certdata and keydata can be a PEM encoded block or a file.
 | 
			
		||||
//
 | 
			
		||||
// If [keydata] is empty, then the private key is assumed to be contained in [certdata].
 | 
			
		||||
func LoadKeyPair(certdata, keydata string) (cert *x509.Certificate, key crypto.PrivateKey, err error) {
 | 
			
		||||
	if keydata == "" {
 | 
			
		||||
		keydata = certdata
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(certdata, "-----BEGIN "+pemTypeCert) {
 | 
			
		||||
		log.Trace().Msg("parsing X.509 certificate")
 | 
			
		||||
		if cert, err = decodePEMBCertificate([]byte(certdata)); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Trace().Str("name", certdata).Msg("loading X.509 certificate")
 | 
			
		||||
		if cert, err = LoadCertificate(certdata); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(keydata, pemTypeAny+"-----") {
 | 
			
		||||
		log.Trace().Msg("parsing private key")
 | 
			
		||||
		if key, err = decodePEMPrivateKey([]byte(keydata)); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else if key, err = LoadPrivateKey(keydata); err != nil {
 | 
			
		||||
		log.Trace().Str("name", keydata).Msg("loading private key")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SaveKeyPair saves a certificate and private key in PEM encoding.
 | 
			
		||||
//
 | 
			
		||||
// If [keyFile] is empty, then the private key is stored in [certFile] alongside the certificate.
 | 
			
		||||
//
 | 
			
		||||
// Attempts are made to use secure file modes for files that contains private keys.
 | 
			
		||||
func SaveKeyPair(cert *x509.Certificate, key crypto.PrivateKey, certFile, keyFile string) (err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		keyDER     []byte
 | 
			
		||||
		keyPEMType = pemTypeAny
 | 
			
		||||
	)
 | 
			
		||||
	switch key := key.(type) {
 | 
			
		||||
	case *ecdsa.PrivateKey:
 | 
			
		||||
		if keyDER, err = x509.MarshalECPrivateKey(key); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		keyPEMType = pemTypeECDSA
 | 
			
		||||
	case ed25519.PrivateKey:
 | 
			
		||||
		if keyDER, err = x509.MarshalPKCS8PrivateKey(key); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case *rsa.PrivateKey:
 | 
			
		||||
		keyDER = x509.MarshalPKCS1PrivateKey(key)
 | 
			
		||||
		keyPEMType = pemTypeRSA
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Errorf("mitm: don't know how to marshal %T", key)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var certf, keyf *os.File
 | 
			
		||||
	if certf, err = os.OpenFile(certFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer func() { _ = certf.Close() }()
 | 
			
		||||
 | 
			
		||||
	if filepath.Clean(certFile) == filepath.Clean(keyFile) || keyFile == "" {
 | 
			
		||||
		if err = certf.Chmod(0o600); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		keyf, keyFile = certf, certFile
 | 
			
		||||
	} else {
 | 
			
		||||
		if keyf, err = os.OpenFile(keyFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o600); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer func() { _ = keyf.Close() }()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug().Str("file", certFile).Msg("saving X.509 certificate")
 | 
			
		||||
	if err = pem.Encode(certf, &pem.Block{Type: pemTypeCert, Bytes: cert.Raw}); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug().Str("fiile", keyFile).Msg("saving private key")
 | 
			
		||||
	if err = pem.Encode(keyf, &pem.Block{Type: keyPEMType, Bytes: keyDER}); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateKeyPair generates a private key and self-signed certificate.
 | 
			
		||||
func GenerateKeyPair(name pkix.Name, days int, keyType string, keyBits int) (cert *x509.Certificate, key crypto.PrivateKey, err error) {
 | 
			
		||||
	if key, err = GeneratePrivateKey(keyType, keyBits); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if cert, err = GenerateCertificateAuthority(name, days, key); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenerateCertificateAuthority(name pkix.Name, days int, key crypto.PrivateKey) (cert *x509.Certificate, err error) {
 | 
			
		||||
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 | 
			
		||||
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	keyUsage := x509.KeyUsageCertSign
 | 
			
		||||
	if _, ok := key.(*rsa.PrivateKey); ok {
 | 
			
		||||
		keyUsage |= x509.KeyUsageDigitalSignature
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	notBefore := roundToDay(time.Now())
 | 
			
		||||
	notAfter := notBefore.Add(time.Duration(days) * 24 * time.Hour)
 | 
			
		||||
 | 
			
		||||
	template := &x509.Certificate{
 | 
			
		||||
		Subject:               name,
 | 
			
		||||
		SerialNumber:          serialNumber,
 | 
			
		||||
		KeyUsage:              keyUsage,
 | 
			
		||||
		SubjectKeyId:          GenerateKeyID(key),
 | 
			
		||||
		IsCA:                  true,
 | 
			
		||||
		BasicConstraintsValid: true,
 | 
			
		||||
		NotBefore:             notBefore,
 | 
			
		||||
		NotAfter:              notAfter,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Info().
 | 
			
		||||
		Str("name", name.CommonName).
 | 
			
		||||
		Int("days", days).
 | 
			
		||||
		Str("key", keyType(key)).
 | 
			
		||||
		Str("serial", serialNumber.String()).
 | 
			
		||||
		Msg("generating self-signed CA certificate")
 | 
			
		||||
 | 
			
		||||
	var der []byte
 | 
			
		||||
	if der, err = x509.CreateCertificate(rand.Reader, template, template, PublicKey(key), key); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return x509.ParseCertificate(der)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GeneratePrivateKey(kind string, bits int) (key crypto.PrivateKey, err error) {
 | 
			
		||||
	switch strings.ToLower(kind) {
 | 
			
		||||
	case TypeRSA, "":
 | 
			
		||||
		if bits == 0 {
 | 
			
		||||
			bits = 2048
 | 
			
		||||
		}
 | 
			
		||||
		log.Trace().Int("bits", bits).Str("type", TypeRSA).Msg("generating private key")
 | 
			
		||||
		return rsa.GenerateKey(rand.Reader, bits)
 | 
			
		||||
 | 
			
		||||
	case TypeECDSA, "ec", "ecc":
 | 
			
		||||
		if bits == 0 {
 | 
			
		||||
			bits = 256
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var curve elliptic.Curve
 | 
			
		||||
		switch bits {
 | 
			
		||||
		case 224:
 | 
			
		||||
			curve = elliptic.P224()
 | 
			
		||||
		case 256:
 | 
			
		||||
			curve = elliptic.P256()
 | 
			
		||||
		case 384:
 | 
			
		||||
			curve = elliptic.P384()
 | 
			
		||||
		case 521:
 | 
			
		||||
			curve = elliptic.P521()
 | 
			
		||||
		default:
 | 
			
		||||
			return nil, fmt.Errorf("mitm: elliptic curve %d bits not supported", bits)
 | 
			
		||||
		}
 | 
			
		||||
		log.Trace().Int("bits", bits).Str("type", TypeECDSA).Msg("generating private key")
 | 
			
		||||
		return ecdsa.GenerateKey(curve, rand.Reader)
 | 
			
		||||
 | 
			
		||||
	case TypeED25519:
 | 
			
		||||
		log.Trace().Str("type", TypeED25519).Msg("generating ED25519 private key")
 | 
			
		||||
		_, key, err = ed25519.GenerateKey(rand.Reader)
 | 
			
		||||
		return
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("mitm: don't know how to generate %s private key", kind)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decodePEMBCertificate(b []byte) (cert *x509.Certificate, err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		rest  = b
 | 
			
		||||
		block *pem.Block
 | 
			
		||||
	)
 | 
			
		||||
	for {
 | 
			
		||||
		if block, rest = pem.Decode(rest); block == nil {
 | 
			
		||||
			return nil, errors.New("mitm: no CERTIFICATE PEM block could be decoded")
 | 
			
		||||
		} else if block.Type == "CERTIFICATE" {
 | 
			
		||||
			return x509.ParseCertificate(block.Bytes)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LoadCertificate(name string) (*x509.Certificate, error) {
 | 
			
		||||
	b, err := os.ReadFile(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return decodePEMBCertificate(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func roundToDay(t time.Time) time.Time {
 | 
			
		||||
	return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								internal/log/log.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/log/log.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package log
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
 | 
			
		||||
	"github.com/rs/zerolog"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Aliases
 | 
			
		||||
const (
 | 
			
		||||
	TraceLevel = zerolog.TraceLevel
 | 
			
		||||
	DebugLevel = zerolog.DebugLevel
 | 
			
		||||
	InfoLevel  = zerolog.InfoLevel
 | 
			
		||||
	WarnLevel  = zerolog.WarnLevel
 | 
			
		||||
	FatalLevel = zerolog.FatalLevel
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Aliases
 | 
			
		||||
type (
 | 
			
		||||
	Event  = zerolog.Event
 | 
			
		||||
	Logger = zerolog.Logger
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Console logger.
 | 
			
		||||
var Console = zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger()
 | 
			
		||||
 | 
			
		||||
func SetLevel(level zerolog.Level) {
 | 
			
		||||
	zerolog.SetGlobalLevel(level)
 | 
			
		||||
	//Console = Console.Level(level)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Trace() *Event { return Console.Trace() }
 | 
			
		||||
func Debug() *Event { return Console.Debug() }
 | 
			
		||||
func Info() *Event  { return Console.Info() }
 | 
			
		||||
func Warn() *Event  { return Console.Warn() }
 | 
			
		||||
func Error() *Event { return Console.Error() }
 | 
			
		||||
func Fatal() *Event { return Console.Fatal() }
 | 
			
		||||
func Panic() *Event { return Console.Panic() }
 | 
			
		||||
 | 
			
		||||
func OnCloseError(event *Event, closer io.Closer) {
 | 
			
		||||
	if err := closer.Close(); err != nil {
 | 
			
		||||
		event.Err(err).Msg("close failed")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										35
									
								
								internal/netutil/addr.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								internal/netutil/addr.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// EnsurePort makes sure the address in [host] contains a port.
 | 
			
		||||
func EnsurePort(host, port string) string {
 | 
			
		||||
	if _, _, err := net.SplitHostPort(host); err == nil {
 | 
			
		||||
		return host
 | 
			
		||||
	}
 | 
			
		||||
	return net.JoinHostPort(host, port)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Host returns the bare host (without port).
 | 
			
		||||
func Host(name string) string {
 | 
			
		||||
	host, _, err := net.SplitHostPort(name)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return host
 | 
			
		||||
	}
 | 
			
		||||
	return name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Port returns the port number.
 | 
			
		||||
func Port(name string) int {
 | 
			
		||||
	_, port, err := net.SplitHostPort(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: name resolution for ports?
 | 
			
		||||
	i, _ := strconv.Atoi(port)
 | 
			
		||||
	return i
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										99
									
								
								internal/netutil/domain.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								internal/netutil/domain.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,99 @@
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DomainTree struct {
 | 
			
		||||
	root *domainTreeNode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type domainTreeNode struct {
 | 
			
		||||
	leaf  map[string]*domainTreeNode
 | 
			
		||||
	isEnd bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDomainList(domains ...string) *DomainTree {
 | 
			
		||||
	tree := &DomainTree{
 | 
			
		||||
		root: &domainTreeNode{leaf: make(map[string]*domainTreeNode)},
 | 
			
		||||
	}
 | 
			
		||||
	for _, domain := range domains {
 | 
			
		||||
		tree.Add(domain)
 | 
			
		||||
	}
 | 
			
		||||
	return tree
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tree *DomainTree) Add(domain string) {
 | 
			
		||||
	domain = normalizeDomain(domain)
 | 
			
		||||
	if domain == "" {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	labels := dns.SplitDomainName(domain)
 | 
			
		||||
	if len(labels) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	node := tree.root
 | 
			
		||||
	for i := len(labels) - 1; i >= 0; i-- {
 | 
			
		||||
		label := labels[i]
 | 
			
		||||
		if label == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if node.leaf == nil {
 | 
			
		||||
			node.leaf = make(map[string]*domainTreeNode)
 | 
			
		||||
		}
 | 
			
		||||
		if node.leaf[label] == nil {
 | 
			
		||||
			node.leaf[label] = &domainTreeNode{}
 | 
			
		||||
		}
 | 
			
		||||
		node = node.leaf[label]
 | 
			
		||||
	}
 | 
			
		||||
	node.isEnd = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tree *DomainTree) Contains(domain string) bool {
 | 
			
		||||
	domain = normalizeDomain(domain)
 | 
			
		||||
	if domain == "" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	labels := dns.SplitDomainName(domain)
 | 
			
		||||
	if len(labels) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	node := tree.root
 | 
			
		||||
	for i := len(labels) - 1; i >= 0; i-- {
 | 
			
		||||
		if node.isEnd {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if node.leaf == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		label := labels[i]
 | 
			
		||||
		if node = node.leaf[label]; node == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return node.isEnd
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func normalizeDomain(domain string) string {
 | 
			
		||||
	domain = strings.ToLower(strings.TrimSpace(domain))
 | 
			
		||||
	if domain == "" {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Remove trailing dot if present, dns.Fqdn will add it back properly
 | 
			
		||||
	domain = strings.TrimSuffix(domain, ".")
 | 
			
		||||
 | 
			
		||||
	if domain == "" {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dns.Fqdn(domain)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										276
									
								
								internal/netutil/domain_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										276
									
								
								internal/netutil/domain_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,276 @@
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestDomainList(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		domains  []string
 | 
			
		||||
		hostname string
 | 
			
		||||
		expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		// Basic exact matches
 | 
			
		||||
		{
 | 
			
		||||
			name:     "exact match",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "exact match with subdomain in list",
 | 
			
		||||
			domains:  []string{"api.example.com"},
 | 
			
		||||
			hostname: "api.example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Suffix matching - if domain is in list, all subdomains should match
 | 
			
		||||
		{
 | 
			
		||||
			name:     "subdomain matches parent domain",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "sub.example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "multiple subdomain levels match",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "deep.nested.sub.example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "subdomain matches intermediate domain",
 | 
			
		||||
			domains:  []string{"api.example.com", "example.com"},
 | 
			
		||||
			hostname: "sub.api.example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Multi-level TLDs
 | 
			
		||||
		{
 | 
			
		||||
			name:     "co.uk domain exact match",
 | 
			
		||||
			domains:  []string{"domain.co.uk"},
 | 
			
		||||
			hostname: "domain.co.uk",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "subdomain of co.uk domain",
 | 
			
		||||
			domains:  []string{"domain.co.uk"},
 | 
			
		||||
			hostname: "sub.domain.co.uk",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Case sensitivity
 | 
			
		||||
		{
 | 
			
		||||
			name:     "case insensitive match",
 | 
			
		||||
			domains:  []string{"Example.COM"},
 | 
			
		||||
			hostname: "example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "case insensitive hostname",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "EXAMPLE.COM",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Trailing dots
 | 
			
		||||
		{
 | 
			
		||||
			name:     "domain with trailing dot",
 | 
			
		||||
			domains:  []string{"example.com."},
 | 
			
		||||
			hostname: "example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "hostname with trailing dot",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "example.com.",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Non-matches
 | 
			
		||||
		{
 | 
			
		||||
			name:     "different TLD",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "example.org",
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "different domain",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "test.com",
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "partial match but not suffix",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "com",
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "empty hostname",
 | 
			
		||||
			domains:  []string{"example.com"},
 | 
			
		||||
			hostname: "",
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Multiple domains in list
 | 
			
		||||
		{
 | 
			
		||||
			name:     "matches first domain in list",
 | 
			
		||||
			domains:  []string{"test.org", "example.com"},
 | 
			
		||||
			hostname: "example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "matches second domain in list",
 | 
			
		||||
			domains:  []string{"test.org", "example.com"},
 | 
			
		||||
			hostname: "test.org",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "subdomain matches any domain in list",
 | 
			
		||||
			domains:  []string{"test.org", "example.com"},
 | 
			
		||||
			hostname: "sub.example.com",
 | 
			
		||||
			expected: true,
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		// Edge cases
 | 
			
		||||
		{
 | 
			
		||||
			name:     "empty domain list",
 | 
			
		||||
			domains:  []string{},
 | 
			
		||||
			hostname: "example.com",
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "invalid domain in list",
 | 
			
		||||
			domains:  []string{""},
 | 
			
		||||
			hostname: "example.com",
 | 
			
		||||
			expected: false,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			list := NewDomainList(tt.domains...)
 | 
			
		||||
			result := list.Contains(tt.hostname)
 | 
			
		||||
 | 
			
		||||
			if result != tt.expected {
 | 
			
		||||
				t.Errorf("Contains(%q) = %v, expected %v (domains: %v)",
 | 
			
		||||
					tt.hostname, result, tt.expected, tt.domains)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDomainList_Performance(t *testing.T) {
 | 
			
		||||
	// Test with a large number of domains to ensure performance
 | 
			
		||||
	domains := make([]string, 1000)
 | 
			
		||||
	for i := 0; i < 1000; i++ {
 | 
			
		||||
		domains[i] = string(rune('a'+(i%26))) + ".com"
 | 
			
		||||
	}
 | 
			
		||||
	domains = append(domains, "example.com") // Add our test domain
 | 
			
		||||
 | 
			
		||||
	list := NewDomainList(domains...)
 | 
			
		||||
 | 
			
		||||
	// These should be fast even with many domains
 | 
			
		||||
	if !list.Contains("example.com") {
 | 
			
		||||
		t.Error("Should match exact domain")
 | 
			
		||||
	}
 | 
			
		||||
	if !list.Contains("sub.example.com") {
 | 
			
		||||
		t.Error("Should match subdomain")
 | 
			
		||||
	}
 | 
			
		||||
	if list.Contains("notfound.com") {
 | 
			
		||||
		t.Error("Should not match unrelated domain")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDomainList_ComplexDomains(t *testing.T) {
 | 
			
		||||
	domains := []string{
 | 
			
		||||
		"very.long.domain.name.with.many.labels.com",
 | 
			
		||||
		"example.co.uk",
 | 
			
		||||
		"sub.domain.example.com",
 | 
			
		||||
		"a.b.c.d.e.f.com",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	list := NewDomainList(domains...)
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		hostname string
 | 
			
		||||
		expected bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"very.long.domain.name.with.many.labels.com", true},
 | 
			
		||||
		{"sub.very.long.domain.name.with.many.labels.com", true},
 | 
			
		||||
		{"example.co.uk", true},
 | 
			
		||||
		{"www.example.co.uk", true},
 | 
			
		||||
		{"sub.domain.example.com", true},
 | 
			
		||||
		{"another.sub.domain.example.com", true},
 | 
			
		||||
		{"a.b.c.d.e.f.com", true},
 | 
			
		||||
		{"x.a.b.c.d.e.f.com", true},
 | 
			
		||||
		{"not.matching.com", false},
 | 
			
		||||
		{"com", false},
 | 
			
		||||
		{"uk", false},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.hostname, func(t *testing.T) {
 | 
			
		||||
			result := list.Contains(tt.hostname)
 | 
			
		||||
			if result != tt.expected {
 | 
			
		||||
				t.Errorf("Contains(%q) = %v, expected %v", tt.hostname, result, tt.expected)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDomainList_SpecialCases(t *testing.T) {
 | 
			
		||||
	t.Run("domain with asterisk treated literally", func(t *testing.T) {
 | 
			
		||||
		list := NewDomainList("*.example.com")
 | 
			
		||||
 | 
			
		||||
		// The asterisk should be treated as a literal label, not a wildcard
 | 
			
		||||
		if !list.Contains("*.example.com") {
 | 
			
		||||
			t.Error("Asterisk should be treated literally, not as wildcard")
 | 
			
		||||
		}
 | 
			
		||||
		if list.Contains("test.example.com") {
 | 
			
		||||
			t.Error("Should not match subdomain with literal asterisk domain")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("domains with hyphens and numbers", func(t *testing.T) {
 | 
			
		||||
		list := NewDomainList("test-123.example.com", "123abc.org")
 | 
			
		||||
 | 
			
		||||
		if !list.Contains("test-123.example.com") {
 | 
			
		||||
			t.Error("Should match domain with hyphens and numbers")
 | 
			
		||||
		}
 | 
			
		||||
		if !list.Contains("sub.test-123.example.com") {
 | 
			
		||||
			t.Error("Should match subdomain of hyphenated domain")
 | 
			
		||||
		}
 | 
			
		||||
		if !list.Contains("123abc.org") {
 | 
			
		||||
			t.Error("Should match domain starting with numbers")
 | 
			
		||||
		}
 | 
			
		||||
		if !list.Contains("www.123abc.org") {
 | 
			
		||||
			t.Error("Should match subdomain of numeric domain")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkDomainList(b *testing.B) {
 | 
			
		||||
	// Benchmark with realistic domain list
 | 
			
		||||
	domains := []string{
 | 
			
		||||
		"google.com",
 | 
			
		||||
		"github.com",
 | 
			
		||||
		"example.org",
 | 
			
		||||
		"sub.domain.com",
 | 
			
		||||
		"api.service.co.uk",
 | 
			
		||||
		"very.long.domain.name.example.com",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	list := NewDomainList(domains...)
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for b.Loop() {
 | 
			
		||||
		// Mix of matches and non-matches
 | 
			
		||||
		list.Contains("sub.example.org")
 | 
			
		||||
		list.Contains("api.github.com")
 | 
			
		||||
		list.Contains("nonexistent.com")
 | 
			
		||||
		list.Contains("deep.nested.sub.domain.com")
 | 
			
		||||
		list.Contains("service.co.uk")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								internal/netutil/network.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/netutil/network.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
 | 
			
		||||
	"github.com/yl2chen/cidranger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NetworkTree struct {
 | 
			
		||||
	ranger cidranger.Ranger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNetworkTree(networks ...string) (*NetworkTree, error) {
 | 
			
		||||
	tree := &NetworkTree{
 | 
			
		||||
		ranger: cidranger.NewPCTrieRanger(),
 | 
			
		||||
	}
 | 
			
		||||
	for _, cidr := range networks {
 | 
			
		||||
		if err := tree.AddCIDR(cidr); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return tree, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tree *NetworkTree) Add(ipnet *net.IPNet) {
 | 
			
		||||
	if ipnet == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tree *NetworkTree) AddCIDR(cidr string) error {
 | 
			
		||||
	_, ipnet, err := net.ParseCIDR(cidr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	tree.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tree *NetworkTree) Contains(ip net.IP) bool {
 | 
			
		||||
	contains, _ := tree.ranger.Contains(ip)
 | 
			
		||||
	return contains
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										410
									
								
								internal/netutil/network_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										410
									
								
								internal/netutil/network_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,410 @@
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewNetworkTree(t *testing.T) {
 | 
			
		||||
	// Test empty creation
 | 
			
		||||
	nl, err := NewNetworkTree()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if nl == nil {
 | 
			
		||||
		t.Fatal("NewNetworkTree() returned nil")
 | 
			
		||||
	}
 | 
			
		||||
	if nl.ranger == nil {
 | 
			
		||||
		t.Error("NetworkTree ranger should not be nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test creation with networks
 | 
			
		||||
	nl, err = NewNetworkTree("192.168.1.0/24", "10.0.0.0/8")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() with networks failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if nl == nil {
 | 
			
		||||
		t.Fatal("NewNetworkTree() with networks returned nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewNetworkTree_InvalidNetworks(t *testing.T) {
 | 
			
		||||
	// Test with invalid network
 | 
			
		||||
	_, err := NewNetworkTree("invalid-cidr")
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("NewNetworkTree() with invalid CIDR should have failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with mix of valid and invalid networks
 | 
			
		||||
	_, err = NewNetworkTree("192.168.1.0/24", "invalid-cidr", "10.0.0.0/8")
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("NewNetworkTree() with mixed valid/invalid CIDRs should have failed")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_AddCIDR_Valid(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		cidr string
 | 
			
		||||
		desc string
 | 
			
		||||
	}{
 | 
			
		||||
		{"192.168.1.0/24", "IPv4 CIDR"},
 | 
			
		||||
		{"10.0.0.0/8", "IPv4 large range"},
 | 
			
		||||
		{"2001:db8::/32", "IPv6 CIDR"},
 | 
			
		||||
		{"::1/128", "IPv6 localhost"},
 | 
			
		||||
		{"0.0.0.0/0", "IPv4 entire internet"},
 | 
			
		||||
		{"::/0", "IPv6 entire internet"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.desc, func(t *testing.T) {
 | 
			
		||||
			if err := nl.AddCIDR(tt.cidr); err != nil {
 | 
			
		||||
				t.Errorf("AddCIDR(%q) failed: %v", tt.cidr, err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_AddCIDR_Invalid(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	invalidCIDRs := []string{
 | 
			
		||||
		"invalid-cidr",
 | 
			
		||||
		"192.168.1.1",    // missing mask
 | 
			
		||||
		"192.168.1.0/33", // invalid mask for IPv4
 | 
			
		||||
		"2001:db8::/129", // invalid mask for IPv6
 | 
			
		||||
		"",
 | 
			
		||||
		"not-an-ip/24",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, cidr := range invalidCIDRs {
 | 
			
		||||
		t.Run(cidr, func(t *testing.T) {
 | 
			
		||||
			if err := nl.AddCIDR(cidr); err == nil {
 | 
			
		||||
				t.Errorf("AddCIDR(%q) should have failed but didn't", cidr)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Add(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		cidr string
 | 
			
		||||
		desc string
 | 
			
		||||
	}{
 | 
			
		||||
		{"192.168.1.0/24", "IPv4 network"},
 | 
			
		||||
		{"2001:db8::/32", "IPv6 network"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.desc, func(t *testing.T) {
 | 
			
		||||
			_, ipNet, err := net.ParseCIDR(tt.cidr)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("ParseCIDR failed: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Should not panic
 | 
			
		||||
			nl.Add(ipNet)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Contains_IPv4(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree("192.168.1.0/24", "10.0.0.0/8", "172.16.0.0/12")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		ip   string
 | 
			
		||||
		want bool
 | 
			
		||||
		desc string
 | 
			
		||||
	}{
 | 
			
		||||
		// IPs that should match
 | 
			
		||||
		{"192.168.1.1", true, "in 192.168.1.0/24"},
 | 
			
		||||
		{"192.168.1.255", true, "broadcast in 192.168.1.0/24"},
 | 
			
		||||
		{"10.0.0.1", true, "in 10.0.0.0/8"},
 | 
			
		||||
		{"10.255.255.255", true, "max in 10.0.0.0/8"},
 | 
			
		||||
		{"172.16.0.1", true, "in 172.16.0.0/12"},
 | 
			
		||||
		{"172.31.255.255", true, "max in 172.16.0.0/12"},
 | 
			
		||||
 | 
			
		||||
		// IPs that should not match
 | 
			
		||||
		{"192.168.2.1", false, "outside 192.168.1.0/24"},
 | 
			
		||||
		{"11.0.0.1", false, "outside 10.0.0.0/8"},
 | 
			
		||||
		{"172.32.0.1", false, "outside 172.16.0.0/12"},
 | 
			
		||||
		{"8.8.8.8", false, "public DNS"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.desc, func(t *testing.T) {
 | 
			
		||||
			ip := net.ParseIP(tt.ip)
 | 
			
		||||
			if ip == nil {
 | 
			
		||||
				t.Fatalf("ParseIP(%q) returned nil", tt.ip)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			got := nl.Contains(ip)
 | 
			
		||||
			if got != tt.want {
 | 
			
		||||
				t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Contains_IPv6(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree("2001:db8::/32", "2001:db8:abcd::/48", "::1/128")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		ip   string
 | 
			
		||||
		want bool
 | 
			
		||||
		desc string
 | 
			
		||||
	}{
 | 
			
		||||
		// IPs that should match
 | 
			
		||||
		{"2001:db8::1", true, "in 2001:db8::/32"},
 | 
			
		||||
		{"2001:db8:ffff:ffff:ffff:ffff:ffff:ffff", true, "max in 2001:db8::/32"},
 | 
			
		||||
		{"2001:db8:abcd::1", true, "in 2001:db8:abcd::/48"},
 | 
			
		||||
		{"::1", true, "localhost"},
 | 
			
		||||
 | 
			
		||||
		// IPs that should not match
 | 
			
		||||
		{"2001:db9::1", false, "outside 2001:db8::/32"},
 | 
			
		||||
		{"2001:db9:abcd::1", false, "outside 2001:db8:abcd::/48"},
 | 
			
		||||
		{"::2", false, "outside ::1/128"},
 | 
			
		||||
		{"2001:4860:4860::8888", false, "public DNS"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.desc, func(t *testing.T) {
 | 
			
		||||
			ip := net.ParseIP(tt.ip)
 | 
			
		||||
			if ip == nil {
 | 
			
		||||
				t.Fatalf("ParseIP(%q) returned nil", tt.ip)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			got := nl.Contains(ip)
 | 
			
		||||
			if got != tt.want {
 | 
			
		||||
				t.Errorf("Contains(%q) = %v, want %v", tt.ip, got, tt.want)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Contains_EdgeCases(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with nil IP
 | 
			
		||||
	if nl.Contains(nil) != false {
 | 
			
		||||
		t.Error("Contains(nil) should return false")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test with empty list
 | 
			
		||||
	ip := net.ParseIP("192.168.1.1")
 | 
			
		||||
	if nl.Contains(ip) != false {
 | 
			
		||||
		t.Error("Contains() on empty list should return false")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Contains_OverlappingRanges(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree("192.168.0.0/16", "192.168.1.0/24", "192.168.1.128/25")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// All these should match because we have overlapping ranges
 | 
			
		||||
	tests := []string{
 | 
			
		||||
		"192.168.1.1",
 | 
			
		||||
		"192.168.1.129",
 | 
			
		||||
		"192.168.2.1",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, ipStr := range tests {
 | 
			
		||||
		t.Run(ipStr, func(t *testing.T) {
 | 
			
		||||
			ip := net.ParseIP(ipStr)
 | 
			
		||||
			if !nl.Contains(ip) {
 | 
			
		||||
				t.Errorf("Contains(%q) should return true for overlapping ranges", ipStr)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Contains_EntireInternet(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree("0.0.0.0/0", "::/0")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		ip   string
 | 
			
		||||
		desc string
 | 
			
		||||
	}{
 | 
			
		||||
		{"192.168.1.1", "IPv4 private"},
 | 
			
		||||
		{"8.8.8.8", "IPv4 public"},
 | 
			
		||||
		{"2001:db8::1", "IPv6"},
 | 
			
		||||
		{"::1", "IPv6 localhost"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.desc, func(t *testing.T) {
 | 
			
		||||
			ip := net.ParseIP(tt.ip)
 | 
			
		||||
			if !nl.Contains(ip) {
 | 
			
		||||
				t.Errorf("Contains(%q) should return true for entire internet range", tt.ip)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_MixedIPv4AndIPv6(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree("192.168.1.0/24", "2001:db8::/32")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test IPv4 in IPv6 format (should still work due to normalization)
 | 
			
		||||
	ipv4InIPv6 := net.ParseIP("::ffff:192.168.1.1") // IPv4-mapped IPv6
 | 
			
		||||
	if !nl.Contains(ipv4InIPv6) {
 | 
			
		||||
		t.Error("Contains() should handle IPv4-mapped IPv6 addresses")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Regular IPv4 should work
 | 
			
		||||
	ipv4 := net.ParseIP("192.168.1.1")
 | 
			
		||||
	if !nl.Contains(ipv4) {
 | 
			
		||||
		t.Error("Contains() should handle regular IPv4 addresses")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// IPv6 should work
 | 
			
		||||
	ipv6 := net.ParseIP("2001:db8::1")
 | 
			
		||||
	if !nl.Contains(ipv6) {
 | 
			
		||||
		t.Error("Contains() should handle IPv6 addresses")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_Add_InvalidIPNet(t *testing.T) {
 | 
			
		||||
	nl, err := NewNetworkTree()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create an invalid IPNet (nil IP)
 | 
			
		||||
	invalidIPNet := &net.IPNet{
 | 
			
		||||
		IP:   nil,
 | 
			
		||||
		Mask: net.CIDRMask(24, 32),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// This should not panic
 | 
			
		||||
	nl.Add(invalidIPNet)
 | 
			
		||||
 | 
			
		||||
	// Verify that it doesn't affect Contains results
 | 
			
		||||
	ip := net.ParseIP("192.168.1.1")
 | 
			
		||||
	if nl.Contains(ip) {
 | 
			
		||||
		t.Error("Contains() should return false after adding invalid IPNet")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNetworkTree_InitializationWithNetworks(t *testing.T) {
 | 
			
		||||
	networks := []string{
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"172.16.0.0/12",
 | 
			
		||||
		"192.168.0.0/16",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nl, err := NewNetworkTree(networks...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("NewNetworkTree() with multiple networks failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test that all networks were added correctly
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		ip   string
 | 
			
		||||
		want bool
 | 
			
		||||
	}{
 | 
			
		||||
		{"10.1.2.3", true},
 | 
			
		||||
		{"172.16.1.1", true},
 | 
			
		||||
		{"192.168.1.1", true},
 | 
			
		||||
		{"2001:db8::1", true},
 | 
			
		||||
		{"8.8.8.8", false},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		ip := net.ParseIP(tc.ip)
 | 
			
		||||
		if got := nl.Contains(ip); got != tc.want {
 | 
			
		||||
			t.Errorf("Contains(%q) = %v, want %v", tc.ip, got, tc.want)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkNetworkTree_Contains(b *testing.B) {
 | 
			
		||||
	nl, err := NewNetworkTree(
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"172.16.0.0/12",
 | 
			
		||||
		"192.168.0.0/16",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		b.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	testIPs := []net.IP{
 | 
			
		||||
		net.ParseIP("10.1.2.3"),
 | 
			
		||||
		net.ParseIP("192.168.1.1"),
 | 
			
		||||
		net.ParseIP("2001:db8::1"),
 | 
			
		||||
		net.ParseIP("8.8.8.8"),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		ip := testIPs[i%len(testIPs)]
 | 
			
		||||
		nl.Contains(ip)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkNetworkTree_NewNetworkTree(b *testing.B) {
 | 
			
		||||
	cidrs := []string{
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"172.16.0.0/12",
 | 
			
		||||
		"192.168.0.0/16",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for b.Loop() {
 | 
			
		||||
		_, err := NewNetworkTree(cidrs...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			b.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkNetworkTree_AddCIDR(b *testing.B) {
 | 
			
		||||
	cidrs := []string{
 | 
			
		||||
		"10.0.0.0/8",
 | 
			
		||||
		"172.16.0.0/12",
 | 
			
		||||
		"192.168.0.0/16",
 | 
			
		||||
		"2001:db8::/32",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for b.Loop() {
 | 
			
		||||
		nl, err := NewNetworkTree()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			b.Fatalf("NewNetworkTree() failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		for _, cidr := range cidrs {
 | 
			
		||||
			nl.AddCIDR(cidr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										145
									
								
								proxy/admin.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								proxy/admin.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,145 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"encoding/pem"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Admin struct {
 | 
			
		||||
	*Proxy
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAdmin(proxy *Proxy) *Admin {
 | 
			
		||||
	a := &Admin{
 | 
			
		||||
		Proxy: proxy,
 | 
			
		||||
	}
 | 
			
		||||
	return a
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) handleRequest(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		logger = ses.log()
 | 
			
		||||
		err    error
 | 
			
		||||
	)
 | 
			
		||||
	switch ses.request.URL.Path {
 | 
			
		||||
	case "/ca.crt":
 | 
			
		||||
		err = a.handleCACert(ses)
 | 
			
		||||
	case "/api/v1/policy":
 | 
			
		||||
		err = a.apiPolicy(ses)
 | 
			
		||||
	case "/api/v1/policy/matcher":
 | 
			
		||||
		err = a.apiPolicyMatcher(ses)
 | 
			
		||||
	case "/api/v1/stats/log":
 | 
			
		||||
		err = a.apiStatsLog(ses)
 | 
			
		||||
	case "/api/v1/stats/status":
 | 
			
		||||
		err = a.apiStatsStatus(ses)
 | 
			
		||||
	default:
 | 
			
		||||
		if strings.HasPrefix(ses.request.URL.Path, "/api") {
 | 
			
		||||
			err = errors.New("invalid endpoint")
 | 
			
		||||
		} else {
 | 
			
		||||
			err = os.ErrNotExist
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Warn().Err(err).Msg("admin error")
 | 
			
		||||
		ses.response = ErrorResponse(ses.request, err)
 | 
			
		||||
		defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
		ses.response.Close = true
 | 
			
		||||
		return a.writeResponse(ses)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) handleCACert(ses *Session) error {
 | 
			
		||||
	b := pem.EncodeToMemory(&pem.Block{
 | 
			
		||||
		Type:  "CERTIFICATE",
 | 
			
		||||
		Bytes: a.authority.Certificate().Raw,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
 | 
			
		||||
	ses.response.ContentLength = int64(len(b))
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiPolicy(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		b = new(bytes.Buffer)
 | 
			
		||||
		e = json.NewEncoder(b)
 | 
			
		||||
	)
 | 
			
		||||
	e.SetIndent("", "  ")
 | 
			
		||||
	if err := e.Encode(a.config.Policy); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiPolicyMatcher(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		b = new(bytes.Buffer)
 | 
			
		||||
		e = json.NewEncoder(b)
 | 
			
		||||
	)
 | 
			
		||||
	e.SetIndent("", "  ")
 | 
			
		||||
	if err := e.Encode(a.config.Policy.Matchers); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiResponse(ses *Session, v any, err error) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		b = new(bytes.Buffer)
 | 
			
		||||
		e = json.NewEncoder(b)
 | 
			
		||||
	)
 | 
			
		||||
	e.SetIndent("", "  ")
 | 
			
		||||
	if err := e.Encode(v); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiStatsLog(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		query     = ses.request.URL.Query()
 | 
			
		||||
		offset, _ = strconv.Atoi(query.Get("offset"))
 | 
			
		||||
		limit, _  = strconv.Atoi(query.Get("limit"))
 | 
			
		||||
	)
 | 
			
		||||
	if limit > 100 {
 | 
			
		||||
		limit = 100
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s, err := a.stats.QueryLog(offset, limit)
 | 
			
		||||
	return a.apiResponse(ses, s, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiStatsStatus(ses *Session) error {
 | 
			
		||||
	s, err := a.stats.QueryStatus(time.Time{})
 | 
			
		||||
	return a.apiResponse(ses, s, err)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										8
									
								
								proxy/cache/config.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								proxy/cache/config.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
			
		||||
package cache
 | 
			
		||||
 | 
			
		||||
import "github.com/hashicorp/hcl/v2"
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										88
									
								
								proxy/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								proxy/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/policy"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/resolver"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ConnectHandler interface {
 | 
			
		||||
	HandleConnect(session *Session, network, address string) net.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request.
 | 
			
		||||
type ConnectHandlerFunc func(session *Session, network, address string) net.Conn
 | 
			
		||||
 | 
			
		||||
func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn {
 | 
			
		||||
	return f(session, network, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RequestHandler interface {
 | 
			
		||||
	HandleRequest(session *Session) (*http.Request, *http.Response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestHandlerFunc is called when the proxy receives a new request.
 | 
			
		||||
type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response)
 | 
			
		||||
 | 
			
		||||
func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) {
 | 
			
		||||
	return f(session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ResponseHandler interface {
 | 
			
		||||
	HandleResponse(session *Session) *http.Response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResponseHandler is called when the proxy receives a response.
 | 
			
		||||
type ResponseHandlerFunc func(session *Session) *http.Response
 | 
			
		||||
 | 
			
		||||
func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response {
 | 
			
		||||
	return f(session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrorHandler interface {
 | 
			
		||||
	HandleError(session *Session, err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrorHandlerFunc func(session *Session, err error)
 | 
			
		||||
 | 
			
		||||
func (f ErrorHandlerFunc) HandleError(session *Session, err error) {
 | 
			
		||||
	f(session, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// Listen address.
 | 
			
		||||
	Listen string `hcl:"listen,optional"`
 | 
			
		||||
 | 
			
		||||
	// Bind address for outgoing connections.
 | 
			
		||||
	Bind string `hcl:"bind,optional"`
 | 
			
		||||
 | 
			
		||||
	// Interface for outgoing connections.
 | 
			
		||||
	Interface string `hcl:"interface,optional"`
 | 
			
		||||
 | 
			
		||||
	// Upstream proxy servers.
 | 
			
		||||
	Upstream []string `hcl:"upstream,optional"`
 | 
			
		||||
 | 
			
		||||
	// DialTimeout for establishing new connections.
 | 
			
		||||
	DialTimeout time.Duration `hcl:"dial_timeout,optional"`
 | 
			
		||||
 | 
			
		||||
	// Policy for the proxy.
 | 
			
		||||
	Policy *policy.Policy `hcl:"policy,block"`
 | 
			
		||||
 | 
			
		||||
	// Resolver for the proxy.
 | 
			
		||||
	Resolver resolver.Resolver
 | 
			
		||||
 | 
			
		||||
	ConnectHandler  ConnectHandler
 | 
			
		||||
	RequestHandler  RequestHandler
 | 
			
		||||
	ResponseHandler ResponseHandler
 | 
			
		||||
	ErrorHandler    ErrorHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	_ ConnectHandler  = (ConnectHandlerFunc)(nil)
 | 
			
		||||
	_ RequestHandler  = (RequestHandlerFunc)(nil)
 | 
			
		||||
	_ ResponseHandler = (ResponseHandlerFunc)(nil)
 | 
			
		||||
	_ ErrorHandler    = (ErrorHandlerFunc)(nil)
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										324
									
								
								proxy/match/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										324
									
								
								proxy/match/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,324 @@
 | 
			
		||||
package match
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Path    string        `hcl:"path,optional"`
 | 
			
		||||
	Refresh time.Duration `hcl:"refresh,optional"`
 | 
			
		||||
	Domain  []*Domain     `hcl:"domain,block"`
 | 
			
		||||
	Network []*Network    `hcl:"network,block"`
 | 
			
		||||
	Content []*Content    `hcl:"content,block"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config Config) Matchers() (Matchers, error) {
 | 
			
		||||
	all := make(Matchers)
 | 
			
		||||
	if config.Domain != nil {
 | 
			
		||||
		all["domain"] = make(map[string]Matcher)
 | 
			
		||||
		for _, domain := range config.Domain {
 | 
			
		||||
			m, err := domain.Matcher()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
			all["domain"][domain.Name] = m
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if config.Network != nil {
 | 
			
		||||
		all["network"] = make(map[string]Matcher)
 | 
			
		||||
		for _, network := range config.Network {
 | 
			
		||||
			m, err := network.Matcher(true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
			all["network"][network.Name] = m
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Content struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentHeader struct {
 | 
			
		||||
	Key     string   `hcl:"name"`
 | 
			
		||||
	Value   string   `hcl:"value,optional"`
 | 
			
		||||
	List    []string `hcl:"list,optional"`
 | 
			
		||||
	name    string
 | 
			
		||||
	keyRe   *regexp.Regexp
 | 
			
		||||
	valueRe *regexp.Regexp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentHeader) Name() string { return m.name }
 | 
			
		||||
func (m contentHeader) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	for k, vv := range r.Header {
 | 
			
		||||
		if m.keyRe.MatchString(k) {
 | 
			
		||||
			for _, v := range vv {
 | 
			
		||||
				if slices.Contains(m.List, v) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				if m.valueRe != nil && m.valueRe.MatchString(v) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentType struct {
 | 
			
		||||
	List []string `hcl:"list"`
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentType) Name() string { return m.name }
 | 
			
		||||
func (m contentType) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	return slices.Contains(m.List, r.Header.Get("Content-Type"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentSizeLargerThan struct {
 | 
			
		||||
	Size int64 `hcl:"size"`
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentSizeLargerThan) Name() string { return m.name }
 | 
			
		||||
func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return size >= m.Size
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentStatus struct {
 | 
			
		||||
	Code []int `hcl:"code"`
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentStatus) Name() string { return m.name }
 | 
			
		||||
func (m contentStatus) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	return slices.Contains(m.Code, r.StatusCode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config Content) Matcher() (Response, error) {
 | 
			
		||||
	switch strings.ToLower(config.Type) {
 | 
			
		||||
	case "content", "contenttype", "content-type", "type":
 | 
			
		||||
		var matcher = contentType{name: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "header":
 | 
			
		||||
		var (
 | 
			
		||||
			matcher = contentHeader{name: config.Name}
 | 
			
		||||
			err     error
 | 
			
		||||
		)
 | 
			
		||||
		if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.Value == "" && len(matcher.List) == 0 {
 | 
			
		||||
			return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name)
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err)
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.Value != "" {
 | 
			
		||||
			if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "size":
 | 
			
		||||
		var matcher = contentSizeLargerThan{name: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "status":
 | 
			
		||||
		var matcher = contentStatus{name: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("unknown content matcher type %q", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Domain struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config Domain) Matcher() (Request, error) {
 | 
			
		||||
	switch config.Type {
 | 
			
		||||
	case "list":
 | 
			
		||||
		var matcher = domainList{Title: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		matcher.list = netutil.NewDomainList(matcher.List...)
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "adblock", "dnsmasq", "hosts", "detect", "domains":
 | 
			
		||||
		var matcher = DomainFile{
 | 
			
		||||
			Title: config.Name,
 | 
			
		||||
			Type:  config.Type,
 | 
			
		||||
		}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.Path == "" && matcher.From == "" {
 | 
			
		||||
			return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name)
 | 
			
		||||
		}
 | 
			
		||||
		if err := matcher.Update(); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("unknown domain matcher type %q", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type domainList struct {
 | 
			
		||||
	Title string   `json:"title"`
 | 
			
		||||
	List  []string `hcl:"list" json:"list"`
 | 
			
		||||
	list  *netutil.DomainTree
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m domainList) Name() string {
 | 
			
		||||
	return m.Title
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m domainList) MatchesRequest(r *http.Request) bool {
 | 
			
		||||
	host := netutil.Host(r.URL.Host)
 | 
			
		||||
	log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List))
 | 
			
		||||
	return m.list.Contains(host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DomainFile struct {
 | 
			
		||||
	Title   string        `json:"name"`
 | 
			
		||||
	Type    string        `json:"type"`
 | 
			
		||||
	Path    string        `hcl:"path,optional" json:"path,omitempty"`
 | 
			
		||||
	From    string        `hcl:"from,optional" json:"from,omitempty"`
 | 
			
		||||
	Refresh time.Duration `hcl:"refresh,optional" json:"refresh"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m DomainFile) Name() string {
 | 
			
		||||
	return m.Title
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m DomainFile) MatchesRequest(_ *http.Request) bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *DomainFile) Update() (err error) {
 | 
			
		||||
	var data []byte
 | 
			
		||||
	if m.Path != "" {
 | 
			
		||||
		if data, err = os.ReadFile(m.Path); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		/*
 | 
			
		||||
			var response *http.Response
 | 
			
		||||
			if response, err = http.DefaultClient.Get(m.From); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			defer func() { _ = response.Body.Close() }()
 | 
			
		||||
			if response.StatusCode != http.StatusOK {
 | 
			
		||||
				return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status)
 | 
			
		||||
			}
 | 
			
		||||
			if data, err = io.ReadAll(response.Body); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		*/
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch m.Type {
 | 
			
		||||
	case "hosts":
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ = data
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Network struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config *Network) Matcher(target bool) (Matcher, error) {
 | 
			
		||||
	switch config.Type {
 | 
			
		||||
	case "list":
 | 
			
		||||
		var (
 | 
			
		||||
			matcher = networkList{Title: config.Name}
 | 
			
		||||
			err     error
 | 
			
		||||
		)
 | 
			
		||||
		if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() {
 | 
			
		||||
			return nil, diag
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return &matcher, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("unknown network matcher type %q", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type networkList struct {
 | 
			
		||||
	Title  string   `json:"name"`
 | 
			
		||||
	List   []string `hcl:"list" json:"list"`
 | 
			
		||||
	tree   *netutil.NetworkTree
 | 
			
		||||
	target bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *networkList) Name() string {
 | 
			
		||||
	return m.Title
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *networkList) MatchesIP(ip net.IP) bool {
 | 
			
		||||
	return m.tree.Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *networkList) MatchesRequest(r *http.Request) bool {
 | 
			
		||||
	var (
 | 
			
		||||
		host string
 | 
			
		||||
		err  error
 | 
			
		||||
	)
 | 
			
		||||
	if m.target {
 | 
			
		||||
		host, _, err = net.SplitHostPort(r.URL.Host)
 | 
			
		||||
	} else {
 | 
			
		||||
		host, _, err = net.SplitHostPort(r.RemoteAddr)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	ip := net.ParseIP(host)
 | 
			
		||||
	return m.MatchesIP(ip)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										45
									
								
								proxy/match/match.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								proxy/match/match.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
package match
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Matchers map[string]map[string]Matcher
 | 
			
		||||
 | 
			
		||||
func (all Matchers) Get(kind, name string) (m Matcher, err error) {
 | 
			
		||||
	if typeMatchers, ok := all[kind]; ok {
 | 
			
		||||
		if m, ok = typeMatchers[name]; ok {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("no %s matcher named %q found", kind, name)
 | 
			
		||||
	}
 | 
			
		||||
	return nil, fmt.Errorf("no %s matcher found", kind)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Matcher interface {
 | 
			
		||||
	Name() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Updater interface {
 | 
			
		||||
	Update() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type IP interface {
 | 
			
		||||
	Matcher
 | 
			
		||||
 | 
			
		||||
	MatchesIP(net.IP) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Request interface {
 | 
			
		||||
	Matcher
 | 
			
		||||
 | 
			
		||||
	MatchesRequest(*http.Request) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Response interface {
 | 
			
		||||
	Matcher
 | 
			
		||||
 | 
			
		||||
	MatchesResponse(*http.Response) bool
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										11
									
								
								proxy/match/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								proxy/match/util.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
			
		||||
package match
 | 
			
		||||
 | 
			
		||||
import "net"
 | 
			
		||||
 | 
			
		||||
func onlyHost(name string) string {
 | 
			
		||||
	host, _, err := net.SplitHostPort(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return name
 | 
			
		||||
	}
 | 
			
		||||
	return host
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										231
									
								
								proxy/mitm/authority.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								proxy/mitm/authority.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,231 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const DefaultValidity = 24 * time.Hour
 | 
			
		||||
 | 
			
		||||
type Authority interface {
 | 
			
		||||
	Certificate() *x509.Certificate
 | 
			
		||||
	TLSConfig(name string) *tls.Config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type authority struct {
 | 
			
		||||
	pool    *x509.CertPool
 | 
			
		||||
	cert    *x509.Certificate
 | 
			
		||||
	key     crypto.PrivateKey
 | 
			
		||||
	keyID   []byte
 | 
			
		||||
	keyPool chan crypto.PrivateKey
 | 
			
		||||
	cache   Cache
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(config *Config) (Authority, error) {
 | 
			
		||||
	cache, err := NewCache(config.Cache)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	caConfig := config.CA
 | 
			
		||||
	if caConfig == nil {
 | 
			
		||||
		caConfig = new(CAConfig)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
 | 
			
		||||
	if os.IsNotExist(err) {
 | 
			
		||||
		days := caConfig.Days
 | 
			
		||||
		if days == 0 {
 | 
			
		||||
			days = DefaultDays
 | 
			
		||||
		}
 | 
			
		||||
		if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
 | 
			
		||||
			if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pool := x509.NewCertPool()
 | 
			
		||||
	pool.AddCert(cert)
 | 
			
		||||
 | 
			
		||||
	keyConfig := config.Key
 | 
			
		||||
	if keyConfig == nil {
 | 
			
		||||
		keyConfig = &defaultKeyConfig
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	keyPoolSize := defaultKeyConfig.Pool
 | 
			
		||||
	if keyConfig.Pool > 0 {
 | 
			
		||||
		keyPoolSize = keyConfig.Pool
 | 
			
		||||
	}
 | 
			
		||||
	keyPool := make(chan crypto.PrivateKey, keyPoolSize)
 | 
			
		||||
	if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		keyPool <- key
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func(pool chan<- crypto.PrivateKey) {
 | 
			
		||||
		for {
 | 
			
		||||
			key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Panic().Err(err).Msg("error generating private key")
 | 
			
		||||
			}
 | 
			
		||||
			pool <- key
 | 
			
		||||
		}
 | 
			
		||||
	}(keyPool)
 | 
			
		||||
 | 
			
		||||
	return &authority{
 | 
			
		||||
		pool:    pool,
 | 
			
		||||
		cert:    cert,
 | 
			
		||||
		key:     key,
 | 
			
		||||
		keyID:   cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
 | 
			
		||||
		keyPool: keyPool,
 | 
			
		||||
		cache:   cache,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) log() log.Logger {
 | 
			
		||||
	return log.Console.With().
 | 
			
		||||
		Str("ca", ca.cert.Subject.String()).
 | 
			
		||||
		Logger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) Certificate() *x509.Certificate {
 | 
			
		||||
	return ca.cert
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) TLSConfig(name string) *tls.Config {
 | 
			
		||||
	return &tls.Config{
 | 
			
		||||
		GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
			log := ca.log()
 | 
			
		||||
			if hello.ServerName != "" {
 | 
			
		||||
				name = strings.ToLower(hello.ServerName)
 | 
			
		||||
				log.Debug().Msg("requesting certificate for server name (SNI)")
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Debug().Msg("requesting certificate for hostname")
 | 
			
		||||
			}
 | 
			
		||||
			if cert, ok := ca.getCached(name); ok {
 | 
			
		||||
				log.Debug().
 | 
			
		||||
					Str("subject", cert.Leaf.Subject.String()).
 | 
			
		||||
					Str("serial", cert.Leaf.SerialNumber.String()).
 | 
			
		||||
					Time("valid", cert.Leaf.NotAfter).
 | 
			
		||||
					Msg("using cached certificate")
 | 
			
		||||
				return cert, nil
 | 
			
		||||
			}
 | 
			
		||||
			return ca.issueFor(name)
 | 
			
		||||
		},
 | 
			
		||||
		NextProtos: []string{"http/1.1"},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
 | 
			
		||||
	log := ca.log()
 | 
			
		||||
 | 
			
		||||
	if cert = ca.cache.Certificate(name); cert == nil {
 | 
			
		||||
		if baseDomain(name) != name {
 | 
			
		||||
			cert = ca.cache.Certificate(baseDomain(name))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if cert != nil {
 | 
			
		||||
		if _, err := cert.Leaf.Verify(x509.VerifyOptions{
 | 
			
		||||
			DNSName: name,
 | 
			
		||||
			Roots:   ca.pool,
 | 
			
		||||
		}); err != nil {
 | 
			
		||||
			log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
 | 
			
		||||
		} else {
 | 
			
		||||
			ok = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		log = ca.log().With().Str("name", name).Logger()
 | 
			
		||||
		key crypto.PrivateKey
 | 
			
		||||
	)
 | 
			
		||||
	select {
 | 
			
		||||
	case key = <-ca.keyPool:
 | 
			
		||||
	case <-time.After(5 * time.Second):
 | 
			
		||||
		return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
 | 
			
		||||
	}
 | 
			
		||||
	if key == nil {
 | 
			
		||||
		panic("key pool returned nil key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 | 
			
		||||
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if part := dns.SplitDomainName(name); len(part) > 2 {
 | 
			
		||||
		name = strings.Join(part[1:], ".")
 | 
			
		||||
		log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	template := &x509.Certificate{
 | 
			
		||||
		SerialNumber:          serialNumber,
 | 
			
		||||
		Subject:               pkix.Name{CommonName: name},
 | 
			
		||||
		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
 | 
			
		||||
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
 | 
			
		||||
		DNSNames:              []string{name, "*." + name},
 | 
			
		||||
		BasicConstraintsValid: true,
 | 
			
		||||
		NotBefore:             now.Add(-DefaultValidity),
 | 
			
		||||
		NotAfter:              now.Add(+DefaultValidity),
 | 
			
		||||
	}
 | 
			
		||||
	der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	cert, err := x509.ParseCertificate(der)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
 | 
			
		||||
	out := &tls.Certificate{
 | 
			
		||||
		Certificate: [][]byte{der},
 | 
			
		||||
		Leaf:        cert,
 | 
			
		||||
		PrivateKey:  key,
 | 
			
		||||
	}
 | 
			
		||||
	//ca.cache[name] = out
 | 
			
		||||
	ca.cache.SaveCertificate(name, out)
 | 
			
		||||
	return out, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func containsValidCertificate(cert *tls.Certificate) bool {
 | 
			
		||||
	if cert == nil || len(cert.Certificate) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cert.Leaf == nil {
 | 
			
		||||
		var err error
 | 
			
		||||
		if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
 | 
			
		||||
	return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										233
									
								
								proxy/mitm/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								proxy/mitm/cache.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,233 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/hashicorp/golang-lru/v2/expirable"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Cache interface {
 | 
			
		||||
	Certificate(name string) *tls.Certificate
 | 
			
		||||
	SaveCertificate(name string, cert *tls.Certificate) error
 | 
			
		||||
	RemoveCertificate(name string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCache(config *CacheConfig) (Cache, error) {
 | 
			
		||||
	if config == nil {
 | 
			
		||||
		return NewCache(&CacheConfig{Type: "memory"})
 | 
			
		||||
	}
 | 
			
		||||
	switch config.Type {
 | 
			
		||||
	case "memory":
 | 
			
		||||
		var cacheConfig = new(MemoryCacheConfig)
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return NewMemoryCache(cacheConfig.Size), nil
 | 
			
		||||
	case "disk":
 | 
			
		||||
		var cacheConfig = new(DiskCacheConfig)
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type memoryCache struct {
 | 
			
		||||
	cache *expirable.LRU[string, *tls.Certificate]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMemoryCache(size int) Cache {
 | 
			
		||||
	return memoryCache{
 | 
			
		||||
		cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
 | 
			
		||||
			log.Debug().Str("name", key).Msg("certificate evicted from cache")
 | 
			
		||||
		}, time.Hour*24),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
 | 
			
		||||
	var ok bool
 | 
			
		||||
	if cert, ok = c.cache.Get(name); !ok {
 | 
			
		||||
		cert, _ = c.cache.Get(baseDomain(name))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
 | 
			
		||||
	c.cache.Add(name, cert)
 | 
			
		||||
	log.Debug().Str("name", name).Msg("certificate added to cache")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c memoryCache) RemoveCertificate(name string) {
 | 
			
		||||
	c.cache.Remove(name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type diskCache string
 | 
			
		||||
 | 
			
		||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
 | 
			
		||||
	if !filepath.IsAbs(dir) {
 | 
			
		||||
		var err error
 | 
			
		||||
		if dir, err = filepath.Abs(dir); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := os.MkdirAll(dir, 0o750); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	info, err := os.Stat(dir)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if info.Mode()&os.ModePerm|0o057 != 0 {
 | 
			
		||||
		if err := os.Chmod(dir, 0o750); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if expire > 0 {
 | 
			
		||||
		go expireDiskCache(dir, expire)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return diskCache(dir), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func expireDiskCache(root string, expire time.Duration) {
 | 
			
		||||
	log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
 | 
			
		||||
	ticker := time.NewTicker(expire)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	for {
 | 
			
		||||
		now := <-ticker.C
 | 
			
		||||
		log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
 | 
			
		||||
		filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if d.IsDir() {
 | 
			
		||||
				// Remove the directory; this will fail if it's not empty, which is fine.
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			cert, err := cryptutil.LoadCertificate(path)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				return nil
 | 
			
		||||
			} else if cert.NotAfter.Before(now) {
 | 
			
		||||
				log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) path(name string) string {
 | 
			
		||||
	part := dns.SplitDomainName(strings.ToLower(name))
 | 
			
		||||
	// x,com -> com,x
 | 
			
		||||
	// www,maze,io -> io,maze,www
 | 
			
		||||
	slices.Reverse(part)
 | 
			
		||||
	// com,x -> com,x,x.com
 | 
			
		||||
	// io,maze,www -> io,m,ma,maze,www.maze.io
 | 
			
		||||
	if len(part) > 2 {
 | 
			
		||||
		if len(part[1]) > 1 {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				part[1][:2],
 | 
			
		||||
				part[1],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				part[1],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else if len(part) > 1 {
 | 
			
		||||
		if len(part[1]) > 1 {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				part[1][:2],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	part[len(part)-1] += ".crt"
 | 
			
		||||
	return filepath.Join(append([]string{string(c)}, part...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
 | 
			
		||||
	if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
 | 
			
		||||
		return &tls.Certificate{
 | 
			
		||||
			Certificate: [][]byte{cert.Raw},
 | 
			
		||||
			Leaf:        cert,
 | 
			
		||||
			PrivateKey:  key,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
 | 
			
		||||
		return &tls.Certificate{
 | 
			
		||||
			Certificate: [][]byte{cert.Raw},
 | 
			
		||||
			Leaf:        cert,
 | 
			
		||||
			PrivateKey:  key,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
 | 
			
		||||
	dir, name := filepath.Split(c.path(name))
 | 
			
		||||
	if err := os.MkdirAll(dir, 0o750); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	log.Debug().Str("name", name).Msg("certificate added to cache")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) RemoveCertificate(name string) {
 | 
			
		||||
	path := c.path(name)
 | 
			
		||||
	if err := os.Remove(path); err != nil {
 | 
			
		||||
		if os.IsNotExist(err) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		log.Error().Err(err).Msg("certificate remove from cache failed")
 | 
			
		||||
	}
 | 
			
		||||
	_ = os.Remove(filepath.Dir(path))
 | 
			
		||||
	log.Debug().Str("name", name).Msg("certificate removed from cache")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baseDomain(name string) string {
 | 
			
		||||
	name = strings.ToLower(name)
 | 
			
		||||
	if part := dns.SplitDomainName(name); len(part) > 2 {
 | 
			
		||||
		return strings.Join(part[1:], ".")
 | 
			
		||||
	}
 | 
			
		||||
	return name
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										25
									
								
								proxy/mitm/cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								proxy/mitm/cache_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
func TestDiskCachePath(t *testing.T) {
 | 
			
		||||
	cache := diskCache("testdata")
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		test string
 | 
			
		||||
		want string
 | 
			
		||||
	}{
 | 
			
		||||
		{"x.com", "testdata/com/x/x.com.crt"},
 | 
			
		||||
		{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
 | 
			
		||||
		{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
 | 
			
		||||
		{"maze.io", "testdata/io/m/ma/maze.io.crt"},
 | 
			
		||||
		{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
 | 
			
		||||
		{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.test, func(it *testing.T) {
 | 
			
		||||
			if v := cache.path(test.test); v != test.want {
 | 
			
		||||
				it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										89
									
								
								proxy/mitm/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								proxy/mitm/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DefaultCommonName = "Styx Certificate Authority"
 | 
			
		||||
	DefaultDays       = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	CA    *CAConfig    `hcl:"ca,block"`
 | 
			
		||||
	Key   *KeyConfig   `hcl:"key,block"`
 | 
			
		||||
	Cache *CacheConfig `hcl:"cache,block"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CAConfig struct {
 | 
			
		||||
	Cert         string   `hcl:"cert"`
 | 
			
		||||
	Key          string   `hcl:"key,optional"`
 | 
			
		||||
	Days         int      `hcl:"days,optional"`
 | 
			
		||||
	KeyType      string   `hcl:"key_type,optional"`
 | 
			
		||||
	Bits         int      `hcl:"bits,optional"`
 | 
			
		||||
	Name         string   `hcl:"name,optional"`
 | 
			
		||||
	Country      string   `hcl:"country,optional"`
 | 
			
		||||
	Organization string   `hcl:"organization,optional"`
 | 
			
		||||
	Unit         string   `hcl:"unit,optional"`
 | 
			
		||||
	Locality     string   `hcl:"locality,optional"`
 | 
			
		||||
	Province     string   `hcl:"province,optional"`
 | 
			
		||||
	Address      []string `hcl:"address,optional"`
 | 
			
		||||
	PostalCode   string   `hcl:"postal_code,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config CAConfig) DN() pkix.Name {
 | 
			
		||||
	var name = pkix.Name{
 | 
			
		||||
		CommonName:    config.Name,
 | 
			
		||||
		StreetAddress: config.Address,
 | 
			
		||||
	}
 | 
			
		||||
	if config.Name == "" {
 | 
			
		||||
		name.CommonName = DefaultCommonName
 | 
			
		||||
	}
 | 
			
		||||
	if config.Country != "" {
 | 
			
		||||
		name.Country = append(name.Country, config.Country)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Organization != "" {
 | 
			
		||||
		name.Organization = append(name.Organization, config.Organization)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Unit != "" {
 | 
			
		||||
		name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Locality != "" {
 | 
			
		||||
		name.Locality = append(name.Locality, config.Locality)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Province != "" {
 | 
			
		||||
		name.Province = append(name.Province, config.Province)
 | 
			
		||||
	}
 | 
			
		||||
	if config.PostalCode != "" {
 | 
			
		||||
		name.PostalCode = append(name.PostalCode, config.PostalCode)
 | 
			
		||||
	}
 | 
			
		||||
	return name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type KeyConfig struct {
 | 
			
		||||
	Type string `hcl:"type,optional"`
 | 
			
		||||
	Bits int    `hcl:"bits,optional"`
 | 
			
		||||
	Pool int    `hcl:"pool,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var defaultKeyConfig = KeyConfig{
 | 
			
		||||
	Type: "rsa",
 | 
			
		||||
	Bits: 2048,
 | 
			
		||||
	Pool: 5,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CacheConfig struct {
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MemoryCacheConfig struct {
 | 
			
		||||
	Size int `hcl:"size,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DiskCacheConfig struct {
 | 
			
		||||
	Path   string  `hcl:"path"`
 | 
			
		||||
	Expire float64 `hcl:"expire,optional"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										53
									
								
								proxy/policy/policy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								proxy/policy/policy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Policy contains rules that make up the policy.
 | 
			
		||||
//
 | 
			
		||||
// Some policy rules contain nested policies.
 | 
			
		||||
type Policy struct {
 | 
			
		||||
	Rules    []*rawRule     `hcl:"on,block" json:"rules"`
 | 
			
		||||
	Permit   *bool          `hcl:"permit" json:"permit"`
 | 
			
		||||
	Matchers match.Matchers `json:"matchers"` // Matchers for the policy
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	for _, r := range p.Rules {
 | 
			
		||||
		if err = r.Configure(matchers); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	p.Matchers = matchers
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) PermitIntercept(r *http.Request) *bool {
 | 
			
		||||
	if p != nil {
 | 
			
		||||
		for _, rule := range p.Rules {
 | 
			
		||||
			if rule, ok := rule.Rule.(InterceptRule); ok {
 | 
			
		||||
				if permit := rule.PermitIntercept(r); permit != nil {
 | 
			
		||||
					return permit
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return p.Permit
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) PermitRequest(r *http.Request) *bool {
 | 
			
		||||
	if p != nil {
 | 
			
		||||
		for _, rule := range p.Rules {
 | 
			
		||||
			if rule, ok := rule.Rule.(RequestRule); ok {
 | 
			
		||||
				if permit := rule.PermitRequest(r); permit != nil {
 | 
			
		||||
					return permit
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return p.Permit
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										139
									
								
								proxy/policy/policy_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								proxy/policy/policy_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,139 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type testInDomainList struct {
 | 
			
		||||
	t    *testing.T
 | 
			
		||||
	list []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (testInDomainList) Name() string { return "testInDomainList" }
 | 
			
		||||
func (l testInDomainList) MatchesRequest(r *http.Request) bool {
 | 
			
		||||
	for _, domain := range l.list {
 | 
			
		||||
		if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) {
 | 
			
		||||
			l.t.Logf("domain %s contains %s", domain, r.URL.Host)
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		l.t.Logf("domain %s does not contain %s", domain, r.URL.Host)
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testInDomain(t *testing.T, domains ...string) match.Matcher {
 | 
			
		||||
	return &testInDomainList{t: t, list: domains}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testInNetworkList struct {
 | 
			
		||||
	t    *testing.T
 | 
			
		||||
	list []*net.IPNet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (testInNetworkList) Name() string { return "testInNetworkList" }
 | 
			
		||||
func (l testInNetworkList) MatchesIP(ip net.IP) bool {
 | 
			
		||||
	for _, ipnet := range l.list {
 | 
			
		||||
		if ipnet.Contains(ip) {
 | 
			
		||||
			l.t.Logf("network %s contains %s", ipnet, ip)
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		l.t.Logf("network %s does not contain %s", ipnet, ip)
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testInNetwork(t *testing.T, cidr string) match.Matcher {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	_, ipnet, err := net.ParseCIDR(cidr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return testInNetworkList{t: t, list: []*net.IPNet{ipnet}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPolicy(t *testing.T) {
 | 
			
		||||
	var (
 | 
			
		||||
		yes  = true
 | 
			
		||||
		nope = false
 | 
			
		||||
	)
 | 
			
		||||
	p := &Policy{
 | 
			
		||||
		Rules: []*rawRule{
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
 | 
			
		||||
						isSource: []bool{true},
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
 | 
			
		||||
						isSource: []bool{false},
 | 
			
		||||
					},
 | 
			
		||||
					Permit: &yes,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")},
 | 
			
		||||
					},
 | 
			
		||||
					Permit: &yes,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInDomain(t, "google.com")},
 | 
			
		||||
					},
 | 
			
		||||
					Permit: &nope,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r := &http.Request{
 | 
			
		||||
		URL:        &url.URL{Scheme: "http", Host: "golang.org:80"},
 | 
			
		||||
		RemoteAddr: "127.0.0.1:1234",
 | 
			
		||||
	}
 | 
			
		||||
	if v := p.PermitRequest(r); v != nil {
 | 
			
		||||
		t.Errorf("expected request to return no verdict, got %t", *v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.Rules[0].Rule.(*requestRule).Permit = &yes
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != yes {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", yes, v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.RemoteAddr = "192.168.1.2:3456"
 | 
			
		||||
	if v := p.PermitRequest(r); v != nil {
 | 
			
		||||
		t.Errorf("expected request to return no verdict, got %t", *v)
 | 
			
		||||
	}
 | 
			
		||||
	if v := p.PermitIntercept(r); v != nil {
 | 
			
		||||
		t.Errorf("expected request to return no verdict, got %t", *v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.URL.Host = "maze.io"
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != yes {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", yes, v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.URL.Host = "google.com"
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != nope {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", nope, v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.URL.Host = "localhost:80"
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != yes {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", yes, v)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										368
									
								
								proxy/policy/rule.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										368
									
								
								proxy/policy/rule.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,368 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Rule is a policy rule.
 | 
			
		||||
type Rule interface {
 | 
			
		||||
	Configure(match.Matchers) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InterceptRule can make policy rule decisions on intercept requests.
 | 
			
		||||
type InterceptRule interface {
 | 
			
		||||
	PermitIntercept(r *http.Request) *bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
 | 
			
		||||
type RequestRule interface {
 | 
			
		||||
	PermitRequest(r *http.Request) *bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rawRule struct {
 | 
			
		||||
	Type string   `hcl:"type,label" json:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain" json:"-"`
 | 
			
		||||
	Rule `json:"rule"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	switch r.Type {
 | 
			
		||||
	case "intercept":
 | 
			
		||||
		r.Rule = new(interceptRule)
 | 
			
		||||
	case "request":
 | 
			
		||||
		r.Rule = new(requestRule)
 | 
			
		||||
	case "days":
 | 
			
		||||
		r.Rule = new(daysRule)
 | 
			
		||||
	case "time":
 | 
			
		||||
		r.Rule = new(timeRule)
 | 
			
		||||
	case "all":
 | 
			
		||||
		r.Rule = new(allRule)
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Errorf("policy: invalid event type %q", r.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return r.Rule.Configure(matchers)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type allRule struct {
 | 
			
		||||
	Rules  []*rawRule `hcl:"on,block"`
 | 
			
		||||
	Permit *bool      `hcl:"permit"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *allRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type domainOrNetworkRule struct {
 | 
			
		||||
	matchers []match.Matcher
 | 
			
		||||
	isSource []bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
 | 
			
		||||
	var m match.Matcher
 | 
			
		||||
	for _, domain := range domains {
 | 
			
		||||
		if m, err = matchers.Get("domain", domain); err != nil {
 | 
			
		||||
			return fmt.Errorf("%s: unknown domain %q", kind, domain)
 | 
			
		||||
		}
 | 
			
		||||
		r.matchers = append(r.matchers, m)
 | 
			
		||||
		r.isSource = append(r.isSource, false)
 | 
			
		||||
	}
 | 
			
		||||
	for _, network := range sources {
 | 
			
		||||
		if m, err = matchers.Get("network", network); err != nil {
 | 
			
		||||
			return fmt.Errorf("%s: unknown source network %q", kind, network)
 | 
			
		||||
		}
 | 
			
		||||
		r.matchers = append(r.matchers, m)
 | 
			
		||||
		r.isSource = append(r.isSource, true)
 | 
			
		||||
	}
 | 
			
		||||
	for _, network := range targets {
 | 
			
		||||
		if m, err = matchers.Get("network", network); err != nil {
 | 
			
		||||
			return fmt.Errorf("%s: unknown target network %q", kind, network)
 | 
			
		||||
		}
 | 
			
		||||
		r.matchers = append(r.matchers, m)
 | 
			
		||||
		r.isSource = append(r.isSource, false)
 | 
			
		||||
	}
 | 
			
		||||
	if len(r.matchers) == 0 {
 | 
			
		||||
		return fmt.Errorf("%s: missing any of domain, source, target", kind)
 | 
			
		||||
	}
 | 
			
		||||
	if id != nil {
 | 
			
		||||
		*id = uuid.NewString()
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
 | 
			
		||||
	for i, m := range r.matchers {
 | 
			
		||||
		if m, ok := m.(match.Request); ok {
 | 
			
		||||
			if m.MatchesRequest(q) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if m, ok := m.(match.IP); ok {
 | 
			
		||||
			if r.isSource[i] {
 | 
			
		||||
				if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				var (
 | 
			
		||||
					host = netutil.Host(q.URL.Host)
 | 
			
		||||
					ips  []net.IP
 | 
			
		||||
				)
 | 
			
		||||
				if ip := net.ParseIP(host); ip != nil {
 | 
			
		||||
					ips = append(ips, ip)
 | 
			
		||||
				} else {
 | 
			
		||||
					ips, _ = net.LookupIP(host)
 | 
			
		||||
				}
 | 
			
		||||
				for _, ip := range ips {
 | 
			
		||||
					if m.MatchesIP(ip) {
 | 
			
		||||
						return true
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type interceptRule struct {
 | 
			
		||||
	ID                  string   `json:"id,omitempty"`
 | 
			
		||||
	Domain              []string `hcl:"domain,optional" json:"domain,omitempty"`
 | 
			
		||||
	Source              []string `hcl:"source,optional" json:"source,omitempty"`
 | 
			
		||||
	Target              []string `hcl:"target,optional" json:"target,omitempty"`
 | 
			
		||||
	Permit              *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	domainOrNetworkRule `json:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
 | 
			
		||||
	if r.matchesRequest(q) {
 | 
			
		||||
		return r.Permit
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type requestRule struct {
 | 
			
		||||
	ID                  string   `json:"id,omitempty"`
 | 
			
		||||
	Domain              []string `hcl:"domain,optional" json:"domain,omitempty"`
 | 
			
		||||
	Source              []string `hcl:"source,optional" json:"source,omitempty"`
 | 
			
		||||
	Target              []string `hcl:"target,optional" json:"target,omitempty"`
 | 
			
		||||
	Permit              *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	domainOrNetworkRule `json:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *requestRule) PermitRequest(q *http.Request) *bool {
 | 
			
		||||
	if r.matchesRequest(q) {
 | 
			
		||||
		return r.Permit
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type timeRule struct {
 | 
			
		||||
	ID     string   `json:"id,omitempty"`
 | 
			
		||||
	Time   []string `hcl:"time" json:"time"`
 | 
			
		||||
	Permit *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	Body   hcl.Body `hcl:",remain" json:"-"`
 | 
			
		||||
	Rules  *Policy  `json:"rules"`
 | 
			
		||||
	Start  Time     `json:"start"`
 | 
			
		||||
	End    Time     `json:"end"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) isActive() bool {
 | 
			
		||||
	if r == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := Now()
 | 
			
		||||
	if r.Start.After(r.End) { // ie: 18:00-06:00
 | 
			
		||||
		return now.After(r.Start) || now.Before(r.End)
 | 
			
		||||
	}
 | 
			
		||||
	return now.After(r.Start) && now.Before(r.End)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	if len(r.Time) != 2 {
 | 
			
		||||
		return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
 | 
			
		||||
	}
 | 
			
		||||
	if r.Start, err = ParseTime(r.Time[0]); err != nil {
 | 
			
		||||
		return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
 | 
			
		||||
	}
 | 
			
		||||
	if r.End, err = ParseTime(r.Time[1]); err != nil {
 | 
			
		||||
		return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.Rules = new(Policy)
 | 
			
		||||
	if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
 | 
			
		||||
		return diag
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = r.Rules.Configure(matchers); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.Rules.Matchers = nil
 | 
			
		||||
 | 
			
		||||
	if r.ID == "" {
 | 
			
		||||
		r.ID = uuid.NewString()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitIntercept(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) PermitRequest(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitRequest(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type daysRule struct {
 | 
			
		||||
	ID     string   `json:"id,omitempty"`
 | 
			
		||||
	Days   string   `hcl:"days" json:"days"`
 | 
			
		||||
	Permit *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	Body   hcl.Body `hcl:",remain" json:"-"`
 | 
			
		||||
	Rules  *Policy  `json:"rules"`
 | 
			
		||||
	cond   []onCond
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) isActive() bool {
 | 
			
		||||
	if r == nil || len(r.cond) == 0 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	for _, cond := range r.cond {
 | 
			
		||||
		if cond(now) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	if r.cond, err = parseOnCond(r.Days); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.Rules = new(Policy)
 | 
			
		||||
	if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
 | 
			
		||||
		return diag
 | 
			
		||||
	}
 | 
			
		||||
	if err = r.Rules.Configure(matchers); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.Rules.Matchers = nil
 | 
			
		||||
 | 
			
		||||
	if r.ID == "" {
 | 
			
		||||
		r.ID = uuid.NewString()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitIntercept(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) PermitRequest(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitRequest(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type onCond func(time.Time) bool
 | 
			
		||||
 | 
			
		||||
var weekdays = map[string]time.Weekday{
 | 
			
		||||
	"sun": time.Sunday,
 | 
			
		||||
	"mon": time.Monday,
 | 
			
		||||
	"tue": time.Tuesday,
 | 
			
		||||
	"wed": time.Wednesday,
 | 
			
		||||
	"thu": time.Thursday,
 | 
			
		||||
	"fri": time.Friday,
 | 
			
		||||
	"sat": time.Saturday,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseOnCond(when string) (conds []onCond, err error) {
 | 
			
		||||
	for _, spec := range strings.Split(when, ",") {
 | 
			
		||||
		spec = strings.ToLower(strings.TrimSpace(spec))
 | 
			
		||||
		if d, ok := weekdays[spec]; ok {
 | 
			
		||||
			conds = append(conds, onWeekday(d))
 | 
			
		||||
		} else if spec == "weekend" || spec == "weekends" {
 | 
			
		||||
			conds = append(conds, onWeekend)
 | 
			
		||||
		} else if spec == "workday" || spec == "workdays" {
 | 
			
		||||
			conds = append(conds, onWorkday)
 | 
			
		||||
		} else if strings.ContainsRune(spec, '-') {
 | 
			
		||||
			var (
 | 
			
		||||
				part       = strings.SplitN(spec, "-", 2)
 | 
			
		||||
				from, upto time.Weekday
 | 
			
		||||
				ok         bool
 | 
			
		||||
			)
 | 
			
		||||
			if from, ok = weekdays[part[0]]; !ok {
 | 
			
		||||
				return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
 | 
			
		||||
			}
 | 
			
		||||
			if upto, ok = weekdays[part[1]]; !ok {
 | 
			
		||||
				return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
 | 
			
		||||
			}
 | 
			
		||||
			if from < upto {
 | 
			
		||||
				for d := from; d < upto; d++ {
 | 
			
		||||
					conds = append(conds, onWeekday(d))
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				for d := time.Sunday; d < from; d++ {
 | 
			
		||||
					conds = append(conds, onWeekday(d))
 | 
			
		||||
				}
 | 
			
		||||
				for d := upto; d <= time.Saturday; d++ {
 | 
			
		||||
					conds = append(conds, onWeekday(d))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			return nil, fmt.Errorf("on %q: invalid condition", spec)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onWeekday(weekday time.Weekday) onCond {
 | 
			
		||||
	return func(t time.Time) bool {
 | 
			
		||||
		return t.Weekday() == weekday
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onWeekend(t time.Time) bool {
 | 
			
		||||
	d := t.Weekday()
 | 
			
		||||
	return d == time.Saturday || d == time.Sunday
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onWorkday(t time.Time) bool {
 | 
			
		||||
	d := t.Weekday()
 | 
			
		||||
	return !(d == time.Saturday || d == time.Sunday)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										53
									
								
								proxy/policy/time.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								proxy/policy/time.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Time struct {
 | 
			
		||||
	Hour   int
 | 
			
		||||
	Minute int
 | 
			
		||||
	Second int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Eq(other Time) bool {
 | 
			
		||||
	return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) After(other Time) bool {
 | 
			
		||||
	return t.Seconds() > other.Seconds()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Before(other Time) bool {
 | 
			
		||||
	return t.Seconds() < other.Seconds()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Seconds() int {
 | 
			
		||||
	return t.Hour*3600 + t.Minute*60 + t.Second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var timeFormats = []string{
 | 
			
		||||
	time.TimeOnly,
 | 
			
		||||
	"15:04",
 | 
			
		||||
	time.Kitchen,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Now() Time {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	return Time{now.Hour(), now.Minute(), now.Second()}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ParseTime(s string) (t Time, err error) {
 | 
			
		||||
	var tt time.Time
 | 
			
		||||
	for _, layout := range timeFormats {
 | 
			
		||||
		if tt, err = time.Parse(layout, s); err == nil {
 | 
			
		||||
			return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return Time{}, fmt.Errorf("time: invalid time %q", s)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										616
									
								
								proxy/proxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										616
									
								
								proxy/proxy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,616 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/mitm"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/policy"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/resolver"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/stats"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DefaultListenAddr      = ":3128"
 | 
			
		||||
	DefaultBindAddr        = ""
 | 
			
		||||
	DefaultDialTimeout     = 30 * time.Second
 | 
			
		||||
	DefaultKeepAlivePeriod = 1 * time.Minute
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	HeaderAcceptEncoding = "Accept-Encoding"
 | 
			
		||||
	HeaderConnection     = "Connection"
 | 
			
		||||
	HeaderContentLength  = "Content-Length"
 | 
			
		||||
	HeaderContentType    = "Content-Type"
 | 
			
		||||
	HeaderUpgrade        = "Upgrade"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrClosed     = errors.New("proxy: shutdown")
 | 
			
		||||
	ErrClientCert = errors.New("tls: client certificate requested")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Proxy struct {
 | 
			
		||||
	addr       *net.TCPAddr
 | 
			
		||||
	bind       *net.TCPAddr
 | 
			
		||||
	resolver   resolver.Resolver
 | 
			
		||||
	transport  *http.Transport
 | 
			
		||||
	dial       func(network, address string) (net.Conn, error)
 | 
			
		||||
	config     *Config
 | 
			
		||||
	authority  mitm.Authority
 | 
			
		||||
	policy     *policy.Policy
 | 
			
		||||
	admin      *Admin
 | 
			
		||||
	stats      *stats.Stats
 | 
			
		||||
	closed     chan struct{}
 | 
			
		||||
	onConnect  ConnectHandler
 | 
			
		||||
	onRequest  RequestHandler
 | 
			
		||||
	onResponse ResponseHandler
 | 
			
		||||
	onError    ErrorHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(config *Config, ca mitm.Authority) (*Proxy, error) {
 | 
			
		||||
	if config == nil {
 | 
			
		||||
		return nil, errors.New("proxy: config can't be nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p := &Proxy{
 | 
			
		||||
		transport:  newTransport(),
 | 
			
		||||
		config:     config,
 | 
			
		||||
		resolver:   resolver.Default,
 | 
			
		||||
		authority:  ca,
 | 
			
		||||
		policy:     config.Policy,
 | 
			
		||||
		closed:     make(chan struct{}),
 | 
			
		||||
		onConnect:  config.ConnectHandler,
 | 
			
		||||
		onRequest:  config.RequestHandler,
 | 
			
		||||
		onResponse: config.ResponseHandler,
 | 
			
		||||
		onError:    config.ErrorHandler,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	if config.Listen == "" {
 | 
			
		||||
		p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr)
 | 
			
		||||
	} else {
 | 
			
		||||
		p.addr, err = net.ResolveTCPAddr("tcp", config.Listen)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("proxy: invalid listen addres: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Bind != "" {
 | 
			
		||||
		if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("proxy: invalid bind address: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	} else if config.Interface != "" {
 | 
			
		||||
		if err = resolveInterfaceAddr(config.Interface); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if p.bind != nil {
 | 
			
		||||
		/* FIXME
 | 
			
		||||
		var c *net.TCPConn
 | 
			
		||||
		if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) {
 | 
			
		||||
			return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL)
 | 
			
		||||
		} else if c != nil {
 | 
			
		||||
			_ = c.Close()
 | 
			
		||||
		}
 | 
			
		||||
		*/
 | 
			
		||||
	}
 | 
			
		||||
	if config.Resolver != nil {
 | 
			
		||||
		p.resolver = config.Resolver
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dialTimeout := DefaultDialTimeout
 | 
			
		||||
	if config.DialTimeout > 0 {
 | 
			
		||||
		dialTimeout = config.DialTimeout
 | 
			
		||||
	}
 | 
			
		||||
	p.dial = (&net.Dialer{
 | 
			
		||||
		Timeout:   dialTimeout,
 | 
			
		||||
		KeepAlive: dialTimeout,
 | 
			
		||||
		LocalAddr: p.bind,
 | 
			
		||||
	}).Dial
 | 
			
		||||
 | 
			
		||||
	p.admin = NewAdmin(p)
 | 
			
		||||
 | 
			
		||||
	if p.stats, err = stats.New(); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return p, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newTransport() *http.Transport {
 | 
			
		||||
	return &http.Transport{
 | 
			
		||||
		TLSNextProto:          make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
 | 
			
		||||
		Proxy:                 http.ProxyFromEnvironment,
 | 
			
		||||
		TLSHandshakeTimeout:   15 * time.Second,
 | 
			
		||||
		ExpectContinueTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) Close() error {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-p.closed:
 | 
			
		||||
		return ErrClosed
 | 
			
		||||
	default:
 | 
			
		||||
		close(p.closed)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) Start() error {
 | 
			
		||||
	l, err := net.ListenTCP("tcp", p.addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go p.Serve(l)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) Serve(listener net.Listener) error {
 | 
			
		||||
	defer func() { _ = listener.Close() }()
 | 
			
		||||
 | 
			
		||||
	log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening")
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-p.closed:
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c, err := listener.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
 | 
			
		||||
		ctx := newContext(c, rw, nil)
 | 
			
		||||
 | 
			
		||||
		if c, ok := c.(*net.TCPConn); ok {
 | 
			
		||||
			_ = c.SetKeepAlive(true)
 | 
			
		||||
			_ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		go p.handle(ctx)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) handle(ctx *Context) {
 | 
			
		||||
	logger := ctx.log()
 | 
			
		||||
	defer log.OnCloseError(logger.Debug(), ctx.conn)
 | 
			
		||||
	logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection")
 | 
			
		||||
 | 
			
		||||
	last := int64(0)
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-p.closed:
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			ses, err := p.handleRequest(ctx)
 | 
			
		||||
			if ses != nil {
 | 
			
		||||
				log := ses.log()
 | 
			
		||||
				log.Info().
 | 
			
		||||
					Str("method", ses.request.Method).
 | 
			
		||||
					Str("url", ses.request.URL.String()).
 | 
			
		||||
					Str("status", ses.response.Status).
 | 
			
		||||
					Int64("size", ctx.conn.bytes-last).
 | 
			
		||||
					Msg("handled request")
 | 
			
		||||
 | 
			
		||||
				p.stats.AddLog(&stats.Log{
 | 
			
		||||
					ClientIP: netutil.Host(ses.request.RemoteAddr),
 | 
			
		||||
					Request:  stats.FromRequest(ses.request),
 | 
			
		||||
					Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last),
 | 
			
		||||
				})
 | 
			
		||||
 | 
			
		||||
				last = ctx.conn.bytes
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) {
 | 
			
		||||
				event := logger.Debug()
 | 
			
		||||
				if ctx.conn.bytes > 0 {
 | 
			
		||||
					event = event.Int64("size", ctx.conn.bytes)
 | 
			
		||||
				}
 | 
			
		||||
				event.Msg("closing client connection")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) {
 | 
			
		||||
	logger := ctx.log()
 | 
			
		||||
 | 
			
		||||
	var request *http.Request
 | 
			
		||||
	if request, err = p.readRequest(ctx); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses = newSession(ctx, request)
 | 
			
		||||
	p.cleanRequest(ses, request)
 | 
			
		||||
 | 
			
		||||
	logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request")
 | 
			
		||||
 | 
			
		||||
	if p.onRequest != nil {
 | 
			
		||||
		newRequest, newResponse := p.onRequest.HandleRequest(ses)
 | 
			
		||||
		if newRequest != nil {
 | 
			
		||||
			logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override")
 | 
			
		||||
			ses.request = newRequest
 | 
			
		||||
		}
 | 
			
		||||
		if newResponse != nil {
 | 
			
		||||
			logger.Debug().Str("status", newResponse.Status).Msg("response override")
 | 
			
		||||
			ses.response = newResponse
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ses.response == nil {
 | 
			
		||||
		// WebSocket request
 | 
			
		||||
		if ses.request.Header.Get(HeaderUpgrade) == "websocket" {
 | 
			
		||||
			return ses, p.handleTunnel(ses)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		cleanHopByHopHeaders(ses.request.Header)
 | 
			
		||||
 | 
			
		||||
		// Proxy CONNECT request
 | 
			
		||||
		if ses.request.Method == http.MethodConnect {
 | 
			
		||||
			return p.handleConnect(ses)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if netutil.Port(ses.request.URL.Host) == p.addr.Port {
 | 
			
		||||
			// Plain API request
 | 
			
		||||
			ses.request.URL.Host = ses.request.Host
 | 
			
		||||
			return ses, p.admin.handleRequest(ses)
 | 
			
		||||
 | 
			
		||||
		} else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil {
 | 
			
		||||
			// Plain HTTP request
 | 
			
		||||
			if p.config.ErrorHandler != nil {
 | 
			
		||||
				p.config.ErrorHandler.HandleError(ses, err)
 | 
			
		||||
			}
 | 
			
		||||
			ses.response = ErrorResponse(ses.request, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Debug().Str("status", ses.response.Status).Msg("received response")
 | 
			
		||||
		cleanHopByHopHeaders(ses.response.Header)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
	return ses, p.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) {
 | 
			
		||||
	next = ses
 | 
			
		||||
 | 
			
		||||
	logger := ses.log()
 | 
			
		||||
	logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
 | 
			
		||||
 | 
			
		||||
	var c net.Conn
 | 
			
		||||
	if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
 | 
			
		||||
		logger.Error().Err(err).Msg("connect failed")
 | 
			
		||||
		if p.onError != nil {
 | 
			
		||||
			p.onError.HandleError(ses, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ses.response = ErrorResponse(ses.request, err)
 | 
			
		||||
		defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
		_ = p.writeResponse(ses)
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := c.Close(); err != nil {
 | 
			
		||||
			if p.onError != nil {
 | 
			
		||||
				p.onError.HandleError(ses, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if p.canIntercept(ses.request) {
 | 
			
		||||
		logger.Debug().Msg("intercepting connection")
 | 
			
		||||
		ses.response = NewResponse(http.StatusOK, nil, ses.request)
 | 
			
		||||
		err = p.writeResponse(ses)
 | 
			
		||||
		log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Peek first byte
 | 
			
		||||
		b := make([]byte, 1)
 | 
			
		||||
		if _, err = io.ReadFull(ses.ctx.rw, b); err != nil {
 | 
			
		||||
			logger.Error().Err(err).Msg("error peeking CONNECT byte")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Drain buffered bytes
 | 
			
		||||
		b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...)
 | 
			
		||||
		ses.ctx.rw.Reader.Read(b[1:])
 | 
			
		||||
 | 
			
		||||
		r := &connReader{
 | 
			
		||||
			Conn:   ses.ctx.conn,
 | 
			
		||||
			Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn),
 | 
			
		||||
		}
 | 
			
		||||
		if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
 | 
			
		||||
			secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host))
 | 
			
		||||
			if err = secure.Handshake(); err != nil {
 | 
			
		||||
				logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure))
 | 
			
		||||
			ctx := newContext(secure, rw, ses)
 | 
			
		||||
			return p.handleRequest(ctx)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r))
 | 
			
		||||
		ctx := newContext(r, rw, ses)
 | 
			
		||||
		return p.handleRequest(ctx)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewResponse(http.StatusOK, nil, ses.request)
 | 
			
		||||
	defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.ContentLength = -1
 | 
			
		||||
	if err = p.writeResponse(ses); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debug().Msg("established CONNECT tunnel, proxying traffic")
 | 
			
		||||
	var wait sync.WaitGroup
 | 
			
		||||
	wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
 | 
			
		||||
	wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
 | 
			
		||||
	wait.Wait()
 | 
			
		||||
	logger.Debug().Msg("closed CONNECT tunnel")
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) handleTunnel(ses *Session) (err error) {
 | 
			
		||||
	logger := ses.log()
 | 
			
		||||
	logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
 | 
			
		||||
 | 
			
		||||
	var c net.Conn
 | 
			
		||||
	if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
 | 
			
		||||
		logger.Error().Err(err).Msg("connect failed")
 | 
			
		||||
		if p.onError != nil {
 | 
			
		||||
			p.onError.HandleError(ses, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ses.response = ErrorResponse(ses.request, err)
 | 
			
		||||
		defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
		_ = p.writeResponse(ses)
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer log.OnCloseError(logger.Debug(), c)
 | 
			
		||||
 | 
			
		||||
	if ses.ctx.IsTLS() {
 | 
			
		||||
		// Open a TLS client connection
 | 
			
		||||
		secure := tls.Client(c, &tls.Config{
 | 
			
		||||
			ServerName: ses.request.URL.Host,
 | 
			
		||||
			GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
 | 
			
		||||
				return nil, ErrClientCert
 | 
			
		||||
			},
 | 
			
		||||
		})
 | 
			
		||||
		if err = secure.Handshake(); err != nil {
 | 
			
		||||
			logger.Error().Err(err).Msg("TLS handshake failed")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c = secure
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = ses.request.Write(c); err != nil {
 | 
			
		||||
		logger.Error().Err(err).Msg("failed to write request")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debug().Msg("established tunnel, proxying traffic")
 | 
			
		||||
	var wait sync.WaitGroup
 | 
			
		||||
	wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
 | 
			
		||||
	wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
 | 
			
		||||
	wait.Wait()
 | 
			
		||||
	logger.Debug().Msg("closed tunnel")
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) canIntercept(request *http.Request) bool {
 | 
			
		||||
	if permit := p.policy.PermitIntercept(request); permit != nil {
 | 
			
		||||
		return *permit
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
func (p *Proxy) handleAPIRequest(ses *Session) error {
 | 
			
		||||
		if ses.request.URL.Path == "/ca.crt" && p.authority != nil {
 | 
			
		||||
			b := pem.EncodeToMemory(&pem.Block{
 | 
			
		||||
				Type:  "CERTIFICATE",
 | 
			
		||||
				Bytes: p.authority.Certificate().Raw,
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
 | 
			
		||||
			defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
 | 
			
		||||
			ses.response.Close = true
 | 
			
		||||
			ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
 | 
			
		||||
			ses.response.ContentLength = int64(len(b))
 | 
			
		||||
			return p.writeResponse(ses)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint"))
 | 
			
		||||
		defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
		ses.response.Close = true
 | 
			
		||||
		return p.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		done = make(chan *http.Request, 1)
 | 
			
		||||
		errs = make(chan error, 1)
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		r, err := http.ReadRequest(ctx.rw.Reader)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			errs <- err
 | 
			
		||||
		} else {
 | 
			
		||||
			done <- r
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-p.closed:
 | 
			
		||||
		return nil, ErrClosed
 | 
			
		||||
	case request = <-done:
 | 
			
		||||
		return
 | 
			
		||||
	case err = <-errs:
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) cleanRequest(ses *Session, request *http.Request) {
 | 
			
		||||
	if request.URL.Host == "" {
 | 
			
		||||
		request.URL.Host = request.Host
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Ensure proper URL scheme
 | 
			
		||||
	if !strings.HasPrefix(request.URL.Scheme, "http") {
 | 
			
		||||
		request.URL.Scheme = "http"
 | 
			
		||||
	}
 | 
			
		||||
	if ses.ctx.IsTLS() {
 | 
			
		||||
		state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState()
 | 
			
		||||
		request.TLS = &state
 | 
			
		||||
		request.URL.Scheme = "https"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Ensure proper RemoteAddr
 | 
			
		||||
	request.RemoteAddr = ses.ctx.RemoteAddr().String()
 | 
			
		||||
 | 
			
		||||
	// Ensure proper encoding
 | 
			
		||||
	if request.Header.Get(HeaderAcceptEncoding) != "" {
 | 
			
		||||
		// We only support gzip
 | 
			
		||||
		request.Header.Set(HeaderAcceptEncoding, "gzip")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) writeResponse(ses *Session) (err error) {
 | 
			
		||||
	log := ses.log()
 | 
			
		||||
 | 
			
		||||
	if p.onResponse != nil {
 | 
			
		||||
		response := p.onResponse.HandleResponse(ses)
 | 
			
		||||
		if response != nil {
 | 
			
		||||
			log.Debug().Str("status", response.Status).Msg("response override")
 | 
			
		||||
			ses.response = response
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = ses.response.Write(ses.ctx); err != nil {
 | 
			
		||||
		log.Error().Err(err).Msg("error writing response back to client")
 | 
			
		||||
	} else if err = ses.ctx.Flush(); err != nil {
 | 
			
		||||
		log.Error().Err(err).Msg("error flushing response back to client")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) {
 | 
			
		||||
	log := ses.log()
 | 
			
		||||
	log.Debug().Msgf("connect to %s://%s", network, address)
 | 
			
		||||
 | 
			
		||||
	if p.onConnect != nil {
 | 
			
		||||
		if c = p.onConnect.HandleConnect(ses, network, address); c != nil {
 | 
			
		||||
			log.Debug().Msg("connect override")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var host, port string
 | 
			
		||||
	if host, port, err = net.SplitHostPort(address); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var hosts []string
 | 
			
		||||
	if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil {
 | 
			
		||||
		log.Warn().Err(err).Msg("connect failed: DNS lookup error")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug().Str("address", hosts[0]).Msg("connect resolved address")
 | 
			
		||||
	return p.dial(network, net.JoinHostPort(hosts[0], port))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var hopByHopHeaders = []string{
 | 
			
		||||
	HeaderConnection,
 | 
			
		||||
	"Keep-Alive",
 | 
			
		||||
	"Proxy-Authenticate",
 | 
			
		||||
	"Proxy-Authorization",
 | 
			
		||||
	"Proxy-Connection", // Non-standard, but required for HTTP/2.
 | 
			
		||||
	"Te",
 | 
			
		||||
	"Trailer",
 | 
			
		||||
	"Transfer-Encoding",
 | 
			
		||||
	HeaderUpgrade,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func cleanHopByHopHeaders(header http.Header) {
 | 
			
		||||
	// Additional hop-by-hop headers may be specified in `Connection` headers.
 | 
			
		||||
	// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1
 | 
			
		||||
	for _, values := range header[HeaderConnection] {
 | 
			
		||||
		for _, key := range strings.Split(values, ",") {
 | 
			
		||||
			header.Del(key)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for _, key := range hopByHopHeaders {
 | 
			
		||||
		header.Del(key)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// copyStream copies data from reader to writer
 | 
			
		||||
func copyStream(ses *Session, w io.Writer, r io.Reader) {
 | 
			
		||||
	log := ses.log()
 | 
			
		||||
	if _, err := io.Copy(w, r); err != nil && !isClosing(err) {
 | 
			
		||||
		log.Error().Err(err).Msg("failed CONNECT tunnel")
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Debug().Msg("finished copying CONNECT tunnel")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isClosing(err error) bool {
 | 
			
		||||
	if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if err, ok := err.(net.Error); ok && err.Timeout() {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func resolveInterfaceAddr(name string) (err error) {
 | 
			
		||||
	var iface *net.Interface
 | 
			
		||||
	if iface, err = net.InterfaceByName(name); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var addrs []net.Addr
 | 
			
		||||
	if addrs, err = iface.Addrs(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, addr := range addrs {
 | 
			
		||||
		if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() {
 | 
			
		||||
			log.Warn().Msgf("addr %T: %s", addr, addr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return errors.New("nope; TODO")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										148
									
								
								proxy/resolver/resolver.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								proxy/resolver/resolver.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,148 @@
 | 
			
		||||
// Package resolver implements a caching DNS resolver
 | 
			
		||||
package resolver
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"math/rand/v2"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"github.com/hashicorp/golang-lru/v2/expirable"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DefaultSize    = 1024
 | 
			
		||||
	DefaultTTL     = 5 * time.Minute
 | 
			
		||||
	DefaultTimeout = 10 * time.Second
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// DefaultConfig are the defaults for the Default resolver.
 | 
			
		||||
	DefaultConfig = Config{
 | 
			
		||||
		Size:    DefaultSize,
 | 
			
		||||
		TTL:     DefaultTTL.Seconds(),
 | 
			
		||||
		Timeout: DefaultTimeout.Seconds(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Default resolver.
 | 
			
		||||
	Default = New(DefaultConfig)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Resolver interface {
 | 
			
		||||
	// Lookup returns resolved IPs for given hostname/ips.
 | 
			
		||||
	Lookup(context.Context, string) ([]string, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type netResolver struct {
 | 
			
		||||
	resolver *net.Resolver
 | 
			
		||||
	timeout  time.Duration
 | 
			
		||||
	noIPv6   bool
 | 
			
		||||
	cache    *expirable.LRU[string, []string]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// Size is our cache size in number of entries.
 | 
			
		||||
	Size int `hcl:"size,optional"`
 | 
			
		||||
 | 
			
		||||
	// TTL is the cache time to live in seconds.
 | 
			
		||||
	TTL float64 `hcl:"ttl,optional"`
 | 
			
		||||
 | 
			
		||||
	// Timeout is the cache timeout in seconds.
 | 
			
		||||
	Timeout float64 `hcl:"timeout,optional"`
 | 
			
		||||
 | 
			
		||||
	// Server are alternative DNS servers.
 | 
			
		||||
	Server []string `hcl:"server,optional"`
 | 
			
		||||
 | 
			
		||||
	// NoIPv6 disables IPv6 DNS resolution.
 | 
			
		||||
	NoIPv6 bool `hcl:"noipv6,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(config Config) Resolver {
 | 
			
		||||
	var (
 | 
			
		||||
		size    = config.Size
 | 
			
		||||
		ttl     = time.Duration(float64(time.Second) * config.TTL)
 | 
			
		||||
		timeout = time.Duration(float64(time.Second) * config.Timeout)
 | 
			
		||||
	)
 | 
			
		||||
	if size <= 0 {
 | 
			
		||||
		size = DefaultSize
 | 
			
		||||
	}
 | 
			
		||||
	if ttl <= 0 {
 | 
			
		||||
		ttl = DefaultTTL
 | 
			
		||||
	}
 | 
			
		||||
	if timeout <= 0 {
 | 
			
		||||
		timeout = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var resolver = new(net.Resolver)
 | 
			
		||||
	if len(config.Server) > 0 {
 | 
			
		||||
		var dialer net.Dialer
 | 
			
		||||
		resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
 | 
			
		||||
			server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
 | 
			
		||||
			return dialer.DialContext(ctx, network, server)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &netResolver{
 | 
			
		||||
		resolver: resolver,
 | 
			
		||||
		timeout:  timeout,
 | 
			
		||||
		noIPv6:   config.NoIPv6,
 | 
			
		||||
		cache:    expirable.NewLRU[string, []string](size, nil, ttl),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
 | 
			
		||||
	host = strings.ToLower(strings.TrimSpace(host))
 | 
			
		||||
	if hosts, ok := r.cache.Get(host); ok {
 | 
			
		||||
		rand.Shuffle(len(hosts), func(i, j int) {
 | 
			
		||||
			hosts[i], hosts[j] = hosts[j], hosts[i]
 | 
			
		||||
		})
 | 
			
		||||
		return hosts, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hosts, err := r.lookup(ctx, host)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	r.cache.Add(host, hosts)
 | 
			
		||||
	return hosts, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
 | 
			
		||||
	if r.timeout > 0 {
 | 
			
		||||
		var cancel func()
 | 
			
		||||
		ctx, cancel = context.WithTimeout(ctx, r.timeout)
 | 
			
		||||
		defer cancel()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if net.ParseIP(host) == nil {
 | 
			
		||||
		addrs, err := r.resolver.LookupHost(ctx, host)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if r.noIPv6 {
 | 
			
		||||
			var addrs4 []string
 | 
			
		||||
			for _, addr := range addrs {
 | 
			
		||||
				if net.ParseIP(addr).To4() != nil {
 | 
			
		||||
					addrs4 = append(addrs4, addr)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return addrs4, nil
 | 
			
		||||
		}
 | 
			
		||||
		return addrs, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addrs, err := r.resolver.LookupIPAddr(ctx, host)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hosts := make([]string, len(addrs))
 | 
			
		||||
	for i, addr := range addrs {
 | 
			
		||||
		if !r.noIPv6 || addr.IP.To4() != nil {
 | 
			
		||||
			hosts[i] = addr.IP.String()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return hosts, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										78
									
								
								proxy/response.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								proxy/response.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,78 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewResponse(code int, body io.Reader, request *http.Request) *http.Response {
 | 
			
		||||
	if body == nil {
 | 
			
		||||
		body = new(bytes.Buffer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rc, ok := body.(io.ReadCloser)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		rc = io.NopCloser(body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	response := &http.Response{
 | 
			
		||||
		Status:     strconv.Itoa(code) + " " + http.StatusText(code),
 | 
			
		||||
		StatusCode: code,
 | 
			
		||||
		Proto:      "HTTP/1.1",
 | 
			
		||||
		ProtoMajor: 1,
 | 
			
		||||
		ProtoMinor: 1,
 | 
			
		||||
		Header:     make(http.Header),
 | 
			
		||||
		Body:       rc,
 | 
			
		||||
		Request:    request,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if request != nil {
 | 
			
		||||
		response.Close = request.Close
 | 
			
		||||
		response.Proto = request.Proto
 | 
			
		||||
		response.ProtoMajor = request.ProtoMajor
 | 
			
		||||
		response.ProtoMinor = request.ProtoMinor
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type withLen interface {
 | 
			
		||||
	Len() int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type withSize interface {
 | 
			
		||||
	Size() int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response {
 | 
			
		||||
	response := NewResponse(code, body, request)
 | 
			
		||||
	response.Header.Set(HeaderContentType, "application/json")
 | 
			
		||||
	if s, ok := body.(withLen); ok {
 | 
			
		||||
		response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len()))
 | 
			
		||||
	} else if s, ok := body.(withSize); ok {
 | 
			
		||||
		response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10))
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size")
 | 
			
		||||
	}
 | 
			
		||||
	response.Close = true
 | 
			
		||||
	return response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ErrorResponse(request *http.Request, err error) *http.Response {
 | 
			
		||||
	response := NewResponse(http.StatusBadGateway, nil, request)
 | 
			
		||||
	switch {
 | 
			
		||||
	case os.IsNotExist(err):
 | 
			
		||||
		response.StatusCode = http.StatusNotFound
 | 
			
		||||
	case os.IsPermission(err):
 | 
			
		||||
		response.StatusCode = http.StatusForbidden
 | 
			
		||||
	}
 | 
			
		||||
	response.Status = http.StatusText(response.StatusCode)
 | 
			
		||||
	response.Close = true
 | 
			
		||||
	return response
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										151
									
								
								proxy/session.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								proxy/session.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,151 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var seed = rand.NewSource(time.Now().UnixNano())
 | 
			
		||||
 | 
			
		||||
type Context struct {
 | 
			
		||||
	id     int64
 | 
			
		||||
	conn   *wrappedConn
 | 
			
		||||
	rw     *bufio.ReadWriter
 | 
			
		||||
	parent *Session
 | 
			
		||||
	data   map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context {
 | 
			
		||||
	if wrapped, ok := conn.(*wrappedConn); ok {
 | 
			
		||||
		conn = wrapped.Conn
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := &Context{
 | 
			
		||||
		id:     seed.Int63(),
 | 
			
		||||
		conn:   &wrappedConn{Conn: conn},
 | 
			
		||||
		rw:     rw,
 | 
			
		||||
		parent: parent,
 | 
			
		||||
		data:   make(map[string]any),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ctx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) log() log.Logger {
 | 
			
		||||
	return log.Console.With().
 | 
			
		||||
		Str("context", ctx.ID()).
 | 
			
		||||
		Str("addr", ctx.RemoteAddr().String()).
 | 
			
		||||
		Logger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) ID() string {
 | 
			
		||||
	var b [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(b[:], uint64(ctx.id))
 | 
			
		||||
	if ctx.parent != nil {
 | 
			
		||||
		return ctx.parent.ID() + "-" + hex.EncodeToString(b[:])
 | 
			
		||||
	}
 | 
			
		||||
	return hex.EncodeToString(b[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) IsTLS() bool {
 | 
			
		||||
	_, ok := ctx.conn.Conn.(*tls.Conn)
 | 
			
		||||
	return ok && ctx.parent != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) RemoteAddr() net.Addr {
 | 
			
		||||
	if ctx.parent != nil {
 | 
			
		||||
		return ctx.parent.ctx.RemoteAddr()
 | 
			
		||||
	}
 | 
			
		||||
	return ctx.conn.RemoteAddr()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) SetDeadline(t time.Time) error {
 | 
			
		||||
	if ctx.parent != nil {
 | 
			
		||||
		return ctx.parent.ctx.SetDeadline(t)
 | 
			
		||||
	}
 | 
			
		||||
	return ctx.conn.SetDeadline(t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Set(key string, value any) {
 | 
			
		||||
	ctx.data[key] = value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Get(key string) (value any, ok bool) {
 | 
			
		||||
	value, ok = ctx.data[key]
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Flush() error {
 | 
			
		||||
	return ctx.rw.Flush()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Write(p []byte) (n int, err error) {
 | 
			
		||||
	if n, err = ctx.rw.Write(p); n > 0 {
 | 
			
		||||
		atomic.AddInt64(&ctx.conn.bytes, int64(n))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Session struct {
 | 
			
		||||
	id       int64
 | 
			
		||||
	ctx      *Context
 | 
			
		||||
	request  *http.Request
 | 
			
		||||
	response *http.Response
 | 
			
		||||
	data     map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSession(ctx *Context, request *http.Request) *Session {
 | 
			
		||||
	return &Session{
 | 
			
		||||
		id:      seed.Int63(),
 | 
			
		||||
		ctx:     ctx,
 | 
			
		||||
		request: request,
 | 
			
		||||
		data:    make(map[string]any),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) log() log.Logger {
 | 
			
		||||
	return log.Console.With().
 | 
			
		||||
		Str("context", ses.ctx.ID()).
 | 
			
		||||
		Str("session", ses.ID()).
 | 
			
		||||
		Str("addr", ses.ctx.RemoteAddr().String()).
 | 
			
		||||
		Logger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) ID() string {
 | 
			
		||||
	var b [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(b[:], uint64(ses.id))
 | 
			
		||||
	return hex.EncodeToString(b[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) Context() *Context {
 | 
			
		||||
	return ses.ctx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) Request() *http.Request {
 | 
			
		||||
	return ses.request
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) Response() *http.Response {
 | 
			
		||||
	return ses.response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type wrappedConn struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	bytes int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrappedConn) Write(p []byte) (n int, err error) {
 | 
			
		||||
	if n, err = c.Conn.Write(p); n > 0 {
 | 
			
		||||
		atomic.AddInt64(&c.bytes, int64(n))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										225
									
								
								proxy/stats/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								proxy/stats/stats.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,225 @@
 | 
			
		||||
package stats
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/user"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	_ "github.com/mattn/go-sqlite3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Stats struct {
 | 
			
		||||
	db *sql.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New() (*Stats, error) {
 | 
			
		||||
	u, err := user.Current()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	path := filepath.Join(u.HomeDir, ".styx", "stats.db")
 | 
			
		||||
	if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, table := range []string{
 | 
			
		||||
		createLog,
 | 
			
		||||
		createDomainStat,
 | 
			
		||||
		createStatusStat,
 | 
			
		||||
	} {
 | 
			
		||||
		if _, err = db.Exec(table); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Stats{db: db}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Stats) AddLog(entry *Log) error {
 | 
			
		||||
	var (
 | 
			
		||||
		request  []byte
 | 
			
		||||
		response []byte
 | 
			
		||||
		err      error
 | 
			
		||||
	)
 | 
			
		||||
	if request, err = json.Marshal(entry.Request); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if response, err = json.Marshal(entry.Response); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx, err := s.db.Begin()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer stmt.Close()
 | 
			
		||||
	if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) {
 | 
			
		||||
	if limit == 0 {
 | 
			
		||||
		limit = 50
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var logs []*Log
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var entry = new(Log)
 | 
			
		||||
		if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		logs = append(logs, entry)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return logs, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Status struct {
 | 
			
		||||
	Code  int `json:"code"`
 | 
			
		||||
	Count int `json:"count"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var timeZero time.Time
 | 
			
		||||
 | 
			
		||||
func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) {
 | 
			
		||||
	if since.Equal(timeZero) {
 | 
			
		||||
		since = time.Now().Add(-24 * time.Hour)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var stats []*Status
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var entry = new(Status)
 | 
			
		||||
		if err = rows.Scan(&entry.Code, &entry.Count); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		stats = append(stats, entry)
 | 
			
		||||
	}
 | 
			
		||||
	return stats, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const createLog = `CREATE TABLE IF NOT EXISTS styx_log (
 | 
			
		||||
	id        INT PRIMARY KEY,
 | 
			
		||||
	dt        DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
			
		||||
	client_ip TEXT NOT NULL,
 | 
			
		||||
	request   JSONB NOT NULL,
 | 
			
		||||
	response  JSONB NOT NULL
 | 
			
		||||
);`
 | 
			
		||||
 | 
			
		||||
type Log struct {
 | 
			
		||||
	Time     time.Time `json:"time"`
 | 
			
		||||
	ClientIP string    `json:"client_ip"`
 | 
			
		||||
	Request  *Request  `json:"request"`
 | 
			
		||||
	Response *Response `json:"response"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	URL    string      `json:"url"`
 | 
			
		||||
	Host   string      `json:"host"`
 | 
			
		||||
	Method string      `json:"method"`
 | 
			
		||||
	Proto  string      `json:"proto"`
 | 
			
		||||
	Header http.Header `json:"header"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Request) Scan(value any) error {
 | 
			
		||||
	switch v := value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return json.Unmarshal([]byte(v), r)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		return json.Unmarshal(v, r)
 | 
			
		||||
	default:
 | 
			
		||||
		log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Request) Value() (driver.Value, error) {
 | 
			
		||||
	b, err := json.Marshal(r)
 | 
			
		||||
	return string(b), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FromRequest(r *http.Request) *Request {
 | 
			
		||||
	return &Request{
 | 
			
		||||
		URL:    r.URL.String(),
 | 
			
		||||
		Host:   r.Host,
 | 
			
		||||
		Method: r.Method,
 | 
			
		||||
		Proto:  r.Proto,
 | 
			
		||||
		Header: r.Header,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Response struct {
 | 
			
		||||
	Status int         `json:"status"`
 | 
			
		||||
	Size   int64       `json:"size"`
 | 
			
		||||
	Header http.Header `json:"header"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Response) Scan(value any) error {
 | 
			
		||||
	switch v := value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return json.Unmarshal([]byte(v), r)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		return json.Unmarshal(v, r)
 | 
			
		||||
	default:
 | 
			
		||||
		log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Response) Value() (driver.Value, error) {
 | 
			
		||||
	b, err := json.Marshal(r)
 | 
			
		||||
	return string(b), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Response) SetSize(size int64) *Response {
 | 
			
		||||
	r.Size = size
 | 
			
		||||
	return r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FromResponse(r *http.Response) *Response {
 | 
			
		||||
	return &Response{
 | 
			
		||||
		Status: r.StatusCode,
 | 
			
		||||
		Header: r.Header,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status (
 | 
			
		||||
	id     INT PRIMARY KEY,
 | 
			
		||||
	dt     DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
			
		||||
	status INT NOT NULL
 | 
			
		||||
);`
 | 
			
		||||
 | 
			
		||||
const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain (
 | 
			
		||||
	id     INT PRIMARY KEY,
 | 
			
		||||
	dt     DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
			
		||||
	domain TEXT NOT NULL
 | 
			
		||||
);`
 | 
			
		||||
							
								
								
									
										16
									
								
								proxy/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								proxy/util.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// connReader is a net.Conn with a separate reader.
 | 
			
		||||
type connReader struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	io.Reader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c connReader) Read(p []byte) (int, error) {
 | 
			
		||||
	return c.Reader.Read(p)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										154
									
								
								styx.hcl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								styx.hcl
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,154 @@
 | 
			
		||||
 | 
			
		||||
proxy {
 | 
			
		||||
    # TCP listen address
 | 
			
		||||
    listen = ":3128"
 | 
			
		||||
 | 
			
		||||
    # TCP bind address for outgoing connections
 | 
			
		||||
    #bind = "10.42.42.215"
 | 
			
		||||
    # Interface for outgoign connections
 | 
			
		||||
    #interface = "en1"
 | 
			
		||||
 | 
			
		||||
    # Upstream proxies
 | 
			
		||||
    upstream = []   
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    policy {
 | 
			
		||||
        on intercept {
 | 
			
		||||
            domain = ["sensitive"]
 | 
			
		||||
            permit = false
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        on request {
 | 
			
		||||
            source = ["kids"]
 | 
			
		||||
            domain = ["nsfw"]
 | 
			
		||||
            permit = false
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        on request {
 | 
			
		||||
            source = ["kids"]
 | 
			
		||||
            domain = ["nsfw"]
 | 
			
		||||
            permit = false
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        on days {
 | 
			
		||||
            days = "mon-thu,sun"
 | 
			
		||||
            on time {
 | 
			
		||||
                time = ["22:00", "06:00"]
 | 
			
		||||
                on request {
 | 
			
		||||
                    source = ["kids"]
 | 
			
		||||
                    domain = ["social"]
 | 
			
		||||
                    permit = false
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
dns {
 | 
			
		||||
    # Set the cache size
 | 
			
		||||
    #size = 1024
 | 
			
		||||
 | 
			
		||||
    # Set the time to live for positive responses (in seconds)
 | 
			
		||||
    #ttl = 300
 | 
			
		||||
 | 
			
		||||
    # Set the resolve timeout (in seconds)
 | 
			
		||||
    #timeout = 10
 | 
			
		||||
 | 
			
		||||
    # Set the DNS servers
 | 
			
		||||
    #servers = ["1.1.1.1", "8.8.8.8"]
 | 
			
		||||
 | 
			
		||||
    # Disable IPv6
 | 
			
		||||
    noipv6 = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
mitm {
 | 
			
		||||
    ca {
 | 
			
		||||
        cert     = "testdata/ca.crt"
 | 
			
		||||
        key      = "testdata/ca.key"
 | 
			
		||||
        key_type = "ecc"
 | 
			
		||||
        days     = 1825
 | 
			
		||||
        organization = "maze.io"
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    key {
 | 
			
		||||
        type = "rsa"
 | 
			
		||||
        bits = 2048
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cache {
 | 
			
		||||
        #type = "memory" 
 | 
			
		||||
        type = "disk"
 | 
			
		||||
        path = "testdata/mitm"
 | 
			
		||||
        expire = 10
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
cache {
 | 
			
		||||
    type = "memory"
 | 
			
		||||
    size = 10485760
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
match {
 | 
			
		||||
    path = "testdata/match"
 | 
			
		||||
 | 
			
		||||
    network "internal" {
 | 
			
		||||
        type = "list"
 | 
			
		||||
        list = [
 | 
			
		||||
            "0.0.0.0/32",
 | 
			
		||||
            "127.0.0.0/8",
 | 
			
		||||
            "169.254.0.0/16",
 | 
			
		||||
            "fe80::/10",
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    network "kids" {
 | 
			
		||||
        type = "list"
 | 
			
		||||
        list = ["10.42.66.0/24"]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    domain "sensitive" {
 | 
			
		||||
        type = "list"
 | 
			
		||||
        list = [
 | 
			
		||||
            # Banking
 | 
			
		||||
            "abnamro.nl",
 | 
			
		||||
            "knab.nl",
 | 
			
		||||
            "rabobank.nl",
 | 
			
		||||
            
 | 
			
		||||
            # Government
 | 
			
		||||
            "belastingdienst.nl",
 | 
			
		||||
            "digid.nl",
 | 
			
		||||
 | 
			
		||||
            # Messaging
 | 
			
		||||
            "signal.org",
 | 
			
		||||
            "telegram.org",
 | 
			
		||||
            "whatsapp.net",
 | 
			
		||||
            "whatsapp.com",
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    domain "social" {
 | 
			
		||||
        type = "list"
 | 
			
		||||
        list = [
 | 
			
		||||
            "pinterest.com",
 | 
			
		||||
            "reddit.com",
 | 
			
		||||
            "x.com",
 | 
			
		||||
            # YouTube
 | 
			
		||||
            "googlevideo.com",
 | 
			
		||||
            "youtube.com",
 | 
			
		||||
            "youtu.be",
 | 
			
		||||
            "ytimg.com",
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    domain "nsfw" {
 | 
			
		||||
        type = "domains"
 | 
			
		||||
        from = "https://energized.pro/nsfw/domains.txt"
 | 
			
		||||
        refresh = 43200 # 12h
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    domain "ads" {
 | 
			
		||||
        type = "detect"
 | 
			
		||||
        from = "https://small.oisd.nl/dnsmasq"
 | 
			
		||||
        refresh = 12
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										12
									
								
								testdata/ca.crt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								testdata/ca.crt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIBrjCCAVSgAwIBAgIQYEfQcIZJ90sXXLyE1F0gpzAKBggqhkjOPQQDAjA3MRAw
 | 
			
		||||
DgYDVQQKEwdtYXplLmlvMSMwIQYDVQQDExpTdHl4IENlcnRpZmljYXRlIEF1dGhv
 | 
			
		||||
cml0eTAeFw0yNTA5MjQwMDAwMDBaFw0zMDA5MjMwMDAwMDBaMDcxEDAOBgNVBAoT
 | 
			
		||||
B21hemUuaW8xIzAhBgNVBAMTGlN0eXggQ2VydGlmaWNhdGUgQXV0aG9yaXR5MFkw
 | 
			
		||||
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMS3tcysM9OjDLrZNTp2Nw5jqsPcrfaGW
 | 
			
		||||
jBsPACynhhNx8oKYrRjabbbZsqXQiBbEeFw75U+CS82WGS+c7DpttaNCMEAwDgYD
 | 
			
		||||
VR0PAQH/BAQDAgIEMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFKHEYd+Lckg0
 | 
			
		||||
ywh26MypID6hLse2MAoGCCqGSM49BAMCA0gAMEUCIQCwNrBAa0W9lHIQ9xy0+402
 | 
			
		||||
QH/xlaz1xDDFwMINQ54r0AIgDp7E2jmbwa45zC1DJVXVJuHS+8XGcgP+LdvzhPV2
 | 
			
		||||
J70=
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
							
								
								
									
										5
									
								
								testdata/ca.key
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								testdata/ca.key
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
-----BEGIN EC PRIVATE KEY-----
 | 
			
		||||
MHcCAQEEIL/DOgsInoOhgVZ24VIf7dfHSyyuj57KQw8vPl1Gs2imoAoGCCqGSM49
 | 
			
		||||
AwEHoUQDQgAEMS3tcysM9OjDLrZNTp2Nw5jqsPcrfaGWjBsPACynhhNx8oKYrRja
 | 
			
		||||
bbbZsqXQiBbEeFw75U+CS82WGS+c7DpttQ==
 | 
			
		||||
-----END EC PRIVATE KEY-----
 | 
			
		||||
		Reference in New Issue
	
	Block a user