Checkpoint
This commit is contained in:
		
							
								
								
									
										145
									
								
								proxy/admin.go
									
									
									
									
									
								
							
							
						
						
									
										145
									
								
								proxy/admin.go
									
									
									
									
									
								
							@@ -1,145 +0,0 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"encoding/pem"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Admin struct {
 | 
			
		||||
	*Proxy
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAdmin(proxy *Proxy) *Admin {
 | 
			
		||||
	a := &Admin{
 | 
			
		||||
		Proxy: proxy,
 | 
			
		||||
	}
 | 
			
		||||
	return a
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) handleRequest(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		logger = ses.log()
 | 
			
		||||
		err    error
 | 
			
		||||
	)
 | 
			
		||||
	switch ses.request.URL.Path {
 | 
			
		||||
	case "/ca.crt":
 | 
			
		||||
		err = a.handleCACert(ses)
 | 
			
		||||
	case "/api/v1/policy":
 | 
			
		||||
		err = a.apiPolicy(ses)
 | 
			
		||||
	case "/api/v1/policy/matcher":
 | 
			
		||||
		err = a.apiPolicyMatcher(ses)
 | 
			
		||||
	case "/api/v1/stats/log":
 | 
			
		||||
		err = a.apiStatsLog(ses)
 | 
			
		||||
	case "/api/v1/stats/status":
 | 
			
		||||
		err = a.apiStatsStatus(ses)
 | 
			
		||||
	default:
 | 
			
		||||
		if strings.HasPrefix(ses.request.URL.Path, "/api") {
 | 
			
		||||
			err = errors.New("invalid endpoint")
 | 
			
		||||
		} else {
 | 
			
		||||
			err = os.ErrNotExist
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Warn().Err(err).Msg("admin error")
 | 
			
		||||
		ses.response = ErrorResponse(ses.request, err)
 | 
			
		||||
		defer log.OnCloseError(logger.Debug(), ses.response.Body)
 | 
			
		||||
		ses.response.Close = true
 | 
			
		||||
		return a.writeResponse(ses)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) handleCACert(ses *Session) error {
 | 
			
		||||
	b := pem.EncodeToMemory(&pem.Block{
 | 
			
		||||
		Type:  "CERTIFICATE",
 | 
			
		||||
		Bytes: a.authority.Certificate().Raw,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
 | 
			
		||||
	ses.response.ContentLength = int64(len(b))
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiPolicy(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		b = new(bytes.Buffer)
 | 
			
		||||
		e = json.NewEncoder(b)
 | 
			
		||||
	)
 | 
			
		||||
	e.SetIndent("", "  ")
 | 
			
		||||
	if err := e.Encode(a.config.Policy); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiPolicyMatcher(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		b = new(bytes.Buffer)
 | 
			
		||||
		e = json.NewEncoder(b)
 | 
			
		||||
	)
 | 
			
		||||
	e.SetIndent("", "  ")
 | 
			
		||||
	if err := e.Encode(a.config.Policy.Matchers); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiResponse(ses *Session, v any, err error) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		b = new(bytes.Buffer)
 | 
			
		||||
		e = json.NewEncoder(b)
 | 
			
		||||
	)
 | 
			
		||||
	e.SetIndent("", "  ")
 | 
			
		||||
	if err := e.Encode(v); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
 | 
			
		||||
	defer log.OnCloseError(log.Debug(), ses.response.Body)
 | 
			
		||||
	ses.response.Close = true
 | 
			
		||||
	return a.writeResponse(ses)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiStatsLog(ses *Session) error {
 | 
			
		||||
	var (
 | 
			
		||||
		query     = ses.request.URL.Query()
 | 
			
		||||
		offset, _ = strconv.Atoi(query.Get("offset"))
 | 
			
		||||
		limit, _  = strconv.Atoi(query.Get("limit"))
 | 
			
		||||
	)
 | 
			
		||||
	if limit > 100 {
 | 
			
		||||
		limit = 100
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s, err := a.stats.QueryLog(offset, limit)
 | 
			
		||||
	return a.apiResponse(ses, s, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Admin) apiStatsStatus(ses *Session) error {
 | 
			
		||||
	s, err := a.stats.QueryStatus(time.Time{})
 | 
			
		||||
	return a.apiResponse(ses, s, err)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										8
									
								
								proxy/cache/config.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								proxy/cache/config.go
									
									
									
									
										vendored
									
									
								
							@@ -1,8 +0,0 @@
 | 
			
		||||
package cache
 | 
			
		||||
 | 
			
		||||
import "github.com/hashicorp/hcl/v2"
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,88 +0,0 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/policy"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/resolver"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ConnectHandler interface {
 | 
			
		||||
	HandleConnect(session *Session, network, address string) net.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request.
 | 
			
		||||
type ConnectHandlerFunc func(session *Session, network, address string) net.Conn
 | 
			
		||||
 | 
			
		||||
func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn {
 | 
			
		||||
	return f(session, network, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RequestHandler interface {
 | 
			
		||||
	HandleRequest(session *Session) (*http.Request, *http.Response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestHandlerFunc is called when the proxy receives a new request.
 | 
			
		||||
type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response)
 | 
			
		||||
 | 
			
		||||
func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) {
 | 
			
		||||
	return f(session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ResponseHandler interface {
 | 
			
		||||
	HandleResponse(session *Session) *http.Response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResponseHandler is called when the proxy receives a response.
 | 
			
		||||
type ResponseHandlerFunc func(session *Session) *http.Response
 | 
			
		||||
 | 
			
		||||
func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response {
 | 
			
		||||
	return f(session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrorHandler interface {
 | 
			
		||||
	HandleError(session *Session, err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrorHandlerFunc func(session *Session, err error)
 | 
			
		||||
 | 
			
		||||
func (f ErrorHandlerFunc) HandleError(session *Session, err error) {
 | 
			
		||||
	f(session, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// Listen address.
 | 
			
		||||
	Listen string `hcl:"listen,optional"`
 | 
			
		||||
 | 
			
		||||
	// Bind address for outgoing connections.
 | 
			
		||||
	Bind string `hcl:"bind,optional"`
 | 
			
		||||
 | 
			
		||||
	// Interface for outgoing connections.
 | 
			
		||||
	Interface string `hcl:"interface,optional"`
 | 
			
		||||
 | 
			
		||||
	// Upstream proxy servers.
 | 
			
		||||
	Upstream []string `hcl:"upstream,optional"`
 | 
			
		||||
 | 
			
		||||
	// DialTimeout for establishing new connections.
 | 
			
		||||
	DialTimeout time.Duration `hcl:"dial_timeout,optional"`
 | 
			
		||||
 | 
			
		||||
	// Policy for the proxy.
 | 
			
		||||
	Policy *policy.Policy `hcl:"policy,block"`
 | 
			
		||||
 | 
			
		||||
	// Resolver for the proxy.
 | 
			
		||||
	Resolver resolver.Resolver
 | 
			
		||||
 | 
			
		||||
	ConnectHandler  ConnectHandler
 | 
			
		||||
	RequestHandler  RequestHandler
 | 
			
		||||
	ResponseHandler ResponseHandler
 | 
			
		||||
	ErrorHandler    ErrorHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	_ ConnectHandler  = (ConnectHandlerFunc)(nil)
 | 
			
		||||
	_ RequestHandler  = (RequestHandlerFunc)(nil)
 | 
			
		||||
	_ ResponseHandler = (ResponseHandlerFunc)(nil)
 | 
			
		||||
	_ ErrorHandler    = (ErrorHandlerFunc)(nil)
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										227
									
								
								proxy/context.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								proxy/context.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,227 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil/arp"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Context provides convenience functions for the current ongoing HTTP proxy transaction (request).
 | 
			
		||||
type Context interface {
 | 
			
		||||
	// Conn is the backing connection for this context.
 | 
			
		||||
	net.Conn
 | 
			
		||||
 | 
			
		||||
	// ID is a unique connection identifier.
 | 
			
		||||
	ID() uint64
 | 
			
		||||
 | 
			
		||||
	// Reader returns a buffered reader on top of the [net.Conn].
 | 
			
		||||
	Reader() *bufio.Reader
 | 
			
		||||
 | 
			
		||||
	// BytesRead returns the number of bytes read.
 | 
			
		||||
	BytesRead() int64
 | 
			
		||||
 | 
			
		||||
	// BytesSent returns the number of bytes written.
 | 
			
		||||
	BytesSent() int64
 | 
			
		||||
 | 
			
		||||
	// TLSState returns the TLS connection state, it returns nil if the connection is not a TLS connection.
 | 
			
		||||
	TLSState() *tls.ConnectionState
 | 
			
		||||
 | 
			
		||||
	// Request is the request made to the proxy.
 | 
			
		||||
	Request() *http.Request
 | 
			
		||||
 | 
			
		||||
	// Response is the response that will be sent back to the client.
 | 
			
		||||
	Response() *http.Response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type countingReader struct {
 | 
			
		||||
	reader io.Reader
 | 
			
		||||
	bytes  int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *countingReader) Read(p []byte) (n int, err error) {
 | 
			
		||||
	if n, err = r.reader.Read(p); n > 0 {
 | 
			
		||||
		atomic.AddInt64(&r.bytes, int64(n))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type countingWriter struct {
 | 
			
		||||
	writer io.Writer
 | 
			
		||||
	bytes  int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *countingWriter) Write(p []byte) (n int, err error) {
 | 
			
		||||
	if n, err = w.writer.Write(p); n > 0 {
 | 
			
		||||
		atomic.AddInt64(&w.bytes, int64(n))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type proxyContext struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	id               uint64
 | 
			
		||||
	mac              net.HardwareAddr
 | 
			
		||||
	cr               *countingReader
 | 
			
		||||
	br               *bufio.Reader
 | 
			
		||||
	cw               *countingWriter
 | 
			
		||||
	isTransparent    bool
 | 
			
		||||
	isTransparentTLS bool
 | 
			
		||||
	serverName       string
 | 
			
		||||
	req              *http.Request
 | 
			
		||||
	res              *http.Response
 | 
			
		||||
	idleTimeout      time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewContext returns an initialized context for the provided [net.Conn].
 | 
			
		||||
func NewContext(c net.Conn) Context {
 | 
			
		||||
	if c, ok := c.(*proxyContext); ok {
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b := make([]byte, 8)
 | 
			
		||||
	io.ReadFull(rand.Reader, b)
 | 
			
		||||
 | 
			
		||||
	cr := &countingReader{reader: c}
 | 
			
		||||
	cw := &countingWriter{writer: c}
 | 
			
		||||
	return &proxyContext{
 | 
			
		||||
		Conn: c,
 | 
			
		||||
		id:   binary.BigEndian.Uint64(b),
 | 
			
		||||
		mac:  arp.Get(c.RemoteAddr()),
 | 
			
		||||
		cr:   cr,
 | 
			
		||||
		br:   bufio.NewReader(cr),
 | 
			
		||||
		cw:   cw,
 | 
			
		||||
		res:  &http.Response{StatusCode: 200},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) AccessLogEntry() *logrus.Entry {
 | 
			
		||||
	var id [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(id[:], c.id)
 | 
			
		||||
	entry := AccessLog.WithFields(logrus.Fields{
 | 
			
		||||
		"client":   c.RemoteAddr().String(),
 | 
			
		||||
		"server":   c.LocalAddr().String(),
 | 
			
		||||
		"id":       hex.EncodeToString(id[:]),
 | 
			
		||||
		"bytes_rx": c.BytesRead(),
 | 
			
		||||
		"bytes_tx": c.BytesSent(),
 | 
			
		||||
	})
 | 
			
		||||
	if c.mac != nil {
 | 
			
		||||
		return entry.WithField("client_mac", c.mac.String())
 | 
			
		||||
	}
 | 
			
		||||
	return entry
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) LogEntry() *logrus.Entry {
 | 
			
		||||
	var id [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(id[:], c.id)
 | 
			
		||||
	return ServerLog.WithFields(logrus.Fields{
 | 
			
		||||
		"client": c.RemoteAddr().String(),
 | 
			
		||||
		"server": c.LocalAddr().String(),
 | 
			
		||||
		"id":     hex.EncodeToString(id[:]),
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) String() string {
 | 
			
		||||
	return fmt.Sprintf("client=%s server=%s id=%#08x",
 | 
			
		||||
		c.RemoteAddr().String(),
 | 
			
		||||
		c.LocalAddr().String(),
 | 
			
		||||
		c.id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) ID() uint64 {
 | 
			
		||||
	return c.id
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) BytesRead() int64 {
 | 
			
		||||
	return atomic.LoadInt64(&c.cr.bytes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) BytesSent() int64 {
 | 
			
		||||
	return atomic.LoadInt64(&c.cw.bytes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Read(p []byte) (n int, err error) {
 | 
			
		||||
	if c.idleTimeout > 0 {
 | 
			
		||||
		if err = c.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return c.br.Read(p)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Write(p []byte) (n int, err error) {
 | 
			
		||||
	if c.idleTimeout > 0 {
 | 
			
		||||
		if err = c.SetWriteDeadline(time.Now().Add(c.idleTimeout)); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return c.cw.Write(p)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Reader() *bufio.Reader {
 | 
			
		||||
	return c.br
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Request() *http.Request {
 | 
			
		||||
	return c.req
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) SetRequest(req *http.Request) {
 | 
			
		||||
	c.req = req
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Response() *http.Response {
 | 
			
		||||
	return c.res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) SetIdleTimeout(t time.Duration) {
 | 
			
		||||
	c.idleTimeout = t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type connectionStater interface {
 | 
			
		||||
	ConnectionState() tls.ConnectionState
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) TLSState() *tls.ConnectionState {
 | 
			
		||||
	if s, ok := c.Conn.(connectionStater); ok {
 | 
			
		||||
		state := s.ConnectionState()
 | 
			
		||||
		return &state
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// http.ResponseWriter interface:
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) Header() http.Header {
 | 
			
		||||
	if c.res == nil {
 | 
			
		||||
		c.res = NewResponse(http.StatusOK, nil, c.req)
 | 
			
		||||
	}
 | 
			
		||||
	return c.res.Header
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *proxyContext) WriteHeader(code int) {
 | 
			
		||||
	if c.res == nil {
 | 
			
		||||
		c.res = NewResponse(code, nil, c.req)
 | 
			
		||||
	} else {
 | 
			
		||||
		if text := http.StatusText(code); text != "" {
 | 
			
		||||
			c.res.Status = strconv.Itoa(code) + " " + text
 | 
			
		||||
		} else {
 | 
			
		||||
			c.res.Status = strconv.Itoa(code)
 | 
			
		||||
		}
 | 
			
		||||
		c.res.StatusCode = code
 | 
			
		||||
	}
 | 
			
		||||
	//return c.res.Header.Write(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Context = (*proxyContext)(nil)
 | 
			
		||||
							
								
								
									
										2
									
								
								proxy/doc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								proxy/doc.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
// Package proxy contains a HTTP(s) (transparent) proxy server.
 | 
			
		||||
package proxy
 | 
			
		||||
							
								
								
									
										309
									
								
								proxy/handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										309
									
								
								proxy/handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,309 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/ca"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Dialer interface {
 | 
			
		||||
	DialContext(context.Context, *http.Request) (net.Conn, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type defaultDialer struct{}
 | 
			
		||||
 | 
			
		||||
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
 | 
			
		||||
	if host := netutil.Host(req.URL.Host); host == "" {
 | 
			
		||||
		return nil, errors.New("proxy: host missing in address")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var d = net.Dialer{
 | 
			
		||||
		Resolver:      net.DefaultResolver,
 | 
			
		||||
		FallbackDelay: -1,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Ensure we have a port.
 | 
			
		||||
	switch req.URL.Scheme {
 | 
			
		||||
	case "http", "ws":
 | 
			
		||||
		req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
 | 
			
		||||
	case "https", "wss":
 | 
			
		||||
		req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Resolve the host.
 | 
			
		||||
	if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else {
 | 
			
		||||
		for _, ip := range ips {
 | 
			
		||||
			switch {
 | 
			
		||||
			case ip.IsUnspecified():
 | 
			
		||||
				return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
 | 
			
		||||
			case ip.IsLoopback():
 | 
			
		||||
				return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Make the connection.
 | 
			
		||||
	switch req.URL.Scheme {
 | 
			
		||||
	case "tcp", "http", "ws":
 | 
			
		||||
		// Plain TCP client connection.
 | 
			
		||||
		return d.DialContext(ctx, "tcp", req.URL.Host)
 | 
			
		||||
	case "https", "wss":
 | 
			
		||||
		// Secure TLS client connection.
 | 
			
		||||
		c, err := d.DialContext(ctx, "tcp", req.URL.Host)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		s := tls.Client(c, new(tls.Config))
 | 
			
		||||
		return s, s.Handshake()
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnFilter is called when a new connection has been accepted by the proxy.
 | 
			
		||||
type ConnFilter interface {
 | 
			
		||||
	FilterConn(Context) (net.Conn, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
 | 
			
		||||
type ConnFilterFunc func(Context) (net.Conn, error)
 | 
			
		||||
 | 
			
		||||
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
 | 
			
		||||
	return f(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TLS starts a TLS handshake on the accepted connection.
 | 
			
		||||
func TLS(certs []tls.Certificate) ConnFilter {
 | 
			
		||||
	return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
 | 
			
		||||
		s := tls.Server(ctx, &tls.Config{
 | 
			
		||||
			Certificates: certs,
 | 
			
		||||
			NextProtos:   []string{"http/1.1"},
 | 
			
		||||
		})
 | 
			
		||||
		if err := s.Handshake(); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return s, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
 | 
			
		||||
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
 | 
			
		||||
	return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
 | 
			
		||||
		s := tls.Server(ctx, &tls.Config{
 | 
			
		||||
			GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
				ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
 | 
			
		||||
				return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
 | 
			
		||||
			},
 | 
			
		||||
			NextProtos: []string{"http/1.1"},
 | 
			
		||||
		})
 | 
			
		||||
		if err := s.Handshake(); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return s, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Transparent can handle transparent HTTP(S) requests on the port.
 | 
			
		||||
//
 | 
			
		||||
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
 | 
			
		||||
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
 | 
			
		||||
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
 | 
			
		||||
func Transparent() ConnFilter {
 | 
			
		||||
	return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
 | 
			
		||||
		ctx, ok := nctx.(*proxyContext)
 | 
			
		||||
		if !ok {
 | 
			
		||||
			return nctx, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		b := new(bytes.Buffer)
 | 
			
		||||
		hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if _, ok := err.(tls.RecordHeaderError); !ok {
 | 
			
		||||
				ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			// Not a TLS connection, moving on to regular HTTP request handling...
 | 
			
		||||
			ctx.LogEntry().Debug("HTTP connection on transparent port")
 | 
			
		||||
			ctx.isTransparent = true
 | 
			
		||||
		} else {
 | 
			
		||||
			ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
 | 
			
		||||
			ctx.isTransparentTLS = true
 | 
			
		||||
			ctx.serverName = hello.ServerName
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return netutil.ReaderConn{
 | 
			
		||||
			Conn:   ctx.Conn,
 | 
			
		||||
			Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
 | 
			
		||||
		}, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestFilter can filter HTTP requests coming to the proxy.
 | 
			
		||||
type RequestFilter interface {
 | 
			
		||||
	// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
 | 
			
		||||
	// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
 | 
			
		||||
	// from [Context.Response].
 | 
			
		||||
	//
 | 
			
		||||
	// Modifications to the current request can be made to the Request returned by [Context.Request]
 | 
			
		||||
	// and do not require returning a new [http.Request].
 | 
			
		||||
	//
 | 
			
		||||
	// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
 | 
			
		||||
	// any upstream server(s).
 | 
			
		||||
	FilterRequest(Context) (*http.Request, *http.Response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
 | 
			
		||||
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
 | 
			
		||||
 | 
			
		||||
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
 | 
			
		||||
	return f(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResponseFilter can filter HTTP responses coming from the proxy.
 | 
			
		||||
type ResponseFilter interface {
 | 
			
		||||
	// FilterResponse filters a HTTP response coming from the proxy. The current response may be
 | 
			
		||||
	// obtained from [Context.Response].
 | 
			
		||||
	//
 | 
			
		||||
	// Modifications to the current response can be made to the [Response] returned by [Context.Response].
 | 
			
		||||
	FilterResponse(Context) *http.Response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
 | 
			
		||||
type ResponseFilterFunc func(Context) *http.Response
 | 
			
		||||
 | 
			
		||||
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
 | 
			
		||||
	return f(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
 | 
			
		||||
func CleanRequestProxyHeaders() RequestFilter {
 | 
			
		||||
	return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
 | 
			
		||||
		if req := ctx.Request(); req != nil {
 | 
			
		||||
			cleanProxyHeaders(req.Header)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
 | 
			
		||||
func CleanResponseProxyHeaders() ResponseFilter {
 | 
			
		||||
	return ResponseFilterFunc(func(ctx Context) *http.Response {
 | 
			
		||||
		if res := ctx.Response(); res != nil {
 | 
			
		||||
			cleanProxyHeaders(res.Header)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
 | 
			
		||||
// key will remain intact.
 | 
			
		||||
func AddRequestHeaders(h http.Header) RequestFilter {
 | 
			
		||||
	return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
 | 
			
		||||
		if req := ctx.Request(); req != nil {
 | 
			
		||||
			if req.Header == nil {
 | 
			
		||||
				req.Header = make(http.Header)
 | 
			
		||||
			}
 | 
			
		||||
			addHeaders(req.Header, h)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
 | 
			
		||||
// key will be removed.
 | 
			
		||||
func SetRequestHeaders(h http.Header) RequestFilter {
 | 
			
		||||
	return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
 | 
			
		||||
		if req := ctx.Request(); req != nil {
 | 
			
		||||
			if req.Header == nil {
 | 
			
		||||
				req.Header = make(http.Header)
 | 
			
		||||
			}
 | 
			
		||||
			setHeaders(req.Header, h)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
 | 
			
		||||
// key will remain intact.
 | 
			
		||||
func AddResponseHeaders(h http.Header) ResponseFilter {
 | 
			
		||||
	return ResponseFilterFunc(func(ctx Context) *http.Response {
 | 
			
		||||
		if res := ctx.Response(); res != nil {
 | 
			
		||||
			if res.Header == nil {
 | 
			
		||||
				res.Header = make(http.Header)
 | 
			
		||||
			}
 | 
			
		||||
			addHeaders(res.Header, h)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
 | 
			
		||||
// key will be removed.
 | 
			
		||||
func SetResponseHeaders(h http.Header) ResponseFilter {
 | 
			
		||||
	return ResponseFilterFunc(func(ctx Context) *http.Response {
 | 
			
		||||
		if res := ctx.Response(); res != nil {
 | 
			
		||||
			if res.Header == nil {
 | 
			
		||||
				res.Header = make(http.Header)
 | 
			
		||||
			}
 | 
			
		||||
			setHeaders(res.Header, h)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
 | 
			
		||||
func cleanProxyHeaders(h http.Header) {
 | 
			
		||||
	if h == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, k := range []string{
 | 
			
		||||
		HeaderForwarded,
 | 
			
		||||
		HeaderForwardedFor,
 | 
			
		||||
		HeaderForwardedHost,
 | 
			
		||||
		HeaderForwardedPort,
 | 
			
		||||
		HeaderForwardedProto,
 | 
			
		||||
		HeaderRealIP,
 | 
			
		||||
		HeaderVia,
 | 
			
		||||
	} {
 | 
			
		||||
		h.Del(k)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// addHeaders adds to the current existing headers.
 | 
			
		||||
func addHeaders(dst, src http.Header) {
 | 
			
		||||
	if src == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k, vv := range src {
 | 
			
		||||
		for _, v := range vv {
 | 
			
		||||
			dst.Add(k, v)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// setHeaders replaces all previous values.
 | 
			
		||||
func setHeaders(dst, src http.Header) {
 | 
			
		||||
	if src == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k, vv := range src {
 | 
			
		||||
		dst.Del(k)
 | 
			
		||||
		for _, v := range vv {
 | 
			
		||||
			dst.Add(k, v)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,324 +0,0 @@
 | 
			
		||||
package match
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Path    string        `hcl:"path,optional"`
 | 
			
		||||
	Refresh time.Duration `hcl:"refresh,optional"`
 | 
			
		||||
	Domain  []*Domain     `hcl:"domain,block"`
 | 
			
		||||
	Network []*Network    `hcl:"network,block"`
 | 
			
		||||
	Content []*Content    `hcl:"content,block"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config Config) Matchers() (Matchers, error) {
 | 
			
		||||
	all := make(Matchers)
 | 
			
		||||
	if config.Domain != nil {
 | 
			
		||||
		all["domain"] = make(map[string]Matcher)
 | 
			
		||||
		for _, domain := range config.Domain {
 | 
			
		||||
			m, err := domain.Matcher()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
			all["domain"][domain.Name] = m
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if config.Network != nil {
 | 
			
		||||
		all["network"] = make(map[string]Matcher)
 | 
			
		||||
		for _, network := range config.Network {
 | 
			
		||||
			m, err := network.Matcher(true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
			all["network"][network.Name] = m
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return all, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Content struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentHeader struct {
 | 
			
		||||
	Key     string   `hcl:"name"`
 | 
			
		||||
	Value   string   `hcl:"value,optional"`
 | 
			
		||||
	List    []string `hcl:"list,optional"`
 | 
			
		||||
	name    string
 | 
			
		||||
	keyRe   *regexp.Regexp
 | 
			
		||||
	valueRe *regexp.Regexp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentHeader) Name() string { return m.name }
 | 
			
		||||
func (m contentHeader) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	for k, vv := range r.Header {
 | 
			
		||||
		if m.keyRe.MatchString(k) {
 | 
			
		||||
			for _, v := range vv {
 | 
			
		||||
				if slices.Contains(m.List, v) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				if m.valueRe != nil && m.valueRe.MatchString(v) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentType struct {
 | 
			
		||||
	List []string `hcl:"list"`
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentType) Name() string { return m.name }
 | 
			
		||||
func (m contentType) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	return slices.Contains(m.List, r.Header.Get("Content-Type"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentSizeLargerThan struct {
 | 
			
		||||
	Size int64 `hcl:"size"`
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentSizeLargerThan) Name() string { return m.name }
 | 
			
		||||
func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return size >= m.Size
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contentStatus struct {
 | 
			
		||||
	Code []int `hcl:"code"`
 | 
			
		||||
	name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m contentStatus) Name() string { return m.name }
 | 
			
		||||
func (m contentStatus) MatchesResponse(r *http.Response) bool {
 | 
			
		||||
	return slices.Contains(m.Code, r.StatusCode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config Content) Matcher() (Response, error) {
 | 
			
		||||
	switch strings.ToLower(config.Type) {
 | 
			
		||||
	case "content", "contenttype", "content-type", "type":
 | 
			
		||||
		var matcher = contentType{name: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "header":
 | 
			
		||||
		var (
 | 
			
		||||
			matcher = contentHeader{name: config.Name}
 | 
			
		||||
			err     error
 | 
			
		||||
		)
 | 
			
		||||
		if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.Value == "" && len(matcher.List) == 0 {
 | 
			
		||||
			return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name)
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err)
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.Value != "" {
 | 
			
		||||
			if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "size":
 | 
			
		||||
		var matcher = contentSizeLargerThan{name: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "status":
 | 
			
		||||
		var matcher = contentStatus{name: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("unknown content matcher type %q", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Domain struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config Domain) Matcher() (Request, error) {
 | 
			
		||||
	switch config.Type {
 | 
			
		||||
	case "list":
 | 
			
		||||
		var matcher = domainList{Title: config.Name}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		matcher.list = netutil.NewDomainList(matcher.List...)
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	case "adblock", "dnsmasq", "hosts", "detect", "domains":
 | 
			
		||||
		var matcher = DomainFile{
 | 
			
		||||
			Title: config.Name,
 | 
			
		||||
			Type:  config.Type,
 | 
			
		||||
		}
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.Path == "" && matcher.From == "" {
 | 
			
		||||
			return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name)
 | 
			
		||||
		}
 | 
			
		||||
		if err := matcher.Update(); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return matcher, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("unknown domain matcher type %q", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type domainList struct {
 | 
			
		||||
	Title string   `json:"title"`
 | 
			
		||||
	List  []string `hcl:"list" json:"list"`
 | 
			
		||||
	list  *netutil.DomainTree
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m domainList) Name() string {
 | 
			
		||||
	return m.Title
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m domainList) MatchesRequest(r *http.Request) bool {
 | 
			
		||||
	host := netutil.Host(r.URL.Host)
 | 
			
		||||
	log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List))
 | 
			
		||||
	return m.list.Contains(host)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DomainFile struct {
 | 
			
		||||
	Title   string        `json:"name"`
 | 
			
		||||
	Type    string        `json:"type"`
 | 
			
		||||
	Path    string        `hcl:"path,optional" json:"path,omitempty"`
 | 
			
		||||
	From    string        `hcl:"from,optional" json:"from,omitempty"`
 | 
			
		||||
	Refresh time.Duration `hcl:"refresh,optional" json:"refresh"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m DomainFile) Name() string {
 | 
			
		||||
	return m.Title
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m DomainFile) MatchesRequest(_ *http.Request) bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *DomainFile) Update() (err error) {
 | 
			
		||||
	var data []byte
 | 
			
		||||
	if m.Path != "" {
 | 
			
		||||
		if data, err = os.ReadFile(m.Path); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		/*
 | 
			
		||||
			var response *http.Response
 | 
			
		||||
			if response, err = http.DefaultClient.Get(m.From); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			defer func() { _ = response.Body.Close() }()
 | 
			
		||||
			if response.StatusCode != http.StatusOK {
 | 
			
		||||
				return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status)
 | 
			
		||||
			}
 | 
			
		||||
			if data, err = io.ReadAll(response.Body); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		*/
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch m.Type {
 | 
			
		||||
	case "hosts":
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ = data
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Network struct {
 | 
			
		||||
	Name string   `hcl:"name,label"`
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config *Network) Matcher(target bool) (Matcher, error) {
 | 
			
		||||
	switch config.Type {
 | 
			
		||||
	case "list":
 | 
			
		||||
		var (
 | 
			
		||||
			matcher = networkList{Title: config.Name}
 | 
			
		||||
			err     error
 | 
			
		||||
		)
 | 
			
		||||
		if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() {
 | 
			
		||||
			return nil, diag
 | 
			
		||||
		}
 | 
			
		||||
		if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return &matcher, nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("unknown network matcher type %q", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type networkList struct {
 | 
			
		||||
	Title  string   `json:"name"`
 | 
			
		||||
	List   []string `hcl:"list" json:"list"`
 | 
			
		||||
	tree   *netutil.NetworkTree
 | 
			
		||||
	target bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *networkList) Name() string {
 | 
			
		||||
	return m.Title
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *networkList) MatchesIP(ip net.IP) bool {
 | 
			
		||||
	return m.tree.Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *networkList) MatchesRequest(r *http.Request) bool {
 | 
			
		||||
	var (
 | 
			
		||||
		host string
 | 
			
		||||
		err  error
 | 
			
		||||
	)
 | 
			
		||||
	if m.target {
 | 
			
		||||
		host, _, err = net.SplitHostPort(r.URL.Host)
 | 
			
		||||
	} else {
 | 
			
		||||
		host, _, err = net.SplitHostPort(r.RemoteAddr)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	ip := net.ParseIP(host)
 | 
			
		||||
	return m.MatchesIP(ip)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,45 +0,0 @@
 | 
			
		||||
package match
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Matchers map[string]map[string]Matcher
 | 
			
		||||
 | 
			
		||||
func (all Matchers) Get(kind, name string) (m Matcher, err error) {
 | 
			
		||||
	if typeMatchers, ok := all[kind]; ok {
 | 
			
		||||
		if m, ok = typeMatchers[name]; ok {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("no %s matcher named %q found", kind, name)
 | 
			
		||||
	}
 | 
			
		||||
	return nil, fmt.Errorf("no %s matcher found", kind)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Matcher interface {
 | 
			
		||||
	Name() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Updater interface {
 | 
			
		||||
	Update() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type IP interface {
 | 
			
		||||
	Matcher
 | 
			
		||||
 | 
			
		||||
	MatchesIP(net.IP) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Request interface {
 | 
			
		||||
	Matcher
 | 
			
		||||
 | 
			
		||||
	MatchesRequest(*http.Request) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Response interface {
 | 
			
		||||
	Matcher
 | 
			
		||||
 | 
			
		||||
	MatchesResponse(*http.Response) bool
 | 
			
		||||
}
 | 
			
		||||
@@ -1,11 +0,0 @@
 | 
			
		||||
package match
 | 
			
		||||
 | 
			
		||||
import "net"
 | 
			
		||||
 | 
			
		||||
func onlyHost(name string) string {
 | 
			
		||||
	host, _, err := net.SplitHostPort(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return name
 | 
			
		||||
	}
 | 
			
		||||
	return host
 | 
			
		||||
}
 | 
			
		||||
@@ -1,231 +0,0 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const DefaultValidity = 24 * time.Hour
 | 
			
		||||
 | 
			
		||||
type Authority interface {
 | 
			
		||||
	Certificate() *x509.Certificate
 | 
			
		||||
	TLSConfig(name string) *tls.Config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type authority struct {
 | 
			
		||||
	pool    *x509.CertPool
 | 
			
		||||
	cert    *x509.Certificate
 | 
			
		||||
	key     crypto.PrivateKey
 | 
			
		||||
	keyID   []byte
 | 
			
		||||
	keyPool chan crypto.PrivateKey
 | 
			
		||||
	cache   Cache
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(config *Config) (Authority, error) {
 | 
			
		||||
	cache, err := NewCache(config.Cache)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	caConfig := config.CA
 | 
			
		||||
	if caConfig == nil {
 | 
			
		||||
		caConfig = new(CAConfig)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
 | 
			
		||||
	if os.IsNotExist(err) {
 | 
			
		||||
		days := caConfig.Days
 | 
			
		||||
		if days == 0 {
 | 
			
		||||
			days = DefaultDays
 | 
			
		||||
		}
 | 
			
		||||
		if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
 | 
			
		||||
			if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pool := x509.NewCertPool()
 | 
			
		||||
	pool.AddCert(cert)
 | 
			
		||||
 | 
			
		||||
	keyConfig := config.Key
 | 
			
		||||
	if keyConfig == nil {
 | 
			
		||||
		keyConfig = &defaultKeyConfig
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	keyPoolSize := defaultKeyConfig.Pool
 | 
			
		||||
	if keyConfig.Pool > 0 {
 | 
			
		||||
		keyPoolSize = keyConfig.Pool
 | 
			
		||||
	}
 | 
			
		||||
	keyPool := make(chan crypto.PrivateKey, keyPoolSize)
 | 
			
		||||
	if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		keyPool <- key
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func(pool chan<- crypto.PrivateKey) {
 | 
			
		||||
		for {
 | 
			
		||||
			key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Panic().Err(err).Msg("error generating private key")
 | 
			
		||||
			}
 | 
			
		||||
			pool <- key
 | 
			
		||||
		}
 | 
			
		||||
	}(keyPool)
 | 
			
		||||
 | 
			
		||||
	return &authority{
 | 
			
		||||
		pool:    pool,
 | 
			
		||||
		cert:    cert,
 | 
			
		||||
		key:     key,
 | 
			
		||||
		keyID:   cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
 | 
			
		||||
		keyPool: keyPool,
 | 
			
		||||
		cache:   cache,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) log() log.Logger {
 | 
			
		||||
	return log.Console.With().
 | 
			
		||||
		Str("ca", ca.cert.Subject.String()).
 | 
			
		||||
		Logger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) Certificate() *x509.Certificate {
 | 
			
		||||
	return ca.cert
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) TLSConfig(name string) *tls.Config {
 | 
			
		||||
	return &tls.Config{
 | 
			
		||||
		GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
			log := ca.log()
 | 
			
		||||
			if hello.ServerName != "" {
 | 
			
		||||
				name = strings.ToLower(hello.ServerName)
 | 
			
		||||
				log.Debug().Msg("requesting certificate for server name (SNI)")
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Debug().Msg("requesting certificate for hostname")
 | 
			
		||||
			}
 | 
			
		||||
			if cert, ok := ca.getCached(name); ok {
 | 
			
		||||
				log.Debug().
 | 
			
		||||
					Str("subject", cert.Leaf.Subject.String()).
 | 
			
		||||
					Str("serial", cert.Leaf.SerialNumber.String()).
 | 
			
		||||
					Time("valid", cert.Leaf.NotAfter).
 | 
			
		||||
					Msg("using cached certificate")
 | 
			
		||||
				return cert, nil
 | 
			
		||||
			}
 | 
			
		||||
			return ca.issueFor(name)
 | 
			
		||||
		},
 | 
			
		||||
		NextProtos: []string{"http/1.1"},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
 | 
			
		||||
	log := ca.log()
 | 
			
		||||
 | 
			
		||||
	if cert = ca.cache.Certificate(name); cert == nil {
 | 
			
		||||
		if baseDomain(name) != name {
 | 
			
		||||
			cert = ca.cache.Certificate(baseDomain(name))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if cert != nil {
 | 
			
		||||
		if _, err := cert.Leaf.Verify(x509.VerifyOptions{
 | 
			
		||||
			DNSName: name,
 | 
			
		||||
			Roots:   ca.pool,
 | 
			
		||||
		}); err != nil {
 | 
			
		||||
			log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
 | 
			
		||||
		} else {
 | 
			
		||||
			ok = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		log = ca.log().With().Str("name", name).Logger()
 | 
			
		||||
		key crypto.PrivateKey
 | 
			
		||||
	)
 | 
			
		||||
	select {
 | 
			
		||||
	case key = <-ca.keyPool:
 | 
			
		||||
	case <-time.After(5 * time.Second):
 | 
			
		||||
		return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
 | 
			
		||||
	}
 | 
			
		||||
	if key == nil {
 | 
			
		||||
		panic("key pool returned nil key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 | 
			
		||||
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if part := dns.SplitDomainName(name); len(part) > 2 {
 | 
			
		||||
		name = strings.Join(part[1:], ".")
 | 
			
		||||
		log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	template := &x509.Certificate{
 | 
			
		||||
		SerialNumber:          serialNumber,
 | 
			
		||||
		Subject:               pkix.Name{CommonName: name},
 | 
			
		||||
		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
 | 
			
		||||
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
 | 
			
		||||
		DNSNames:              []string{name, "*." + name},
 | 
			
		||||
		BasicConstraintsValid: true,
 | 
			
		||||
		NotBefore:             now.Add(-DefaultValidity),
 | 
			
		||||
		NotAfter:              now.Add(+DefaultValidity),
 | 
			
		||||
	}
 | 
			
		||||
	der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	cert, err := x509.ParseCertificate(der)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
 | 
			
		||||
	out := &tls.Certificate{
 | 
			
		||||
		Certificate: [][]byte{der},
 | 
			
		||||
		Leaf:        cert,
 | 
			
		||||
		PrivateKey:  key,
 | 
			
		||||
	}
 | 
			
		||||
	//ca.cache[name] = out
 | 
			
		||||
	ca.cache.SaveCertificate(name, out)
 | 
			
		||||
	return out, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func containsValidCertificate(cert *tls.Certificate) bool {
 | 
			
		||||
	if cert == nil || len(cert.Certificate) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cert.Leaf == nil {
 | 
			
		||||
		var err error
 | 
			
		||||
		if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
 | 
			
		||||
	return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
 | 
			
		||||
}
 | 
			
		||||
@@ -1,233 +0,0 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/hashicorp/golang-lru/v2/expirable"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/cryptutil"
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Cache interface {
 | 
			
		||||
	Certificate(name string) *tls.Certificate
 | 
			
		||||
	SaveCertificate(name string, cert *tls.Certificate) error
 | 
			
		||||
	RemoveCertificate(name string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCache(config *CacheConfig) (Cache, error) {
 | 
			
		||||
	if config == nil {
 | 
			
		||||
		return NewCache(&CacheConfig{Type: "memory"})
 | 
			
		||||
	}
 | 
			
		||||
	switch config.Type {
 | 
			
		||||
	case "memory":
 | 
			
		||||
		var cacheConfig = new(MemoryCacheConfig)
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return NewMemoryCache(cacheConfig.Size), nil
 | 
			
		||||
	case "disk":
 | 
			
		||||
		var cacheConfig = new(DiskCacheConfig)
 | 
			
		||||
		if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type memoryCache struct {
 | 
			
		||||
	cache *expirable.LRU[string, *tls.Certificate]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMemoryCache(size int) Cache {
 | 
			
		||||
	return memoryCache{
 | 
			
		||||
		cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
 | 
			
		||||
			log.Debug().Str("name", key).Msg("certificate evicted from cache")
 | 
			
		||||
		}, time.Hour*24),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
 | 
			
		||||
	var ok bool
 | 
			
		||||
	if cert, ok = c.cache.Get(name); !ok {
 | 
			
		||||
		cert, _ = c.cache.Get(baseDomain(name))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
 | 
			
		||||
	c.cache.Add(name, cert)
 | 
			
		||||
	log.Debug().Str("name", name).Msg("certificate added to cache")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c memoryCache) RemoveCertificate(name string) {
 | 
			
		||||
	c.cache.Remove(name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type diskCache string
 | 
			
		||||
 | 
			
		||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
 | 
			
		||||
	if !filepath.IsAbs(dir) {
 | 
			
		||||
		var err error
 | 
			
		||||
		if dir, err = filepath.Abs(dir); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := os.MkdirAll(dir, 0o750); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	info, err := os.Stat(dir)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if info.Mode()&os.ModePerm|0o057 != 0 {
 | 
			
		||||
		if err := os.Chmod(dir, 0o750); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if expire > 0 {
 | 
			
		||||
		go expireDiskCache(dir, expire)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return diskCache(dir), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func expireDiskCache(root string, expire time.Duration) {
 | 
			
		||||
	log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
 | 
			
		||||
	ticker := time.NewTicker(expire)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	for {
 | 
			
		||||
		now := <-ticker.C
 | 
			
		||||
		log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
 | 
			
		||||
		filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if d.IsDir() {
 | 
			
		||||
				// Remove the directory; this will fail if it's not empty, which is fine.
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			cert, err := cryptutil.LoadCertificate(path)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				return nil
 | 
			
		||||
			} else if cert.NotAfter.Before(now) {
 | 
			
		||||
				log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) path(name string) string {
 | 
			
		||||
	part := dns.SplitDomainName(strings.ToLower(name))
 | 
			
		||||
	// x,com -> com,x
 | 
			
		||||
	// www,maze,io -> io,maze,www
 | 
			
		||||
	slices.Reverse(part)
 | 
			
		||||
	// com,x -> com,x,x.com
 | 
			
		||||
	// io,maze,www -> io,m,ma,maze,www.maze.io
 | 
			
		||||
	if len(part) > 2 {
 | 
			
		||||
		if len(part[1]) > 1 {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				part[1][:2],
 | 
			
		||||
				part[1],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				part[1],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else if len(part) > 1 {
 | 
			
		||||
		if len(part[1]) > 1 {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				part[1][:2],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			part = []string{
 | 
			
		||||
				part[0],
 | 
			
		||||
				part[1][:1],
 | 
			
		||||
				name,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	part[len(part)-1] += ".crt"
 | 
			
		||||
	return filepath.Join(append([]string{string(c)}, part...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
 | 
			
		||||
	if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
 | 
			
		||||
		return &tls.Certificate{
 | 
			
		||||
			Certificate: [][]byte{cert.Raw},
 | 
			
		||||
			Leaf:        cert,
 | 
			
		||||
			PrivateKey:  key,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
 | 
			
		||||
		return &tls.Certificate{
 | 
			
		||||
			Certificate: [][]byte{cert.Raw},
 | 
			
		||||
			Leaf:        cert,
 | 
			
		||||
			PrivateKey:  key,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
 | 
			
		||||
	dir, name := filepath.Split(c.path(name))
 | 
			
		||||
	if err := os.MkdirAll(dir, 0o750); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	log.Debug().Str("name", name).Msg("certificate added to cache")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c diskCache) RemoveCertificate(name string) {
 | 
			
		||||
	path := c.path(name)
 | 
			
		||||
	if err := os.Remove(path); err != nil {
 | 
			
		||||
		if os.IsNotExist(err) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		log.Error().Err(err).Msg("certificate remove from cache failed")
 | 
			
		||||
	}
 | 
			
		||||
	_ = os.Remove(filepath.Dir(path))
 | 
			
		||||
	log.Debug().Str("name", name).Msg("certificate removed from cache")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baseDomain(name string) string {
 | 
			
		||||
	name = strings.ToLower(name)
 | 
			
		||||
	if part := dns.SplitDomainName(name); len(part) > 2 {
 | 
			
		||||
		return strings.Join(part[1:], ".")
 | 
			
		||||
	}
 | 
			
		||||
	return name
 | 
			
		||||
}
 | 
			
		||||
@@ -1,25 +0,0 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
func TestDiskCachePath(t *testing.T) {
 | 
			
		||||
	cache := diskCache("testdata")
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		test string
 | 
			
		||||
		want string
 | 
			
		||||
	}{
 | 
			
		||||
		{"x.com", "testdata/com/x/x.com.crt"},
 | 
			
		||||
		{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
 | 
			
		||||
		{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
 | 
			
		||||
		{"maze.io", "testdata/io/m/ma/maze.io.crt"},
 | 
			
		||||
		{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
 | 
			
		||||
		{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.test, func(it *testing.T) {
 | 
			
		||||
			if v := cache.path(test.test); v != test.want {
 | 
			
		||||
				it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,89 +0,0 @@
 | 
			
		||||
package mitm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DefaultCommonName = "Styx Certificate Authority"
 | 
			
		||||
	DefaultDays       = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	CA    *CAConfig    `hcl:"ca,block"`
 | 
			
		||||
	Key   *KeyConfig   `hcl:"key,block"`
 | 
			
		||||
	Cache *CacheConfig `hcl:"cache,block"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CAConfig struct {
 | 
			
		||||
	Cert         string   `hcl:"cert"`
 | 
			
		||||
	Key          string   `hcl:"key,optional"`
 | 
			
		||||
	Days         int      `hcl:"days,optional"`
 | 
			
		||||
	KeyType      string   `hcl:"key_type,optional"`
 | 
			
		||||
	Bits         int      `hcl:"bits,optional"`
 | 
			
		||||
	Name         string   `hcl:"name,optional"`
 | 
			
		||||
	Country      string   `hcl:"country,optional"`
 | 
			
		||||
	Organization string   `hcl:"organization,optional"`
 | 
			
		||||
	Unit         string   `hcl:"unit,optional"`
 | 
			
		||||
	Locality     string   `hcl:"locality,optional"`
 | 
			
		||||
	Province     string   `hcl:"province,optional"`
 | 
			
		||||
	Address      []string `hcl:"address,optional"`
 | 
			
		||||
	PostalCode   string   `hcl:"postal_code,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (config CAConfig) DN() pkix.Name {
 | 
			
		||||
	var name = pkix.Name{
 | 
			
		||||
		CommonName:    config.Name,
 | 
			
		||||
		StreetAddress: config.Address,
 | 
			
		||||
	}
 | 
			
		||||
	if config.Name == "" {
 | 
			
		||||
		name.CommonName = DefaultCommonName
 | 
			
		||||
	}
 | 
			
		||||
	if config.Country != "" {
 | 
			
		||||
		name.Country = append(name.Country, config.Country)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Organization != "" {
 | 
			
		||||
		name.Organization = append(name.Organization, config.Organization)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Unit != "" {
 | 
			
		||||
		name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Locality != "" {
 | 
			
		||||
		name.Locality = append(name.Locality, config.Locality)
 | 
			
		||||
	}
 | 
			
		||||
	if config.Province != "" {
 | 
			
		||||
		name.Province = append(name.Province, config.Province)
 | 
			
		||||
	}
 | 
			
		||||
	if config.PostalCode != "" {
 | 
			
		||||
		name.PostalCode = append(name.PostalCode, config.PostalCode)
 | 
			
		||||
	}
 | 
			
		||||
	return name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type KeyConfig struct {
 | 
			
		||||
	Type string `hcl:"type,optional"`
 | 
			
		||||
	Bits int    `hcl:"bits,optional"`
 | 
			
		||||
	Pool int    `hcl:"pool,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var defaultKeyConfig = KeyConfig{
 | 
			
		||||
	Type: "rsa",
 | 
			
		||||
	Bits: 2048,
 | 
			
		||||
	Pool: 5,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CacheConfig struct {
 | 
			
		||||
	Type string   `hcl:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MemoryCacheConfig struct {
 | 
			
		||||
	Size int `hcl:"size,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DiskCacheConfig struct {
 | 
			
		||||
	Path   string  `hcl:"path"`
 | 
			
		||||
	Expire float64 `hcl:"expire,optional"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,53 +0,0 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Policy contains rules that make up the policy.
 | 
			
		||||
//
 | 
			
		||||
// Some policy rules contain nested policies.
 | 
			
		||||
type Policy struct {
 | 
			
		||||
	Rules    []*rawRule     `hcl:"on,block" json:"rules"`
 | 
			
		||||
	Permit   *bool          `hcl:"permit" json:"permit"`
 | 
			
		||||
	Matchers match.Matchers `json:"matchers"` // Matchers for the policy
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	for _, r := range p.Rules {
 | 
			
		||||
		if err = r.Configure(matchers); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	p.Matchers = matchers
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) PermitIntercept(r *http.Request) *bool {
 | 
			
		||||
	if p != nil {
 | 
			
		||||
		for _, rule := range p.Rules {
 | 
			
		||||
			if rule, ok := rule.Rule.(InterceptRule); ok {
 | 
			
		||||
				if permit := rule.PermitIntercept(r); permit != nil {
 | 
			
		||||
					return permit
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return p.Permit
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *Policy) PermitRequest(r *http.Request) *bool {
 | 
			
		||||
	if p != nil {
 | 
			
		||||
		for _, rule := range p.Rules {
 | 
			
		||||
			if rule, ok := rule.Rule.(RequestRule); ok {
 | 
			
		||||
				if permit := rule.PermitRequest(r); permit != nil {
 | 
			
		||||
					return permit
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return p.Permit
 | 
			
		||||
}
 | 
			
		||||
@@ -1,139 +0,0 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type testInDomainList struct {
 | 
			
		||||
	t    *testing.T
 | 
			
		||||
	list []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (testInDomainList) Name() string { return "testInDomainList" }
 | 
			
		||||
func (l testInDomainList) MatchesRequest(r *http.Request) bool {
 | 
			
		||||
	for _, domain := range l.list {
 | 
			
		||||
		if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) {
 | 
			
		||||
			l.t.Logf("domain %s contains %s", domain, r.URL.Host)
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		l.t.Logf("domain %s does not contain %s", domain, r.URL.Host)
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testInDomain(t *testing.T, domains ...string) match.Matcher {
 | 
			
		||||
	return &testInDomainList{t: t, list: domains}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testInNetworkList struct {
 | 
			
		||||
	t    *testing.T
 | 
			
		||||
	list []*net.IPNet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (testInNetworkList) Name() string { return "testInNetworkList" }
 | 
			
		||||
func (l testInNetworkList) MatchesIP(ip net.IP) bool {
 | 
			
		||||
	for _, ipnet := range l.list {
 | 
			
		||||
		if ipnet.Contains(ip) {
 | 
			
		||||
			l.t.Logf("network %s contains %s", ipnet, ip)
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		l.t.Logf("network %s does not contain %s", ipnet, ip)
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testInNetwork(t *testing.T, cidr string) match.Matcher {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	_, ipnet, err := net.ParseCIDR(cidr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return testInNetworkList{t: t, list: []*net.IPNet{ipnet}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPolicy(t *testing.T) {
 | 
			
		||||
	var (
 | 
			
		||||
		yes  = true
 | 
			
		||||
		nope = false
 | 
			
		||||
	)
 | 
			
		||||
	p := &Policy{
 | 
			
		||||
		Rules: []*rawRule{
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
 | 
			
		||||
						isSource: []bool{true},
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
 | 
			
		||||
						isSource: []bool{false},
 | 
			
		||||
					},
 | 
			
		||||
					Permit: &yes,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")},
 | 
			
		||||
					},
 | 
			
		||||
					Permit: &yes,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Rule: &requestRule{
 | 
			
		||||
					domainOrNetworkRule: domainOrNetworkRule{
 | 
			
		||||
						matchers: []match.Matcher{testInDomain(t, "google.com")},
 | 
			
		||||
					},
 | 
			
		||||
					Permit: &nope,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r := &http.Request{
 | 
			
		||||
		URL:        &url.URL{Scheme: "http", Host: "golang.org:80"},
 | 
			
		||||
		RemoteAddr: "127.0.0.1:1234",
 | 
			
		||||
	}
 | 
			
		||||
	if v := p.PermitRequest(r); v != nil {
 | 
			
		||||
		t.Errorf("expected request to return no verdict, got %t", *v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.Rules[0].Rule.(*requestRule).Permit = &yes
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != yes {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", yes, v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.RemoteAddr = "192.168.1.2:3456"
 | 
			
		||||
	if v := p.PermitRequest(r); v != nil {
 | 
			
		||||
		t.Errorf("expected request to return no verdict, got %t", *v)
 | 
			
		||||
	}
 | 
			
		||||
	if v := p.PermitIntercept(r); v != nil {
 | 
			
		||||
		t.Errorf("expected request to return no verdict, got %t", *v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.URL.Host = "maze.io"
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != yes {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", yes, v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.URL.Host = "google.com"
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != nope {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", nope, v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.URL.Host = "localhost:80"
 | 
			
		||||
	if v := p.PermitRequest(r); v == nil || *v != yes {
 | 
			
		||||
		t.Errorf("expected request to return %t, %v", yes, v)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,368 +0,0 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"git.maze.io/maze/styx/proxy/match"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2"
 | 
			
		||||
	"github.com/hashicorp/hcl/v2/gohcl"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Rule is a policy rule.
 | 
			
		||||
type Rule interface {
 | 
			
		||||
	Configure(match.Matchers) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InterceptRule can make policy rule decisions on intercept requests.
 | 
			
		||||
type InterceptRule interface {
 | 
			
		||||
	PermitIntercept(r *http.Request) *bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
 | 
			
		||||
type RequestRule interface {
 | 
			
		||||
	PermitRequest(r *http.Request) *bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rawRule struct {
 | 
			
		||||
	Type string   `hcl:"type,label" json:"type"`
 | 
			
		||||
	Body hcl.Body `hcl:",remain" json:"-"`
 | 
			
		||||
	Rule `json:"rule"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	switch r.Type {
 | 
			
		||||
	case "intercept":
 | 
			
		||||
		r.Rule = new(interceptRule)
 | 
			
		||||
	case "request":
 | 
			
		||||
		r.Rule = new(requestRule)
 | 
			
		||||
	case "days":
 | 
			
		||||
		r.Rule = new(daysRule)
 | 
			
		||||
	case "time":
 | 
			
		||||
		r.Rule = new(timeRule)
 | 
			
		||||
	case "all":
 | 
			
		||||
		r.Rule = new(allRule)
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Errorf("policy: invalid event type %q", r.Type)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return r.Rule.Configure(matchers)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type allRule struct {
 | 
			
		||||
	Rules  []*rawRule `hcl:"on,block"`
 | 
			
		||||
	Permit *bool      `hcl:"permit"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *allRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type domainOrNetworkRule struct {
 | 
			
		||||
	matchers []match.Matcher
 | 
			
		||||
	isSource []bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
 | 
			
		||||
	var m match.Matcher
 | 
			
		||||
	for _, domain := range domains {
 | 
			
		||||
		if m, err = matchers.Get("domain", domain); err != nil {
 | 
			
		||||
			return fmt.Errorf("%s: unknown domain %q", kind, domain)
 | 
			
		||||
		}
 | 
			
		||||
		r.matchers = append(r.matchers, m)
 | 
			
		||||
		r.isSource = append(r.isSource, false)
 | 
			
		||||
	}
 | 
			
		||||
	for _, network := range sources {
 | 
			
		||||
		if m, err = matchers.Get("network", network); err != nil {
 | 
			
		||||
			return fmt.Errorf("%s: unknown source network %q", kind, network)
 | 
			
		||||
		}
 | 
			
		||||
		r.matchers = append(r.matchers, m)
 | 
			
		||||
		r.isSource = append(r.isSource, true)
 | 
			
		||||
	}
 | 
			
		||||
	for _, network := range targets {
 | 
			
		||||
		if m, err = matchers.Get("network", network); err != nil {
 | 
			
		||||
			return fmt.Errorf("%s: unknown target network %q", kind, network)
 | 
			
		||||
		}
 | 
			
		||||
		r.matchers = append(r.matchers, m)
 | 
			
		||||
		r.isSource = append(r.isSource, false)
 | 
			
		||||
	}
 | 
			
		||||
	if len(r.matchers) == 0 {
 | 
			
		||||
		return fmt.Errorf("%s: missing any of domain, source, target", kind)
 | 
			
		||||
	}
 | 
			
		||||
	if id != nil {
 | 
			
		||||
		*id = uuid.NewString()
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
 | 
			
		||||
	for i, m := range r.matchers {
 | 
			
		||||
		if m, ok := m.(match.Request); ok {
 | 
			
		||||
			if m.MatchesRequest(q) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if m, ok := m.(match.IP); ok {
 | 
			
		||||
			if r.isSource[i] {
 | 
			
		||||
				if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				var (
 | 
			
		||||
					host = netutil.Host(q.URL.Host)
 | 
			
		||||
					ips  []net.IP
 | 
			
		||||
				)
 | 
			
		||||
				if ip := net.ParseIP(host); ip != nil {
 | 
			
		||||
					ips = append(ips, ip)
 | 
			
		||||
				} else {
 | 
			
		||||
					ips, _ = net.LookupIP(host)
 | 
			
		||||
				}
 | 
			
		||||
				for _, ip := range ips {
 | 
			
		||||
					if m.MatchesIP(ip) {
 | 
			
		||||
						return true
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type interceptRule struct {
 | 
			
		||||
	ID                  string   `json:"id,omitempty"`
 | 
			
		||||
	Domain              []string `hcl:"domain,optional" json:"domain,omitempty"`
 | 
			
		||||
	Source              []string `hcl:"source,optional" json:"source,omitempty"`
 | 
			
		||||
	Target              []string `hcl:"target,optional" json:"target,omitempty"`
 | 
			
		||||
	Permit              *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	domainOrNetworkRule `json:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
 | 
			
		||||
	if r.matchesRequest(q) {
 | 
			
		||||
		return r.Permit
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type requestRule struct {
 | 
			
		||||
	ID                  string   `json:"id,omitempty"`
 | 
			
		||||
	Domain              []string `hcl:"domain,optional" json:"domain,omitempty"`
 | 
			
		||||
	Source              []string `hcl:"source,optional" json:"source,omitempty"`
 | 
			
		||||
	Target              []string `hcl:"target,optional" json:"target,omitempty"`
 | 
			
		||||
	Permit              *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	domainOrNetworkRule `json:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *requestRule) PermitRequest(q *http.Request) *bool {
 | 
			
		||||
	if r.matchesRequest(q) {
 | 
			
		||||
		return r.Permit
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type timeRule struct {
 | 
			
		||||
	ID     string   `json:"id,omitempty"`
 | 
			
		||||
	Time   []string `hcl:"time" json:"time"`
 | 
			
		||||
	Permit *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	Body   hcl.Body `hcl:",remain" json:"-"`
 | 
			
		||||
	Rules  *Policy  `json:"rules"`
 | 
			
		||||
	Start  Time     `json:"start"`
 | 
			
		||||
	End    Time     `json:"end"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) isActive() bool {
 | 
			
		||||
	if r == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := Now()
 | 
			
		||||
	if r.Start.After(r.End) { // ie: 18:00-06:00
 | 
			
		||||
		return now.After(r.Start) || now.Before(r.End)
 | 
			
		||||
	}
 | 
			
		||||
	return now.After(r.Start) && now.Before(r.End)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	if len(r.Time) != 2 {
 | 
			
		||||
		return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
 | 
			
		||||
	}
 | 
			
		||||
	if r.Start, err = ParseTime(r.Time[0]); err != nil {
 | 
			
		||||
		return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
 | 
			
		||||
	}
 | 
			
		||||
	if r.End, err = ParseTime(r.Time[1]); err != nil {
 | 
			
		||||
		return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.Rules = new(Policy)
 | 
			
		||||
	if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
 | 
			
		||||
		return diag
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = r.Rules.Configure(matchers); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.Rules.Matchers = nil
 | 
			
		||||
 | 
			
		||||
	if r.ID == "" {
 | 
			
		||||
		r.ID = uuid.NewString()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitIntercept(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *timeRule) PermitRequest(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitRequest(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type daysRule struct {
 | 
			
		||||
	ID     string   `json:"id,omitempty"`
 | 
			
		||||
	Days   string   `hcl:"days" json:"days"`
 | 
			
		||||
	Permit *bool    `hcl:"permit" json:"permit"`
 | 
			
		||||
	Body   hcl.Body `hcl:",remain" json:"-"`
 | 
			
		||||
	Rules  *Policy  `json:"rules"`
 | 
			
		||||
	cond   []onCond
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) isActive() bool {
 | 
			
		||||
	if r == nil || len(r.cond) == 0 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	for _, cond := range r.cond {
 | 
			
		||||
		if cond(now) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
 | 
			
		||||
	if r.cond, err = parseOnCond(r.Days); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.Rules = new(Policy)
 | 
			
		||||
	if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
 | 
			
		||||
		return diag
 | 
			
		||||
	}
 | 
			
		||||
	if err = r.Rules.Configure(matchers); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.Rules.Matchers = nil
 | 
			
		||||
 | 
			
		||||
	if r.ID == "" {
 | 
			
		||||
		r.ID = uuid.NewString()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitIntercept(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *daysRule) PermitRequest(q *http.Request) *bool {
 | 
			
		||||
	if !r.isActive() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return r.Rules.PermitRequest(q)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type onCond func(time.Time) bool
 | 
			
		||||
 | 
			
		||||
var weekdays = map[string]time.Weekday{
 | 
			
		||||
	"sun": time.Sunday,
 | 
			
		||||
	"mon": time.Monday,
 | 
			
		||||
	"tue": time.Tuesday,
 | 
			
		||||
	"wed": time.Wednesday,
 | 
			
		||||
	"thu": time.Thursday,
 | 
			
		||||
	"fri": time.Friday,
 | 
			
		||||
	"sat": time.Saturday,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseOnCond(when string) (conds []onCond, err error) {
 | 
			
		||||
	for _, spec := range strings.Split(when, ",") {
 | 
			
		||||
		spec = strings.ToLower(strings.TrimSpace(spec))
 | 
			
		||||
		if d, ok := weekdays[spec]; ok {
 | 
			
		||||
			conds = append(conds, onWeekday(d))
 | 
			
		||||
		} else if spec == "weekend" || spec == "weekends" {
 | 
			
		||||
			conds = append(conds, onWeekend)
 | 
			
		||||
		} else if spec == "workday" || spec == "workdays" {
 | 
			
		||||
			conds = append(conds, onWorkday)
 | 
			
		||||
		} else if strings.ContainsRune(spec, '-') {
 | 
			
		||||
			var (
 | 
			
		||||
				part       = strings.SplitN(spec, "-", 2)
 | 
			
		||||
				from, upto time.Weekday
 | 
			
		||||
				ok         bool
 | 
			
		||||
			)
 | 
			
		||||
			if from, ok = weekdays[part[0]]; !ok {
 | 
			
		||||
				return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
 | 
			
		||||
			}
 | 
			
		||||
			if upto, ok = weekdays[part[1]]; !ok {
 | 
			
		||||
				return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
 | 
			
		||||
			}
 | 
			
		||||
			if from < upto {
 | 
			
		||||
				for d := from; d < upto; d++ {
 | 
			
		||||
					conds = append(conds, onWeekday(d))
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				for d := time.Sunday; d < from; d++ {
 | 
			
		||||
					conds = append(conds, onWeekday(d))
 | 
			
		||||
				}
 | 
			
		||||
				for d := upto; d <= time.Saturday; d++ {
 | 
			
		||||
					conds = append(conds, onWeekday(d))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			return nil, fmt.Errorf("on %q: invalid condition", spec)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onWeekday(weekday time.Weekday) onCond {
 | 
			
		||||
	return func(t time.Time) bool {
 | 
			
		||||
		return t.Weekday() == weekday
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onWeekend(t time.Time) bool {
 | 
			
		||||
	d := t.Weekday()
 | 
			
		||||
	return d == time.Saturday || d == time.Sunday
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onWorkday(t time.Time) bool {
 | 
			
		||||
	d := t.Weekday()
 | 
			
		||||
	return !(d == time.Saturday || d == time.Sunday)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,53 +0,0 @@
 | 
			
		||||
package policy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Time struct {
 | 
			
		||||
	Hour   int
 | 
			
		||||
	Minute int
 | 
			
		||||
	Second int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Eq(other Time) bool {
 | 
			
		||||
	return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) After(other Time) bool {
 | 
			
		||||
	return t.Seconds() > other.Seconds()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Before(other Time) bool {
 | 
			
		||||
	return t.Seconds() < other.Seconds()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) Seconds() int {
 | 
			
		||||
	return t.Hour*3600 + t.Minute*60 + t.Second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Time) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var timeFormats = []string{
 | 
			
		||||
	time.TimeOnly,
 | 
			
		||||
	"15:04",
 | 
			
		||||
	time.Kitchen,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Now() Time {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	return Time{now.Hour(), now.Minute(), now.Second()}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ParseTime(s string) (t Time, err error) {
 | 
			
		||||
	var tt time.Time
 | 
			
		||||
	for _, layout := range timeFormats {
 | 
			
		||||
		if tt, err = time.Parse(layout, s); err == nil {
 | 
			
		||||
			return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return Time{}, fmt.Errorf("time: invalid time %q", s)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										985
									
								
								proxy/proxy.go
									
									
									
									
									
								
							
							
						
						
									
										985
									
								
								proxy/proxy.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -1,148 +0,0 @@
 | 
			
		||||
// Package resolver implements a caching DNS resolver
 | 
			
		||||
package resolver
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"math/rand/v2"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/netutil"
 | 
			
		||||
	"github.com/hashicorp/golang-lru/v2/expirable"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DefaultSize    = 1024
 | 
			
		||||
	DefaultTTL     = 5 * time.Minute
 | 
			
		||||
	DefaultTimeout = 10 * time.Second
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// DefaultConfig are the defaults for the Default resolver.
 | 
			
		||||
	DefaultConfig = Config{
 | 
			
		||||
		Size:    DefaultSize,
 | 
			
		||||
		TTL:     DefaultTTL.Seconds(),
 | 
			
		||||
		Timeout: DefaultTimeout.Seconds(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Default resolver.
 | 
			
		||||
	Default = New(DefaultConfig)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Resolver interface {
 | 
			
		||||
	// Lookup returns resolved IPs for given hostname/ips.
 | 
			
		||||
	Lookup(context.Context, string) ([]string, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type netResolver struct {
 | 
			
		||||
	resolver *net.Resolver
 | 
			
		||||
	timeout  time.Duration
 | 
			
		||||
	noIPv6   bool
 | 
			
		||||
	cache    *expirable.LRU[string, []string]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// Size is our cache size in number of entries.
 | 
			
		||||
	Size int `hcl:"size,optional"`
 | 
			
		||||
 | 
			
		||||
	// TTL is the cache time to live in seconds.
 | 
			
		||||
	TTL float64 `hcl:"ttl,optional"`
 | 
			
		||||
 | 
			
		||||
	// Timeout is the cache timeout in seconds.
 | 
			
		||||
	Timeout float64 `hcl:"timeout,optional"`
 | 
			
		||||
 | 
			
		||||
	// Server are alternative DNS servers.
 | 
			
		||||
	Server []string `hcl:"server,optional"`
 | 
			
		||||
 | 
			
		||||
	// NoIPv6 disables IPv6 DNS resolution.
 | 
			
		||||
	NoIPv6 bool `hcl:"noipv6,optional"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(config Config) Resolver {
 | 
			
		||||
	var (
 | 
			
		||||
		size    = config.Size
 | 
			
		||||
		ttl     = time.Duration(float64(time.Second) * config.TTL)
 | 
			
		||||
		timeout = time.Duration(float64(time.Second) * config.Timeout)
 | 
			
		||||
	)
 | 
			
		||||
	if size <= 0 {
 | 
			
		||||
		size = DefaultSize
 | 
			
		||||
	}
 | 
			
		||||
	if ttl <= 0 {
 | 
			
		||||
		ttl = DefaultTTL
 | 
			
		||||
	}
 | 
			
		||||
	if timeout <= 0 {
 | 
			
		||||
		timeout = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var resolver = new(net.Resolver)
 | 
			
		||||
	if len(config.Server) > 0 {
 | 
			
		||||
		var dialer net.Dialer
 | 
			
		||||
		resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
 | 
			
		||||
			server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
 | 
			
		||||
			return dialer.DialContext(ctx, network, server)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &netResolver{
 | 
			
		||||
		resolver: resolver,
 | 
			
		||||
		timeout:  timeout,
 | 
			
		||||
		noIPv6:   config.NoIPv6,
 | 
			
		||||
		cache:    expirable.NewLRU[string, []string](size, nil, ttl),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
 | 
			
		||||
	host = strings.ToLower(strings.TrimSpace(host))
 | 
			
		||||
	if hosts, ok := r.cache.Get(host); ok {
 | 
			
		||||
		rand.Shuffle(len(hosts), func(i, j int) {
 | 
			
		||||
			hosts[i], hosts[j] = hosts[j], hosts[i]
 | 
			
		||||
		})
 | 
			
		||||
		return hosts, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hosts, err := r.lookup(ctx, host)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	r.cache.Add(host, hosts)
 | 
			
		||||
	return hosts, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
 | 
			
		||||
	if r.timeout > 0 {
 | 
			
		||||
		var cancel func()
 | 
			
		||||
		ctx, cancel = context.WithTimeout(ctx, r.timeout)
 | 
			
		||||
		defer cancel()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if net.ParseIP(host) == nil {
 | 
			
		||||
		addrs, err := r.resolver.LookupHost(ctx, host)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if r.noIPv6 {
 | 
			
		||||
			var addrs4 []string
 | 
			
		||||
			for _, addr := range addrs {
 | 
			
		||||
				if net.ParseIP(addr).To4() != nil {
 | 
			
		||||
					addrs4 = append(addrs4, addr)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return addrs4, nil
 | 
			
		||||
		}
 | 
			
		||||
		return addrs, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addrs, err := r.resolver.LookupIPAddr(ctx, host)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hosts := make([]string, len(addrs))
 | 
			
		||||
	for i, addr := range addrs {
 | 
			
		||||
		if !r.noIPv6 || addr.IP.To4() != nil {
 | 
			
		||||
			hosts[i] = addr.IP.String()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return hosts, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -2,77 +2,58 @@ package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewResponse(code int, body io.Reader, request *http.Request) *http.Response {
 | 
			
		||||
	if body == nil {
 | 
			
		||||
		body = new(bytes.Buffer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rc, ok := body.(io.ReadCloser)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		rc = io.NopCloser(body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	response := &http.Response{
 | 
			
		||||
		Status:     strconv.Itoa(code) + " " + http.StatusText(code),
 | 
			
		||||
		StatusCode: code,
 | 
			
		||||
		Proto:      "HTTP/1.1",
 | 
			
		||||
		ProtoMajor: 1,
 | 
			
		||||
		ProtoMinor: 1,
 | 
			
		||||
		Header:     make(http.Header),
 | 
			
		||||
		Body:       rc,
 | 
			
		||||
		Request:    request,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if request != nil {
 | 
			
		||||
		response.Close = request.Close
 | 
			
		||||
		response.Proto = request.Proto
 | 
			
		||||
		response.ProtoMajor = request.ProtoMajor
 | 
			
		||||
		response.ProtoMinor = request.ProtoMinor
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type withLen interface {
 | 
			
		||||
	Len() int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type withSize interface {
 | 
			
		||||
type sizer interface {
 | 
			
		||||
	Size() int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response {
 | 
			
		||||
	response := NewResponse(code, body, request)
 | 
			
		||||
	response.Header.Set(HeaderContentType, "application/json")
 | 
			
		||||
	if s, ok := body.(withLen); ok {
 | 
			
		||||
		response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len()))
 | 
			
		||||
	} else if s, ok := body.(withSize); ok {
 | 
			
		||||
		response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10))
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size")
 | 
			
		||||
// NewResponse prepares a net [http.Response], based on the status code, optional body and
 | 
			
		||||
// optional [http.Request].
 | 
			
		||||
func NewResponse(code int, body io.ReadCloser, req *http.Request) *http.Response {
 | 
			
		||||
	res := &http.Response{
 | 
			
		||||
		StatusCode: code,
 | 
			
		||||
		Header:     make(http.Header),
 | 
			
		||||
		Proto:      "HTTP/1.1",
 | 
			
		||||
		ProtoMajor: 1,
 | 
			
		||||
		ProtoMinor: 1,
 | 
			
		||||
	}
 | 
			
		||||
	response.Close = true
 | 
			
		||||
	return response
 | 
			
		||||
 | 
			
		||||
	if text := http.StatusText(code); text != "" {
 | 
			
		||||
		res.Status = strconv.Itoa(code) + " " + text
 | 
			
		||||
	} else {
 | 
			
		||||
		res.Status = strconv.Itoa(code)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if body == nil && code >= 400 {
 | 
			
		||||
		body = io.NopCloser(bytes.NewBufferString(http.StatusText(code)))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res.Body = body
 | 
			
		||||
 | 
			
		||||
	if s, ok := body.(sizer); ok {
 | 
			
		||||
		res.ContentLength = s.Size()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if req != nil {
 | 
			
		||||
		res.Close = req.Close
 | 
			
		||||
		res.Proto = req.Proto
 | 
			
		||||
		res.ProtoMajor = req.ProtoMajor
 | 
			
		||||
		res.ProtoMinor = req.ProtoMinor
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ErrorResponse(request *http.Request, err error) *http.Response {
 | 
			
		||||
	response := NewResponse(http.StatusBadGateway, nil, request)
 | 
			
		||||
func NewErrorResponse(err error, req *http.Request) *http.Response {
 | 
			
		||||
	switch {
 | 
			
		||||
	case os.IsNotExist(err):
 | 
			
		||||
		response.StatusCode = http.StatusNotFound
 | 
			
		||||
	case os.IsPermission(err):
 | 
			
		||||
		response.StatusCode = http.StatusForbidden
 | 
			
		||||
	case os.IsTimeout(err):
 | 
			
		||||
		return NewResponse(http.StatusGatewayTimeout, nil, req)
 | 
			
		||||
	default:
 | 
			
		||||
		return NewResponse(http.StatusBadGateway, nil, req)
 | 
			
		||||
	}
 | 
			
		||||
	response.Status = http.StatusText(response.StatusCode)
 | 
			
		||||
	response.Close = true
 | 
			
		||||
	return response
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										151
									
								
								proxy/session.go
									
									
									
									
									
								
							
							
						
						
									
										151
									
								
								proxy/session.go
									
									
									
									
									
								
							@@ -1,151 +0,0 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var seed = rand.NewSource(time.Now().UnixNano())
 | 
			
		||||
 | 
			
		||||
type Context struct {
 | 
			
		||||
	id     int64
 | 
			
		||||
	conn   *wrappedConn
 | 
			
		||||
	rw     *bufio.ReadWriter
 | 
			
		||||
	parent *Session
 | 
			
		||||
	data   map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context {
 | 
			
		||||
	if wrapped, ok := conn.(*wrappedConn); ok {
 | 
			
		||||
		conn = wrapped.Conn
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := &Context{
 | 
			
		||||
		id:     seed.Int63(),
 | 
			
		||||
		conn:   &wrappedConn{Conn: conn},
 | 
			
		||||
		rw:     rw,
 | 
			
		||||
		parent: parent,
 | 
			
		||||
		data:   make(map[string]any),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ctx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) log() log.Logger {
 | 
			
		||||
	return log.Console.With().
 | 
			
		||||
		Str("context", ctx.ID()).
 | 
			
		||||
		Str("addr", ctx.RemoteAddr().String()).
 | 
			
		||||
		Logger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) ID() string {
 | 
			
		||||
	var b [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(b[:], uint64(ctx.id))
 | 
			
		||||
	if ctx.parent != nil {
 | 
			
		||||
		return ctx.parent.ID() + "-" + hex.EncodeToString(b[:])
 | 
			
		||||
	}
 | 
			
		||||
	return hex.EncodeToString(b[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) IsTLS() bool {
 | 
			
		||||
	_, ok := ctx.conn.Conn.(*tls.Conn)
 | 
			
		||||
	return ok && ctx.parent != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) RemoteAddr() net.Addr {
 | 
			
		||||
	if ctx.parent != nil {
 | 
			
		||||
		return ctx.parent.ctx.RemoteAddr()
 | 
			
		||||
	}
 | 
			
		||||
	return ctx.conn.RemoteAddr()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) SetDeadline(t time.Time) error {
 | 
			
		||||
	if ctx.parent != nil {
 | 
			
		||||
		return ctx.parent.ctx.SetDeadline(t)
 | 
			
		||||
	}
 | 
			
		||||
	return ctx.conn.SetDeadline(t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Set(key string, value any) {
 | 
			
		||||
	ctx.data[key] = value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Get(key string) (value any, ok bool) {
 | 
			
		||||
	value, ok = ctx.data[key]
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Flush() error {
 | 
			
		||||
	return ctx.rw.Flush()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ctx *Context) Write(p []byte) (n int, err error) {
 | 
			
		||||
	if n, err = ctx.rw.Write(p); n > 0 {
 | 
			
		||||
		atomic.AddInt64(&ctx.conn.bytes, int64(n))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Session struct {
 | 
			
		||||
	id       int64
 | 
			
		||||
	ctx      *Context
 | 
			
		||||
	request  *http.Request
 | 
			
		||||
	response *http.Response
 | 
			
		||||
	data     map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSession(ctx *Context, request *http.Request) *Session {
 | 
			
		||||
	return &Session{
 | 
			
		||||
		id:      seed.Int63(),
 | 
			
		||||
		ctx:     ctx,
 | 
			
		||||
		request: request,
 | 
			
		||||
		data:    make(map[string]any),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) log() log.Logger {
 | 
			
		||||
	return log.Console.With().
 | 
			
		||||
		Str("context", ses.ctx.ID()).
 | 
			
		||||
		Str("session", ses.ID()).
 | 
			
		||||
		Str("addr", ses.ctx.RemoteAddr().String()).
 | 
			
		||||
		Logger()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) ID() string {
 | 
			
		||||
	var b [8]byte
 | 
			
		||||
	binary.BigEndian.PutUint64(b[:], uint64(ses.id))
 | 
			
		||||
	return hex.EncodeToString(b[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) Context() *Context {
 | 
			
		||||
	return ses.ctx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) Request() *http.Request {
 | 
			
		||||
	return ses.request
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ses *Session) Response() *http.Response {
 | 
			
		||||
	return ses.response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type wrappedConn struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	bytes int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *wrappedConn) Write(p []byte) (n int, err error) {
 | 
			
		||||
	if n, err = c.Conn.Write(p); n > 0 {
 | 
			
		||||
		atomic.AddInt64(&c.bytes, int64(n))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								proxy/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								proxy/stats.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"expvar"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/db/stats"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func countStatus(code int) {
 | 
			
		||||
	k := "http:status:" + strconv.Itoa(code)
 | 
			
		||||
	v := expvar.Get(k)
 | 
			
		||||
	if v == nil {
 | 
			
		||||
		//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
 | 
			
		||||
		v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
 | 
			
		||||
		expvar.Publish(k, v)
 | 
			
		||||
	}
 | 
			
		||||
	v.(stats.Metric).Add(1)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,225 +0,0 @@
 | 
			
		||||
package stats
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/user"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.maze.io/maze/styx/internal/log"
 | 
			
		||||
	_ "github.com/mattn/go-sqlite3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Stats struct {
 | 
			
		||||
	db *sql.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New() (*Stats, error) {
 | 
			
		||||
	u, err := user.Current()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	path := filepath.Join(u.HomeDir, ".styx", "stats.db")
 | 
			
		||||
	if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, table := range []string{
 | 
			
		||||
		createLog,
 | 
			
		||||
		createDomainStat,
 | 
			
		||||
		createStatusStat,
 | 
			
		||||
	} {
 | 
			
		||||
		if _, err = db.Exec(table); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Stats{db: db}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Stats) AddLog(entry *Log) error {
 | 
			
		||||
	var (
 | 
			
		||||
		request  []byte
 | 
			
		||||
		response []byte
 | 
			
		||||
		err      error
 | 
			
		||||
	)
 | 
			
		||||
	if request, err = json.Marshal(entry.Request); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if response, err = json.Marshal(entry.Response); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx, err := s.db.Begin()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer stmt.Close()
 | 
			
		||||
	if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) {
 | 
			
		||||
	if limit == 0 {
 | 
			
		||||
		limit = 50
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer rows.Close()
 | 
			
		||||
 | 
			
		||||
	var logs []*Log
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var entry = new(Log)
 | 
			
		||||
		if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		logs = append(logs, entry)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return logs, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Status struct {
 | 
			
		||||
	Code  int `json:"code"`
 | 
			
		||||
	Count int `json:"count"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var timeZero time.Time
 | 
			
		||||
 | 
			
		||||
func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) {
 | 
			
		||||
	if since.Equal(timeZero) {
 | 
			
		||||
		since = time.Now().Add(-24 * time.Hour)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var stats []*Status
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var entry = new(Status)
 | 
			
		||||
		if err = rows.Scan(&entry.Code, &entry.Count); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		stats = append(stats, entry)
 | 
			
		||||
	}
 | 
			
		||||
	return stats, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const createLog = `CREATE TABLE IF NOT EXISTS styx_log (
 | 
			
		||||
	id        INT PRIMARY KEY,
 | 
			
		||||
	dt        DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
			
		||||
	client_ip TEXT NOT NULL,
 | 
			
		||||
	request   JSONB NOT NULL,
 | 
			
		||||
	response  JSONB NOT NULL
 | 
			
		||||
);`
 | 
			
		||||
 | 
			
		||||
type Log struct {
 | 
			
		||||
	Time     time.Time `json:"time"`
 | 
			
		||||
	ClientIP string    `json:"client_ip"`
 | 
			
		||||
	Request  *Request  `json:"request"`
 | 
			
		||||
	Response *Response `json:"response"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	URL    string      `json:"url"`
 | 
			
		||||
	Host   string      `json:"host"`
 | 
			
		||||
	Method string      `json:"method"`
 | 
			
		||||
	Proto  string      `json:"proto"`
 | 
			
		||||
	Header http.Header `json:"header"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Request) Scan(value any) error {
 | 
			
		||||
	switch v := value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return json.Unmarshal([]byte(v), r)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		return json.Unmarshal(v, r)
 | 
			
		||||
	default:
 | 
			
		||||
		log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Request) Value() (driver.Value, error) {
 | 
			
		||||
	b, err := json.Marshal(r)
 | 
			
		||||
	return string(b), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FromRequest(r *http.Request) *Request {
 | 
			
		||||
	return &Request{
 | 
			
		||||
		URL:    r.URL.String(),
 | 
			
		||||
		Host:   r.Host,
 | 
			
		||||
		Method: r.Method,
 | 
			
		||||
		Proto:  r.Proto,
 | 
			
		||||
		Header: r.Header,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Response struct {
 | 
			
		||||
	Status int         `json:"status"`
 | 
			
		||||
	Size   int64       `json:"size"`
 | 
			
		||||
	Header http.Header `json:"header"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Response) Scan(value any) error {
 | 
			
		||||
	switch v := value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return json.Unmarshal([]byte(v), r)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		return json.Unmarshal(v, r)
 | 
			
		||||
	default:
 | 
			
		||||
		log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Response) Value() (driver.Value, error) {
 | 
			
		||||
	b, err := json.Marshal(r)
 | 
			
		||||
	return string(b), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Response) SetSize(size int64) *Response {
 | 
			
		||||
	r.Size = size
 | 
			
		||||
	return r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FromResponse(r *http.Response) *Response {
 | 
			
		||||
	return &Response{
 | 
			
		||||
		Status: r.StatusCode,
 | 
			
		||||
		Header: r.Header,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status (
 | 
			
		||||
	id     INT PRIMARY KEY,
 | 
			
		||||
	dt     DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
			
		||||
	status INT NOT NULL
 | 
			
		||||
);`
 | 
			
		||||
 | 
			
		||||
const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain (
 | 
			
		||||
	id     INT PRIMARY KEY,
 | 
			
		||||
	dt     DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
			
		||||
	domain TEXT NOT NULL
 | 
			
		||||
);`
 | 
			
		||||
@@ -1,16 +0,0 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// connReader is a net.Conn with a separate reader.
 | 
			
		||||
type connReader struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	io.Reader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c connReader) Read(p []byte) (int, error) {
 | 
			
		||||
	return c.Reader.Read(p)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user