Initial import

This commit is contained in:
2025-09-26 08:49:53 +02:00
commit a76650da35
35 changed files with 4660 additions and 0 deletions

145
proxy/admin.go Normal file
View File

@@ -0,0 +1,145 @@
package proxy
import (
"bytes"
"encoding/json"
"encoding/pem"
"errors"
"net/http"
"os"
"strconv"
"strings"
"time"
"git.maze.io/maze/styx/internal/log"
)
type Admin struct {
*Proxy
}
func NewAdmin(proxy *Proxy) *Admin {
a := &Admin{
Proxy: proxy,
}
return a
}
func (a *Admin) handleRequest(ses *Session) error {
var (
logger = ses.log()
err error
)
switch ses.request.URL.Path {
case "/ca.crt":
err = a.handleCACert(ses)
case "/api/v1/policy":
err = a.apiPolicy(ses)
case "/api/v1/policy/matcher":
err = a.apiPolicyMatcher(ses)
case "/api/v1/stats/log":
err = a.apiStatsLog(ses)
case "/api/v1/stats/status":
err = a.apiStatsStatus(ses)
default:
if strings.HasPrefix(ses.request.URL.Path, "/api") {
err = errors.New("invalid endpoint")
} else {
err = os.ErrNotExist
}
}
if err != nil {
logger.Warn().Err(err).Msg("admin error")
ses.response = ErrorResponse(ses.request, err)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.Close = true
return a.writeResponse(ses)
}
return err
}
func (a *Admin) handleCACert(ses *Session) error {
b := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: a.authority.Certificate().Raw,
})
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
defer log.OnCloseError(log.Debug(), ses.response.Body)
ses.response.Close = true
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
ses.response.ContentLength = int64(len(b))
return a.writeResponse(ses)
}
func (a *Admin) apiPolicy(ses *Session) error {
var (
b = new(bytes.Buffer)
e = json.NewEncoder(b)
)
e.SetIndent("", " ")
if err := e.Encode(a.config.Policy); err != nil {
return err
}
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
defer log.OnCloseError(log.Debug(), ses.response.Body)
ses.response.Close = true
return a.writeResponse(ses)
}
func (a *Admin) apiPolicyMatcher(ses *Session) error {
var (
b = new(bytes.Buffer)
e = json.NewEncoder(b)
)
e.SetIndent("", " ")
if err := e.Encode(a.config.Policy.Matchers); err != nil {
return err
}
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
defer log.OnCloseError(log.Debug(), ses.response.Body)
ses.response.Close = true
return a.writeResponse(ses)
}
func (a *Admin) apiResponse(ses *Session, v any, err error) error {
if err != nil {
return err
}
var (
b = new(bytes.Buffer)
e = json.NewEncoder(b)
)
e.SetIndent("", " ")
if err := e.Encode(v); err != nil {
return err
}
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
defer log.OnCloseError(log.Debug(), ses.response.Body)
ses.response.Close = true
return a.writeResponse(ses)
}
func (a *Admin) apiStatsLog(ses *Session) error {
var (
query = ses.request.URL.Query()
offset, _ = strconv.Atoi(query.Get("offset"))
limit, _ = strconv.Atoi(query.Get("limit"))
)
if limit > 100 {
limit = 100
}
s, err := a.stats.QueryLog(offset, limit)
return a.apiResponse(ses, s, err)
}
func (a *Admin) apiStatsStatus(ses *Session) error {
s, err := a.stats.QueryStatus(time.Time{})
return a.apiResponse(ses, s, err)
}

8
proxy/cache/config.go vendored Normal file
View File

@@ -0,0 +1,8 @@
package cache
import "github.com/hashicorp/hcl/v2"
type Config struct {
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}

88
proxy/config.go Normal file
View File

@@ -0,0 +1,88 @@
package proxy
import (
"net"
"net/http"
"time"
"git.maze.io/maze/styx/proxy/policy"
"git.maze.io/maze/styx/proxy/resolver"
)
type ConnectHandler interface {
HandleConnect(session *Session, network, address string) net.Conn
}
// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request.
type ConnectHandlerFunc func(session *Session, network, address string) net.Conn
func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn {
return f(session, network, address)
}
type RequestHandler interface {
HandleRequest(session *Session) (*http.Request, *http.Response)
}
// RequestHandlerFunc is called when the proxy receives a new request.
type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response)
func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) {
return f(session)
}
type ResponseHandler interface {
HandleResponse(session *Session) *http.Response
}
// ResponseHandler is called when the proxy receives a response.
type ResponseHandlerFunc func(session *Session) *http.Response
func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response {
return f(session)
}
type ErrorHandler interface {
HandleError(session *Session, err error)
}
type ErrorHandlerFunc func(session *Session, err error)
func (f ErrorHandlerFunc) HandleError(session *Session, err error) {
f(session, err)
}
type Config struct {
// Listen address.
Listen string `hcl:"listen,optional"`
// Bind address for outgoing connections.
Bind string `hcl:"bind,optional"`
// Interface for outgoing connections.
Interface string `hcl:"interface,optional"`
// Upstream proxy servers.
Upstream []string `hcl:"upstream,optional"`
// DialTimeout for establishing new connections.
DialTimeout time.Duration `hcl:"dial_timeout,optional"`
// Policy for the proxy.
Policy *policy.Policy `hcl:"policy,block"`
// Resolver for the proxy.
Resolver resolver.Resolver
ConnectHandler ConnectHandler
RequestHandler RequestHandler
ResponseHandler ResponseHandler
ErrorHandler ErrorHandler
}
var (
_ ConnectHandler = (ConnectHandlerFunc)(nil)
_ RequestHandler = (RequestHandlerFunc)(nil)
_ ResponseHandler = (ResponseHandlerFunc)(nil)
_ ErrorHandler = (ErrorHandlerFunc)(nil)
)

324
proxy/match/config.go Normal file
View File

