Checkpoint
This commit is contained in:
145
proxy/admin.go
145
proxy/admin.go
@@ -1,145 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
type Admin struct {
|
||||
*Proxy
|
||||
}
|
||||
|
||||
func NewAdmin(proxy *Proxy) *Admin {
|
||||
a := &Admin{
|
||||
Proxy: proxy,
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *Admin) handleRequest(ses *Session) error {
|
||||
var (
|
||||
logger = ses.log()
|
||||
err error
|
||||
)
|
||||
switch ses.request.URL.Path {
|
||||
case "/ca.crt":
|
||||
err = a.handleCACert(ses)
|
||||
case "/api/v1/policy":
|
||||
err = a.apiPolicy(ses)
|
||||
case "/api/v1/policy/matcher":
|
||||
err = a.apiPolicyMatcher(ses)
|
||||
case "/api/v1/stats/log":
|
||||
err = a.apiStatsLog(ses)
|
||||
case "/api/v1/stats/status":
|
||||
err = a.apiStatsStatus(ses)
|
||||
default:
|
||||
if strings.HasPrefix(ses.request.URL.Path, "/api") {
|
||||
err = errors.New("invalid endpoint")
|
||||
} else {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("admin error")
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Admin) handleCACert(ses *Session) error {
|
||||
b := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: a.authority.Certificate().Raw,
|
||||
})
|
||||
|
||||
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
|
||||
ses.response.Close = true
|
||||
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
|
||||
ses.response.ContentLength = int64(len(b))
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (a *Admin) apiPolicy(ses *Session) error {
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
e = json.NewEncoder(b)
|
||||
)
|
||||
e.SetIndent("", " ")
|
||||
if err := e.Encode(a.config.Policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (a *Admin) apiPolicyMatcher(ses *Session) error {
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
e = json.NewEncoder(b)
|
||||
)
|
||||
e.SetIndent("", " ")
|
||||
if err := e.Encode(a.config.Policy.Matchers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (a *Admin) apiResponse(ses *Session, v any, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
e = json.NewEncoder(b)
|
||||
)
|
||||
e.SetIndent("", " ")
|
||||
if err := e.Encode(v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
|
||||
}
|
||||
|
||||
func (a *Admin) apiStatsLog(ses *Session) error {
|
||||
var (
|
||||
query = ses.request.URL.Query()
|
||||
offset, _ = strconv.Atoi(query.Get("offset"))
|
||||
limit, _ = strconv.Atoi(query.Get("limit"))
|
||||
)
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
s, err := a.stats.QueryLog(offset, limit)
|
||||
return a.apiResponse(ses, s, err)
|
||||
}
|
||||
|
||||
func (a *Admin) apiStatsStatus(ses *Session) error {
|
||||
s, err := a.stats.QueryStatus(time.Time{})
|
||||
return a.apiResponse(ses, s, err)
|
||||
}
|
8
proxy/cache/config.go
vendored
8
proxy/cache/config.go
vendored
@@ -1,8 +0,0 @@
|
||||
package cache
|
||||
|
||||
import "github.com/hashicorp/hcl/v2"
|
||||
|
||||
type Config struct {
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
@@ -1,88 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/proxy/policy"
|
||||
"git.maze.io/maze/styx/proxy/resolver"
|
||||
)
|
||||
|
||||
type ConnectHandler interface {
|
||||
HandleConnect(session *Session, network, address string) net.Conn
|
||||
}
|
||||
|
||||
// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request.
|
||||
type ConnectHandlerFunc func(session *Session, network, address string) net.Conn
|
||||
|
||||
func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn {
|
||||
return f(session, network, address)
|
||||
}
|
||||
|
||||
type RequestHandler interface {
|
||||
HandleRequest(session *Session) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
// RequestHandlerFunc is called when the proxy receives a new request.
|
||||
type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response)
|
||||
|
||||
func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) {
|
||||
return f(session)
|
||||
}
|
||||
|
||||
type ResponseHandler interface {
|
||||
HandleResponse(session *Session) *http.Response
|
||||
}
|
||||
|
||||
// ResponseHandler is called when the proxy receives a response.
|
||||
type ResponseHandlerFunc func(session *Session) *http.Response
|
||||
|
||||
func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response {
|
||||
return f(session)
|
||||
}
|
||||
|
||||
type ErrorHandler interface {
|
||||
HandleError(session *Session, err error)
|
||||
}
|
||||
|
||||
type ErrorHandlerFunc func(session *Session, err error)
|
||||
|
||||
func (f ErrorHandlerFunc) HandleError(session *Session, err error) {
|
||||
f(session, err)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Listen address.
|
||||
Listen string `hcl:"listen,optional"`
|
||||
|
||||
// Bind address for outgoing connections.
|
||||
Bind string `hcl:"bind,optional"`
|
||||
|
||||
// Interface for outgoing connections.
|
||||
Interface string `hcl:"interface,optional"`
|
||||
|
||||
// Upstream proxy servers.
|
||||
Upstream []string `hcl:"upstream,optional"`
|
||||
|
||||
// DialTimeout for establishing new connections.
|
||||
DialTimeout time.Duration `hcl:"dial_timeout,optional"`
|
||||
|
||||
// Policy for the proxy.
|
||||
Policy *policy.Policy `hcl:"policy,block"`
|
||||
|
||||
// Resolver for the proxy.
|
||||
Resolver resolver.Resolver
|
||||
|
||||
ConnectHandler ConnectHandler
|
||||
RequestHandler RequestHandler
|
||||
ResponseHandler ResponseHandler
|
||||
ErrorHandler ErrorHandler
|
||||
}
|
||||
|
||||
var (
|
||||
_ ConnectHandler = (ConnectHandlerFunc)(nil)
|
||||
_ RequestHandler = (RequestHandlerFunc)(nil)
|
||||
_ ResponseHandler = (ResponseHandlerFunc)(nil)
|
||||
_ ErrorHandler = (ErrorHandlerFunc)(nil)
|
||||
)
|
227
proxy/context.go
Normal file
227
proxy/context.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil/arp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Context provides convenience functions for the current ongoing HTTP proxy transaction (request).
|
||||
type Context interface {
|
||||
// Conn is the backing connection for this context.
|
||||
net.Conn
|
||||
|
||||
// ID is a unique connection identifier.
|
||||
ID() uint64
|
||||
|
||||
// Reader returns a buffered reader on top of the [net.Conn].
|
||||
Reader() *bufio.Reader
|
||||
|
||||
// BytesRead returns the number of bytes read.
|
||||
BytesRead() int64
|
||||
|
||||
// BytesSent returns the number of bytes written.
|
||||
BytesSent() int64
|
||||
|
||||
// TLSState returns the TLS connection state, it returns nil if the connection is not a TLS connection.
|
||||
TLSState() *tls.ConnectionState
|
||||
|
||||
// Request is the request made to the proxy.
|
||||
Request() *http.Request
|
||||
|
||||
// Response is the response that will be sent back to the client.
|
||||
Response() *http.Response
|
||||
}
|
||||
|
||||
type countingReader struct {
|
||||
reader io.Reader
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func (r *countingReader) Read(p []byte) (n int, err error) {
|
||||
if n, err = r.reader.Read(p); n > 0 {
|
||||
atomic.AddInt64(&r.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type countingWriter struct {
|
||||
writer io.Writer
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func (w *countingWriter) Write(p []byte) (n int, err error) {
|
||||
if n, err = w.writer.Write(p); n > 0 {
|
||||
atomic.AddInt64(&w.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type proxyContext struct {
|
||||
net.Conn
|
||||
id uint64
|
||||
mac net.HardwareAddr
|
||||
cr *countingReader
|
||||
br *bufio.Reader
|
||||
cw *countingWriter
|
||||
isTransparent bool
|
||||
isTransparentTLS bool
|
||||
serverName string
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewContext returns an initialized context for the provided [net.Conn].
|
||||
func NewContext(c net.Conn) Context {
|
||||
if c, ok := c.(*proxyContext); ok {
|
||||
return c
|
||||
}
|
||||
|
||||
b := make([]byte, 8)
|
||||
io.ReadFull(rand.Reader, b)
|
||||
|
||||
cr := &countingReader{reader: c}
|
||||
cw := &countingWriter{writer: c}
|
||||
return &proxyContext{
|
||||
Conn: c,
|
||||
id: binary.BigEndian.Uint64(b),
|
||||
mac: arp.Get(c.RemoteAddr()),
|
||||
cr: cr,
|
||||
br: bufio.NewReader(cr),
|
||||
cw: cw,
|
||||
res: &http.Response{StatusCode: 200},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *proxyContext) AccessLogEntry() *logrus.Entry {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
entry := AccessLog.WithFields(logrus.Fields{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
"id": hex.EncodeToString(id[:]),
|
||||
"bytes_rx": c.BytesRead(),
|
||||
"bytes_tx": c.BytesSent(),
|
||||
})
|
||||
if c.mac != nil {
|
||||
return entry.WithField("client_mac", c.mac.String())
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *proxyContext) LogEntry() *logrus.Entry {
|
||||
var id [8]byte
|
||||
binary.BigEndian.PutUint64(id[:], c.id)
|
||||
return ServerLog.WithFields(logrus.Fields{
|
||||
"client": c.RemoteAddr().String(),
|
||||
"server": c.LocalAddr().String(),
|
||||
"id": hex.EncodeToString(id[:]),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *proxyContext) String() string {
|
||||
return fmt.Sprintf("client=%s server=%s id=%#08x",
|
||||
c.RemoteAddr().String(),
|
||||
c.LocalAddr().String(),
|
||||
c.id)
|
||||
}
|
||||
|
||||
func (c *proxyContext) ID() uint64 {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *proxyContext) BytesRead() int64 {
|
||||
return atomic.LoadInt64(&c.cr.bytes)
|
||||
}
|
||||
|
||||
func (c *proxyContext) BytesSent() int64 {
|
||||
return atomic.LoadInt64(&c.cw.bytes)
|
||||
}
|
||||
|
||||
func (c *proxyContext) Read(p []byte) (n int, err error) {
|
||||
if c.idleTimeout > 0 {
|
||||
if err = c.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return c.br.Read(p)
|
||||
}
|
||||
|
||||
func (c *proxyContext) Write(p []byte) (n int, err error) {
|
||||
if c.idleTimeout > 0 {
|
||||
if err = c.SetWriteDeadline(time.Now().Add(c.idleTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return c.cw.Write(p)
|
||||
}
|
||||
|
||||
func (c *proxyContext) Reader() *bufio.Reader {
|
||||
return c.br
|
||||
}
|
||||
|
||||
func (c *proxyContext) Request() *http.Request {
|
||||
return c.req
|
||||
}
|
||||
|
||||
func (c *proxyContext) SetRequest(req *http.Request) {
|
||||
c.req = req
|
||||
}
|
||||
|
||||
func (c *proxyContext) Response() *http.Response {
|
||||
return c.res
|
||||
}
|
||||
|
||||
func (c *proxyContext) SetIdleTimeout(t time.Duration) {
|
||||
c.idleTimeout = t
|
||||
}
|
||||
|
||||
type connectionStater interface {
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
|
||||
func (c *proxyContext) TLSState() *tls.ConnectionState {
|
||||
if s, ok := c.Conn.(connectionStater); ok {
|
||||
state := s.ConnectionState()
|
||||
return &state
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// http.ResponseWriter interface:
|
||||
|
||||
func (c *proxyContext) Header() http.Header {
|
||||
if c.res == nil {
|
||||
c.res = NewResponse(http.StatusOK, nil, c.req)
|
||||
}
|
||||
return c.res.Header
|
||||
}
|
||||
|
||||
func (c *proxyContext) WriteHeader(code int) {
|
||||
if c.res == nil {
|
||||
c.res = NewResponse(code, nil, c.req)
|
||||
} else {
|
||||
if text := http.StatusText(code); text != "" {
|
||||
c.res.Status = strconv.Itoa(code) + " " + text
|
||||
} else {
|
||||
c.res.Status = strconv.Itoa(code)
|
||||
}
|
||||
c.res.StatusCode = code
|
||||
}
|
||||
//return c.res.Header.Write(c)
|
||||
}
|
||||
|
||||
var _ Context = (*proxyContext)(nil)
|
2
proxy/doc.go
Normal file
2
proxy/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package proxy contains a HTTP(s) (transparent) proxy server.
|
||||
package proxy
|
309
proxy/handler.go
Normal file
309
proxy/handler.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/ca"
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
DialContext(context.Context, *http.Request) (net.Conn, error)
|
||||
}
|
||||
|
||||
type defaultDialer struct{}
|
||||
|
||||
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
|
||||
if host := netutil.Host(req.URL.Host); host == "" {
|
||||
return nil, errors.New("proxy: host missing in address")
|
||||
}
|
||||
|
||||
var d = net.Dialer{
|
||||
Resolver: net.DefaultResolver,
|
||||
FallbackDelay: -1,
|
||||
}
|
||||
|
||||
// Ensure we have a port.
|
||||
switch req.URL.Scheme {
|
||||
case "http", "ws":
|
||||
req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
|
||||
case "https", "wss":
|
||||
req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
|
||||
}
|
||||
|
||||
// Resolve the host.
|
||||
if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
for _, ip := range ips {
|
||||
switch {
|
||||
case ip.IsUnspecified():
|
||||
return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
||||
case ip.IsLoopback():
|
||||
return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make the connection.
|
||||
switch req.URL.Scheme {
|
||||
case "tcp", "http", "ws":
|
||||
// Plain TCP client connection.
|
||||
return d.DialContext(ctx, "tcp", req.URL.Host)
|
||||
case "https", "wss":
|
||||
// Secure TLS client connection.
|
||||
c, err := d.DialContext(ctx, "tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := tls.Client(c, new(tls.Config))
|
||||
return s, s.Handshake()
|
||||
default:
|
||||
return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// ConnFilter is called when a new connection has been accepted by the proxy.
|
||||
type ConnFilter interface {
|
||||
FilterConn(Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
|
||||
type ConnFilterFunc func(Context) (net.Conn, error)
|
||||
|
||||
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// TLS starts a TLS handshake on the accepted connection.
|
||||
func TLS(certs []tls.Certificate) ConnFilter {
|
||||
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, &tls.Config{
|
||||
Certificates: certs,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
})
|
||||
if err := s.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
})
|
||||
}
|
||||
|
||||
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
|
||||
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
|
||||
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||
s := tls.Server(ctx, &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
|
||||
return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
})
|
||||
if err := s.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Transparent can handle transparent HTTP(S) requests on the port.
|
||||
//
|
||||
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
|
||||
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
|
||||
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
|
||||
func Transparent() ConnFilter {
|
||||
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
|
||||
ctx, ok := nctx.(*proxyContext)
|
||||
if !ok {
|
||||
return nctx, nil
|
||||
}
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
|
||||
if err != nil {
|
||||
if _, ok := err.(tls.RecordHeaderError); !ok {
|
||||
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||
return nil, err
|
||||
}
|
||||
// Not a TLS connection, moving on to regular HTTP request handling...
|
||||
ctx.LogEntry().Debug("HTTP connection on transparent port")
|
||||
ctx.isTransparent = true
|
||||
} else {
|
||||
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||
ctx.isTransparentTLS = true
|
||||
ctx.serverName = hello.ServerName
|
||||
}
|
||||
|
||||
return netutil.ReaderConn{
|
||||
Conn: ctx.Conn,
|
||||
Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// RequestFilter can filter HTTP requests coming to the proxy.
|
||||
type RequestFilter interface {
|
||||
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
|
||||
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
|
||||
// from [Context.Response].
|
||||
//
|
||||
// Modifications to the current request can be made to the Request returned by [Context.Request]
|
||||
// and do not require returning a new [http.Request].
|
||||
//
|
||||
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
|
||||
// any upstream server(s).
|
||||
FilterRequest(Context) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
|
||||
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
|
||||
|
||||
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// ResponseFilter can filter HTTP responses coming from the proxy.
|
||||
type ResponseFilter interface {
|
||||
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
|
||||
// obtained from [Context.Response].
|
||||
//
|
||||
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
|
||||
FilterResponse(Context) *http.Response
|
||||
}
|
||||
|
||||
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
|
||||
type ResponseFilterFunc func(Context) *http.Response
|
||||
|
||||
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
|
||||
func CleanRequestProxyHeaders() RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
cleanProxyHeaders(req.Header)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
|
||||
func CleanResponseProxyHeaders() ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
cleanProxyHeaders(res.Header)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
|
||||
// key will remain intact.
|
||||
func AddRequestHeaders(h http.Header) RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
if req.Header == nil {
|
||||
req.Header = make(http.Header)
|
||||
}
|
||||
addHeaders(req.Header, h)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
|
||||
// key will be removed.
|
||||
func SetRequestHeaders(h http.Header) RequestFilter {
|
||||
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||
if req := ctx.Request(); req != nil {
|
||||
if req.Header == nil {
|
||||
req.Header = make(http.Header)
|
||||
}
|
||||
setHeaders(req.Header, h)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
|
||||
// key will remain intact.
|
||||
func AddResponseHeaders(h http.Header) ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
if res.Header == nil {
|
||||
res.Header = make(http.Header)
|
||||
}
|
||||
addHeaders(res.Header, h)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
|
||||
// key will be removed.
|
||||
func SetResponseHeaders(h http.Header) ResponseFilter {
|
||||
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||
if res := ctx.Response(); res != nil {
|
||||
if res.Header == nil {
|
||||
res.Header = make(http.Header)
|
||||
}
|
||||
setHeaders(res.Header, h)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
|
||||
func cleanProxyHeaders(h http.Header) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, k := range []string{
|
||||
HeaderForwarded,
|
||||
HeaderForwardedFor,
|
||||
HeaderForwardedHost,
|
||||
HeaderForwardedPort,
|
||||
HeaderForwardedProto,
|
||||
HeaderRealIP,
|
||||
HeaderVia,
|
||||
} {
|
||||
h.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
// addHeaders adds to the current existing headers.
|
||||
func addHeaders(dst, src http.Header) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setHeaders replaces all previous values.
|
||||
func setHeaders(dst, src http.Header) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range src {
|
||||
dst.Del(k)
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,324 +0,0 @@
|
||||
package match
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Path string `hcl:"path,optional"`
|
||||
Refresh time.Duration `hcl:"refresh,optional"`
|
||||
Domain []*Domain `hcl:"domain,block"`
|
||||
Network []*Network `hcl:"network,block"`
|
||||
Content []*Content `hcl:"content,block"`
|
||||
}
|
||||
|
||||
func (config Config) Matchers() (Matchers, error) {
|
||||
all := make(Matchers)
|
||||
if config.Domain != nil {
|
||||
all["domain"] = make(map[string]Matcher)
|
||||
for _, domain := range config.Domain {
|
||||
m, err := domain.Matcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err)
|
||||
}
|
||||
all["domain"][domain.Name] = m
|
||||
}
|
||||
}
|
||||
if config.Network != nil {
|
||||
all["network"] = make(map[string]Matcher)
|
||||
for _, network := range config.Network {
|
||||
m, err := network.Matcher(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err)
|
||||
}
|
||||
all["network"][network.Name] = m
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Name string `hcl:"name,label"`
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
type contentHeader struct {
|
||||
Key string `hcl:"name"`
|
||||
Value string `hcl:"value,optional"`
|
||||
List []string `hcl:"list,optional"`
|
||||
name string
|
||||
keyRe *regexp.Regexp
|
||||
valueRe *regexp.Regexp
|
||||
}
|
||||
|
||||
func (m contentHeader) Name() string { return m.name }
|
||||
func (m contentHeader) MatchesResponse(r *http.Response) bool {
|
||||
for k, vv := range r.Header {
|
||||
if m.keyRe.MatchString(k) {
|
||||
for _, v := range vv {
|
||||
if slices.Contains(m.List, v) {
|
||||
return true
|
||||
}
|
||||
if m.valueRe != nil && m.valueRe.MatchString(v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type contentType struct {
|
||||
List []string `hcl:"list"`
|
||||
name string
|
||||
}
|
||||
|
||||
func (m contentType) Name() string { return m.name }
|
||||
func (m contentType) MatchesResponse(r *http.Response) bool {
|
||||
return slices.Contains(m.List, r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
type contentSizeLargerThan struct {
|
||||
Size int64 `hcl:"size"`
|
||||
name string
|
||||
}
|
||||
|
||||
func (m contentSizeLargerThan) Name() string { return m.name }
|
||||
func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool {
|
||||
size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return size >= m.Size
|
||||
}
|
||||
|
||||
type contentStatus struct {
|
||||
Code []int `hcl:"code"`
|
||||
name string
|
||||
}
|
||||
|
||||
func (m contentStatus) Name() string { return m.name }
|
||||
func (m contentStatus) MatchesResponse(r *http.Response) bool {
|
||||
return slices.Contains(m.Code, r.StatusCode)
|
||||
}
|
||||
|
||||
func (config Content) Matcher() (Response, error) {
|
||||
switch strings.ToLower(config.Type) {
|
||||
case "content", "contenttype", "content-type", "type":
|
||||
var matcher = contentType{name: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
case "header":
|
||||
var (
|
||||
matcher = contentHeader{name: config.Name}
|
||||
err error
|
||||
)
|
||||
if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if matcher.Value == "" && len(matcher.List) == 0 {
|
||||
return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name)
|
||||
}
|
||||
if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil {
|
||||
return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err)
|
||||
}
|
||||
if matcher.Value != "" {
|
||||
if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil {
|
||||
return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err)
|
||||
}
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
case "size":
|
||||
var matcher = contentSizeLargerThan{name: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
case "status":
|
||||
var matcher = contentStatus{name: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown content matcher type %q", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type Domain struct {
|
||||
Name string `hcl:"name,label"`
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
func (config Domain) Matcher() (Request, error) {
|
||||
switch config.Type {
|
||||
case "list":
|
||||
var matcher = domainList{Title: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matcher.list = netutil.NewDomainList(matcher.List...)
|
||||
return matcher, nil
|
||||
|
||||
case "adblock", "dnsmasq", "hosts", "detect", "domains":
|
||||
var matcher = DomainFile{
|
||||
Title: config.Name,
|
||||
Type: config.Type,
|
||||
}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if matcher.Path == "" && matcher.From == "" {
|
||||
return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name)
|
||||
}
|
||||
if err := matcher.Update(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown domain matcher type %q", config.Type)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type domainList struct {
|
||||
Title string `json:"title"`
|
||||
List []string `hcl:"list" json:"list"`
|
||||
list *netutil.DomainTree
|
||||
}
|
||||
|
||||
func (m domainList) Name() string {
|
||||
return m.Title
|
||||
}
|
||||
|
||||
func (m domainList) MatchesRequest(r *http.Request) bool {
|
||||
host := netutil.Host(r.URL.Host)
|
||||
log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List))
|
||||
return m.list.Contains(host)
|
||||
}
|
||||
|
||||
type DomainFile struct {
|
||||
Title string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Path string `hcl:"path,optional" json:"path,omitempty"`
|
||||
From string `hcl:"from,optional" json:"from,omitempty"`
|
||||
Refresh time.Duration `hcl:"refresh,optional" json:"refresh"`
|
||||
}
|
||||
|
||||
func (m DomainFile) Name() string {
|
||||
return m.Title
|
||||
}
|
||||
|
||||
func (m DomainFile) MatchesRequest(_ *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *DomainFile) Update() (err error) {
|
||||
var data []byte
|
||||
if m.Path != "" {
|
||||
if data, err = os.ReadFile(m.Path); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
var response *http.Response
|
||||
if response, err = http.DefaultClient.Get(m.From); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = response.Body.Close() }()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status)
|
||||
}
|
||||
if data, err = io.ReadAll(response.Body); err != nil {
|
||||
return
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case "hosts":
|
||||
}
|
||||
|
||||
_ = data
|
||||
return nil
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
Name string `hcl:"name,label"`
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
func (config *Network) Matcher(target bool) (Matcher, error) {
|
||||
switch config.Type {
|
||||
case "list":
|
||||
var (
|
||||
matcher = networkList{Title: config.Name}
|
||||
err error
|
||||
)
|
||||
if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() {
|
||||
return nil, diag
|
||||
}
|
||||
if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &matcher, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown network matcher type %q", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type networkList struct {
|
||||
Title string `json:"name"`
|
||||
List []string `hcl:"list" json:"list"`
|
||||
tree *netutil.NetworkTree
|
||||
target bool
|
||||
}
|
||||
|
||||
func (m *networkList) Name() string {
|
||||
return m.Title
|
||||
}
|
||||
|
||||
func (m *networkList) MatchesIP(ip net.IP) bool {
|
||||
return m.tree.Contains(ip)
|
||||
}
|
||||
|
||||
func (m *networkList) MatchesRequest(r *http.Request) bool {
|
||||
var (
|
||||
host string
|
||||
err error
|
||||
)
|
||||
if m.target {
|
||||
host, _, err = net.SplitHostPort(r.URL.Host)
|
||||
} else {
|
||||
host, _, err = net.SplitHostPort(r.RemoteAddr)
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return m.MatchesIP(ip)
|
||||
}
|
@@ -1,45 +0,0 @@
|
||||
package match
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Matchers map[string]map[string]Matcher
|
||||
|
||||
func (all Matchers) Get(kind, name string) (m Matcher, err error) {
|
||||
if typeMatchers, ok := all[kind]; ok {
|
||||
if m, ok = typeMatchers[name]; ok {
|
||||
return
|
||||
}
|
||||
return nil, fmt.Errorf("no %s matcher named %q found", kind, name)
|
||||
}
|
||||
return nil, fmt.Errorf("no %s matcher found", kind)
|
||||
}
|
||||
|
||||
type Matcher interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
type Updater interface {
|
||||
Update() error
|
||||
}
|
||||
|
||||
type IP interface {
|
||||
Matcher
|
||||
|
||||
MatchesIP(net.IP) bool
|
||||
}
|
||||
|
||||
type Request interface {
|
||||
Matcher
|
||||
|
||||
MatchesRequest(*http.Request) bool
|
||||
}
|
||||
|
||||
type Response interface {
|
||||
Matcher
|
||||
|
||||
MatchesResponse(*http.Response) bool
|
||||
}
|
@@ -1,11 +0,0 @@
|
||||
package match
|
||||
|
||||
import "net"
|
||||
|
||||
func onlyHost(name string) string {
|
||||
host, _, err := net.SplitHostPort(name)
|
||||
if err != nil {
|
||||
return name
|
||||
}
|
||||
return host
|
||||
}
|
@@ -1,231 +0,0 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultValidity = 24 * time.Hour
|
||||
|
||||
type Authority interface {
|
||||
Certificate() *x509.Certificate
|
||||
TLSConfig(name string) *tls.Config
|
||||
}
|
||||
|
||||
type authority struct {
|
||||
pool *x509.CertPool
|
||||
cert *x509.Certificate
|
||||
key crypto.PrivateKey
|
||||
keyID []byte
|
||||
keyPool chan crypto.PrivateKey
|
||||
cache Cache
|
||||
}
|
||||
|
||||
func New(config *Config) (Authority, error) {
|
||||
cache, err := NewCache(config.Cache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caConfig := config.CA
|
||||
if caConfig == nil {
|
||||
caConfig = new(CAConfig)
|
||||
}
|
||||
|
||||
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
|
||||
if os.IsNotExist(err) {
|
||||
days := caConfig.Days
|
||||
if days == 0 {
|
||||
days = DefaultDays
|
||||
}
|
||||
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
|
||||
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
keyConfig := config.Key
|
||||
if keyConfig == nil {
|
||||
keyConfig = &defaultKeyConfig
|
||||
}
|
||||
|
||||
keyPoolSize := defaultKeyConfig.Pool
|
||||
if keyConfig.Pool > 0 {
|
||||
keyPoolSize = keyConfig.Pool
|
||||
}
|
||||
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
|
||||
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
|
||||
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
|
||||
} else {
|
||||
keyPool <- key
|
||||
}
|
||||
|
||||
go func(pool chan<- crypto.PrivateKey) {
|
||||
for {
|
||||
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msg("error generating private key")
|
||||
}
|
||||
pool <- key
|
||||
}
|
||||
}(keyPool)
|
||||
|
||||
return &authority{
|
||||
pool: pool,
|
||||
cert: cert,
|
||||
key: key,
|
||||
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
|
||||
keyPool: keyPool,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ca *authority) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("ca", ca.cert.Subject.String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ca *authority) Certificate() *x509.Certificate {
|
||||
return ca.cert
|
||||
}
|
||||
|
||||
func (ca *authority) TLSConfig(name string) *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
log := ca.log()
|
||||
if hello.ServerName != "" {
|
||||
name = strings.ToLower(hello.ServerName)
|
||||
log.Debug().Msg("requesting certificate for server name (SNI)")
|
||||
} else {
|
||||
log.Debug().Msg("requesting certificate for hostname")
|
||||
}
|
||||
if cert, ok := ca.getCached(name); ok {
|
||||
log.Debug().
|
||||
Str("subject", cert.Leaf.Subject.String()).
|
||||
Str("serial", cert.Leaf.SerialNumber.String()).
|
||||
Time("valid", cert.Leaf.NotAfter).
|
||||
Msg("using cached certificate")
|
||||
return cert, nil
|
||||
}
|
||||
return ca.issueFor(name)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
|
||||
log := ca.log()
|
||||
|
||||
if cert = ca.cache.Certificate(name); cert == nil {
|
||||
if baseDomain(name) != name {
|
||||
cert = ca.cache.Certificate(baseDomain(name))
|
||||
}
|
||||
}
|
||||
if cert != nil {
|
||||
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
|
||||
DNSName: name,
|
||||
Roots: ca.pool,
|
||||
}); err != nil {
|
||||
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
|
||||
} else {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
|
||||
var (
|
||||
log = ca.log().With().Str("name", name).Logger()
|
||||
key crypto.PrivateKey
|
||||
)
|
||||
select {
|
||||
case key = <-ca.keyPool:
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
|
||||
}
|
||||
if key == nil {
|
||||
panic("key pool returned nil key")
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
name = strings.Join(part[1:], ".")
|
||||
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{name, "*." + name},
|
||||
BasicConstraintsValid: true,
|
||||
NotBefore: now.Add(-DefaultValidity),
|
||||
NotAfter: now.Add(+DefaultValidity),
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
|
||||
out := &tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
//ca.cache[name] = out
|
||||
ca.cache.SaveCertificate(name, out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func containsValidCertificate(cert *tls.Certificate) bool {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if cert.Leaf == nil {
|
||||
var err error
|
||||
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
|
||||
}
|
@@ -1,233 +0,0 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Certificate(name string) *tls.Certificate
|
||||
SaveCertificate(name string, cert *tls.Certificate) error
|
||||
RemoveCertificate(name string)
|
||||
}
|
||||
|
||||
func NewCache(config *CacheConfig) (Cache, error) {
|
||||
if config == nil {
|
||||
return NewCache(&CacheConfig{Type: "memory"})
|
||||
}
|
||||
switch config.Type {
|
||||
case "memory":
|
||||
var cacheConfig = new(MemoryCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMemoryCache(cacheConfig.Size), nil
|
||||
case "disk":
|
||||
var cacheConfig = new(DiskCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
|
||||
default:
|
||||
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type memoryCache struct {
|
||||
cache *expirable.LRU[string, *tls.Certificate]
|
||||
}
|
||||
|
||||
func NewMemoryCache(size int) Cache {
|
||||
return memoryCache{
|
||||
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
|
||||
log.Debug().Str("name", key).Msg("certificate evicted from cache")
|
||||
}, time.Hour*24),
|
||||
}
|
||||
}
|
||||
|
||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
var ok bool
|
||||
if cert, ok = c.cache.Get(name); !ok {
|
||||
cert, _ = c.cache.Get(baseDomain(name))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
c.cache.Add(name, cert)
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c memoryCache) RemoveCertificate(name string) {
|
||||
c.cache.Remove(name)
|
||||
}
|
||||
|
||||
type diskCache string
|
||||
|
||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
|
||||
if !filepath.IsAbs(dir) {
|
||||
var err error
|
||||
if dir, err = filepath.Abs(dir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.Mode()&os.ModePerm|0o057 != 0 {
|
||||
if err := os.Chmod(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if expire > 0 {
|
||||
go expireDiskCache(dir, expire)
|
||||
}
|
||||
|
||||
return diskCache(dir), nil
|
||||
}
|
||||
|
||||
func expireDiskCache(root string, expire time.Duration) {
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
|
||||
ticker := time.NewTicker(expire)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
now := <-ticker.C
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
|
||||
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
// Remove the directory; this will fail if it's not empty, which is fine.
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.LoadCertificate(path)
|
||||
if err != nil {
|
||||
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
} else if cert.NotAfter.Before(now) {
|
||||
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c diskCache) path(name string) string {
|
||||
part := dns.SplitDomainName(strings.ToLower(name))
|
||||
// x,com -> com,x
|
||||
// www,maze,io -> io,maze,www
|
||||
slices.Reverse(part)
|
||||
// com,x -> com,x,x.com
|
||||
// io,maze,www -> io,m,ma,maze,www.maze.io
|
||||
if len(part) > 2 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
} else if len(part) > 1 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
}
|
||||
part[len(part)-1] += ".crt"
|
||||
return filepath.Join(append([]string{string(c)}, part...)...)
|
||||
}
|
||||
|
||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
dir, name := filepath.Split(c.path(name))
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) RemoveCertificate(name string) {
|
||||
path := c.path(name)
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("certificate remove from cache failed")
|
||||
}
|
||||
_ = os.Remove(filepath.Dir(path))
|
||||
log.Debug().Str("name", name).Msg("certificate removed from cache")
|
||||
}
|
||||
|
||||
func baseDomain(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
return strings.Join(part[1:], ".")
|
||||
}
|
||||
return name
|
||||
}
|
@@ -1,25 +0,0 @@
|
||||
package mitm
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDiskCachePath(t *testing.T) {
|
||||
cache := diskCache("testdata")
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{"x.com", "testdata/com/x/x.com.crt"},
|
||||
{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
|
||||
{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
|
||||
{"maze.io", "testdata/io/m/ma/maze.io.crt"},
|
||||
{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
|
||||
{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.test, func(it *testing.T) {
|
||||
if v := cache.path(test.test); v != test.want {
|
||||
it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,89 +0,0 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCommonName = "Styx Certificate Authority"
|
||||
DefaultDays = 3
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
CA *CAConfig `hcl:"ca,block"`
|
||||
Key *KeyConfig `hcl:"key,block"`
|
||||
Cache *CacheConfig `hcl:"cache,block"`
|
||||
}
|
||||
|
||||
type CAConfig struct {
|
||||
Cert string `hcl:"cert"`
|
||||
Key string `hcl:"key,optional"`
|
||||
Days int `hcl:"days,optional"`
|
||||
KeyType string `hcl:"key_type,optional"`
|
||||
Bits int `hcl:"bits,optional"`
|
||||
Name string `hcl:"name,optional"`
|
||||
Country string `hcl:"country,optional"`
|
||||
Organization string `hcl:"organization,optional"`
|
||||
Unit string `hcl:"unit,optional"`
|
||||
Locality string `hcl:"locality,optional"`
|
||||
Province string `hcl:"province,optional"`
|
||||
Address []string `hcl:"address,optional"`
|
||||
PostalCode string `hcl:"postal_code,optional"`
|
||||
}
|
||||
|
||||
func (config CAConfig) DN() pkix.Name {
|
||||
var name = pkix.Name{
|
||||
CommonName: config.Name,
|
||||
StreetAddress: config.Address,
|
||||
}
|
||||
if config.Name == "" {
|
||||
name.CommonName = DefaultCommonName
|
||||
}
|
||||
if config.Country != "" {
|
||||
name.Country = append(name.Country, config.Country)
|
||||
}
|
||||
if config.Organization != "" {
|
||||
name.Organization = append(name.Organization, config.Organization)
|
||||
}
|
||||
if config.Unit != "" {
|
||||
name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
|
||||
}
|
||||
if config.Locality != "" {
|
||||
name.Locality = append(name.Locality, config.Locality)
|
||||
}
|
||||
if config.Province != "" {
|
||||
name.Province = append(name.Province, config.Province)
|
||||
}
|
||||
if config.PostalCode != "" {
|
||||
name.PostalCode = append(name.PostalCode, config.PostalCode)
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
type KeyConfig struct {
|
||||
Type string `hcl:"type,optional"`
|
||||
Bits int `hcl:"bits,optional"`
|
||||
Pool int `hcl:"pool,optional"`
|
||||
}
|
||||
|
||||
var defaultKeyConfig = KeyConfig{
|
||||
Type: "rsa",
|
||||
Bits: 2048,
|
||||
Pool: 5,
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
type MemoryCacheConfig struct {
|
||||
Size int `hcl:"size,optional"`
|
||||
}
|
||||
|
||||
type DiskCacheConfig struct {
|
||||
Path string `hcl:"path"`
|
||||
Expire float64 `hcl:"expire,optional"`
|
||||
}
|
@@ -1,53 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/proxy/match"
|
||||
)
|
||||
|
||||
// Policy contains rules that make up the policy.
|
||||
//
|
||||
// Some policy rules contain nested policies.
|
||||
type Policy struct {
|
||||
Rules []*rawRule `hcl:"on,block" json:"rules"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
Matchers match.Matchers `json:"matchers"` // Matchers for the policy
|
||||
|
||||
}
|
||||
|
||||
func (p *Policy) Configure(matchers match.Matchers) (err error) {
|
||||
for _, r := range p.Rules {
|
||||
if err = r.Configure(matchers); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
p.Matchers = matchers
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Policy) PermitIntercept(r *http.Request) *bool {
|
||||
if p != nil {
|
||||
for _, rule := range p.Rules {
|
||||
if rule, ok := rule.Rule.(InterceptRule); ok {
|
||||
if permit := rule.PermitIntercept(r); permit != nil {
|
||||
return permit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return p.Permit
|
||||
}
|
||||
|
||||
func (p *Policy) PermitRequest(r *http.Request) *bool {
|
||||
if p != nil {
|
||||
for _, rule := range p.Rules {
|
||||
if rule, ok := rule.Rule.(RequestRule); ok {
|
||||
if permit := rule.PermitRequest(r); permit != nil {
|
||||
return permit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return p.Permit
|
||||
}
|
@@ -1,139 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/proxy/match"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type testInDomainList struct {
|
||||
t *testing.T
|
||||
list []string
|
||||
}
|
||||
|
||||
func (testInDomainList) Name() string { return "testInDomainList" }
|
||||
func (l testInDomainList) MatchesRequest(r *http.Request) bool {
|
||||
for _, domain := range l.list {
|
||||
if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) {
|
||||
l.t.Logf("domain %s contains %s", domain, r.URL.Host)
|
||||
return true
|
||||
}
|
||||
l.t.Logf("domain %s does not contain %s", domain, r.URL.Host)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func testInDomain(t *testing.T, domains ...string) match.Matcher {
|
||||
return &testInDomainList{t: t, list: domains}
|
||||
}
|
||||
|
||||
type testInNetworkList struct {
|
||||
t *testing.T
|
||||
list []*net.IPNet
|
||||
}
|
||||
|
||||
func (testInNetworkList) Name() string { return "testInNetworkList" }
|
||||
func (l testInNetworkList) MatchesIP(ip net.IP) bool {
|
||||
for _, ipnet := range l.list {
|
||||
if ipnet.Contains(ip) {
|
||||
l.t.Logf("network %s contains %s", ipnet, ip)
|
||||
return true
|
||||
}
|
||||
l.t.Logf("network %s does not contain %s", ipnet, ip)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func testInNetwork(t *testing.T, cidr string) match.Matcher {
|
||||
t.Helper()
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return testInNetworkList{t: t, list: []*net.IPNet{ipnet}}
|
||||
}
|
||||
|
||||
func TestPolicy(t *testing.T) {
|
||||
var (
|
||||
yes = true
|
||||
nope = false
|
||||
)
|
||||
p := &Policy{
|
||||
Rules: []*rawRule{
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
|
||||
isSource: []bool{true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
|
||||
isSource: []bool{false},
|
||||
},
|
||||
Permit: &yes,
|
||||
},
|
||||
},
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")},
|
||||
},
|
||||
Permit: &yes,
|
||||
},
|
||||
},
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInDomain(t, "google.com")},
|
||||
},
|
||||
Permit: &nope,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
r := &http.Request{
|
||||
URL: &url.URL{Scheme: "http", Host: "golang.org:80"},
|
||||
RemoteAddr: "127.0.0.1:1234",
|
||||
}
|
||||
if v := p.PermitRequest(r); v != nil {
|
||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
||||
}
|
||||
|
||||
p.Rules[0].Rule.(*requestRule).Permit = &yes
|
||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
||||
t.Errorf("expected request to return %t, %v", yes, v)
|
||||
}
|
||||
|
||||
r.RemoteAddr = "192.168.1.2:3456"
|
||||
if v := p.PermitRequest(r); v != nil {
|
||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
||||
}
|
||||
if v := p.PermitIntercept(r); v != nil {
|
||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
||||
}
|
||||
|
||||
r.URL.Host = "maze.io"
|
||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
||||
t.Errorf("expected request to return %t, %v", yes, v)
|
||||
}
|
||||
|
||||
r.URL.Host = "google.com"
|
||||
if v := p.PermitRequest(r); v == nil || *v != nope {
|
||||
t.Errorf("expected request to return %t, %v", nope, v)
|
||||
}
|
||||
|
||||
r.URL.Host = "localhost:80"
|
||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
||||
t.Errorf("expected request to return %t, %v", yes, v)
|
||||
}
|
||||
}
|
@@ -1,368 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/proxy/match"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
)
|
||||
|
||||
// Rule is a policy rule.
|
||||
type Rule interface {
|
||||
Configure(match.Matchers) error
|
||||
}
|
||||
|
||||
// InterceptRule can make policy rule decisions on intercept requests.
|
||||
type InterceptRule interface {
|
||||
PermitIntercept(r *http.Request) *bool
|
||||
}
|
||||
|
||||
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
|
||||
type RequestRule interface {
|
||||
PermitRequest(r *http.Request) *bool
|
||||
}
|
||||
|
||||
type rawRule struct {
|
||||
Type string `hcl:"type,label" json:"type"`
|
||||
Body hcl.Body `hcl:",remain" json:"-"`
|
||||
Rule `json:"rule"`
|
||||
}
|
||||
|
||||
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
|
||||
switch r.Type {
|
||||
case "intercept":
|
||||
r.Rule = new(interceptRule)
|
||||
case "request":
|
||||
r.Rule = new(requestRule)
|
||||
case "days":
|
||||
r.Rule = new(daysRule)
|
||||
case "time":
|
||||
r.Rule = new(timeRule)
|
||||
case "all":
|
||||
r.Rule = new(allRule)
|
||||
default:
|
||||
return fmt.Errorf("policy: invalid event type %q", r.Type)
|
||||
}
|
||||
|
||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Rule.Configure(matchers)
|
||||
}
|
||||
|
||||
type allRule struct {
|
||||
Rules []*rawRule `hcl:"on,block"`
|
||||
Permit *bool `hcl:"permit"`
|
||||
}
|
||||
|
||||
func (r *allRule) Configure(matchers match.Matchers) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
type domainOrNetworkRule struct {
|
||||
matchers []match.Matcher
|
||||
isSource []bool
|
||||
}
|
||||
|
||||
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
|
||||
var m match.Matcher
|
||||
for _, domain := range domains {
|
||||
if m, err = matchers.Get("domain", domain); err != nil {
|
||||
return fmt.Errorf("%s: unknown domain %q", kind, domain)
|
||||
}
|
||||
r.matchers = append(r.matchers, m)
|
||||
r.isSource = append(r.isSource, false)
|
||||
}
|
||||
for _, network := range sources {
|
||||
if m, err = matchers.Get("network", network); err != nil {
|
||||
return fmt.Errorf("%s: unknown source network %q", kind, network)
|
||||
}
|
||||
r.matchers = append(r.matchers, m)
|
||||
r.isSource = append(r.isSource, true)
|
||||
}
|
||||
for _, network := range targets {
|
||||
if m, err = matchers.Get("network", network); err != nil {
|
||||
return fmt.Errorf("%s: unknown target network %q", kind, network)
|
||||
}
|
||||
r.matchers = append(r.matchers, m)
|
||||
r.isSource = append(r.isSource, false)
|
||||
}
|
||||
if len(r.matchers) == 0 {
|
||||
return fmt.Errorf("%s: missing any of domain, source, target", kind)
|
||||
}
|
||||
if id != nil {
|
||||
*id = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
|
||||
for i, m := range r.matchers {
|
||||
if m, ok := m.(match.Request); ok {
|
||||
if m.MatchesRequest(q) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if m, ok := m.(match.IP); ok {
|
||||
if r.isSource[i] {
|
||||
if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
var (
|
||||
host = netutil.Host(q.URL.Host)
|
||||
ips []net.IP
|
||||
)
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
ips = append(ips, ip)
|
||||
} else {
|
||||
ips, _ = net.LookupIP(host)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if m.MatchesIP(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type interceptRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
|
||||
Source []string `hcl:"source,optional" json:"source,omitempty"`
|
||||
Target []string `hcl:"target,optional" json:"target,omitempty"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
domainOrNetworkRule `json:"-"`
|
||||
}
|
||||
|
||||
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
|
||||
return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
|
||||
}
|
||||
|
||||
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
|
||||
if r.matchesRequest(q) {
|
||||
return r.Permit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type requestRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
|
||||
Source []string `hcl:"source,optional" json:"source,omitempty"`
|
||||
Target []string `hcl:"target,optional" json:"target,omitempty"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
domainOrNetworkRule `json:"-"`
|
||||
}
|
||||
|
||||
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
|
||||
return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
|
||||
}
|
||||
|
||||
func (r *requestRule) PermitRequest(q *http.Request) *bool {
|
||||
if r.matchesRequest(q) {
|
||||
return r.Permit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type timeRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Time []string `hcl:"time" json:"time"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
Body hcl.Body `hcl:",remain" json:"-"`
|
||||
Rules *Policy `json:"rules"`
|
||||
Start Time `json:"start"`
|
||||
End Time `json:"end"`
|
||||
}
|
||||
|
||||
func (r *timeRule) isActive() bool {
|
||||
if r == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
now := Now()
|
||||
if r.Start.After(r.End) { // ie: 18:00-06:00
|
||||
return now.After(r.Start) || now.Before(r.End)
|
||||
}
|
||||
return now.After(r.Start) && now.Before(r.End)
|
||||
}
|
||||
|
||||
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
|
||||
if len(r.Time) != 2 {
|
||||
return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
|
||||
}
|
||||
if r.Start, err = ParseTime(r.Time[0]); err != nil {
|
||||
return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
|
||||
}
|
||||
if r.End, err = ParseTime(r.Time[1]); err != nil {
|
||||
return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
|
||||
}
|
||||
|
||||
r.Rules = new(Policy)
|
||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
|
||||
return diag
|
||||
}
|
||||
|
||||
if err = r.Rules.Configure(matchers); err != nil {
|
||||
return
|
||||
}
|
||||
r.Rules.Matchers = nil
|
||||
|
||||
if r.ID == "" {
|
||||
r.ID = uuid.NewString()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitIntercept(q)
|
||||
}
|
||||
|
||||
func (r *timeRule) PermitRequest(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitRequest(q)
|
||||
}
|
||||
|
||||
type daysRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Days string `hcl:"days" json:"days"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
Body hcl.Body `hcl:",remain" json:"-"`
|
||||
Rules *Policy `json:"rules"`
|
||||
cond []onCond
|
||||
}
|
||||
|
||||
func (r *daysRule) isActive() bool {
|
||||
if r == nil || len(r.cond) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, cond := range r.cond {
|
||||
if cond(now) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
|
||||
if r.cond, err = parseOnCond(r.Days); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.Rules = new(Policy)
|
||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
|
||||
return diag
|
||||
}
|
||||
if err = r.Rules.Configure(matchers); err != nil {
|
||||
return
|
||||
}
|
||||
r.Rules.Matchers = nil
|
||||
|
||||
if r.ID == "" {
|
||||
r.ID = uuid.NewString()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitIntercept(q)
|
||||
}
|
||||
|
||||
func (r *daysRule) PermitRequest(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitRequest(q)
|
||||
}
|
||||
|
||||
type onCond func(time.Time) bool
|
||||
|
||||
var weekdays = map[string]time.Weekday{
|
||||
"sun": time.Sunday,
|
||||
"mon": time.Monday,
|
||||
"tue": time.Tuesday,
|
||||
"wed": time.Wednesday,
|
||||
"thu": time.Thursday,
|
||||
"fri": time.Friday,
|
||||
"sat": time.Saturday,
|
||||
}
|
||||
|
||||
func parseOnCond(when string) (conds []onCond, err error) {
|
||||
for _, spec := range strings.Split(when, ",") {
|
||||
spec = strings.ToLower(strings.TrimSpace(spec))
|
||||
if d, ok := weekdays[spec]; ok {
|
||||
conds = append(conds, onWeekday(d))
|
||||
} else if spec == "weekend" || spec == "weekends" {
|
||||
conds = append(conds, onWeekend)
|
||||
} else if spec == "workday" || spec == "workdays" {
|
||||
conds = append(conds, onWorkday)
|
||||
} else if strings.ContainsRune(spec, '-') {
|
||||
var (
|
||||
part = strings.SplitN(spec, "-", 2)
|
||||
from, upto time.Weekday
|
||||
ok bool
|
||||
)
|
||||
if from, ok = weekdays[part[0]]; !ok {
|
||||
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
|
||||
}
|
||||
if upto, ok = weekdays[part[1]]; !ok {
|
||||
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
|
||||
}
|
||||
if from < upto {
|
||||
for d := from; d < upto; d++ {
|
||||
conds = append(conds, onWeekday(d))
|
||||
}
|
||||
} else {
|
||||
for d := time.Sunday; d < from; d++ {
|
||||
conds = append(conds, onWeekday(d))
|
||||
}
|
||||
for d := upto; d <= time.Saturday; d++ {
|
||||
conds = append(conds, onWeekday(d))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("on %q: invalid condition", spec)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func onWeekday(weekday time.Weekday) onCond {
|
||||
return func(t time.Time) bool {
|
||||
return t.Weekday() == weekday
|
||||
}
|
||||
}
|
||||
|
||||
func onWeekend(t time.Time) bool {
|
||||
d := t.Weekday()
|
||||
return d == time.Saturday || d == time.Sunday
|
||||
}
|
||||
|
||||
func onWorkday(t time.Time) bool {
|
||||
d := t.Weekday()
|
||||
return !(d == time.Saturday || d == time.Sunday)
|
||||
}
|
@@ -1,53 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Time struct {
|
||||
Hour int
|
||||
Minute int
|
||||
Second int
|
||||
}
|
||||
|
||||
func (t Time) Eq(other Time) bool {
|
||||
return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second
|
||||
}
|
||||
|
||||
func (t Time) After(other Time) bool {
|
||||
return t.Seconds() > other.Seconds()
|
||||
}
|
||||
|
||||
func (t Time) Before(other Time) bool {
|
||||
return t.Seconds() < other.Seconds()
|
||||
}
|
||||
|
||||
func (t Time) Seconds() int {
|
||||
return t.Hour*3600 + t.Minute*60 + t.Second
|
||||
}
|
||||
|
||||
func (t Time) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil
|
||||
}
|
||||
|
||||
var timeFormats = []string{
|
||||
time.TimeOnly,
|
||||
"15:04",
|
||||
time.Kitchen,
|
||||
}
|
||||
|
||||
func Now() Time {
|
||||
now := time.Now()
|
||||
return Time{now.Hour(), now.Minute(), now.Second()}
|
||||
}
|
||||
|
||||
func ParseTime(s string) (t Time, err error) {
|
||||
var tt time.Time
|
||||
for _, layout := range timeFormats {
|
||||
if tt, err = time.Parse(layout, s); err == nil {
|
||||
return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil
|
||||
}
|
||||
}
|
||||
return Time{}, fmt.Errorf("time: invalid time %q", s)
|
||||
}
|
993
proxy/proxy.go
993
proxy/proxy.go
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
||||
// Package resolver implements a caching DNS resolver
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultSize = 1024
|
||||
DefaultTTL = 5 * time.Minute
|
||||
DefaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfig are the defaults for the Default resolver.
|
||||
DefaultConfig = Config{
|
||||
Size: DefaultSize,
|
||||
TTL: DefaultTTL.Seconds(),
|
||||
Timeout: DefaultTimeout.Seconds(),
|
||||
}
|
||||
|
||||
// Default resolver.
|
||||
Default = New(DefaultConfig)
|
||||
)
|
||||
|
||||
type Resolver interface {
|
||||
// Lookup returns resolved IPs for given hostname/ips.
|
||||
Lookup(context.Context, string) ([]string, error)
|
||||
}
|
||||
|
||||
type netResolver struct {
|
||||
resolver *net.Resolver
|
||||
timeout time.Duration
|
||||
noIPv6 bool
|
||||
cache *expirable.LRU[string, []string]
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Size is our cache size in number of entries.
|
||||
Size int `hcl:"size,optional"`
|
||||
|
||||
// TTL is the cache time to live in seconds.
|
||||
TTL float64 `hcl:"ttl,optional"`
|
||||
|
||||
// Timeout is the cache timeout in seconds.
|
||||
Timeout float64 `hcl:"timeout,optional"`
|
||||
|
||||
// Server are alternative DNS servers.
|
||||
Server []string `hcl:"server,optional"`
|
||||
|
||||
// NoIPv6 disables IPv6 DNS resolution.
|
||||
NoIPv6 bool `hcl:"noipv6,optional"`
|
||||
}
|
||||
|
||||
func New(config Config) Resolver {
|
||||
var (
|
||||
size = config.Size
|
||||
ttl = time.Duration(float64(time.Second) * config.TTL)
|
||||
timeout = time.Duration(float64(time.Second) * config.Timeout)
|
||||
)
|
||||
if size <= 0 {
|
||||
size = DefaultSize
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultTTL
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 0
|
||||
}
|
||||
|
||||
var resolver = new(net.Resolver)
|
||||
if len(config.Server) > 0 {
|
||||
var dialer net.Dialer
|
||||
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
|
||||
return dialer.DialContext(ctx, network, server)
|
||||
}
|
||||
}
|
||||
|
||||
return &netResolver{
|
||||
resolver: resolver,
|
||||
timeout: timeout,
|
||||
noIPv6: config.NoIPv6,
|
||||
cache: expirable.NewLRU[string, []string](size, nil, ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
if hosts, ok := r.cache.Get(host); ok {
|
||||
rand.Shuffle(len(hosts), func(i, j int) {
|
||||
hosts[i], hosts[j] = hosts[j], hosts[i]
|
||||
})
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
hosts, err := r.lookup(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.cache.Add(host, hosts)
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
|
||||
if r.timeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, r.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if net.ParseIP(host) == nil {
|
||||
addrs, err := r.resolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.noIPv6 {
|
||||
var addrs4 []string
|
||||
for _, addr := range addrs {
|
||||
if net.ParseIP(addr).To4() != nil {
|
||||
addrs4 = append(addrs4, addr)
|
||||
}
|
||||
}
|
||||
return addrs4, nil
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
addrs, err := r.resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hosts := make([]string, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
if !r.noIPv6 || addr.IP.To4() != nil {
|
||||
hosts[i] = addr.IP.String()
|
||||
}
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
@@ -2,77 +2,58 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
func NewResponse(code int, body io.Reader, request *http.Request) *http.Response {
|
||||
if body == nil {
|
||||
body = new(bytes.Buffer)
|
||||
}
|
||||
|
||||
rc, ok := body.(io.ReadCloser)
|
||||
if !ok {
|
||||
rc = io.NopCloser(body)
|
||||
}
|
||||
|
||||
response := &http.Response{
|
||||
Status: strconv.Itoa(code) + " " + http.StatusText(code),
|
||||
StatusCode: code,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Body: rc,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
if request != nil {
|
||||
response.Close = request.Close
|
||||
response.Proto = request.Proto
|
||||
response.ProtoMajor = request.ProtoMajor
|
||||
response.ProtoMinor = request.ProtoMinor
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
type withLen interface {
|
||||
Len() int
|
||||
}
|
||||
|
||||
type withSize interface {
|
||||
type sizer interface {
|
||||
Size() int64
|
||||
}
|
||||
|
||||
func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response {
|
||||
response := NewResponse(code, body, request)
|
||||
response.Header.Set(HeaderContentType, "application/json")
|
||||
if s, ok := body.(withLen); ok {
|
||||
response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len()))
|
||||
} else if s, ok := body.(withSize); ok {
|
||||
response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10))
|
||||
} else {
|
||||
log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size")
|
||||
// NewResponse prepares a net [http.Response], based on the status code, optional body and
|
||||
// optional [http.Request].
|
||||
func NewResponse(code int, body io.ReadCloser, req *http.Request) *http.Response {
|
||||
res := &http.Response{
|
||||
StatusCode: code,
|
||||
Header: make(http.Header),
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
}
|
||||
response.Close = true
|
||||
return response
|
||||
|
||||
if text := http.StatusText(code); text != "" {
|
||||
res.Status = strconv.Itoa(code) + " " + text
|
||||
} else {
|
||||
res.Status = strconv.Itoa(code)
|
||||
}
|
||||
|
||||
if body == nil && code >= 400 {
|
||||
body = io.NopCloser(bytes.NewBufferString(http.StatusText(code)))
|
||||
}
|
||||
|
||||
res.Body = body
|
||||
|
||||
if s, ok := body.(sizer); ok {
|
||||
res.ContentLength = s.Size()
|
||||
}
|
||||
|
||||
if req != nil {
|
||||
res.Close = req.Close
|
||||
res.Proto = req.Proto
|
||||
res.ProtoMajor = req.ProtoMajor
|
||||
res.ProtoMinor = req.ProtoMinor
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func ErrorResponse(request *http.Request, err error) *http.Response {
|
||||
response := NewResponse(http.StatusBadGateway, nil, request)
|
||||
func NewErrorResponse(err error, req *http.Request) *http.Response {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
response.StatusCode = http.StatusNotFound
|
||||
case os.IsPermission(err):
|
||||
response.StatusCode = http.StatusForbidden
|
||||
case os.IsTimeout(err):
|
||||
return NewResponse(http.StatusGatewayTimeout, nil, req)
|
||||
default:
|
||||
return NewResponse(http.StatusBadGateway, nil, req)
|
||||
}
|
||||
response.Status = http.StatusText(response.StatusCode)
|
||||
response.Close = true
|
||||
return response
|
||||
}
|
||||
|
151
proxy/session.go
151
proxy/session.go
@@ -1,151 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
var seed = rand.NewSource(time.Now().UnixNano())
|
||||
|
||||
type Context struct {
|
||||
id int64
|
||||
conn *wrappedConn
|
||||
rw *bufio.ReadWriter
|
||||
parent *Session
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context {
|
||||
if wrapped, ok := conn.(*wrappedConn); ok {
|
||||
conn = wrapped.Conn
|
||||
}
|
||||
|
||||
ctx := &Context{
|
||||
id: seed.Int63(),
|
||||
conn: &wrappedConn{Conn: conn},
|
||||
rw: rw,
|
||||
parent: parent,
|
||||
data: make(map[string]any),
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (ctx *Context) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("context", ctx.ID()).
|
||||
Str("addr", ctx.RemoteAddr().String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ctx *Context) ID() string {
|
||||
var b [8]byte
|
||||
binary.BigEndian.PutUint64(b[:], uint64(ctx.id))
|
||||
if ctx.parent != nil {
|
||||
return ctx.parent.ID() + "-" + hex.EncodeToString(b[:])
|
||||
}
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (ctx *Context) IsTLS() bool {
|
||||
_, ok := ctx.conn.Conn.(*tls.Conn)
|
||||
return ok && ctx.parent != nil
|
||||
}
|
||||
|
||||
func (ctx *Context) RemoteAddr() net.Addr {
|
||||
if ctx.parent != nil {
|
||||
return ctx.parent.ctx.RemoteAddr()
|
||||
}
|
||||
return ctx.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (ctx *Context) SetDeadline(t time.Time) error {
|
||||
if ctx.parent != nil {
|
||||
return ctx.parent.ctx.SetDeadline(t)
|
||||
}
|
||||
return ctx.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (ctx *Context) Set(key string, value any) {
|
||||
ctx.data[key] = value
|
||||
}
|
||||
|
||||
func (ctx *Context) Get(key string) (value any, ok bool) {
|
||||
value, ok = ctx.data[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (ctx *Context) Flush() error {
|
||||
return ctx.rw.Flush()
|
||||
}
|
||||
|
||||
func (ctx *Context) Write(p []byte) (n int, err error) {
|
||||
if n, err = ctx.rw.Write(p); n > 0 {
|
||||
atomic.AddInt64(&ctx.conn.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
id int64
|
||||
ctx *Context
|
||||
request *http.Request
|
||||
response *http.Response
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
func newSession(ctx *Context, request *http.Request) *Session {
|
||||
return &Session{
|
||||
id: seed.Int63(),
|
||||
ctx: ctx,
|
||||
request: request,
|
||||
data: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (ses *Session) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("context", ses.ctx.ID()).
|
||||
Str("session", ses.ID()).
|
||||
Str("addr", ses.ctx.RemoteAddr().String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ses *Session) ID() string {
|
||||
var b [8]byte
|
||||
binary.BigEndian.PutUint64(b[:], uint64(ses.id))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (ses *Session) Context() *Context {
|
||||
return ses.ctx
|
||||
}
|
||||
|
||||
func (ses *Session) Request() *http.Request {
|
||||
return ses.request
|
||||
}
|
||||
|
||||
func (ses *Session) Response() *http.Response {
|
||||
return ses.response
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Write(p []byte) (n int, err error) {
|
||||
if n, err = c.Conn.Write(p); n > 0 {
|
||||
atomic.AddInt64(&c.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
19
proxy/stats.go
Normal file
19
proxy/stats.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"strconv"
|
||||
|
||||
"git.maze.io/maze/styx/db/stats"
|
||||
)
|
||||
|
||||
func countStatus(code int) {
|
||||
k := "http:status:" + strconv.Itoa(code)
|
||||
v := expvar.Get(k)
|
||||
if v == nil {
|
||||
//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
|
||||
v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
|
||||
expvar.Publish(k, v)
|
||||
}
|
||||
v.(stats.Metric).Add(1)
|
||||
}
|
@@ -1,225 +0,0 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type Stats struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func New() (*Stats, error) {
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := filepath.Join(u.HomeDir, ".styx", "stats.db")
|
||||
if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, table := range []string{
|
||||
createLog,
|
||||
createDomainStat,
|
||||
createStatusStat,
|
||||
} {
|
||||
if _, err = db.Exec(table); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Stats{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *Stats) AddLog(entry *Log) error {
|
||||
var (
|
||||
request []byte
|
||||
response []byte
|
||||
err error
|
||||
)
|
||||
if request, err = json.Marshal(entry.Request); err != nil {
|
||||
return err
|
||||
}
|
||||
if response, err = json.Marshal(entry.Response); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) {
|
||||
if limit == 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []*Log
|
||||
for rows.Next() {
|
||||
var entry = new(Log)
|
||||
if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logs = append(logs, entry)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
Code int `json:"code"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
var timeZero time.Time
|
||||
|
||||
func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) {
|
||||
if since.Equal(timeZero) {
|
||||
since = time.Now().Add(-24 * time.Hour)
|
||||
}
|
||||
|
||||
rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var stats []*Status
|
||||
for rows.Next() {
|
||||
var entry = new(Status)
|
||||
if err = rows.Scan(&entry.Code, &entry.Count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats = append(stats, entry)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
const createLog = `CREATE TABLE IF NOT EXISTS styx_log (
|
||||
id INT PRIMARY KEY,
|
||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
client_ip TEXT NOT NULL,
|
||||
request JSONB NOT NULL,
|
||||
response JSONB NOT NULL
|
||||
);`
|
||||
|
||||
type Log struct {
|
||||
Time time.Time `json:"time"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
Request *Request `json:"request"`
|
||||
Response *Response `json:"response"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
URL string `json:"url"`
|
||||
Host string `json:"host"`
|
||||
Method string `json:"method"`
|
||||
Proto string `json:"proto"`
|
||||
Header http.Header `json:"header"`
|
||||
}
|
||||
|
||||
func (r *Request) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), r)
|
||||
case []byte:
|
||||
return json.Unmarshal(v, r)
|
||||
default:
|
||||
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Request) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(r)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
func FromRequest(r *http.Request) *Request {
|
||||
return &Request{
|
||||
URL: r.URL.String(),
|
||||
Host: r.Host,
|
||||
Method: r.Method,
|
||||
Proto: r.Proto,
|
||||
Header: r.Header,
|
||||
}
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Status int `json:"status"`
|
||||
Size int64 `json:"size"`
|
||||
Header http.Header `json:"header"`
|
||||
}
|
||||
|
||||
func (r *Response) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), r)
|
||||
case []byte:
|
||||
return json.Unmarshal(v, r)
|
||||
default:
|
||||
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Response) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(r)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
func (r *Response) SetSize(size int64) *Response {
|
||||
r.Size = size
|
||||
return r
|
||||
}
|
||||
|
||||
func FromResponse(r *http.Response) *Response {
|
||||
return &Response{
|
||||
Status: r.StatusCode,
|
||||
Header: r.Header,
|
||||
}
|
||||
}
|
||||
|
||||
const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status (
|
||||
id INT PRIMARY KEY,
|
||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
status INT NOT NULL
|
||||
);`
|
||||
|
||||
const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain (
|
||||
id INT PRIMARY KEY,
|
||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
domain TEXT NOT NULL
|
||||
);`
|
@@ -1,16 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// connReader is a net.Conn with a separate reader.
|
||||
type connReader struct {
|
||||
net.Conn
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (c connReader) Read(p []byte) (int, error) {
|
||||
return c.Reader.Read(p)
|
||||
}
|
Reference in New Issue
Block a user