Checkpoint
This commit is contained in:
145
proxy/admin.go
145
proxy/admin.go
@@ -1,145 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Admin struct {
|
|
||||||
*Proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAdmin(proxy *Proxy) *Admin {
|
|
||||||
a := &Admin{
|
|
||||||
Proxy: proxy,
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) handleRequest(ses *Session) error {
|
|
||||||
var (
|
|
||||||
logger = ses.log()
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
switch ses.request.URL.Path {
|
|
||||||
case "/ca.crt":
|
|
||||||
err = a.handleCACert(ses)
|
|
||||||
case "/api/v1/policy":
|
|
||||||
err = a.apiPolicy(ses)
|
|
||||||
case "/api/v1/policy/matcher":
|
|
||||||
err = a.apiPolicyMatcher(ses)
|
|
||||||
case "/api/v1/stats/log":
|
|
||||||
err = a.apiStatsLog(ses)
|
|
||||||
case "/api/v1/stats/status":
|
|
||||||
err = a.apiStatsStatus(ses)
|
|
||||||
default:
|
|
||||||
if strings.HasPrefix(ses.request.URL.Path, "/api") {
|
|
||||||
err = errors.New("invalid endpoint")
|
|
||||||
} else {
|
|
||||||
err = os.ErrNotExist
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("admin error")
|
|
||||||
ses.response = ErrorResponse(ses.request, err)
|
|
||||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
|
||||||
ses.response.Close = true
|
|
||||||
return a.writeResponse(ses)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) handleCACert(ses *Session) error {
|
|
||||||
b := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: a.authority.Certificate().Raw,
|
|
||||||
})
|
|
||||||
|
|
||||||
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
|
|
||||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
|
||||||
|
|
||||||
ses.response.Close = true
|
|
||||||
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
|
|
||||||
ses.response.ContentLength = int64(len(b))
|
|
||||||
return a.writeResponse(ses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) apiPolicy(ses *Session) error {
|
|
||||||
var (
|
|
||||||
b = new(bytes.Buffer)
|
|
||||||
e = json.NewEncoder(b)
|
|
||||||
)
|
|
||||||
e.SetIndent("", " ")
|
|
||||||
if err := e.Encode(a.config.Policy); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
|
||||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
|
||||||
ses.response.Close = true
|
|
||||||
return a.writeResponse(ses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) apiPolicyMatcher(ses *Session) error {
|
|
||||||
var (
|
|
||||||
b = new(bytes.Buffer)
|
|
||||||
e = json.NewEncoder(b)
|
|
||||||
)
|
|
||||||
e.SetIndent("", " ")
|
|
||||||
if err := e.Encode(a.config.Policy.Matchers); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
|
||||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
|
||||||
ses.response.Close = true
|
|
||||||
return a.writeResponse(ses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) apiResponse(ses *Session, v any, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
b = new(bytes.Buffer)
|
|
||||||
e = json.NewEncoder(b)
|
|
||||||
)
|
|
||||||
e.SetIndent("", " ")
|
|
||||||
if err := e.Encode(v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
|
||||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
|
||||||
ses.response.Close = true
|
|
||||||
return a.writeResponse(ses)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) apiStatsLog(ses *Session) error {
|
|
||||||
var (
|
|
||||||
query = ses.request.URL.Query()
|
|
||||||
offset, _ = strconv.Atoi(query.Get("offset"))
|
|
||||||
limit, _ = strconv.Atoi(query.Get("limit"))
|
|
||||||
)
|
|
||||||
if limit > 100 {
|
|
||||||
limit = 100
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := a.stats.QueryLog(offset, limit)
|
|
||||||
return a.apiResponse(ses, s, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Admin) apiStatsStatus(ses *Session) error {
|
|
||||||
s, err := a.stats.QueryStatus(time.Time{})
|
|
||||||
return a.apiResponse(ses, s, err)
|
|
||||||
}
|
|
8
proxy/cache/config.go
vendored
8
proxy/cache/config.go
vendored
@@ -1,8 +0,0 @@
|
|||||||
package cache
|
|
||||||
|
|
||||||
import "github.com/hashicorp/hcl/v2"
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Type string `hcl:"type"`
|
|
||||||
Body hcl.Body `hcl:",remain"`
|
|
||||||
}
|
|
@@ -1,88 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/proxy/policy"
|
|
||||||
"git.maze.io/maze/styx/proxy/resolver"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ConnectHandler interface {
|
|
||||||
HandleConnect(session *Session, network, address string) net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request.
|
|
||||||
type ConnectHandlerFunc func(session *Session, network, address string) net.Conn
|
|
||||||
|
|
||||||
func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn {
|
|
||||||
return f(session, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestHandler interface {
|
|
||||||
HandleRequest(session *Session) (*http.Request, *http.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestHandlerFunc is called when the proxy receives a new request.
|
|
||||||
type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response)
|
|
||||||
|
|
||||||
func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) {
|
|
||||||
return f(session)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseHandler interface {
|
|
||||||
HandleResponse(session *Session) *http.Response
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponseHandler is called when the proxy receives a response.
|
|
||||||
type ResponseHandlerFunc func(session *Session) *http.Response
|
|
||||||
|
|
||||||
func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response {
|
|
||||||
return f(session)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrorHandler interface {
|
|
||||||
HandleError(session *Session, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrorHandlerFunc func(session *Session, err error)
|
|
||||||
|
|
||||||
func (f ErrorHandlerFunc) HandleError(session *Session, err error) {
|
|
||||||
f(session, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
// Listen address.
|
|
||||||
Listen string `hcl:"listen,optional"`
|
|
||||||
|
|
||||||
// Bind address for outgoing connections.
|
|
||||||
Bind string `hcl:"bind,optional"`
|
|
||||||
|
|
||||||
// Interface for outgoing connections.
|
|
||||||
Interface string `hcl:"interface,optional"`
|
|
||||||
|
|
||||||
// Upstream proxy servers.
|
|
||||||
Upstream []string `hcl:"upstream,optional"`
|
|
||||||
|
|
||||||
// DialTimeout for establishing new connections.
|
|
||||||
DialTimeout time.Duration `hcl:"dial_timeout,optional"`
|
|
||||||
|
|
||||||
// Policy for the proxy.
|
|
||||||
Policy *policy.Policy `hcl:"policy,block"`
|
|
||||||
|
|
||||||
// Resolver for the proxy.
|
|
||||||
Resolver resolver.Resolver
|
|
||||||
|
|
||||||
ConnectHandler ConnectHandler
|
|
||||||
RequestHandler RequestHandler
|
|
||||||
ResponseHandler ResponseHandler
|
|
||||||
ErrorHandler ErrorHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ ConnectHandler = (ConnectHandlerFunc)(nil)
|
|
||||||
_ RequestHandler = (RequestHandlerFunc)(nil)
|
|
||||||
_ ResponseHandler = (ResponseHandlerFunc)(nil)
|
|
||||||
_ ErrorHandler = (ErrorHandlerFunc)(nil)
|
|
||||||
)
|
|
227
proxy/context.go
Normal file
227
proxy/context.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.maze.io/maze/styx/internal/netutil/arp"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context provides convenience functions for the current ongoing HTTP proxy transaction (request).
|
||||||
|
type Context interface {
|
||||||
|
// Conn is the backing connection for this context.
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
// ID is a unique connection identifier.
|
||||||
|
ID() uint64
|
||||||
|
|
||||||
|
// Reader returns a buffered reader on top of the [net.Conn].
|
||||||
|
Reader() *bufio.Reader
|
||||||
|
|
||||||
|
// BytesRead returns the number of bytes read.
|
||||||
|
BytesRead() int64
|
||||||
|
|
||||||
|
// BytesSent returns the number of bytes written.
|
||||||
|
BytesSent() int64
|
||||||
|
|
||||||
|
// TLSState returns the TLS connection state, it returns nil if the connection is not a TLS connection.
|
||||||
|
TLSState() *tls.ConnectionState
|
||||||
|
|
||||||
|
// Request is the request made to the proxy.
|
||||||
|
Request() *http.Request
|
||||||
|
|
||||||
|
// Response is the response that will be sent back to the client.
|
||||||
|
Response() *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingReader struct {
|
||||||
|
reader io.Reader
|
||||||
|
bytes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *countingReader) Read(p []byte) (n int, err error) {
|
||||||
|
if n, err = r.reader.Read(p); n > 0 {
|
||||||
|
atomic.AddInt64(&r.bytes, int64(n))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingWriter struct {
|
||||||
|
writer io.Writer
|
||||||
|
bytes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *countingWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if n, err = w.writer.Write(p); n > 0 {
|
||||||
|
atomic.AddInt64(&w.bytes, int64(n))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyContext struct {
|
||||||
|
net.Conn
|
||||||
|
id uint64
|
||||||
|
mac net.HardwareAddr
|
||||||
|
cr *countingReader
|
||||||
|
br *bufio.Reader
|
||||||
|
cw *countingWriter
|
||||||
|
isTransparent bool
|
||||||
|
isTransparentTLS bool
|
||||||
|
serverName string
|
||||||
|
req *http.Request
|
||||||
|
res *http.Response
|
||||||
|
idleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContext returns an initialized context for the provided [net.Conn].
|
||||||
|
func NewContext(c net.Conn) Context {
|
||||||
|
if c, ok := c.(*proxyContext); ok {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 8)
|
||||||
|
io.ReadFull(rand.Reader, b)
|
||||||
|
|
||||||
|
cr := &countingReader{reader: c}
|
||||||
|
cw := &countingWriter{writer: c}
|
||||||
|
return &proxyContext{
|
||||||
|
Conn: c,
|
||||||
|
id: binary.BigEndian.Uint64(b),
|
||||||
|
mac: arp.Get(c.RemoteAddr()),
|
||||||
|
cr: cr,
|
||||||
|
br: bufio.NewReader(cr),
|
||||||
|
cw: cw,
|
||||||
|
res: &http.Response{StatusCode: 200},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) AccessLogEntry() *logrus.Entry {
|
||||||
|
var id [8]byte
|
||||||
|
binary.BigEndian.PutUint64(id[:], c.id)
|
||||||
|
entry := AccessLog.WithFields(logrus.Fields{
|
||||||
|
"client": c.RemoteAddr().String(),
|
||||||
|
"server": c.LocalAddr().String(),
|
||||||
|
"id": hex.EncodeToString(id[:]),
|
||||||
|
"bytes_rx": c.BytesRead(),
|
||||||
|
"bytes_tx": c.BytesSent(),
|
||||||
|
})
|
||||||
|
if c.mac != nil {
|
||||||
|
return entry.WithField("client_mac", c.mac.String())
|
||||||
|
}
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) LogEntry() *logrus.Entry {
|
||||||
|
var id [8]byte
|
||||||
|
binary.BigEndian.PutUint64(id[:], c.id)
|
||||||
|
return ServerLog.WithFields(logrus.Fields{
|
||||||
|
"client": c.RemoteAddr().String(),
|
||||||
|
"server": c.LocalAddr().String(),
|
||||||
|
"id": hex.EncodeToString(id[:]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) String() string {
|
||||||
|
return fmt.Sprintf("client=%s server=%s id=%#08x",
|
||||||
|
c.RemoteAddr().String(),
|
||||||
|
c.LocalAddr().String(),
|
||||||
|
c.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) ID() uint64 {
|
||||||
|
return c.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) BytesRead() int64 {
|
||||||
|
return atomic.LoadInt64(&c.cr.bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) BytesSent() int64 {
|
||||||
|
return atomic.LoadInt64(&c.cw.bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) Read(p []byte) (n int, err error) {
|
||||||
|
if c.idleTimeout > 0 {
|
||||||
|
if err = c.SetReadDeadline(time.Now().Add(c.idleTimeout)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.br.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) Write(p []byte) (n int, err error) {
|
||||||
|
if c.idleTimeout > 0 {
|
||||||
|
if err = c.SetWriteDeadline(time.Now().Add(c.idleTimeout)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.cw.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) Reader() *bufio.Reader {
|
||||||
|
return c.br
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) Request() *http.Request {
|
||||||
|
return c.req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) SetRequest(req *http.Request) {
|
||||||
|
c.req = req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) Response() *http.Response {
|
||||||
|
return c.res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) SetIdleTimeout(t time.Duration) {
|
||||||
|
c.idleTimeout = t
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionStater interface {
|
||||||
|
ConnectionState() tls.ConnectionState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) TLSState() *tls.ConnectionState {
|
||||||
|
if s, ok := c.Conn.(connectionStater); ok {
|
||||||
|
state := s.ConnectionState()
|
||||||
|
return &state
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// http.ResponseWriter interface:
|
||||||
|
|
||||||
|
func (c *proxyContext) Header() http.Header {
|
||||||
|
if c.res == nil {
|
||||||
|
c.res = NewResponse(http.StatusOK, nil, c.req)
|
||||||
|
}
|
||||||
|
return c.res.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyContext) WriteHeader(code int) {
|
||||||
|
if c.res == nil {
|
||||||
|
c.res = NewResponse(code, nil, c.req)
|
||||||
|
} else {
|
||||||
|
if text := http.StatusText(code); text != "" {
|
||||||
|
c.res.Status = strconv.Itoa(code) + " " + text
|
||||||
|
} else {
|
||||||
|
c.res.Status = strconv.Itoa(code)
|
||||||
|
}
|
||||||
|
c.res.StatusCode = code
|
||||||
|
}
|
||||||
|
//return c.res.Header.Write(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Context = (*proxyContext)(nil)
|
2
proxy/doc.go
Normal file
2
proxy/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package proxy contains a HTTP(s) (transparent) proxy server.
|
||||||
|
package proxy
|
309
proxy/handler.go
Normal file
309
proxy/handler.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.maze.io/maze/styx/ca"
|
||||||
|
"git.maze.io/maze/styx/internal/cryptutil"
|
||||||
|
"git.maze.io/maze/styx/internal/netutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dialer interface {
|
||||||
|
DialContext(context.Context, *http.Request) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultDialer struct{}
|
||||||
|
|
||||||
|
func (defaultDialer) DialContext(ctx context.Context, req *http.Request) (net.Conn, error) {
|
||||||
|
if host := netutil.Host(req.URL.Host); host == "" {
|
||||||
|
return nil, errors.New("proxy: host missing in address")
|
||||||
|
}
|
||||||
|
|
||||||
|
var d = net.Dialer{
|
||||||
|
Resolver: net.DefaultResolver,
|
||||||
|
FallbackDelay: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have a port.
|
||||||
|
switch req.URL.Scheme {
|
||||||
|
case "http", "ws":
|
||||||
|
req.URL.Host = netutil.EnsurePort(req.URL.Host, "80")
|
||||||
|
case "https", "wss":
|
||||||
|
req.URL.Host = netutil.EnsurePort(req.URL.Host, "443")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve the host.
|
||||||
|
if ips, err := d.Resolver.LookupIP(ctx, "ip", netutil.Host(req.URL.Host)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
for _, ip := range ips {
|
||||||
|
switch {
|
||||||
|
case ip.IsUnspecified():
|
||||||
|
return nil, fmt.Errorf("proxy: host %s resolves to unspecified address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
||||||
|
case ip.IsLoopback():
|
||||||
|
return nil, fmt.Errorf("proxy: host %s resolves to loopback address (blocked by DNS?)", netutil.Host(req.URL.Host))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the connection.
|
||||||
|
switch req.URL.Scheme {
|
||||||
|
case "tcp", "http", "ws":
|
||||||
|
// Plain TCP client connection.
|
||||||
|
return d.DialContext(ctx, "tcp", req.URL.Host)
|
||||||
|
case "https", "wss":
|
||||||
|
// Secure TLS client connection.
|
||||||
|
c, err := d.DialContext(ctx, "tcp", req.URL.Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s := tls.Client(c, new(tls.Config))
|
||||||
|
return s, s.Handshake()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("proxy: can't dial %s protocol", req.URL.Scheme)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnFilter is called when a new connection has been accepted by the proxy.
|
||||||
|
type ConnFilter interface {
|
||||||
|
FilterConn(Context) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnFilterFunc is a function that implements the [ConnFilter] interface.
|
||||||
|
type ConnFilterFunc func(Context) (net.Conn, error)
|
||||||
|
|
||||||
|
func (f ConnFilterFunc) FilterConn(ctx Context) (net.Conn, error) {
|
||||||
|
return f(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLS starts a TLS handshake on the accepted connection.
|
||||||
|
func TLS(certs []tls.Certificate) ConnFilter {
|
||||||
|
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||||
|
s := tls.Server(ctx, &tls.Config{
|
||||||
|
Certificates: certs,
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
})
|
||||||
|
if err := s.Handshake(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSInterceptor can generate certificates on-the-fly for clients that use a compatible TLS version.
|
||||||
|
func TLSInterceptor(ca ca.CertificateAuthority) ConnFilter {
|
||||||
|
return ConnFilterFunc(func(ctx Context) (net.Conn, error) {
|
||||||
|
s := tls.Server(ctx, &tls.Config{
|
||||||
|
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
ips := []net.IP{net.ParseIP(netutil.Host(ctx.RemoteAddr().String()))}
|
||||||
|
return ca.GetCertificate(hello.ServerName, []string{hello.ServerName}, ips)
|
||||||
|
},
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
})
|
||||||
|
if err := s.Handshake(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transparent can handle transparent HTTP(S) requests on the port.
|
||||||
|
//
|
||||||
|
// When a new [net.Conn] is made, this function will inspect the initial request packet for a
|
||||||
|
// TLS handshake. If a TLS handshake is detected, the connection will make a feaux HTTP CONNECT
|
||||||
|
// request using TLS, if no handshake is detected, it will make a feaux plain HTTP CONNECT request.
|
||||||
|
func Transparent() ConnFilter {
|
||||||
|
return ConnFilterFunc(func(nctx Context) (net.Conn, error) {
|
||||||
|
ctx, ok := nctx.(*proxyContext)
|
||||||
|
if !ok {
|
||||||
|
return nctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b := new(bytes.Buffer)
|
||||||
|
hello, err := cryptutil.ReadClientHello(io.TeeReader(ctx, b))
|
||||||
|
if err != nil {
|
||||||
|
if _, ok := err.(tls.RecordHeaderError); !ok {
|
||||||
|
ctx.LogEntry().WithError(err).WithField("error_type", fmt.Sprintf("%T", err)).Warn("TLS sniff error")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Not a TLS connection, moving on to regular HTTP request handling...
|
||||||
|
ctx.LogEntry().Debug("HTTP connection on transparent port")
|
||||||
|
ctx.isTransparent = true
|
||||||
|
} else {
|
||||||
|
ctx.LogEntry().WithField("target", hello.ServerName).Debug("TLS connection on transparent port")
|
||||||
|
ctx.isTransparentTLS = true
|
||||||
|
ctx.serverName = hello.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
return netutil.ReaderConn{
|
||||||
|
Conn: ctx.Conn,
|
||||||
|
Reader: io.MultiReader(bytes.NewReader(b.Bytes()), ctx.Conn),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFilter can filter HTTP requests coming to the proxy.
|
||||||
|
type RequestFilter interface {
|
||||||
|
// FilterRequest filters a HTTP request made to the proxy. The current request may be obtained
|
||||||
|
// from [Context.Request]. If a previous RequestFilter provided a HTTP response, it is available
|
||||||
|
// from [Context.Response].
|
||||||
|
//
|
||||||
|
// Modifications to the current request can be made to the Request returned by [Context.Request]
|
||||||
|
// and do not require returning a new [http.Request].
|
||||||
|
//
|
||||||
|
// If the filter returns a non-nil [http.Response], then the [Request] will not be proxied to
|
||||||
|
// any upstream server(s).
|
||||||
|
FilterRequest(Context) (*http.Request, *http.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFilterFunc is a function that implements the [RequestFilter] interface.
|
||||||
|
type RequestFilterFunc func(Context) (*http.Request, *http.Response)
|
||||||
|
|
||||||
|
func (f RequestFilterFunc) FilterRequest(ctx Context) (*http.Request, *http.Response) {
|
||||||
|
return f(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseFilter can filter HTTP responses coming from the proxy.
|
||||||
|
type ResponseFilter interface {
|
||||||
|
// FilterResponse filters a HTTP response coming from the proxy. The current response may be
|
||||||
|
// obtained from [Context.Response].
|
||||||
|
//
|
||||||
|
// Modifications to the current response can be made to the [Response] returned by [Context.Response].
|
||||||
|
FilterResponse(Context) *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseFilterFunc is a function that implements the [ResponseFilter] interface.
|
||||||
|
type ResponseFilterFunc func(Context) *http.Response
|
||||||
|
|
||||||
|
func (f ResponseFilterFunc) FilterResponse(ctx Context) *http.Response {
|
||||||
|
return f(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanRequestProxyHeaders removes all headers added by downstream proxies from the [http.Request].
|
||||||
|
func CleanRequestProxyHeaders() RequestFilter {
|
||||||
|
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||||
|
if req := ctx.Request(); req != nil {
|
||||||
|
cleanProxyHeaders(req.Header)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanRequestProxyHeaders removes all headers for upstream proxies from the [http.Response].
|
||||||
|
func CleanResponseProxyHeaders() ResponseFilter {
|
||||||
|
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||||
|
if res := ctx.Response(); res != nil {
|
||||||
|
cleanProxyHeaders(res.Header)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRequestHeaders adds headers to the [http.Request]. Any existing headers with the same
|
||||||
|
// key will remain intact.
|
||||||
|
func AddRequestHeaders(h http.Header) RequestFilter {
|
||||||
|
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||||
|
if req := ctx.Request(); req != nil {
|
||||||
|
if req.Header == nil {
|
||||||
|
req.Header = make(http.Header)
|
||||||
|
}
|
||||||
|
addHeaders(req.Header, h)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRequestHeaders sets headers to the [http.Request]. Any existing headers with the same
|
||||||
|
// key will be removed.
|
||||||
|
func SetRequestHeaders(h http.Header) RequestFilter {
|
||||||
|
return RequestFilterFunc(func(ctx Context) (*http.Request, *http.Response) {
|
||||||
|
if req := ctx.Request(); req != nil {
|
||||||
|
if req.Header == nil {
|
||||||
|
req.Header = make(http.Header)
|
||||||
|
}
|
||||||
|
setHeaders(req.Header, h)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddResponseHeaders adds headers to the [http.Response]. Any existing headers with the same
|
||||||
|
// key will remain intact.
|
||||||
|
func AddResponseHeaders(h http.Header) ResponseFilter {
|
||||||
|
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||||
|
if res := ctx.Response(); res != nil {
|
||||||
|
if res.Header == nil {
|
||||||
|
res.Header = make(http.Header)
|
||||||
|
}
|
||||||
|
addHeaders(res.Header, h)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseHeaders sets headers to the [http.Response]. Any existing headers with the same
|
||||||
|
// key will be removed.
|
||||||
|
func SetResponseHeaders(h http.Header) ResponseFilter {
|
||||||
|
return ResponseFilterFunc(func(ctx Context) *http.Response {
|
||||||
|
if res := ctx.Response(); res != nil {
|
||||||
|
if res.Header == nil {
|
||||||
|
res.Header = make(http.Header)
|
||||||
|
}
|
||||||
|
setHeaders(res.Header, h)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanProxyHeaders removes all headers commonly used by (reverse) HTTP proxies.
|
||||||
|
func cleanProxyHeaders(h http.Header) {
|
||||||
|
if h == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range []string{
|
||||||
|
HeaderForwarded,
|
||||||
|
HeaderForwardedFor,
|
||||||
|
HeaderForwardedHost,
|
||||||
|
HeaderForwardedPort,
|
||||||
|
HeaderForwardedProto,
|
||||||
|
HeaderRealIP,
|
||||||
|
HeaderVia,
|
||||||
|
} {
|
||||||
|
h.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addHeaders adds to the current existing headers.
|
||||||
|
func addHeaders(dst, src http.Header) {
|
||||||
|
if src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, vv := range src {
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setHeaders replaces all previous values.
|
||||||
|
func setHeaders(dst, src http.Header) {
|
||||||
|
if src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, vv := range src {
|
||||||
|
dst.Del(k)
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -1,324 +0,0 @@
|
|||||||
package match
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"slices"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
"git.maze.io/maze/styx/internal/netutil"
|
|
||||||
"github.com/hashicorp/hcl/v2"
|
|
||||||
"github.com/hashicorp/hcl/v2/gohcl"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Path string `hcl:"path,optional"`
|
|
||||||
Refresh time.Duration `hcl:"refresh,optional"`
|
|
||||||
Domain []*Domain `hcl:"domain,block"`
|
|
||||||
Network []*Network `hcl:"network,block"`
|
|
||||||
Content []*Content `hcl:"content,block"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (config Config) Matchers() (Matchers, error) {
|
|
||||||
all := make(Matchers)
|
|
||||||
if config.Domain != nil {
|
|
||||||
all["domain"] = make(map[string]Matcher)
|
|
||||||
for _, domain := range config.Domain {
|
|
||||||
m, err := domain.Matcher()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err)
|
|
||||||
}
|
|
||||||
all["domain"][domain.Name] = m
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if config.Network != nil {
|
|
||||||
all["network"] = make(map[string]Matcher)
|
|
||||||
for _, network := range config.Network {
|
|
||||||
m, err := network.Matcher(true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err)
|
|
||||||
}
|
|
||||||
all["network"][network.Name] = m
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return all, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Content struct {
|
|
||||||
Name string `hcl:"name,label"`
|
|
||||||
Type string `hcl:"type"`
|
|
||||||
Body hcl.Body `hcl:",remain"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type contentHeader struct {
|
|
||||||
Key string `hcl:"name"`
|
|
||||||
Value string `hcl:"value,optional"`
|
|
||||||
List []string `hcl:"list,optional"`
|
|
||||||
name string
|
|
||||||
keyRe *regexp.Regexp
|
|
||||||
valueRe *regexp.Regexp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m contentHeader) Name() string { return m.name }
|
|
||||||
func (m contentHeader) MatchesResponse(r *http.Response) bool {
|
|
||||||
for k, vv := range r.Header {
|
|
||||||
if m.keyRe.MatchString(k) {
|
|
||||||
for _, v := range vv {
|
|
||||||
if slices.Contains(m.List, v) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if m.valueRe != nil && m.valueRe.MatchString(v) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type contentType struct {
|
|
||||||
List []string `hcl:"list"`
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m contentType) Name() string { return m.name }
|
|
||||||
func (m contentType) MatchesResponse(r *http.Response) bool {
|
|
||||||
return slices.Contains(m.List, r.Header.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
|
|
||||||
type contentSizeLargerThan struct {
|
|
||||||
Size int64 `hcl:"size"`
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m contentSizeLargerThan) Name() string { return m.name }
|
|
||||||
func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool {
|
|
||||||
size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return size >= m.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
type contentStatus struct {
|
|
||||||
Code []int `hcl:"code"`
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m contentStatus) Name() string { return m.name }
|
|
||||||
func (m contentStatus) MatchesResponse(r *http.Response) bool {
|
|
||||||
return slices.Contains(m.Code, r.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (config Content) Matcher() (Response, error) {
|
|
||||||
switch strings.ToLower(config.Type) {
|
|
||||||
case "content", "contenttype", "content-type", "type":
|
|
||||||
var matcher = contentType{name: config.Name}
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return matcher, nil
|
|
||||||
|
|
||||||
case "header":
|
|
||||||
var (
|
|
||||||
matcher = contentHeader{name: config.Name}
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if matcher.Value == "" && len(matcher.List) == 0 {
|
|
||||||
return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name)
|
|
||||||
}
|
|
||||||
if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err)
|
|
||||||
}
|
|
||||||
if matcher.Value != "" {
|
|
||||||
if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return matcher, nil
|
|
||||||
|
|
||||||
case "size":
|
|
||||||
var matcher = contentSizeLargerThan{name: config.Name}
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return matcher, nil
|
|
||||||
|
|
||||||
case "status":
|
|
||||||
var matcher = contentStatus{name: config.Name}
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return matcher, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown content matcher type %q", config.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Domain struct {
|
|
||||||
Name string `hcl:"name,label"`
|
|
||||||
Type string `hcl:"type"`
|
|
||||||
Body hcl.Body `hcl:",remain"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (config Domain) Matcher() (Request, error) {
|
|
||||||
switch config.Type {
|
|
||||||
case "list":
|
|
||||||
var matcher = domainList{Title: config.Name}
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
matcher.list = netutil.NewDomainList(matcher.List...)
|
|
||||||
return matcher, nil
|
|
||||||
|
|
||||||
case "adblock", "dnsmasq", "hosts", "detect", "domains":
|
|
||||||
var matcher = DomainFile{
|
|
||||||
Title: config.Name,
|
|
||||||
Type: config.Type,
|
|
||||||
}
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if matcher.Path == "" && matcher.From == "" {
|
|
||||||
return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name)
|
|
||||||
}
|
|
||||||
if err := matcher.Update(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return matcher, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown domain matcher type %q", config.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type domainList struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
List []string `hcl:"list" json:"list"`
|
|
||||||
list *netutil.DomainTree
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m domainList) Name() string {
|
|
||||||
return m.Title
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m domainList) MatchesRequest(r *http.Request) bool {
|
|
||||||
host := netutil.Host(r.URL.Host)
|
|
||||||
log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List))
|
|
||||||
return m.list.Contains(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
type DomainFile struct {
|
|
||||||
Title string `json:"name"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Path string `hcl:"path,optional" json:"path,omitempty"`
|
|
||||||
From string `hcl:"from,optional" json:"from,omitempty"`
|
|
||||||
Refresh time.Duration `hcl:"refresh,optional" json:"refresh"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m DomainFile) Name() string {
|
|
||||||
return m.Title
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m DomainFile) MatchesRequest(_ *http.Request) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DomainFile) Update() (err error) {
|
|
||||||
var data []byte
|
|
||||||
if m.Path != "" {
|
|
||||||
if data, err = os.ReadFile(m.Path); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
/*
|
|
||||||
var response *http.Response
|
|
||||||
if response, err = http.DefaultClient.Get(m.From); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { _ = response.Body.Close() }()
|
|
||||||
if response.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status)
|
|
||||||
}
|
|
||||||
if data, err = io.ReadAll(response.Body); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
switch m.Type {
|
|
||||||
case "hosts":
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = data
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Network struct {
|
|
||||||
Name string `hcl:"name,label"`
|
|
||||||
Type string `hcl:"type"`
|
|
||||||
Body hcl.Body `hcl:",remain"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (config *Network) Matcher(target bool) (Matcher, error) {
|
|
||||||
switch config.Type {
|
|
||||||
case "list":
|
|
||||||
var (
|
|
||||||
matcher = networkList{Title: config.Name}
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() {
|
|
||||||
return nil, diag
|
|
||||||
}
|
|
||||||
if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &matcher, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown network matcher type %q", config.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type networkList struct {
|
|
||||||
Title string `json:"name"`
|
|
||||||
List []string `hcl:"list" json:"list"`
|
|
||||||
tree *netutil.NetworkTree
|
|
||||||
target bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *networkList) Name() string {
|
|
||||||
return m.Title
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *networkList) MatchesIP(ip net.IP) bool {
|
|
||||||
return m.tree.Contains(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *networkList) MatchesRequest(r *http.Request) bool {
|
|
||||||
var (
|
|
||||||
host string
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if m.target {
|
|
||||||
host, _, err = net.SplitHostPort(r.URL.Host)
|
|
||||||
} else {
|
|
||||||
host, _, err = net.SplitHostPort(r.RemoteAddr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(host)
|
|
||||||
return m.MatchesIP(ip)
|
|
||||||
}
|
|
@@ -1,45 +0,0 @@
|
|||||||
package match
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Matchers map[string]map[string]Matcher
|
|
||||||
|
|
||||||
func (all Matchers) Get(kind, name string) (m Matcher, err error) {
|
|
||||||
if typeMatchers, ok := all[kind]; ok {
|
|
||||||
if m, ok = typeMatchers[name]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("no %s matcher named %q found", kind, name)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("no %s matcher found", kind)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Matcher interface {
|
|
||||||
Name() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Updater interface {
|
|
||||||
Update() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type IP interface {
|
|
||||||
Matcher
|
|
||||||
|
|
||||||
MatchesIP(net.IP) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request interface {
|
|
||||||
Matcher
|
|
||||||
|
|
||||||
MatchesRequest(*http.Request) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response interface {
|
|
||||||
Matcher
|
|
||||||
|
|
||||||
MatchesResponse(*http.Response) bool
|
|
||||||
}
|
|
@@ -1,11 +0,0 @@
|
|||||||
package match
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
func onlyHost(name string) string {
|
|
||||||
host, _, err := net.SplitHostPort(name)
|
|
||||||
if err != nil {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
@@ -1,231 +0,0 @@
|
|||||||
package mitm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/cryptutil"
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultValidity = 24 * time.Hour
|
|
||||||
|
|
||||||
type Authority interface {
|
|
||||||
Certificate() *x509.Certificate
|
|
||||||
TLSConfig(name string) *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
type authority struct {
|
|
||||||
pool *x509.CertPool
|
|
||||||
cert *x509.Certificate
|
|
||||||
key crypto.PrivateKey
|
|
||||||
keyID []byte
|
|
||||||
keyPool chan crypto.PrivateKey
|
|
||||||
cache Cache
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(config *Config) (Authority, error) {
|
|
||||||
cache, err := NewCache(config.Cache)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
caConfig := config.CA
|
|
||||||
if caConfig == nil {
|
|
||||||
caConfig = new(CAConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
days := caConfig.Days
|
|
||||||
if days == 0 {
|
|
||||||
days = DefaultDays
|
|
||||||
}
|
|
||||||
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
|
|
||||||
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AddCert(cert)
|
|
||||||
|
|
||||||
keyConfig := config.Key
|
|
||||||
if keyConfig == nil {
|
|
||||||
keyConfig = &defaultKeyConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
keyPoolSize := defaultKeyConfig.Pool
|
|
||||||
if keyConfig.Pool > 0 {
|
|
||||||
keyPoolSize = keyConfig.Pool
|
|
||||||
}
|
|
||||||
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
|
|
||||||
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
|
|
||||||
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
|
|
||||||
} else {
|
|
||||||
keyPool <- key
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(pool chan<- crypto.PrivateKey) {
|
|
||||||
for {
|
|
||||||
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
|
|
||||||
if err != nil {
|
|
||||||
log.Panic().Err(err).Msg("error generating private key")
|
|
||||||
}
|
|
||||||
pool <- key
|
|
||||||
}
|
|
||||||
}(keyPool)
|
|
||||||
|
|
||||||
return &authority{
|
|
||||||
pool: pool,
|
|
||||||
cert: cert,
|
|
||||||
key: key,
|
|
||||||
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
|
|
||||||
keyPool: keyPool,
|
|
||||||
cache: cache,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ca *authority) log() log.Logger {
|
|
||||||
return log.Console.With().
|
|
||||||
Str("ca", ca.cert.Subject.String()).
|
|
||||||
Logger()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ca *authority) Certificate() *x509.Certificate {
|
|
||||||
return ca.cert
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ca *authority) TLSConfig(name string) *tls.Config {
|
|
||||||
return &tls.Config{
|
|
||||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
log := ca.log()
|
|
||||||
if hello.ServerName != "" {
|
|
||||||
name = strings.ToLower(hello.ServerName)
|
|
||||||
log.Debug().Msg("requesting certificate for server name (SNI)")
|
|
||||||
} else {
|
|
||||||
log.Debug().Msg("requesting certificate for hostname")
|
|
||||||
}
|
|
||||||
if cert, ok := ca.getCached(name); ok {
|
|
||||||
log.Debug().
|
|
||||||
Str("subject", cert.Leaf.Subject.String()).
|
|
||||||
Str("serial", cert.Leaf.SerialNumber.String()).
|
|
||||||
Time("valid", cert.Leaf.NotAfter).
|
|
||||||
Msg("using cached certificate")
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
return ca.issueFor(name)
|
|
||||||
},
|
|
||||||
NextProtos: []string{"http/1.1"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
|
|
||||||
log := ca.log()
|
|
||||||
|
|
||||||
if cert = ca.cache.Certificate(name); cert == nil {
|
|
||||||
if baseDomain(name) != name {
|
|
||||||
cert = ca.cache.Certificate(baseDomain(name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if cert != nil {
|
|
||||||
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
|
|
||||||
DNSName: name,
|
|
||||||
Roots: ca.pool,
|
|
||||||
}); err != nil {
|
|
||||||
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
|
|
||||||
} else {
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
|
|
||||||
var (
|
|
||||||
log = ca.log().With().Str("name", name).Logger()
|
|
||||||
key crypto.PrivateKey
|
|
||||||
)
|
|
||||||
select {
|
|
||||||
case key = <-ca.keyPool:
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
|
|
||||||
}
|
|
||||||
if key == nil {
|
|
||||||
panic("key pool returned nil key")
|
|
||||||
}
|
|
||||||
|
|
||||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
|
||||||
name = strings.Join(part[1:], ".")
|
|
||||||
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{CommonName: name},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
DNSNames: []string{name, "*." + name},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
NotBefore: now.Add(-DefaultValidity),
|
|
||||||
NotAfter: now.Add(+DefaultValidity),
|
|
||||||
}
|
|
||||||
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cert, err := x509.ParseCertificate(der)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
|
|
||||||
out := &tls.Certificate{
|
|
||||||
Certificate: [][]byte{der},
|
|
||||||
Leaf: cert,
|
|
||||||
PrivateKey: key,
|
|
||||||
}
|
|
||||||
//ca.cache[name] = out
|
|
||||||
ca.cache.SaveCertificate(name, out)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsValidCertificate(cert *tls.Certificate) bool {
|
|
||||||
if cert == nil || len(cert.Certificate) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert.Leaf == nil {
|
|
||||||
var err error
|
|
||||||
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
|
|
||||||
}
|
|
@@ -1,233 +0,0 @@
|
|||||||
package mitm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
||||||
"github.com/hashicorp/hcl/v2/gohcl"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/cryptutil"
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Cache interface {
|
|
||||||
Certificate(name string) *tls.Certificate
|
|
||||||
SaveCertificate(name string, cert *tls.Certificate) error
|
|
||||||
RemoveCertificate(name string)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCache(config *CacheConfig) (Cache, error) {
|
|
||||||
if config == nil {
|
|
||||||
return NewCache(&CacheConfig{Type: "memory"})
|
|
||||||
}
|
|
||||||
switch config.Type {
|
|
||||||
case "memory":
|
|
||||||
var cacheConfig = new(MemoryCacheConfig)
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewMemoryCache(cacheConfig.Size), nil
|
|
||||||
case "disk":
|
|
||||||
var cacheConfig = new(DiskCacheConfig)
|
|
||||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type memoryCache struct {
|
|
||||||
cache *expirable.LRU[string, *tls.Certificate]
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMemoryCache(size int) Cache {
|
|
||||||
return memoryCache{
|
|
||||||
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
|
|
||||||
log.Debug().Str("name", key).Msg("certificate evicted from cache")
|
|
||||||
}, time.Hour*24),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
|
|
||||||
var ok bool
|
|
||||||
if cert, ok = c.cache.Get(name); !ok {
|
|
||||||
cert, _ = c.cache.Get(baseDomain(name))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
|
||||||
c.cache.Add(name, cert)
|
|
||||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c memoryCache) RemoveCertificate(name string) {
|
|
||||||
c.cache.Remove(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
type diskCache string
|
|
||||||
|
|
||||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
|
|
||||||
if !filepath.IsAbs(dir) {
|
|
||||||
var err error
|
|
||||||
if dir, err = filepath.Abs(dir); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
info, err := os.Stat(dir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if info.Mode()&os.ModePerm|0o057 != 0 {
|
|
||||||
if err := os.Chmod(dir, 0o750); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if expire > 0 {
|
|
||||||
go expireDiskCache(dir, expire)
|
|
||||||
}
|
|
||||||
|
|
||||||
return diskCache(dir), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func expireDiskCache(root string, expire time.Duration) {
|
|
||||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
|
|
||||||
ticker := time.NewTicker(expire)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
now := <-ticker.C
|
|
||||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
|
|
||||||
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if d.IsDir() {
|
|
||||||
// Remove the directory; this will fail if it's not empty, which is fine.
|
|
||||||
_ = os.Remove(path)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := cryptutil.LoadCertificate(path)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
|
|
||||||
_ = os.Remove(path)
|
|
||||||
return nil
|
|
||||||
} else if cert.NotAfter.Before(now) {
|
|
||||||
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
|
|
||||||
_ = os.Remove(path)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c diskCache) path(name string) string {
|
|
||||||
part := dns.SplitDomainName(strings.ToLower(name))
|
|
||||||
// x,com -> com,x
|
|
||||||
// www,maze,io -> io,maze,www
|
|
||||||
slices.Reverse(part)
|
|
||||||
// com,x -> com,x,x.com
|
|
||||||
// io,maze,www -> io,m,ma,maze,www.maze.io
|
|
||||||
if len(part) > 2 {
|
|
||||||
if len(part[1]) > 1 {
|
|
||||||
part = []string{
|
|
||||||
part[0],
|
|
||||||
part[1][:1],
|
|
||||||
part[1][:2],
|
|
||||||
part[1],
|
|
||||||
name,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
part = []string{
|
|
||||||
part[0],
|
|
||||||
part[1][:1],
|
|
||||||
part[1],
|
|
||||||
name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if len(part) > 1 {
|
|
||||||
if len(part[1]) > 1 {
|
|
||||||
part = []string{
|
|
||||||
part[0],
|
|
||||||
part[1][:1],
|
|
||||||
part[1][:2],
|
|
||||||
name,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
part = []string{
|
|
||||||
part[0],
|
|
||||||
part[1][:1],
|
|
||||||
name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
part[len(part)-1] += ".crt"
|
|
||||||
return filepath.Join(append([]string{string(c)}, part...)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
|
|
||||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
|
|
||||||
return &tls.Certificate{
|
|
||||||
Certificate: [][]byte{cert.Raw},
|
|
||||||
Leaf: cert,
|
|
||||||
PrivateKey: key,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
|
|
||||||
return &tls.Certificate{
|
|
||||||
Certificate: [][]byte{cert.Raw},
|
|
||||||
Leaf: cert,
|
|
||||||
PrivateKey: key,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
|
||||||
dir, name := filepath.Split(c.path(name))
|
|
||||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c diskCache) RemoveCertificate(name string) {
|
|
||||||
path := c.path(name)
|
|
||||||
if err := os.Remove(path); err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error().Err(err).Msg("certificate remove from cache failed")
|
|
||||||
}
|
|
||||||
_ = os.Remove(filepath.Dir(path))
|
|
||||||
log.Debug().Str("name", name).Msg("certificate removed from cache")
|
|
||||||
}
|
|
||||||
|
|
||||||
func baseDomain(name string) string {
|
|
||||||
name = strings.ToLower(name)
|
|
||||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
|
||||||
return strings.Join(part[1:], ".")
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
@@ -1,25 +0,0 @@
|
|||||||
package mitm
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestDiskCachePath(t *testing.T) {
|
|
||||||
cache := diskCache("testdata")
|
|
||||||
tests := []struct {
|
|
||||||
test string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"x.com", "testdata/com/x/x.com.crt"},
|
|
||||||
{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
|
|
||||||
{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
|
|
||||||
{"maze.io", "testdata/io/m/ma/maze.io.crt"},
|
|
||||||
{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
|
|
||||||
{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.test, func(it *testing.T) {
|
|
||||||
if v := cache.path(test.test); v != test.want {
|
|
||||||
it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,89 +0,0 @@
|
|||||||
package mitm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
|
|
||||||
"github.com/hashicorp/hcl/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DefaultCommonName = "Styx Certificate Authority"
|
|
||||||
DefaultDays = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
CA *CAConfig `hcl:"ca,block"`
|
|
||||||
Key *KeyConfig `hcl:"key,block"`
|
|
||||||
Cache *CacheConfig `hcl:"cache,block"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CAConfig struct {
|
|
||||||
Cert string `hcl:"cert"`
|
|
||||||
Key string `hcl:"key,optional"`
|
|
||||||
Days int `hcl:"days,optional"`
|
|
||||||
KeyType string `hcl:"key_type,optional"`
|
|
||||||
Bits int `hcl:"bits,optional"`
|
|
||||||
Name string `hcl:"name,optional"`
|
|
||||||
Country string `hcl:"country,optional"`
|
|
||||||
Organization string `hcl:"organization,optional"`
|
|
||||||
Unit string `hcl:"unit,optional"`
|
|
||||||
Locality string `hcl:"locality,optional"`
|
|
||||||
Province string `hcl:"province,optional"`
|
|
||||||
Address []string `hcl:"address,optional"`
|
|
||||||
PostalCode string `hcl:"postal_code,optional"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (config CAConfig) DN() pkix.Name {
|
|
||||||
var name = pkix.Name{
|
|
||||||
CommonName: config.Name,
|
|
||||||
StreetAddress: config.Address,
|
|
||||||
}
|
|
||||||
if config.Name == "" {
|
|
||||||
name.CommonName = DefaultCommonName
|
|
||||||
}
|
|
||||||
if config.Country != "" {
|
|
||||||
name.Country = append(name.Country, config.Country)
|
|
||||||
}
|
|
||||||
if config.Organization != "" {
|
|
||||||
name.Organization = append(name.Organization, config.Organization)
|
|
||||||
}
|
|
||||||
if config.Unit != "" {
|
|
||||||
name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
|
|
||||||
}
|
|
||||||
if config.Locality != "" {
|
|
||||||
name.Locality = append(name.Locality, config.Locality)
|
|
||||||
}
|
|
||||||
if config.Province != "" {
|
|
||||||
name.Province = append(name.Province, config.Province)
|
|
||||||
}
|
|
||||||
if config.PostalCode != "" {
|
|
||||||
name.PostalCode = append(name.PostalCode, config.PostalCode)
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
type KeyConfig struct {
|
|
||||||
Type string `hcl:"type,optional"`
|
|
||||||
Bits int `hcl:"bits,optional"`
|
|
||||||
Pool int `hcl:"pool,optional"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultKeyConfig = KeyConfig{
|
|
||||||
Type: "rsa",
|
|
||||||
Bits: 2048,
|
|
||||||
Pool: 5,
|
|
||||||
}
|
|
||||||
|
|
||||||
type CacheConfig struct {
|
|
||||||
Type string `hcl:"type"`
|
|
||||||
Body hcl.Body `hcl:",remain"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MemoryCacheConfig struct {
|
|
||||||
Size int `hcl:"size,optional"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DiskCacheConfig struct {
|
|
||||||
Path string `hcl:"path"`
|
|
||||||
Expire float64 `hcl:"expire,optional"`
|
|
||||||
}
|
|
@@ -1,53 +0,0 @@
|
|||||||
package policy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/proxy/match"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Policy contains rules that make up the policy.
|
|
||||||
//
|
|
||||||
// Some policy rules contain nested policies.
|
|
||||||
type Policy struct {
|
|
||||||
Rules []*rawRule `hcl:"on,block" json:"rules"`
|
|
||||||
Permit *bool `hcl:"permit" json:"permit"`
|
|
||||||
Matchers match.Matchers `json:"matchers"` // Matchers for the policy
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Policy) Configure(matchers match.Matchers) (err error) {
|
|
||||||
for _, r := range p.Rules {
|
|
||||||
if err = r.Configure(matchers); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p.Matchers = matchers
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Policy) PermitIntercept(r *http.Request) *bool {
|
|
||||||
if p != nil {
|
|
||||||
for _, rule := range p.Rules {
|
|
||||||
if rule, ok := rule.Rule.(InterceptRule); ok {
|
|
||||||
if permit := rule.PermitIntercept(r); permit != nil {
|
|
||||||
return permit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.Permit
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Policy) PermitRequest(r *http.Request) *bool {
|
|
||||||
if p != nil {
|
|
||||||
for _, rule := range p.Rules {
|
|
||||||
if rule, ok := rule.Rule.(RequestRule); ok {
|
|
||||||
if permit := rule.PermitRequest(r); permit != nil {
|
|
||||||
return permit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.Permit
|
|
||||||
}
|
|
@@ -1,139 +0,0 @@
|
|||||||
package policy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/netutil"
|
|
||||||
"git.maze.io/maze/styx/proxy/match"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testInDomainList struct {
|
|
||||||
t *testing.T
|
|
||||||
list []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (testInDomainList) Name() string { return "testInDomainList" }
|
|
||||||
func (l testInDomainList) MatchesRequest(r *http.Request) bool {
|
|
||||||
for _, domain := range l.list {
|
|
||||||
if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) {
|
|
||||||
l.t.Logf("domain %s contains %s", domain, r.URL.Host)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
l.t.Logf("domain %s does not contain %s", domain, r.URL.Host)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func testInDomain(t *testing.T, domains ...string) match.Matcher {
|
|
||||||
return &testInDomainList{t: t, list: domains}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testInNetworkList struct {
|
|
||||||
t *testing.T
|
|
||||||
list []*net.IPNet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (testInNetworkList) Name() string { return "testInNetworkList" }
|
|
||||||
func (l testInNetworkList) MatchesIP(ip net.IP) bool {
|
|
||||||
for _, ipnet := range l.list {
|
|
||||||
if ipnet.Contains(ip) {
|
|
||||||
l.t.Logf("network %s contains %s", ipnet, ip)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
l.t.Logf("network %s does not contain %s", ipnet, ip)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func testInNetwork(t *testing.T, cidr string) match.Matcher {
|
|
||||||
t.Helper()
|
|
||||||
_, ipnet, err := net.ParseCIDR(cidr)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return testInNetworkList{t: t, list: []*net.IPNet{ipnet}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPolicy(t *testing.T) {
|
|
||||||
var (
|
|
||||||
yes = true
|
|
||||||
nope = false
|
|
||||||
)
|
|
||||||
p := &Policy{
|
|
||||||
Rules: []*rawRule{
|
|
||||||
{
|
|
||||||
Rule: &requestRule{
|
|
||||||
domainOrNetworkRule: domainOrNetworkRule{
|
|
||||||
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
|
|
||||||
isSource: []bool{true},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Rule: &requestRule{
|
|
||||||
domainOrNetworkRule: domainOrNetworkRule{
|
|
||||||
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
|
|
||||||
isSource: []bool{false},
|
|
||||||
},
|
|
||||||
Permit: &yes,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Rule: &requestRule{
|
|
||||||
domainOrNetworkRule: domainOrNetworkRule{
|
|
||||||
matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")},
|
|
||||||
},
|
|
||||||
Permit: &yes,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Rule: &requestRule{
|
|
||||||
domainOrNetworkRule: domainOrNetworkRule{
|
|
||||||
matchers: []match.Matcher{testInDomain(t, "google.com")},
|
|
||||||
},
|
|
||||||
Permit: &nope,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
r := &http.Request{
|
|
||||||
URL: &url.URL{Scheme: "http", Host: "golang.org:80"},
|
|
||||||
RemoteAddr: "127.0.0.1:1234",
|
|
||||||
}
|
|
||||||
if v := p.PermitRequest(r); v != nil {
|
|
||||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Rules[0].Rule.(*requestRule).Permit = &yes
|
|
||||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
|
||||||
t.Errorf("expected request to return %t, %v", yes, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RemoteAddr = "192.168.1.2:3456"
|
|
||||||
if v := p.PermitRequest(r); v != nil {
|
|
||||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
|
||||||
}
|
|
||||||
if v := p.PermitIntercept(r); v != nil {
|
|
||||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.URL.Host = "maze.io"
|
|
||||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
|
||||||
t.Errorf("expected request to return %t, %v", yes, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.URL.Host = "google.com"
|
|
||||||
if v := p.PermitRequest(r); v == nil || *v != nope {
|
|
||||||
t.Errorf("expected request to return %t, %v", nope, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.URL.Host = "localhost:80"
|
|
||||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
|
||||||
t.Errorf("expected request to return %t, %v", yes, v)
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,368 +0,0 @@
|
|||||||
package policy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/netutil"
|
|
||||||
"git.maze.io/maze/styx/proxy/match"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/hashicorp/hcl/v2"
|
|
||||||
"github.com/hashicorp/hcl/v2/gohcl"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rule is a policy rule.
|
|
||||||
type Rule interface {
|
|
||||||
Configure(match.Matchers) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// InterceptRule can make policy rule decisions on intercept requests.
|
|
||||||
type InterceptRule interface {
|
|
||||||
PermitIntercept(r *http.Request) *bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
|
|
||||||
type RequestRule interface {
|
|
||||||
PermitRequest(r *http.Request) *bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type rawRule struct {
|
|
||||||
Type string `hcl:"type,label" json:"type"`
|
|
||||||
Body hcl.Body `hcl:",remain" json:"-"`
|
|
||||||
Rule `json:"rule"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
|
|
||||||
switch r.Type {
|
|
||||||
case "intercept":
|
|
||||||
r.Rule = new(interceptRule)
|
|
||||||
case "request":
|
|
||||||
r.Rule = new(requestRule)
|
|
||||||
case "days":
|
|
||||||
r.Rule = new(daysRule)
|
|
||||||
case "time":
|
|
||||||
r.Rule = new(timeRule)
|
|
||||||
case "all":
|
|
||||||
r.Rule = new(allRule)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("policy: invalid event type %q", r.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.Rule.Configure(matchers)
|
|
||||||
}
|
|
||||||
|
|
||||||
type allRule struct {
|
|
||||||
Rules []*rawRule `hcl:"on,block"`
|
|
||||||
Permit *bool `hcl:"permit"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *allRule) Configure(matchers match.Matchers) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type domainOrNetworkRule struct {
|
|
||||||
matchers []match.Matcher
|
|
||||||
isSource []bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
|
|
||||||
var m match.Matcher
|
|
||||||
for _, domain := range domains {
|
|
||||||
if m, err = matchers.Get("domain", domain); err != nil {
|
|
||||||
return fmt.Errorf("%s: unknown domain %q", kind, domain)
|
|
||||||
}
|
|
||||||
r.matchers = append(r.matchers, m)
|
|
||||||
r.isSource = append(r.isSource, false)
|
|
||||||
}
|
|
||||||
for _, network := range sources {
|
|
||||||
if m, err = matchers.Get("network", network); err != nil {
|
|
||||||
return fmt.Errorf("%s: unknown source network %q", kind, network)
|
|
||||||
}
|
|
||||||
r.matchers = append(r.matchers, m)
|
|
||||||
r.isSource = append(r.isSource, true)
|
|
||||||
}
|
|
||||||
for _, network := range targets {
|
|
||||||
if m, err = matchers.Get("network", network); err != nil {
|
|
||||||
return fmt.Errorf("%s: unknown target network %q", kind, network)
|
|
||||||
}
|
|
||||||
r.matchers = append(r.matchers, m)
|
|
||||||
r.isSource = append(r.isSource, false)
|
|
||||||
}
|
|
||||||
if len(r.matchers) == 0 {
|
|
||||||
return fmt.Errorf("%s: missing any of domain, source, target", kind)
|
|
||||||
}
|
|
||||||
if id != nil {
|
|
||||||
*id = uuid.NewString()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
|
|
||||||
for i, m := range r.matchers {
|
|
||||||
if m, ok := m.(match.Request); ok {
|
|
||||||
if m.MatchesRequest(q) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if m, ok := m.(match.IP); ok {
|
|
||||||
if r.isSource[i] {
|
|
||||||
if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var (
|
|
||||||
host = netutil.Host(q.URL.Host)
|
|
||||||
ips []net.IP
|
|
||||||
)
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
ips = append(ips, ip)
|
|
||||||
} else {
|
|
||||||
ips, _ = net.LookupIP(host)
|
|
||||||
}
|
|
||||||
for _, ip := range ips {
|
|
||||||
if m.MatchesIP(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type interceptRule struct {
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
|
|
||||||
Source []string `hcl:"source,optional" json:"source,omitempty"`
|
|
||||||
Target []string `hcl:"target,optional" json:"target,omitempty"`
|
|
||||||
Permit *bool `hcl:"permit" json:"permit"`
|
|
||||||
domainOrNetworkRule `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
|
|
||||||
return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
|
|
||||||
if r.matchesRequest(q) {
|
|
||||||
return r.Permit
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type requestRule struct {
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
|
|
||||||
Source []string `hcl:"source,optional" json:"source,omitempty"`
|
|
||||||
Target []string `hcl:"target,optional" json:"target,omitempty"`
|
|
||||||
Permit *bool `hcl:"permit" json:"permit"`
|
|
||||||
domainOrNetworkRule `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
|
|
||||||
return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *requestRule) PermitRequest(q *http.Request) *bool {
|
|
||||||
if r.matchesRequest(q) {
|
|
||||||
return r.Permit
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type timeRule struct {
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Time []string `hcl:"time" json:"time"`
|
|
||||||
Permit *bool `hcl:"permit" json:"permit"`
|
|
||||||
Body hcl.Body `hcl:",remain" json:"-"`
|
|
||||||
Rules *Policy `json:"rules"`
|
|
||||||
Start Time `json:"start"`
|
|
||||||
End Time `json:"end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *timeRule) isActive() bool {
|
|
||||||
if r == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
now := Now()
|
|
||||||
if r.Start.After(r.End) { // ie: 18:00-06:00
|
|
||||||
return now.After(r.Start) || now.Before(r.End)
|
|
||||||
}
|
|
||||||
return now.After(r.Start) && now.Before(r.End)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
|
|
||||||
if len(r.Time) != 2 {
|
|
||||||
return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
|
|
||||||
}
|
|
||||||
if r.Start, err = ParseTime(r.Time[0]); err != nil {
|
|
||||||
return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
|
|
||||||
}
|
|
||||||
if r.End, err = ParseTime(r.Time[1]); err != nil {
|
|
||||||
return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Rules = new(Policy)
|
|
||||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
|
|
||||||
return diag
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = r.Rules.Configure(matchers); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.Rules.Matchers = nil
|
|
||||||
|
|
||||||
if r.ID == "" {
|
|
||||||
r.ID = uuid.NewString()
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
|
|
||||||
if !r.isActive() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return r.Rules.PermitIntercept(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *timeRule) PermitRequest(q *http.Request) *bool {
|
|
||||||
if !r.isActive() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return r.Rules.PermitRequest(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
type daysRule struct {
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Days string `hcl:"days" json:"days"`
|
|
||||||
Permit *bool `hcl:"permit" json:"permit"`
|
|
||||||
Body hcl.Body `hcl:",remain" json:"-"`
|
|
||||||
Rules *Policy `json:"rules"`
|
|
||||||
cond []onCond
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *daysRule) isActive() bool {
|
|
||||||
if r == nil || len(r.cond) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for _, cond := range r.cond {
|
|
||||||
if cond(now) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
|
|
||||||
if r.cond, err = parseOnCond(r.Days); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Rules = new(Policy)
|
|
||||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
|
|
||||||
return diag
|
|
||||||
}
|
|
||||||
if err = r.Rules.Configure(matchers); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.Rules.Matchers = nil
|
|
||||||
|
|
||||||
if r.ID == "" {
|
|
||||||
r.ID = uuid.NewString()
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
|
|
||||||
if !r.isActive() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return r.Rules.PermitIntercept(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *daysRule) PermitRequest(q *http.Request) *bool {
|
|
||||||
if !r.isActive() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return r.Rules.PermitRequest(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
type onCond func(time.Time) bool
|
|
||||||
|
|
||||||
var weekdays = map[string]time.Weekday{
|
|
||||||
"sun": time.Sunday,
|
|
||||||
"mon": time.Monday,
|
|
||||||
"tue": time.Tuesday,
|
|
||||||
"wed": time.Wednesday,
|
|
||||||
"thu": time.Thursday,
|
|
||||||
"fri": time.Friday,
|
|
||||||
"sat": time.Saturday,
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseOnCond(when string) (conds []onCond, err error) {
|
|
||||||
for _, spec := range strings.Split(when, ",") {
|
|
||||||
spec = strings.ToLower(strings.TrimSpace(spec))
|
|
||||||
if d, ok := weekdays[spec]; ok {
|
|
||||||
conds = append(conds, onWeekday(d))
|
|
||||||
} else if spec == "weekend" || spec == "weekends" {
|
|
||||||
conds = append(conds, onWeekend)
|
|
||||||
} else if spec == "workday" || spec == "workdays" {
|
|
||||||
conds = append(conds, onWorkday)
|
|
||||||
} else if strings.ContainsRune(spec, '-') {
|
|
||||||
var (
|
|
||||||
part = strings.SplitN(spec, "-", 2)
|
|
||||||
from, upto time.Weekday
|
|
||||||
ok bool
|
|
||||||
)
|
|
||||||
if from, ok = weekdays[part[0]]; !ok {
|
|
||||||
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
|
|
||||||
}
|
|
||||||
if upto, ok = weekdays[part[1]]; !ok {
|
|
||||||
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
|
|
||||||
}
|
|
||||||
if from < upto {
|
|
||||||
for d := from; d < upto; d++ {
|
|
||||||
conds = append(conds, onWeekday(d))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for d := time.Sunday; d < from; d++ {
|
|
||||||
conds = append(conds, onWeekday(d))
|
|
||||||
}
|
|
||||||
for d := upto; d <= time.Saturday; d++ {
|
|
||||||
conds = append(conds, onWeekday(d))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("on %q: invalid condition", spec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func onWeekday(weekday time.Weekday) onCond {
|
|
||||||
return func(t time.Time) bool {
|
|
||||||
return t.Weekday() == weekday
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func onWeekend(t time.Time) bool {
|
|
||||||
d := t.Weekday()
|
|
||||||
return d == time.Saturday || d == time.Sunday
|
|
||||||
}
|
|
||||||
|
|
||||||
func onWorkday(t time.Time) bool {
|
|
||||||
d := t.Weekday()
|
|
||||||
return !(d == time.Saturday || d == time.Sunday)
|
|
||||||
}
|
|
@@ -1,53 +0,0 @@
|
|||||||
package policy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Time struct {
|
|
||||||
Hour int
|
|
||||||
Minute int
|
|
||||||
Second int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Time) Eq(other Time) bool {
|
|
||||||
return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Time) After(other Time) bool {
|
|
||||||
return t.Seconds() > other.Seconds()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Time) Before(other Time) bool {
|
|
||||||
return t.Seconds() < other.Seconds()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Time) Seconds() int {
|
|
||||||
return t.Hour*3600 + t.Minute*60 + t.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Time) MarshalJSON() ([]byte, error) {
|
|
||||||
return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeFormats = []string{
|
|
||||||
time.TimeOnly,
|
|
||||||
"15:04",
|
|
||||||
time.Kitchen,
|
|
||||||
}
|
|
||||||
|
|
||||||
func Now() Time {
|
|
||||||
now := time.Now()
|
|
||||||
return Time{now.Hour(), now.Minute(), now.Second()}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseTime(s string) (t Time, err error) {
|
|
||||||
var tt time.Time
|
|
||||||
for _, layout := range timeFormats {
|
|
||||||
if tt, err = time.Parse(layout, s); err == nil {
|
|
||||||
return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Time{}, fmt.Errorf("time: invalid time %q", s)
|
|
||||||
}
|
|
999
proxy/proxy.go
999
proxy/proxy.go
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
|||||||
// Package resolver implements a caching DNS resolver
|
|
||||||
package resolver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"math/rand/v2"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/netutil"
|
|
||||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DefaultSize = 1024
|
|
||||||
DefaultTTL = 5 * time.Minute
|
|
||||||
DefaultTimeout = 10 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// DefaultConfig are the defaults for the Default resolver.
|
|
||||||
DefaultConfig = Config{
|
|
||||||
Size: DefaultSize,
|
|
||||||
TTL: DefaultTTL.Seconds(),
|
|
||||||
Timeout: DefaultTimeout.Seconds(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default resolver.
|
|
||||||
Default = New(DefaultConfig)
|
|
||||||
)
|
|
||||||
|
|
||||||
type Resolver interface {
|
|
||||||
// Lookup returns resolved IPs for given hostname/ips.
|
|
||||||
Lookup(context.Context, string) ([]string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type netResolver struct {
|
|
||||||
resolver *net.Resolver
|
|
||||||
timeout time.Duration
|
|
||||||
noIPv6 bool
|
|
||||||
cache *expirable.LRU[string, []string]
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
// Size is our cache size in number of entries.
|
|
||||||
Size int `hcl:"size,optional"`
|
|
||||||
|
|
||||||
// TTL is the cache time to live in seconds.
|
|
||||||
TTL float64 `hcl:"ttl,optional"`
|
|
||||||
|
|
||||||
// Timeout is the cache timeout in seconds.
|
|
||||||
Timeout float64 `hcl:"timeout,optional"`
|
|
||||||
|
|
||||||
// Server are alternative DNS servers.
|
|
||||||
Server []string `hcl:"server,optional"`
|
|
||||||
|
|
||||||
// NoIPv6 disables IPv6 DNS resolution.
|
|
||||||
NoIPv6 bool `hcl:"noipv6,optional"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(config Config) Resolver {
|
|
||||||
var (
|
|
||||||
size = config.Size
|
|
||||||
ttl = time.Duration(float64(time.Second) * config.TTL)
|
|
||||||
timeout = time.Duration(float64(time.Second) * config.Timeout)
|
|
||||||
)
|
|
||||||
if size <= 0 {
|
|
||||||
size = DefaultSize
|
|
||||||
}
|
|
||||||
if ttl <= 0 {
|
|
||||||
ttl = DefaultTTL
|
|
||||||
}
|
|
||||||
if timeout <= 0 {
|
|
||||||
timeout = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var resolver = new(net.Resolver)
|
|
||||||
if len(config.Server) > 0 {
|
|
||||||
var dialer net.Dialer
|
|
||||||
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
|
|
||||||
return dialer.DialContext(ctx, network, server)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &netResolver{
|
|
||||||
resolver: resolver,
|
|
||||||
timeout: timeout,
|
|
||||||
noIPv6: config.NoIPv6,
|
|
||||||
cache: expirable.NewLRU[string, []string](size, nil, ttl),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
|
|
||||||
host = strings.ToLower(strings.TrimSpace(host))
|
|
||||||
if hosts, ok := r.cache.Get(host); ok {
|
|
||||||
rand.Shuffle(len(hosts), func(i, j int) {
|
|
||||||
hosts[i], hosts[j] = hosts[j], hosts[i]
|
|
||||||
})
|
|
||||||
return hosts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts, err := r.lookup(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r.cache.Add(host, hosts)
|
|
||||||
return hosts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
|
|
||||||
if r.timeout > 0 {
|
|
||||||
var cancel func()
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, r.timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
if net.ParseIP(host) == nil {
|
|
||||||
addrs, err := r.resolver.LookupHost(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r.noIPv6 {
|
|
||||||
var addrs4 []string
|
|
||||||
for _, addr := range addrs {
|
|
||||||
if net.ParseIP(addr).To4() != nil {
|
|
||||||
addrs4 = append(addrs4, addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return addrs4, nil
|
|
||||||
}
|
|
||||||
return addrs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs, err := r.resolver.LookupIPAddr(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts := make([]string, len(addrs))
|
|
||||||
for i, addr := range addrs {
|
|
||||||
if !r.noIPv6 || addr.IP.To4() != nil {
|
|
||||||
hosts[i] = addr.IP.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hosts, nil
|
|
||||||
}
|
|
@@ -2,77 +2,58 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewResponse(code int, body io.Reader, request *http.Request) *http.Response {
|
type sizer interface {
|
||||||
if body == nil {
|
|
||||||
body = new(bytes.Buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, ok := body.(io.ReadCloser)
|
|
||||||
if !ok {
|
|
||||||
rc = io.NopCloser(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
response := &http.Response{
|
|
||||||
Status: strconv.Itoa(code) + " " + http.StatusText(code),
|
|
||||||
StatusCode: code,
|
|
||||||
Proto: "HTTP/1.1",
|
|
||||||
ProtoMajor: 1,
|
|
||||||
ProtoMinor: 1,
|
|
||||||
Header: make(http.Header),
|
|
||||||
Body: rc,
|
|
||||||
Request: request,
|
|
||||||
}
|
|
||||||
|
|
||||||
if request != nil {
|
|
||||||
response.Close = request.Close
|
|
||||||
response.Proto = request.Proto
|
|
||||||
response.ProtoMajor = request.ProtoMajor
|
|
||||||
response.ProtoMinor = request.ProtoMinor
|
|
||||||
}
|
|
||||||
|
|
||||||
return response
|
|
||||||
}
|
|
||||||
|
|
||||||
type withLen interface {
|
|
||||||
Len() int
|
|
||||||
}
|
|
||||||
|
|
||||||
type withSize interface {
|
|
||||||
Size() int64
|
Size() int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response {
|
// NewResponse prepares a net [http.Response], based on the status code, optional body and
|
||||||
response := NewResponse(code, body, request)
|
// optional [http.Request].
|
||||||
response.Header.Set(HeaderContentType, "application/json")
|
func NewResponse(code int, body io.ReadCloser, req *http.Request) *http.Response {
|
||||||
if s, ok := body.(withLen); ok {
|
res := &http.Response{
|
||||||
response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len()))
|
StatusCode: code,
|
||||||
} else if s, ok := body.(withSize); ok {
|
Header: make(http.Header),
|
||||||
response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10))
|
Proto: "HTTP/1.1",
|
||||||
} else {
|
ProtoMajor: 1,
|
||||||
log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size")
|
ProtoMinor: 1,
|
||||||
}
|
|
||||||
response.Close = true
|
|
||||||
return response
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ErrorResponse(request *http.Request, err error) *http.Response {
|
if text := http.StatusText(code); text != "" {
|
||||||
response := NewResponse(http.StatusBadGateway, nil, request)
|
res.Status = strconv.Itoa(code) + " " + text
|
||||||
|
} else {
|
||||||
|
res.Status = strconv.Itoa(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if body == nil && code >= 400 {
|
||||||
|
body = io.NopCloser(bytes.NewBufferString(http.StatusText(code)))
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Body = body
|
||||||
|
|
||||||
|
if s, ok := body.(sizer); ok {
|
||||||
|
res.ContentLength = s.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
if req != nil {
|
||||||
|
res.Close = req.Close
|
||||||
|
res.Proto = req.Proto
|
||||||
|
res.ProtoMajor = req.ProtoMajor
|
||||||
|
res.ProtoMinor = req.ProtoMinor
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrorResponse(err error, req *http.Request) *http.Response {
|
||||||
switch {
|
switch {
|
||||||
case os.IsNotExist(err):
|
case os.IsTimeout(err):
|
||||||
response.StatusCode = http.StatusNotFound
|
return NewResponse(http.StatusGatewayTimeout, nil, req)
|
||||||
case os.IsPermission(err):
|
default:
|
||||||
response.StatusCode = http.StatusForbidden
|
return NewResponse(http.StatusBadGateway, nil, req)
|
||||||
}
|
}
|
||||||
response.Status = http.StatusText(response.StatusCode)
|
|
||||||
response.Close = true
|
|
||||||
return response
|
|
||||||
}
|
}
|
||||||
|
151
proxy/session.go
151
proxy/session.go
@@ -1,151 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
var seed = rand.NewSource(time.Now().UnixNano())
|
|
||||||
|
|
||||||
type Context struct {
|
|
||||||
id int64
|
|
||||||
conn *wrappedConn
|
|
||||||
rw *bufio.ReadWriter
|
|
||||||
parent *Session
|
|
||||||
data map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context {
|
|
||||||
if wrapped, ok := conn.(*wrappedConn); ok {
|
|
||||||
conn = wrapped.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := &Context{
|
|
||||||
id: seed.Int63(),
|
|
||||||
conn: &wrappedConn{Conn: conn},
|
|
||||||
rw: rw,
|
|
||||||
parent: parent,
|
|
||||||
data: make(map[string]any),
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) log() log.Logger {
|
|
||||||
return log.Console.With().
|
|
||||||
Str("context", ctx.ID()).
|
|
||||||
Str("addr", ctx.RemoteAddr().String()).
|
|
||||||
Logger()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) ID() string {
|
|
||||||
var b [8]byte
|
|
||||||
binary.BigEndian.PutUint64(b[:], uint64(ctx.id))
|
|
||||||
if ctx.parent != nil {
|
|
||||||
return ctx.parent.ID() + "-" + hex.EncodeToString(b[:])
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(b[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) IsTLS() bool {
|
|
||||||
_, ok := ctx.conn.Conn.(*tls.Conn)
|
|
||||||
return ok && ctx.parent != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) RemoteAddr() net.Addr {
|
|
||||||
if ctx.parent != nil {
|
|
||||||
return ctx.parent.ctx.RemoteAddr()
|
|
||||||
}
|
|
||||||
return ctx.conn.RemoteAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) SetDeadline(t time.Time) error {
|
|
||||||
if ctx.parent != nil {
|
|
||||||
return ctx.parent.ctx.SetDeadline(t)
|
|
||||||
}
|
|
||||||
return ctx.conn.SetDeadline(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) Set(key string, value any) {
|
|
||||||
ctx.data[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) Get(key string) (value any, ok bool) {
|
|
||||||
value, ok = ctx.data[key]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) Flush() error {
|
|
||||||
return ctx.rw.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *Context) Write(p []byte) (n int, err error) {
|
|
||||||
if n, err = ctx.rw.Write(p); n > 0 {
|
|
||||||
atomic.AddInt64(&ctx.conn.bytes, int64(n))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
id int64
|
|
||||||
ctx *Context
|
|
||||||
request *http.Request
|
|
||||||
response *http.Response
|
|
||||||
data map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSession(ctx *Context, request *http.Request) *Session {
|
|
||||||
return &Session{
|
|
||||||
id: seed.Int63(),
|
|
||||||
ctx: ctx,
|
|
||||||
request: request,
|
|
||||||
data: make(map[string]any),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ses *Session) log() log.Logger {
|
|
||||||
return log.Console.With().
|
|
||||||
Str("context", ses.ctx.ID()).
|
|
||||||
Str("session", ses.ID()).
|
|
||||||
Str("addr", ses.ctx.RemoteAddr().String()).
|
|
||||||
Logger()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ses *Session) ID() string {
|
|
||||||
var b [8]byte
|
|
||||||
binary.BigEndian.PutUint64(b[:], uint64(ses.id))
|
|
||||||
return hex.EncodeToString(b[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ses *Session) Context() *Context {
|
|
||||||
return ses.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ses *Session) Request() *http.Request {
|
|
||||||
return ses.request
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ses *Session) Response() *http.Response {
|
|
||||||
return ses.response
|
|
||||||
}
|
|
||||||
|
|
||||||
type wrappedConn struct {
|
|
||||||
net.Conn
|
|
||||||
bytes int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *wrappedConn) Write(p []byte) (n int, err error) {
|
|
||||||
if n, err = c.Conn.Write(p); n > 0 {
|
|
||||||
atomic.AddInt64(&c.bytes, int64(n))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
19
proxy/stats.go
Normal file
19
proxy/stats.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"expvar"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.maze.io/maze/styx/db/stats"
|
||||||
|
)
|
||||||
|
|
||||||
|
func countStatus(code int) {
|
||||||
|
k := "http:status:" + strconv.Itoa(code)
|
||||||
|
v := expvar.Get(k)
|
||||||
|
if v == nil {
|
||||||
|
//v = stats.NewCounter("120s1s", "15m10s", "1h1m", "4w1d", "1y4w")
|
||||||
|
v = stats.NewCounter(k, stats.Minutely, stats.Hourly, stats.Daily, stats.Yearly)
|
||||||
|
expvar.Publish(k, v)
|
||||||
|
}
|
||||||
|
v.(stats.Metric).Add(1)
|
||||||
|
}
|
@@ -1,225 +0,0 @@
|
|||||||
package stats
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/user"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.maze.io/maze/styx/internal/log"
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Stats struct {
|
|
||||||
db *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() (*Stats, error) {
|
|
||||||
u, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
path := filepath.Join(u.HomeDir, ".styx", "stats.db")
|
|
||||||
if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range []string{
|
|
||||||
createLog,
|
|
||||||
createDomainStat,
|
|
||||||
createStatusStat,
|
|
||||||
} {
|
|
||||||
if _, err = db.Exec(table); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Stats{db: db}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stats) AddLog(entry *Log) error {
|
|
||||||
var (
|
|
||||||
request []byte
|
|
||||||
response []byte
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if request, err = json.Marshal(entry.Request); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if response, err = json.Marshal(entry.Response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) {
|
|
||||||
if limit == 0 {
|
|
||||||
limit = 50
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var logs []*Log
|
|
||||||
for rows.Next() {
|
|
||||||
var entry = new(Log)
|
|
||||||
if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logs = append(logs, entry)
|
|
||||||
}
|
|
||||||
|
|
||||||
return logs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Status struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Count int `json:"count"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeZero time.Time
|
|
||||||
|
|
||||||
func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) {
|
|
||||||
if since.Equal(timeZero) {
|
|
||||||
since = time.Now().Add(-24 * time.Hour)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var stats []*Status
|
|
||||||
for rows.Next() {
|
|
||||||
var entry = new(Status)
|
|
||||||
if err = rows.Scan(&entry.Code, &entry.Count); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stats = append(stats, entry)
|
|
||||||
}
|
|
||||||
return stats, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const createLog = `CREATE TABLE IF NOT EXISTS styx_log (
|
|
||||||
id INT PRIMARY KEY,
|
|
||||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
client_ip TEXT NOT NULL,
|
|
||||||
request JSONB NOT NULL,
|
|
||||||
response JSONB NOT NULL
|
|
||||||
);`
|
|
||||||
|
|
||||||
type Log struct {
|
|
||||||
Time time.Time `json:"time"`
|
|
||||||
ClientIP string `json:"client_ip"`
|
|
||||||
Request *Request `json:"request"`
|
|
||||||
Response *Response `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
Host string `json:"host"`
|
|
||||||
Method string `json:"method"`
|
|
||||||
Proto string `json:"proto"`
|
|
||||||
Header http.Header `json:"header"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Request) Scan(value any) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(v), r)
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(v, r)
|
|
||||||
default:
|
|
||||||
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Request) Value() (driver.Value, error) {
|
|
||||||
b, err := json.Marshal(r)
|
|
||||||
return string(b), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func FromRequest(r *http.Request) *Request {
|
|
||||||
return &Request{
|
|
||||||
URL: r.URL.String(),
|
|
||||||
Host: r.Host,
|
|
||||||
Method: r.Method,
|
|
||||||
Proto: r.Proto,
|
|
||||||
Header: r.Header,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
Size int64 `json:"size"`
|
|
||||||
Header http.Header `json:"header"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Response) Scan(value any) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
return json.Unmarshal([]byte(v), r)
|
|
||||||
case []byte:
|
|
||||||
return json.Unmarshal(v, r)
|
|
||||||
default:
|
|
||||||
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Response) Value() (driver.Value, error) {
|
|
||||||
b, err := json.Marshal(r)
|
|
||||||
return string(b), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Response) SetSize(size int64) *Response {
|
|
||||||
r.Size = size
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func FromResponse(r *http.Response) *Response {
|
|
||||||
return &Response{
|
|
||||||
Status: r.StatusCode,
|
|
||||||
Header: r.Header,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status (
|
|
||||||
id INT PRIMARY KEY,
|
|
||||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
status INT NOT NULL
|
|
||||||
);`
|
|
||||||
|
|
||||||
const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain (
|
|
||||||
id INT PRIMARY KEY,
|
|
||||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
domain TEXT NOT NULL
|
|
||||||
);`
|
|
@@ -1,16 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// connReader is a net.Conn with a separate reader.
|
|
||||||
type connReader struct {
|
|
||||||
net.Conn
|
|
||||||
io.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c connReader) Read(p []byte) (int, error) {
|
|
||||||
return c.Reader.Read(p)
|
|
||||||
}
|
|
Reference in New Issue
Block a user