@@ -0,0 +1,324 @@
package match
import (
"fmt"
"net"
"net/http"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
"git.maze.io/maze/styx/internal/log"
"git.maze.io/maze/styx/internal/netutil"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
)
type Config struct {
Path string `hcl:"path,optional"`
Refresh time.Duration `hcl:"refresh,optional"`
Domain []*Domain `hcl:"domain,block"`
Network []*Network `hcl:"network,block"`
Content []*Content `hcl:"content,block"`
}
func (config Config) Matchers() (Matchers, error) {
all := make(Matchers)
if config.Domain != nil {
all["domain"] = make(map[string]Matcher)
for _, domain := range config.Domain {
m, err := domain.Matcher()
if err != nil {
return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err)
}
all["domain"][domain.Name] = m
}
}
if config.Network != nil {
all["network"] = make(map[string]Matcher)
for _, network := range config.Network {
m, err := network.Matcher(true)
if err != nil {
return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err)
}
all["network"][network.Name] = m
}
}
return all, nil
}
type Content struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
type contentHeader struct {
Key string `hcl:"name"`
Value string `hcl:"value,optional"`
List []string `hcl:"list,optional"`
name string
keyRe *regexp.Regexp
valueRe *regexp.Regexp
}
func (m contentHeader) Name() string { return m.name }
func (m contentHeader) MatchesResponse(r *http.Response) bool {
for k, vv := range r.Header {
if m.keyRe.MatchString(k) {
for _, v := range vv {
if slices.Contains(m.List, v) {
return true
}
if m.valueRe != nil && m.valueRe.MatchString(v) {
return true
}
}
}
}
return false
}
type contentType struct {
List []string `hcl:"list"`
name string
}
func (m contentType) Name() string { return m.name }
func (m contentType) MatchesResponse(r *http.Response) bool {
return slices.Contains(m.List, r.Header.Get("Content-Type"))
}
type contentSizeLargerThan struct {
Size int64 `hcl:"size"`
name string
}
func (m contentSizeLargerThan) Name() string { return m.name }
func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool {
size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
if err != nil {
return false
}
return size >= m.Size
}
type contentStatus struct {
Code []int `hcl:"code"`
name string
}
func (m contentStatus) Name() string { return m.name }
func (m contentStatus) MatchesResponse(r *http.Response) bool {
return slices.Contains(m.Code, r.StatusCode)
}
func (config Content) Matcher() (Response, error) {
switch strings.ToLower(config.Type) {
case "content", "contenttype", "content-type", "type":
var matcher = contentType{name: config.Name}
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
return nil, err
}
return matcher, nil
case "header":
var (
matcher = contentHeader{name: config.Name}
err error
)
if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
return nil, err
}
if matcher.Value == "" && len(matcher.List) == 0 {
return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name)
}
if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil {
return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err)
}
if matcher.Value != "" {
if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil {
return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err)
}
}
return matcher, nil
case "size":
var matcher = contentSizeLargerThan{name: config.Name}
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
return nil, err
}
return matcher, nil
case "status":
var matcher = contentStatus{name: config.Name}
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
return nil, err
}
return matcher, nil
default:
return nil, fmt.Errorf("unknown content matcher type %q", config.Type)
}
}
type Domain struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
func (config Domain) Matcher() (Request, error) {
switch config.Type {
case "list":
var matcher = domainList{Title: config.Name}
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
return nil, err
}
matcher.list = netutil.NewDomainList(matcher.List...)
return matcher, nil
case "adblock", "dnsmasq", "hosts", "detect", "domains":
var matcher = DomainFile{
Title: config.Name,
Type: config.Type,
}
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
return nil, err
}
if matcher.Path == "" && matcher.From == "" {
return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name)
}
if err := matcher.Update(); err != nil {
return nil, err
}
return matcher, nil
default:
return nil, fmt.Errorf("unknown domain matcher type %q", config.Type)
}
}
type domainList struct {
Title string `json:"title"`
List []string `hcl:"list" json:"list"`
list *netutil.DomainTree
}
func (m domainList) Name() string {
return m.Title
}
func (m domainList) MatchesRequest(r *http.Request) bool {
host := netutil.Host(r.URL.Host)
log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List))
return m.list.Contains(host)
}
type DomainFile struct {
Title string `json:"name"`
Type string `json:"type"`
Path string `hcl:"path,optional" json:"path,omitempty"`
From string `hcl:"from,optional" json:"from,omitempty"`
Refresh time.Duration `hcl:"refresh,optional" json:"refresh"`
}
func (m DomainFile) Name() string {
return m.Title
}
func (m DomainFile) MatchesRequest(_ *http.Request) bool {
return false
}
func (m *DomainFile) Update() (err error) {
var data []byte
if m.Path != "" {
if data, err = os.ReadFile(m.Path); err != nil {
return
}
} else {
/*
var response *http.Response
if response, err = http.DefaultClient.Get(m.From); err != nil {
return
}
defer func() { _ = response.Body.Close() }()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status)
}
if data, err = io.ReadAll(response.Body); err != nil {
return
}
*/
}
switch m.Type {
case "hosts":
}
_ = data
return nil
}
type Network struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
func (config *Network) Matcher(target bool) (Matcher, error) {
switch config.Type {
case "list":
var (
matcher = networkList{Title: config.Name}
err error
)
if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() {
return nil, diag
}
if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil {
return nil, err
}
return &matcher, nil
default:
return nil, fmt.Errorf("unknown network matcher type %q", config.Type)
}
}
type networkList struct {
Title string `json:"name"`
List []string `hcl:"list" json:"list"`
tree *netutil.NetworkTree
target bool
}
func (m *networkList) Name() string {
return m.Title
}
func (m *networkList) MatchesIP(ip net.IP) bool {
return m.tree.Contains(ip)
}
func (m *networkList) MatchesRequest(r *http.Request) bool {
var (
host string
err error
)
if m.target {
host, _, err = net.SplitHostPort(r.URL.Host)
} else {
host, _, err = net.SplitHostPort(r.RemoteAddr)
}
if err != nil {
return false
}
ip := net.ParseIP(host)
return m.MatchesIP(ip)
}

45
proxy/match/match.go Normal file
View File

@@ -0,0 +1,45 @@
package match
import (
"fmt"
"net"
"net/http"
)
type Matchers map[string]map[string]Matcher
func (all Matchers) Get(kind, name string) (m Matcher, err error) {
if typeMatchers, ok := all[kind]; ok {
if m, ok = typeMatchers[name]; ok {
return
}
return nil, fmt.Errorf("no %s matcher named %q found", kind, name)
}
return nil, fmt.Errorf("no %s matcher found", kind)
}
type Matcher interface {
Name() string
}
type Updater interface {
Update() error
}
type IP interface {
Matcher
MatchesIP(net.IP) bool
}
type Request interface {
Matcher
MatchesRequest(*http.Request) bool
}
type Response interface {
Matcher
MatchesResponse(*http.Response) bool
}

11
proxy/match/util.go Normal file
View File

@@ -0,0 +1,11 @@
package match
import "net"
func onlyHost(name string) string {
host, _, err := net.SplitHostPort(name)
if err != nil {
return name
}
return host
}

231
proxy/mitm/authority.go Normal file
View File

@@ -0,0 +1,231 @@
package mitm
import (
"crypto"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"math/big"
"os"
"strings"
"time"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/log"
"github.com/miekg/dns"
)
const DefaultValidity = 24 * time.Hour
type Authority interface {
Certificate() *x509.Certificate
TLSConfig(name string) *tls.Config
}
type authority struct {
pool *x509.CertPool
cert *x509.Certificate
key crypto.PrivateKey
keyID []byte
keyPool chan crypto.PrivateKey
cache Cache
}
func New(config *Config) (Authority, error) {
cache, err := NewCache(config.Cache)
if err != nil {
return nil, err
}
caConfig := config.CA
if caConfig == nil {
caConfig = new(CAConfig)
}
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
if os.IsNotExist(err) {
days := caConfig.Days
if days == 0 {
days = DefaultDays
}
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
return nil, err
}
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
return nil, err
}
}
} else if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(cert)
keyConfig := config.Key
if keyConfig == nil {
keyConfig = &defaultKeyConfig
}
keyPoolSize := defaultKeyConfig.Pool
if keyConfig.Pool > 0 {
keyPoolSize = keyConfig.Pool
}
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
} else {
keyPool <- key
}
go func(pool chan<- crypto.PrivateKey) {
for {
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
if err != nil {
log.Panic().Err(err).Msg("error generating private key")
}
pool <- key
}
}(keyPool)
return &authority{
pool: pool,
cert: cert,
key: key,
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
keyPool: keyPool,
cache: cache,
}, nil
}
func (ca *authority) log() log.Logger {
return log.Console.With().
Str("ca", ca.cert.Subject.String()).
Logger()
}
func (ca *authority) Certificate() *x509.Certificate {
return ca.cert
}
func (ca *authority) TLSConfig(name string) *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log := ca.log()
if hello.ServerName != "" {
name = strings.ToLower(hello.ServerName)
log.Debug().Msg("requesting certificate for server name (SNI)")
} else {
log.Debug().Msg("requesting certificate for hostname")
}
if cert, ok := ca.getCached(name); ok {
log.Debug().
Str("subject", cert.Leaf.Subject.String()).
Str("serial", cert.Leaf.SerialNumber.String()).
Time("valid", cert.Leaf.NotAfter).
Msg("using cached certificate")
return cert, nil
}
return ca.issueFor(name)
},
NextProtos: []string{"http/1.1"},
}
}
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
log := ca.log()
if cert = ca.cache.Certificate(name); cert == nil {
if baseDomain(name) != name {
cert = ca.cache.Certificate(baseDomain(name))
}
}
if cert != nil {
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
DNSName: name,
Roots: ca.pool,
}); err != nil {
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
} else {
ok = true
}
}
return
}
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
var (
log = ca.log().With().Str("name", name).Logger()
key crypto.PrivateKey
)
select {
case key = <-ca.keyPool:
case <-time.After(5 * time.Second):
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
}
if key == nil {
panic("key pool returned nil key")
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
}
if part := dns.SplitDomainName(name); len(part) > 2 {
name = strings.Join(part[1:], ".")
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{name, "*." + name},
BasicConstraintsValid: true,
NotBefore: now.Add(-DefaultValidity),
NotAfter: now.Add(+DefaultValidity),
}
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
if err != nil {
return nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
out := &tls.Certificate{
Certificate: [][]byte{der},
Leaf: cert,
PrivateKey: key,
}
//ca.cache[name] = out
ca.cache.SaveCertificate(name, out)
return out, nil
}
func containsValidCertificate(cert *tls.Certificate) bool {
if cert == nil || len(cert.Certificate) == 0 {
return false
}
if cert.Leaf == nil {
var err error
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
return false
}
}
now := time.Now()
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
}

233
proxy/mitm/cache.go Normal file
View File

@@ -0,0 +1,233 @@
package mitm
import (
"crypto/tls"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/miekg/dns"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/internal/log"
)
type Cache interface {
Certificate(name string) *tls.Certificate
SaveCertificate(name string, cert *tls.Certificate) error
RemoveCertificate(name string)
}
func NewCache(config *CacheConfig) (Cache, error) {
if config == nil {
return NewCache(&CacheConfig{Type: "memory"})
}
switch config.Type {
case "memory":
var cacheConfig = new(MemoryCacheConfig)
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
return nil, err
}
return NewMemoryCache(cacheConfig.Size), nil
case "disk":
var cacheConfig = new(DiskCacheConfig)
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
return nil, err
}
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
default:
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
}
}
type memoryCache struct {
cache *expirable.LRU[string, *tls.Certificate]
}
func NewMemoryCache(size int) Cache {
return memoryCache{
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
log.Debug().Str("name", key).Msg("certificate evicted from cache")
}, time.Hour*24),
}
}
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
var ok bool
if cert, ok = c.cache.Get(name); !ok {
cert, _ = c.cache.Get(baseDomain(name))
}
return
}
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
c.cache.Add(name, cert)
log.Debug().Str("name", name).Msg("certificate added to cache")
return nil
}
func (c memoryCache) RemoveCertificate(name string) {
c.cache.Remove(name)
}
type diskCache string
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
if !filepath.IsAbs(dir) {
var err error
if dir, err = filepath.Abs(dir); err != nil {
return nil, err
}
}
if err := os.MkdirAll(dir, 0o750); err != nil {
return nil, err
}
info, err := os.Stat(dir)
if err != nil {
return nil, err
}
if info.Mode()&os.ModePerm|0o057 != 0 {
if err := os.Chmod(dir, 0o750); err != nil {
return nil, err
}
}
if expire > 0 {
go expireDiskCache(dir, expire)
}
return diskCache(dir), nil
}
func expireDiskCache(root string, expire time.Duration) {
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
ticker := time.NewTicker(expire)
defer ticker.Stop()
for {
now := <-ticker.C
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
// Remove the directory; this will fail if it's not empty, which is fine.
_ = os.Remove(path)
return nil
}
cert, err := cryptutil.LoadCertificate(path)
if err != nil {
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
_ = os.Remove(path)
return nil
} else if cert.NotAfter.Before(now) {
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
_ = os.Remove(path)
return nil
}
return nil
})
}
}
func (c diskCache) path(name string) string {
part := dns.SplitDomainName(strings.ToLower(name))
// x,com -> com,x
// www,maze,io -> io,maze,www
slices.Reverse(part)
// com,x -> com,x,x.com
// io,maze,www -> io,m,ma,maze,www.maze.io
if len(part) > 2 {
if len(part[1]) > 1 {
part = []string{
part[0],
part[1][:1],
part[1][:2],
part[1],
name,
}
} else {
part = []string{
part[0],
part[1][:1],
part[1],
name,
}
}
} else if len(part) > 1 {
if len(part[1]) > 1 {
part = []string{
part[0],
part[1][:1],
part[1][:2],
name,
}
} else {
part = []string{
part[0],
part[1][:1],
name,
}
}
}
part[len(part)-1] += ".crt"
return filepath.Join(append([]string{string(c)}, part...)...)
}
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
return &tls.Certificate{
Certificate: [][]byte{cert.Raw},
Leaf: cert,
PrivateKey: key,
}
}
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
return &tls.Certificate{
Certificate: [][]byte{cert.Raw},
Leaf: cert,
PrivateKey: key,
}
}
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
return nil
}
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
dir, name := filepath.Split(c.path(name))
if err := os.MkdirAll(dir, 0o750); err != nil {
return err
}
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
return err
}
log.Debug().Str("name", name).Msg("certificate added to cache")
return nil
}
func (c diskCache) RemoveCertificate(name string) {
path := c.path(name)
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
return
}
log.Error().Err(err).Msg("certificate remove from cache failed")
}
_ = os.Remove(filepath.Dir(path))
log.Debug().Str("name", name).Msg("certificate removed from cache")
}
func baseDomain(name string) string {
name = strings.ToLower(name)
if part := dns.SplitDomainName(name); len(part) > 2 {
return strings.Join(part[1:], ".")
}
return name
}

25
proxy/mitm/cache_test.go Normal file
View File

@@ -0,0 +1,25 @@
package mitm
import "testing"
func TestDiskCachePath(t *testing.T) {
cache := diskCache("testdata")
tests := []struct {
test string
want string
}{
{"x.com", "testdata/com/x/x.com.crt"},
{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
{"maze.io", "testdata/io/m/ma/maze.io.crt"},
{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
}
for _, test := range tests {
t.Run(test.test, func(it *testing.T) {
if v := cache.path(test.test); v != test.want {
it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
}
})
}
}

89
proxy/mitm/config.go Normal file
View File

@@ -0,0 +1,89 @@
package mitm
import (
"crypto/x509/pkix"
"github.com/hashicorp/hcl/v2"
)
const (
DefaultCommonName = "Styx Certificate Authority"
DefaultDays = 3
)
type Config struct {
CA *CAConfig `hcl:"ca,block"`
Key *KeyConfig `hcl:"key,block"`
Cache *CacheConfig `hcl:"cache,block"`
}
type CAConfig struct {
Cert string `hcl:"cert"`
Key string `hcl:"key,optional"`
Days int `hcl:"days,optional"`
KeyType string `hcl:"key_type,optional"`
Bits int `hcl:"bits,optional"`
Name string `hcl:"name,optional"`
Country string `hcl:"country,optional"`
Organization string `hcl:"organization,optional"`
Unit string `hcl:"unit,optional"`
Locality string `hcl:"locality,optional"`
Province string `hcl:"province,optional"`
Address []string `hcl:"address,optional"`
PostalCode string `hcl:"postal_code,optional"`
}
func (config CAConfig) DN() pkix.Name {
var name = pkix.Name{
CommonName: config.Name,
StreetAddress: config.Address,
}
if config.Name == "" {
name.CommonName = DefaultCommonName
}
if config.Country != "" {
name.Country = append(name.Country, config.Country)
}
if config.Organization != "" {
name.Organization = append(name.Organization, config.Organization)
}
if config.Unit != "" {
name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
}
if config.Locality != "" {
name.Locality = append(name.Locality, config.Locality)
}
if config.Province != "" {
name.Province = append(name.Province, config.Province)
}
if config.PostalCode != "" {
name.PostalCode = append(name.PostalCode, config.PostalCode)
}
return name
}
type KeyConfig struct {
Type string `hcl:"type,optional"`
Bits int `hcl:"bits,optional"`
Pool int `hcl:"pool,optional"`
}
var defaultKeyConfig = KeyConfig{
Type: "rsa",
Bits: 2048,
Pool: 5,
}
type CacheConfig struct {
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
type MemoryCacheConfig struct {
Size int `hcl:"size,optional"`
}
type DiskCacheConfig struct {
Path string `hcl:"path"`
Expire float64 `hcl:"expire,optional"`
}

53
proxy/policy/policy.go Normal file
View File

@@ -0,0 +1,53 @@
package policy
import (
"net/http"
"git.maze.io/maze/styx/proxy/match"
)
// Policy contains rules that make up the policy.
//
// Some policy rules contain nested policies.
type Policy struct {
Rules []*rawRule `hcl:"on,block" json:"rules"`
Permit *bool `hcl:"permit" json:"permit"`
Matchers match.Matchers `json:"matchers"` // Matchers for the policy
}
func (p *Policy) Configure(matchers match.Matchers) (err error) {
for _, r := range p.Rules {
if err = r.Configure(matchers); err != nil {
return
}
}
p.Matchers = matchers
return
}
func (p *Policy) PermitIntercept(r *http.Request) *bool {
if p != nil {
for _, rule := range p.Rules {
if rule, ok := rule.Rule.(InterceptRule); ok {
if permit := rule.PermitIntercept(r); permit != nil {
return permit
}
}
}
}
return p.Permit
}
func (p *Policy) PermitRequest(r *http.Request) *bool {
if p != nil {
for _, rule := range p.Rules {
if rule, ok := rule.Rule.(RequestRule); ok {
if permit := rule.PermitRequest(r); permit != nil {
return permit
}
}
}
}
return p.Permit
}

139
proxy/policy/policy_test.go Normal file
View File

@@ -0,0 +1,139 @@
package policy
import (
"net"
"net/http"
"net/url"
"testing"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/proxy/match"
"github.com/miekg/dns"
)
type testInDomainList struct {
t *testing.T
list []string
}
func (testInDomainList) Name() string { return "testInDomainList" }
func (l testInDomainList) MatchesRequest(r *http.Request) bool {
for _, domain := range l.list {
if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) {
l.t.Logf("domain %s contains %s", domain, r.URL.Host)
return true
}
l.t.Logf("domain %s does not contain %s", domain, r.URL.Host)
}
return false
}
func testInDomain(t *testing.T, domains ...string) match.Matcher {
return &testInDomainList{t: t, list: domains}
}
type testInNetworkList struct {
t *testing.T
list []*net.IPNet
}
func (testInNetworkList) Name() string { return "testInNetworkList" }
func (l testInNetworkList) MatchesIP(ip net.IP) bool {
for _, ipnet := range l.list {
if ipnet.Contains(ip) {
l.t.Logf("network %s contains %s", ipnet, ip)
return true
}
l.t.Logf("network %s does not contain %s", ipnet, ip)
}
return false
}
func testInNetwork(t *testing.T, cidr string) match.Matcher {
t.Helper()
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
return testInNetworkList{t: t, list: []*net.IPNet{ipnet}}
}
func TestPolicy(t *testing.T) {
var (
yes = true
nope = false
)
p := &Policy{
Rules: []*rawRule{
{
Rule: &requestRule{
domainOrNetworkRule: domainOrNetworkRule{
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
isSource: []bool{true},
},
},
},
{
Rule: &requestRule{
domainOrNetworkRule: domainOrNetworkRule{
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
isSource: []bool{false},
},
Permit: &yes,
},
},
{
Rule: &requestRule{
domainOrNetworkRule: domainOrNetworkRule{
matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")},
},
Permit: &yes,
},
},
{
Rule: &requestRule{
domainOrNetworkRule: domainOrNetworkRule{
matchers: []match.Matcher{testInDomain(t, "google.com")},
},
Permit: &nope,
},
},
},
}
r := &http.Request{
URL: &url.URL{Scheme: "http", Host: "golang.org:80"},
RemoteAddr: "127.0.0.1:1234",
}
if v := p.PermitRequest(r); v != nil {
t.Errorf("expected request to return no verdict, got %t", *v)
}
p.Rules[0].Rule.(*requestRule).Permit = &yes
if v := p.PermitRequest(r); v == nil || *v != yes {
t.Errorf("expected request to return %t, %v", yes, v)
}
r.RemoteAddr = "192.168.1.2:3456"
if v := p.PermitRequest(r); v != nil {
t.Errorf("expected request to return no verdict, got %t", *v)
}
if v := p.PermitIntercept(r); v != nil {
t.Errorf("expected request to return no verdict, got %t", *v)
}
r.URL.Host = "maze.io"
if v := p.PermitRequest(r); v == nil || *v != yes {
t.Errorf("expected request to return %t, %v", yes, v)
}
r.URL.Host = "google.com"
if v := p.PermitRequest(r); v == nil || *v != nope {
t.Errorf("expected request to return %t, %v", nope, v)
}
r.URL.Host = "localhost:80"
if v := p.PermitRequest(r); v == nil || *v != yes {
t.Errorf("expected request to return %t, %v", yes, v)
}
}

368
proxy/policy/rule.go Normal file
View File

@@ -0,0 +1,368 @@
package policy
import (
"fmt"
"net"
"net/http"
"strings"
"time"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/proxy/match"
"github.com/google/uuid"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
)
// Rule is a policy rule.
type Rule interface {
Configure(match.Matchers) error
}
// InterceptRule can make policy rule decisions on intercept requests.
type InterceptRule interface {
PermitIntercept(r *http.Request) *bool
}
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
type RequestRule interface {
PermitRequest(r *http.Request) *bool
}
type rawRule struct {
Type string `hcl:"type,label" json:"type"`
Body hcl.Body `hcl:",remain" json:"-"`
Rule `json:"rule"`
}
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
switch r.Type {
case "intercept":
r.Rule = new(interceptRule)
case "request":
r.Rule = new(requestRule)
case "days":
r.Rule = new(daysRule)
case "time":
r.Rule = new(timeRule)
case "all":
r.Rule = new(allRule)
default:
return fmt.Errorf("policy: invalid event type %q", r.Type)
}
if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
return err
}
return r.Rule.Configure(matchers)
}
type allRule struct {
Rules []*rawRule `hcl:"on,block"`
Permit *bool `hcl:"permit"`
}
func (r *allRule) Configure(matchers match.Matchers) (err error) {
return
}
type domainOrNetworkRule struct {
matchers []match.Matcher
isSource []bool
}
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
var m match.Matcher
for _, domain := range domains {
if m, err = matchers.Get("domain", domain); err != nil {
return fmt.Errorf("%s: unknown domain %q", kind, domain)
}
r.matchers = append(r.matchers, m)
r.isSource = append(r.isSource, false)
}
for _, network := range sources {
if m, err = matchers.Get("network", network); err != nil {
return fmt.Errorf("%s: unknown source network %q", kind, network)
}
r.matchers = append(r.matchers, m)
r.isSource = append(r.isSource, true)
}
for _, network := range targets {
if m, err = matchers.Get("network", network); err != nil {
return fmt.Errorf("%s: unknown target network %q", kind, network)
}
r.matchers = append(r.matchers, m)
r.isSource = append(r.isSource, false)
}
if len(r.matchers) == 0 {
return fmt.Errorf("%s: missing any of domain, source, target", kind)
}
if id != nil {
*id = uuid.NewString()
}
return
}
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
for i, m := range r.matchers {
if m, ok := m.(match.Request); ok {
if m.MatchesRequest(q) {
return true
}
}
if m, ok := m.(match.IP); ok {
if r.isSource[i] {
if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
return true
}
} else {
var (
host = netutil.Host(q.URL.Host)
ips []net.IP
)
if ip := net.ParseIP(host); ip != nil {
ips = append(ips, ip)
} else {
ips, _ = net.LookupIP(host)
}
for _, ip := range ips {
if m.MatchesIP(ip) {
return true
}
}
}
}
}
return false
}
type interceptRule struct {
ID string `json:"id,omitempty"`
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
Source []string `hcl:"source,optional" json:"source,omitempty"`
Target []string `hcl:"target,optional" json:"target,omitempty"`
Permit *bool `hcl:"permit" json:"permit"`
domainOrNetworkRule `json:"-"`
}
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
}
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
if r.matchesRequest(q) {
return r.Permit
}
return nil
}
type requestRule struct {
ID string `json:"id,omitempty"`
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
Source []string `hcl:"source,optional" json:"source,omitempty"`
Target []string `hcl:"target,optional" json:"target,omitempty"`
Permit *bool `hcl:"permit" json:"permit"`
domainOrNetworkRule `json:"-"`
}
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
}
func (r *requestRule) PermitRequest(q *http.Request) *bool {
if r.matchesRequest(q) {
return r.Permit
}
return nil
}
type timeRule struct {
ID string `json:"id,omitempty"`
Time []string `hcl:"time" json:"time"`
Permit *bool `hcl:"permit" json:"permit"`
Body hcl.Body `hcl:",remain" json:"-"`
Rules *Policy `json:"rules"`
Start Time `json:"start"`
End Time `json:"end"`
}
func (r *timeRule) isActive() bool {
if r == nil {
return true
}
now := Now()
if r.Start.After(r.End) { // ie: 18:00-06:00
return now.After(r.Start) || now.Before(r.End)
}
return now.After(r.Start) && now.Before(r.End)
}
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
if len(r.Time) != 2 {
return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
}
if r.Start, err = ParseTime(r.Time[0]); err != nil {
return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
}
if r.End, err = ParseTime(r.Time[1]); err != nil {
return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
}
r.Rules = new(Policy)
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
return diag
}
if err = r.Rules.Configure(matchers); err != nil {
return
}
r.Rules.Matchers = nil
if r.ID == "" {
r.ID = uuid.NewString()
}
return
}
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitIntercept(q)
}
func (r *timeRule) PermitRequest(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitRequest(q)
}
type daysRule struct {
ID string `json:"id,omitempty"`
Days string `hcl:"days" json:"days"`
Permit *bool `hcl:"permit" json:"permit"`
Body hcl.Body `hcl:",remain" json:"-"`
Rules *Policy `json:"rules"`
cond []onCond
}
func (r *daysRule) isActive() bool {
if r == nil || len(r.cond) == 0 {
return true
}
now := time.Now()
for _, cond := range r.cond {
if cond(now) {
return true
}
}
return false
}
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
if r.cond, err = parseOnCond(r.Days); err != nil {
return
}
r.Rules = new(Policy)
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
return diag
}
if err = r.Rules.Configure(matchers); err != nil {
return
}
r.Rules.Matchers = nil
if r.ID == "" {
r.ID = uuid.NewString()
}
return
}
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitIntercept(q)
}
func (r *daysRule) PermitRequest(q *http.Request) *bool {
if !r.isActive() {
return nil
}
return r.Rules.PermitRequest(q)
}
type onCond func(time.Time) bool
var weekdays = map[string]time.Weekday{
"sun": time.Sunday,
"mon": time.Monday,
"tue": time.Tuesday,
"wed": time.Wednesday,
"thu": time.Thursday,
"fri": time.Friday,
"sat": time.Saturday,
}
func parseOnCond(when string) (conds []onCond, err error) {
for _, spec := range strings.Split(when, ",") {
spec = strings.ToLower(strings.TrimSpace(spec))
if d, ok := weekdays[spec]; ok {
conds = append(conds, onWeekday(d))
} else if spec == "weekend" || spec == "weekends" {
conds = append(conds, onWeekend)
} else if spec == "workday" || spec == "workdays" {
conds = append(conds, onWorkday)
} else if strings.ContainsRune(spec, '-') {
var (
part = strings.SplitN(spec, "-", 2)
from, upto time.Weekday
ok bool
)
if from, ok = weekdays[part[0]]; !ok {
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
}
if upto, ok = weekdays[part[1]]; !ok {
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
}
if from < upto {
for d := from; d < upto; d++ {
conds = append(conds, onWeekday(d))
}
} else {
for d := time.Sunday; d < from; d++ {
conds = append(conds, onWeekday(d))
}
for d := upto; d <= time.Saturday; d++ {
conds = append(conds, onWeekday(d))
}
}
} else {
return nil, fmt.Errorf("on %q: invalid condition", spec)
}
}
return
}
func onWeekday(weekday time.Weekday) onCond {
return func(t time.Time) bool {
return t.Weekday() == weekday
}
}
func onWeekend(t time.Time) bool {
d := t.Weekday()
return d == time.Saturday || d == time.Sunday
}
func onWorkday(t time.Time) bool {
d := t.Weekday()
return !(d == time.Saturday || d == time.Sunday)
}

53
proxy/policy/time.go Normal file
View File

@@ -0,0 +1,53 @@
package policy
import (
"fmt"
"time"
)
type Time struct {
Hour int
Minute int
Second int
}
func (t Time) Eq(other Time) bool {
return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second
}
func (t Time) After(other Time) bool {
return t.Seconds() > other.Seconds()
}
func (t Time) Before(other Time) bool {
return t.Seconds() < other.Seconds()
}
func (t Time) Seconds() int {
return t.Hour*3600 + t.Minute*60 + t.Second
}
func (t Time) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil
}
var timeFormats = []string{
time.TimeOnly,
"15:04",
time.Kitchen,
}
func Now() Time {
now := time.Now()
return Time{now.Hour(), now.Minute(), now.Second()}
}
func ParseTime(s string) (t Time, err error) {
var tt time.Time
for _, layout := range timeFormats {
if tt, err = time.Parse(layout, s); err == nil {
return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil
}
}
return Time{}, fmt.Errorf("time: invalid time %q", s)
}

616
proxy/proxy.go Normal file
View File

@@ -0,0 +1,616 @@
package proxy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"syscall"
"time"
"git.maze.io/maze/styx/internal/log"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/proxy/mitm"
"git.maze.io/maze/styx/proxy/policy"
"git.maze.io/maze/styx/proxy/resolver"
"git.maze.io/maze/styx/proxy/stats"
)
const (
DefaultListenAddr = ":3128"
DefaultBindAddr = ""
DefaultDialTimeout = 30 * time.Second
DefaultKeepAlivePeriod = 1 * time.Minute
)
const (
HeaderAcceptEncoding = "Accept-Encoding"
HeaderConnection = "Connection"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderUpgrade = "Upgrade"
)
var (
ErrClosed = errors.New("proxy: shutdown")
ErrClientCert = errors.New("tls: client certificate requested")
)
type Proxy struct {
addr *net.TCPAddr
bind *net.TCPAddr
resolver resolver.Resolver
transport *http.Transport
dial func(network, address string) (net.Conn, error)
config *Config
authority mitm.Authority
policy *policy.Policy
admin *Admin
stats *stats.Stats
closed chan struct{}
onConnect ConnectHandler
onRequest RequestHandler
onResponse ResponseHandler
onError ErrorHandler
}
func New(config *Config, ca mitm.Authority) (*Proxy, error) {
if config == nil {
return nil, errors.New("proxy: config can't be nil")
}
p := &Proxy{
transport: newTransport(),
config: config,
resolver: resolver.Default,
authority: ca,
policy: config.Policy,
closed: make(chan struct{}),
onConnect: config.ConnectHandler,
onRequest: config.RequestHandler,
onResponse: config.ResponseHandler,
onError: config.ErrorHandler,
}
var err error
if config.Listen == "" {
p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr)
} else {
p.addr, err = net.ResolveTCPAddr("tcp", config.Listen)
}
if err != nil {
return nil, fmt.Errorf("proxy: invalid listen addres: %w", err)
}
if config.Bind != "" {
if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil {
return nil, fmt.Errorf("proxy: invalid bind address: %w", err)
}
} else if config.Interface != "" {
if err = resolveInterfaceAddr(config.Interface); err != nil {
return nil, err
}
}
if p.bind != nil {
/* FIXME
var c *net.TCPConn
if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) {
return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL)
} else if c != nil {
_ = c.Close()
}
*/
}
if config.Resolver != nil {
p.resolver = config.Resolver
}
dialTimeout := DefaultDialTimeout
if config.DialTimeout > 0 {
dialTimeout = config.DialTimeout
}
p.dial = (&net.Dialer{
Timeout: dialTimeout,
KeepAlive: dialTimeout,
LocalAddr: p.bind,
}).Dial
p.admin = NewAdmin(p)
if p.stats, err = stats.New(); err != nil {
return nil, err
}
return p, nil
}
func newTransport() *http.Transport {
return &http.Transport{
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: 5 * time.Second,
}
}
func (p *Proxy) Close() error {
select {
case <-p.closed:
return ErrClosed
default:
close(p.closed)
return nil
}
}
func (p *Proxy) Start() error {
l, err := net.ListenTCP("tcp", p.addr)
if err != nil {
return err
}
go p.Serve(l)
return nil
}
func (p *Proxy) Serve(listener net.Listener) error {
defer func() { _ = listener.Close() }()
log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening")
for {
select {
case <-p.closed:
return nil
default:
}
c, err := listener.Accept()
if err != nil {
return err
}
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
ctx := newContext(c, rw, nil)
if c, ok := c.(*net.TCPConn); ok {
_ = c.SetKeepAlive(true)
_ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod)
}
go p.handle(ctx)
}
}
func (p *Proxy) handle(ctx *Context) {
logger := ctx.log()
defer log.OnCloseError(logger.Debug(), ctx.conn)
logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection")
last := int64(0)
for {
select {
case <-p.closed:
return
default:
ses, err := p.handleRequest(ctx)
if ses != nil {
log := ses.log()
log.Info().
Str("method", ses.request.Method).
Str("url", ses.request.URL.String()).
Str("status", ses.response.Status).
Int64("size", ctx.conn.bytes-last).
Msg("handled request")
p.stats.AddLog(&stats.Log{
ClientIP: netutil.Host(ses.request.RemoteAddr),
Request: stats.FromRequest(ses.request),
Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last),
})
last = ctx.conn.bytes
}
if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) {
event := logger.Debug()
if ctx.conn.bytes > 0 {
event = event.Int64("size", ctx.conn.bytes)
}
event.Msg("closing client connection")
return
}
}
}
}
func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) {
logger := ctx.log()
var request *http.Request
if request, err = p.readRequest(ctx); err != nil {
return
}
ses = newSession(ctx, request)
p.cleanRequest(ses, request)
logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request")
if p.onRequest != nil {
newRequest, newResponse := p.onRequest.HandleRequest(ses)
if newRequest != nil {
logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override")
ses.request = newRequest
}
if newResponse != nil {
logger.Debug().Str("status", newResponse.Status).Msg("response override")
ses.response = newResponse
}
}
if ses.response == nil {
// WebSocket request
if ses.request.Header.Get(HeaderUpgrade) == "websocket" {
return ses, p.handleTunnel(ses)
}
cleanHopByHopHeaders(ses.request.Header)
// Proxy CONNECT request
if ses.request.Method == http.MethodConnect {
return p.handleConnect(ses)
}
if netutil.Port(ses.request.URL.Host) == p.addr.Port {
// Plain API request
ses.request.URL.Host = ses.request.Host
return ses, p.admin.handleRequest(ses)
} else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil {
// Plain HTTP request
if p.config.ErrorHandler != nil {
p.config.ErrorHandler.HandleError(ses, err)
}
ses.response = ErrorResponse(ses.request, err)
}
logger.Debug().Str("status", ses.response.Status).Msg("received response")
cleanHopByHopHeaders(ses.response.Header)
}
ses.response.Close = true
defer log.OnCloseError(logger.Debug(), ses.response.Body)
return ses, p.writeResponse(ses)
}
func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) {
next = ses
logger := ses.log()
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
var c net.Conn
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
logger.Error().Err(err).Msg("connect failed")
if p.onError != nil {
p.onError.HandleError(ses, err)
}
ses.response = ErrorResponse(ses.request, err)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
_ = p.writeResponse(ses)
return
}
defer func() {
if err := c.Close(); err != nil {
if p.onError != nil {
p.onError.HandleError(ses, err)
}
}
}()
if p.canIntercept(ses.request) {
logger.Debug().Msg("intercepting connection")
ses.response = NewResponse(http.StatusOK, nil, ses.request)
err = p.writeResponse(ses)
log.OnCloseError(logger.Debug(), ses.response.Body)
if err != nil {
return
}
// Peek first byte
b := make([]byte, 1)
if _, err = io.ReadFull(ses.ctx.rw, b); err != nil {
logger.Error().Err(err).Msg("error peeking CONNECT byte")
return
}
// Drain buffered bytes
b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...)
ses.ctx.rw.Reader.Read(b[1:])
r := &connReader{
Conn: ses.ctx.conn,
Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn),
}
if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host))
if err = secure.Handshake(); err != nil {
logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed")
return
}
rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure))
ctx := newContext(secure, rw, ses)
return p.handleRequest(ctx)
}
rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r))
ctx := newContext(r, rw, ses)
return p.handleRequest(ctx)
}
ses.response = NewResponse(http.StatusOK, nil, ses.request)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.ContentLength = -1
if err = p.writeResponse(ses); err != nil {
return
}
logger.Debug().Msg("established CONNECT tunnel, proxying traffic")
var wait sync.WaitGroup
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
wait.Wait()
logger.Debug().Msg("closed CONNECT tunnel")
return
}
func (p *Proxy) handleTunnel(ses *Session) (err error) {
logger := ses.log()
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
var c net.Conn
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
logger.Error().Err(err).Msg("connect failed")
if p.onError != nil {
p.onError.HandleError(ses, err)
}
ses.response = ErrorResponse(ses.request, err)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
_ = p.writeResponse(ses)
return
}
defer log.OnCloseError(logger.Debug(), c)
if ses.ctx.IsTLS() {
// Open a TLS client connection
secure := tls.Client(c, &tls.Config{
ServerName: ses.request.URL.Host,
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, ErrClientCert
},
})
if err = secure.Handshake(); err != nil {
logger.Error().Err(err).Msg("TLS handshake failed")
return
}
c = secure
}
if err = ses.request.Write(c); err != nil {
logger.Error().Err(err).Msg("failed to write request")
return
}
logger.Debug().Msg("established tunnel, proxying traffic")
var wait sync.WaitGroup
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
wait.Wait()
logger.Debug().Msg("closed tunnel")
return
}
func (p *Proxy) canIntercept(request *http.Request) bool {
if permit := p.policy.PermitIntercept(request); permit != nil {
return *permit
}
return true
}
/*
func (p *Proxy) handleAPIRequest(ses *Session) error {
if ses.request.URL.Path == "/ca.crt" && p.authority != nil {
b := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: p.authority.Certificate().Raw,
})
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.Close = true
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
ses.response.ContentLength = int64(len(b))
return p.writeResponse(ses)
}
ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint"))
defer log.OnCloseError(logger.Debug(), ses.response.Body)
ses.response.Close = true
return p.writeResponse(ses)
}
*/
func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) {
var (
done = make(chan *http.Request, 1)
errs = make(chan error, 1)
)
go func() {
r, err := http.ReadRequest(ctx.rw.Reader)
if err != nil {
errs <- err
} else {
done <- r
}
}()
select {
case <-p.closed:
return nil, ErrClosed
case request = <-done:
return
case err = <-errs:
return
}
}
func (p *Proxy) cleanRequest(ses *Session, request *http.Request) {
if request.URL.Host == "" {
request.URL.Host = request.Host
}
// Ensure proper URL scheme
if !strings.HasPrefix(request.URL.Scheme, "http") {
request.URL.Scheme = "http"
}
if ses.ctx.IsTLS() {
state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState()
request.TLS = &state
request.URL.Scheme = "https"
}
// Ensure proper RemoteAddr
request.RemoteAddr = ses.ctx.RemoteAddr().String()
// Ensure proper encoding
if request.Header.Get(HeaderAcceptEncoding) != "" {
// We only support gzip
request.Header.Set(HeaderAcceptEncoding, "gzip")
}
}
func (p *Proxy) writeResponse(ses *Session) (err error) {
log := ses.log()
if p.onResponse != nil {
response := p.onResponse.HandleResponse(ses)
if response != nil {
log.Debug().Str("status", response.Status).Msg("response override")
ses.response = response
}
}
if err = ses.response.Write(ses.ctx); err != nil {
log.Error().Err(err).Msg("error writing response back to client")
} else if err = ses.ctx.Flush(); err != nil {
log.Error().Err(err).Msg("error flushing response back to client")
}
return
}
func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) {
log := ses.log()
log.Debug().Msgf("connect to %s://%s", network, address)
if p.onConnect != nil {
if c = p.onConnect.HandleConnect(ses, network, address); c != nil {
log.Debug().Msg("connect override")
return
}
}
var host, port string
if host, port, err = net.SplitHostPort(address); err != nil {
return
}
var hosts []string
if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil {
log.Warn().Err(err).Msg("connect failed: DNS lookup error")
return
}
log.Debug().Str("address", hosts[0]).Msg("connect resolved address")
return p.dial(network, net.JoinHostPort(hosts[0], port))
}
var hopByHopHeaders = []string{
HeaderConnection,
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Proxy-Connection", // Non-standard, but required for HTTP/2.
"Te",
"Trailer",
"Transfer-Encoding",
HeaderUpgrade,
}
func cleanHopByHopHeaders(header http.Header) {
// Additional hop-by-hop headers may be specified in `Connection` headers.
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1
for _, values := range header[HeaderConnection] {
for _, key := range strings.Split(values, ",") {
header.Del(key)
}
}
for _, key := range hopByHopHeaders {
header.Del(key)
}
}
// copyStream copies data from reader to writer
func copyStream(ses *Session, w io.Writer, r io.Reader) {
log := ses.log()
if _, err := io.Copy(w, r); err != nil && !isClosing(err) {
log.Error().Err(err).Msg("failed CONNECT tunnel")
} else {
log.Debug().Msg("finished copying CONNECT tunnel")
}
}
func isClosing(err error) bool {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed {
return true
}
if err, ok := err.(net.Error); ok && err.Timeout() {
return true
}
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
return false
}
func resolveInterfaceAddr(name string) (err error) {
var iface *net.Interface
if iface, err = net.InterfaceByName(name); err != nil {
return
}
var addrs []net.Addr
if addrs, err = iface.Addrs(); err != nil {
return
}
for _, addr := range addrs {
if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() {
log.Warn().Msgf("addr %T: %s", addr, addr)
}
}
return errors.New("nope; TODO")
}

148
proxy/resolver/resolver.go Normal file
View File

@@ -0,0 +1,148 @@
// Package resolver implements a caching DNS resolver
package resolver
import (
"context"
"math/rand/v2"
"net"
"strings"
"time"
"git.maze.io/maze/styx/internal/netutil"
"github.com/hashicorp/golang-lru/v2/expirable"
)
const (
DefaultSize = 1024
DefaultTTL = 5 * time.Minute
DefaultTimeout = 10 * time.Second
)
var (
// DefaultConfig are the defaults for the Default resolver.
DefaultConfig = Config{
Size: DefaultSize,
TTL: DefaultTTL.Seconds(),
Timeout: DefaultTimeout.Seconds(),
}
// Default resolver.
Default = New(DefaultConfig)
)
type Resolver interface {
// Lookup returns resolved IPs for given hostname/ips.
Lookup(context.Context, string) ([]string, error)
}
type netResolver struct {
resolver *net.Resolver
timeout time.Duration
noIPv6 bool
cache *expirable.LRU[string, []string]
}
type Config struct {
// Size is our cache size in number of entries.
Size int `hcl:"size,optional"`
// TTL is the cache time to live in seconds.
TTL float64 `hcl:"ttl,optional"`
// Timeout is the cache timeout in seconds.
Timeout float64 `hcl:"timeout,optional"`
// Server are alternative DNS servers.
Server []string `hcl:"server,optional"`
// NoIPv6 disables IPv6 DNS resolution.
NoIPv6 bool `hcl:"noipv6,optional"`
}
func New(config Config) Resolver {
var (
size = config.Size
ttl = time.Duration(float64(time.Second) * config.TTL)
timeout = time.Duration(float64(time.Second) * config.Timeout)
)
if size <= 0 {
size = DefaultSize
}
if ttl <= 0 {
ttl = DefaultTTL
}
if timeout <= 0 {
timeout = 0
}
var resolver = new(net.Resolver)
if len(config.Server) > 0 {
var dialer net.Dialer
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
return dialer.DialContext(ctx, network, server)
}
}
return &netResolver{
resolver: resolver,
timeout: timeout,
noIPv6: config.NoIPv6,
cache: expirable.NewLRU[string, []string](size, nil, ttl),
}
}
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
host = strings.ToLower(strings.TrimSpace(host))
if hosts, ok := r.cache.Get(host); ok {
rand.Shuffle(len(hosts), func(i, j int) {
hosts[i], hosts[j] = hosts[j], hosts[i]
})
return hosts, nil
}
hosts, err := r.lookup(ctx, host)
if err != nil {
return nil, err
}
r.cache.Add(host, hosts)
return hosts, nil
}
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
if r.timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, r.timeout)
defer cancel()
}
if net.ParseIP(host) == nil {
addrs, err := r.resolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}
if r.noIPv6 {
var addrs4 []string
for _, addr := range addrs {
if net.ParseIP(addr).To4() != nil {
addrs4 = append(addrs4, addr)
}
}
return addrs4, nil
}
return addrs, nil
}
addrs, err := r.resolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
hosts := make([]string, len(addrs))
for i, addr := range addrs {
if !r.noIPv6 || addr.IP.To4() != nil {
hosts[i] = addr.IP.String()
}
}
return hosts, nil
}

78
proxy/response.go Normal file
View File

@@ -0,0 +1,78 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"strconv"
"git.maze.io/maze/styx/internal/log"
)
func NewResponse(code int, body io.Reader, request *http.Request) *http.Response {
if body == nil {
body = new(bytes.Buffer)
}
rc, ok := body.(io.ReadCloser)
if !ok {
rc = io.NopCloser(body)
}
response := &http.Response{
Status: strconv.Itoa(code) + " " + http.StatusText(code),
StatusCode: code,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Body: rc,
Request: request,
}
if request != nil {
response.Close = request.Close
response.Proto = request.Proto
response.ProtoMajor = request.ProtoMajor
response.ProtoMinor = request.ProtoMinor
}
return response
}
type withLen interface {
Len() int
}
type withSize interface {
Size() int64
}
func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response {
response := NewResponse(code, body, request)
response.Header.Set(HeaderContentType, "application/json")
if s, ok := body.(withLen); ok {
response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len()))
} else if s, ok := body.(withSize); ok {
response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10))
} else {
log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size")
}
response.Close = true
return response
}
func ErrorResponse(request *http.Request, err error) *http.Response {
response := NewResponse(http.StatusBadGateway, nil, request)
switch {
case os.IsNotExist(err):
response.StatusCode = http.StatusNotFound
case os.IsPermission(err):
response.StatusCode = http.StatusForbidden
}
response.Status = http.StatusText(response.StatusCode)
response.Close = true
return response
}

151
proxy/session.go Normal file
View File

@@ -0,0 +1,151 @@
package proxy
import (
"bufio"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"math/rand"
"net"
"net/http"
"sync/atomic"
"time"
"git.maze.io/maze/styx/internal/log"
)
var seed = rand.NewSource(time.Now().UnixNano())
type Context struct {
id int64
conn *wrappedConn
rw *bufio.ReadWriter
parent *Session
data map[string]any
}
func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context {
if wrapped, ok := conn.(*wrappedConn); ok {
conn = wrapped.Conn
}
ctx := &Context{
id: seed.Int63(),
conn: &wrappedConn{Conn: conn},
rw: rw,
parent: parent,
data: make(map[string]any),
}
return ctx
}
func (ctx *Context) log() log.Logger {
return log.Console.With().
Str("context", ctx.ID()).
Str("addr", ctx.RemoteAddr().String()).
Logger()
}
func (ctx *Context) ID() string {
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(ctx.id))
if ctx.parent != nil {
return ctx.parent.ID() + "-" + hex.EncodeToString(b[:])
}
return hex.EncodeToString(b[:])
}
func (ctx *Context) IsTLS() bool {
_, ok := ctx.conn.Conn.(*tls.Conn)
return ok && ctx.parent != nil
}
func (ctx *Context) RemoteAddr() net.Addr {
if ctx.parent != nil {
return ctx.parent.ctx.RemoteAddr()
}
return ctx.conn.RemoteAddr()
}
func (ctx *Context) SetDeadline(t time.Time) error {
if ctx.parent != nil {
return ctx.parent.ctx.SetDeadline(t)
}
return ctx.conn.SetDeadline(t)
}
func (ctx *Context) Set(key string, value any) {
ctx.data[key] = value
}
func (ctx *Context) Get(key string) (value any, ok bool) {
value, ok = ctx.data[key]
return
}
func (ctx *Context) Flush() error {
return ctx.rw.Flush()
}
func (ctx *Context) Write(p []byte) (n int, err error) {
if n, err = ctx.rw.Write(p); n > 0 {
atomic.AddInt64(&ctx.conn.bytes, int64(n))
}
return
}
type Session struct {
id int64
ctx *Context
request *http.Request
response *http.Response
data map[string]any
}
func newSession(ctx *Context, request *http.Request) *Session {
return &Session{
id: seed.Int63(),
ctx: ctx,
request: request,
data: make(map[string]any),
}
}
func (ses *Session) log() log.Logger {
return log.Console.With().
Str("context", ses.ctx.ID()).
Str("session", ses.ID()).
Str("addr", ses.ctx.RemoteAddr().String()).
Logger()
}
func (ses *Session) ID() string {
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(ses.id))
return hex.EncodeToString(b[:])
}
func (ses *Session) Context() *Context {
return ses.ctx
}
func (ses *Session) Request() *http.Request {
return ses.request
}
func (ses *Session) Response() *http.Response {
return ses.response
}
type wrappedConn struct {
net.Conn
bytes int64
}
func (c *wrappedConn) Write(p []byte) (n int, err error) {
if n, err = c.Conn.Write(p); n > 0 {
atomic.AddInt64(&c.bytes, int64(n))
}
return
}

225
proxy/stats/stats.go Normal file
View File

@@ -0,0 +1,225 @@
package stats
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"net/http"
"os"
"os/user"
"path/filepath"
"time"
"git.maze.io/maze/styx/internal/log"
_ "github.com/mattn/go-sqlite3"
)
type Stats struct {
db *sql.DB
}
func New() (*Stats, error) {
u, err := user.Current()
if err != nil {
return nil, err
}
path := filepath.Join(u.HomeDir, ".styx", "stats.db")
if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
return nil, err
}
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL")
if err != nil {
return nil, err
}
for _, table := range []string{
createLog,
createDomainStat,
createStatusStat,
} {
if _, err = db.Exec(table); err != nil {
return nil, err
}
}
return &Stats{db: db}, nil
}
func (s *Stats) AddLog(entry *Log) error {
var (
request []byte
response []byte
err error
)
if request, err = json.Marshal(entry.Request); err != nil {
return err
}
if response, err = json.Marshal(entry.Response); err != nil {
return err
}
tx, err := s.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()
if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil {
return err
}
return tx.Commit()
}
func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) {
if limit == 0 {
limit = 50
}
rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var logs []*Log
for rows.Next() {
var entry = new(Log)
if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil {
return nil, err
}
logs = append(logs, entry)
}
return logs, nil
}
type Status struct {
Code int `json:"code"`
Count int `json:"count"`
}
var timeZero time.Time
func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) {
if since.Equal(timeZero) {
since = time.Now().Add(-24 * time.Hour)
}
rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since)
if err != nil {
return nil, err
}
var stats []*Status
for rows.Next() {
var entry = new(Status)
if err = rows.Scan(&entry.Code, &entry.Count); err != nil {
return nil, err
}
stats = append(stats, entry)
}
return stats, nil
}
const createLog = `CREATE TABLE IF NOT EXISTS styx_log (
id INT PRIMARY KEY,
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
client_ip TEXT NOT NULL,
request JSONB NOT NULL,
response JSONB NOT NULL
);`
type Log struct {
Time time.Time `json:"time"`
ClientIP string `json:"client_ip"`
Request *Request `json:"request"`
Response *Response `json:"response"`
}
type Request struct {
URL string `json:"url"`
Host string `json:"host"`
Method string `json:"method"`
Proto string `json:"proto"`
Header http.Header `json:"header"`
}
func (r *Request) Scan(value any) error {
switch v := value.(type) {
case string:
return json.Unmarshal([]byte(v), r)
case []byte:
return json.Unmarshal(v, r)
default:
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type")
return nil
}
}
func (r *Request) Value() (driver.Value, error) {
b, err := json.Marshal(r)
return string(b), err
}
func FromRequest(r *http.Request) *Request {
return &Request{
URL: r.URL.String(),
Host: r.Host,
Method: r.Method,
Proto: r.Proto,
Header: r.Header,
}
}
type Response struct {
Status int `json:"status"`
Size int64 `json:"size"`
Header http.Header `json:"header"`
}
func (r *Response) Scan(value any) error {
switch v := value.(type) {
case string:
return json.Unmarshal([]byte(v), r)
case []byte:
return json.Unmarshal(v, r)
default:
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type")
return nil
}
}
func (r *Response) Value() (driver.Value, error) {
b, err := json.Marshal(r)
return string(b), err
}
func (r *Response) SetSize(size int64) *Response {
r.Size = size
return r
}
func FromResponse(r *http.Response) *Response {
return &Response{
Status: r.StatusCode,
Header: r.Header,
}
}
const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status (
id INT PRIMARY KEY,
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
status INT NOT NULL
);`
const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain (
id INT PRIMARY KEY,
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
domain TEXT NOT NULL
);`

16
proxy/util.go Normal file
View File

@@ -0,0 +1,16 @@
package proxy
import (
"io"
"net"
)
// connReader is a net.Conn with a separate reader.
type connReader struct {
net.Conn
io.Reader
}
func (c connReader) Read(p []byte) (int, error) {
return c.Reader.Read(p)
}