Initial import
This commit is contained in:
145
proxy/admin.go
Normal file
145
proxy/admin.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
type Admin struct {
|
||||
*Proxy
|
||||
}
|
||||
|
||||
func NewAdmin(proxy *Proxy) *Admin {
|
||||
a := &Admin{
|
||||
Proxy: proxy,
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *Admin) handleRequest(ses *Session) error {
|
||||
var (
|
||||
logger = ses.log()
|
||||
err error
|
||||
)
|
||||
switch ses.request.URL.Path {
|
||||
case "/ca.crt":
|
||||
err = a.handleCACert(ses)
|
||||
case "/api/v1/policy":
|
||||
err = a.apiPolicy(ses)
|
||||
case "/api/v1/policy/matcher":
|
||||
err = a.apiPolicyMatcher(ses)
|
||||
case "/api/v1/stats/log":
|
||||
err = a.apiStatsLog(ses)
|
||||
case "/api/v1/stats/status":
|
||||
err = a.apiStatsStatus(ses)
|
||||
default:
|
||||
if strings.HasPrefix(ses.request.URL.Path, "/api") {
|
||||
err = errors.New("invalid endpoint")
|
||||
} else {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("admin error")
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Admin) handleCACert(ses *Session) error {
|
||||
b := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: a.authority.Certificate().Raw,
|
||||
})
|
||||
|
||||
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
|
||||
ses.response.Close = true
|
||||
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
|
||||
ses.response.ContentLength = int64(len(b))
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (a *Admin) apiPolicy(ses *Session) error {
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
e = json.NewEncoder(b)
|
||||
)
|
||||
e.SetIndent("", " ")
|
||||
if err := e.Encode(a.config.Policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (a *Admin) apiPolicyMatcher(ses *Session) error {
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
e = json.NewEncoder(b)
|
||||
)
|
||||
e.SetIndent("", " ")
|
||||
if err := e.Encode(a.config.Policy.Matchers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (a *Admin) apiResponse(ses *Session, v any, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
b = new(bytes.Buffer)
|
||||
e = json.NewEncoder(b)
|
||||
)
|
||||
e.SetIndent("", " ")
|
||||
if err := e.Encode(v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ses.response = NewJSONResponse(http.StatusOK, b, ses.request)
|
||||
defer log.OnCloseError(log.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return a.writeResponse(ses)
|
||||
|
||||
}
|
||||
|
||||
func (a *Admin) apiStatsLog(ses *Session) error {
|
||||
var (
|
||||
query = ses.request.URL.Query()
|
||||
offset, _ = strconv.Atoi(query.Get("offset"))
|
||||
limit, _ = strconv.Atoi(query.Get("limit"))
|
||||
)
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
s, err := a.stats.QueryLog(offset, limit)
|
||||
return a.apiResponse(ses, s, err)
|
||||
}
|
||||
|
||||
func (a *Admin) apiStatsStatus(ses *Session) error {
|
||||
s, err := a.stats.QueryStatus(time.Time{})
|
||||
return a.apiResponse(ses, s, err)
|
||||
}
|
8
proxy/cache/config.go
vendored
Normal file
8
proxy/cache/config.go
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
package cache
|
||||
|
||||
import "github.com/hashicorp/hcl/v2"
|
||||
|
||||
type Config struct {
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
88
proxy/config.go
Normal file
88
proxy/config.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/proxy/policy"
|
||||
"git.maze.io/maze/styx/proxy/resolver"
|
||||
)
|
||||
|
||||
type ConnectHandler interface {
|
||||
HandleConnect(session *Session, network, address string) net.Conn
|
||||
}
|
||||
|
||||
// ConnectHandlerFunc is called when the proxy receives a new HTTP CONNECT request.
|
||||
type ConnectHandlerFunc func(session *Session, network, address string) net.Conn
|
||||
|
||||
func (f ConnectHandlerFunc) HandleConnect(session *Session, network, address string) net.Conn {
|
||||
return f(session, network, address)
|
||||
}
|
||||
|
||||
type RequestHandler interface {
|
||||
HandleRequest(session *Session) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
// RequestHandlerFunc is called when the proxy receives a new request.
|
||||
type RequestHandlerFunc func(session *Session) (*http.Request, *http.Response)
|
||||
|
||||
func (f RequestHandlerFunc) HandleRequest(session *Session) (*http.Request, *http.Response) {
|
||||
return f(session)
|
||||
}
|
||||
|
||||
type ResponseHandler interface {
|
||||
HandleResponse(session *Session) *http.Response
|
||||
}
|
||||
|
||||
// ResponseHandler is called when the proxy receives a response.
|
||||
type ResponseHandlerFunc func(session *Session) *http.Response
|
||||
|
||||
func (f ResponseHandlerFunc) HandleResponse(session *Session) *http.Response {
|
||||
return f(session)
|
||||
}
|
||||
|
||||
type ErrorHandler interface {
|
||||
HandleError(session *Session, err error)
|
||||
}
|
||||
|
||||
type ErrorHandlerFunc func(session *Session, err error)
|
||||
|
||||
func (f ErrorHandlerFunc) HandleError(session *Session, err error) {
|
||||
f(session, err)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Listen address.
|
||||
Listen string `hcl:"listen,optional"`
|
||||
|
||||
// Bind address for outgoing connections.
|
||||
Bind string `hcl:"bind,optional"`
|
||||
|
||||
// Interface for outgoing connections.
|
||||
Interface string `hcl:"interface,optional"`
|
||||
|
||||
// Upstream proxy servers.
|
||||
Upstream []string `hcl:"upstream,optional"`
|
||||
|
||||
// DialTimeout for establishing new connections.
|
||||
DialTimeout time.Duration `hcl:"dial_timeout,optional"`
|
||||
|
||||
// Policy for the proxy.
|
||||
Policy *policy.Policy `hcl:"policy,block"`
|
||||
|
||||
// Resolver for the proxy.
|
||||
Resolver resolver.Resolver
|
||||
|
||||
ConnectHandler ConnectHandler
|
||||
RequestHandler RequestHandler
|
||||
ResponseHandler ResponseHandler
|
||||
ErrorHandler ErrorHandler
|
||||
}
|
||||
|
||||
var (
|
||||
_ ConnectHandler = (ConnectHandlerFunc)(nil)
|
||||
_ RequestHandler = (RequestHandlerFunc)(nil)
|
||||
_ ResponseHandler = (ResponseHandlerFunc)(nil)
|
||||
_ ErrorHandler = (ErrorHandlerFunc)(nil)
|
||||
)
|
324
proxy/match/config.go
Normal file
324
proxy/match/config.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package match
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Path string `hcl:"path,optional"`
|
||||
Refresh time.Duration `hcl:"refresh,optional"`
|
||||
Domain []*Domain `hcl:"domain,block"`
|
||||
Network []*Network `hcl:"network,block"`
|
||||
Content []*Content `hcl:"content,block"`
|
||||
}
|
||||
|
||||
func (config Config) Matchers() (Matchers, error) {
|
||||
all := make(Matchers)
|
||||
if config.Domain != nil {
|
||||
all["domain"] = make(map[string]Matcher)
|
||||
for _, domain := range config.Domain {
|
||||
m, err := domain.Matcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("matcher domain %q invalid: %w", domain.Name, err)
|
||||
}
|
||||
all["domain"][domain.Name] = m
|
||||
}
|
||||
}
|
||||
if config.Network != nil {
|
||||
all["network"] = make(map[string]Matcher)
|
||||
for _, network := range config.Network {
|
||||
m, err := network.Matcher(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("matcher network %q invalid: %w", network.Name, err)
|
||||
}
|
||||
all["network"][network.Name] = m
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Name string `hcl:"name,label"`
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
type contentHeader struct {
|
||||
Key string `hcl:"name"`
|
||||
Value string `hcl:"value,optional"`
|
||||
List []string `hcl:"list,optional"`
|
||||
name string
|
||||
keyRe *regexp.Regexp
|
||||
valueRe *regexp.Regexp
|
||||
}
|
||||
|
||||
func (m contentHeader) Name() string { return m.name }
|
||||
func (m contentHeader) MatchesResponse(r *http.Response) bool {
|
||||
for k, vv := range r.Header {
|
||||
if m.keyRe.MatchString(k) {
|
||||
for _, v := range vv {
|
||||
if slices.Contains(m.List, v) {
|
||||
return true
|
||||
}
|
||||
if m.valueRe != nil && m.valueRe.MatchString(v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type contentType struct {
|
||||
List []string `hcl:"list"`
|
||||
name string
|
||||
}
|
||||
|
||||
func (m contentType) Name() string { return m.name }
|
||||
func (m contentType) MatchesResponse(r *http.Response) bool {
|
||||
return slices.Contains(m.List, r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
type contentSizeLargerThan struct {
|
||||
Size int64 `hcl:"size"`
|
||||
name string
|
||||
}
|
||||
|
||||
func (m contentSizeLargerThan) Name() string { return m.name }
|
||||
func (m contentSizeLargerThan) MatchesResponse(r *http.Response) bool {
|
||||
size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return size >= m.Size
|
||||
}
|
||||
|
||||
type contentStatus struct {
|
||||
Code []int `hcl:"code"`
|
||||
name string
|
||||
}
|
||||
|
||||
func (m contentStatus) Name() string { return m.name }
|
||||
func (m contentStatus) MatchesResponse(r *http.Response) bool {
|
||||
return slices.Contains(m.Code, r.StatusCode)
|
||||
}
|
||||
|
||||
func (config Content) Matcher() (Response, error) {
|
||||
switch strings.ToLower(config.Type) {
|
||||
case "content", "contenttype", "content-type", "type":
|
||||
var matcher = contentType{name: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
case "header":
|
||||
var (
|
||||
matcher = contentHeader{name: config.Name}
|
||||
err error
|
||||
)
|
||||
if err = gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if matcher.Value == "" && len(matcher.List) == 0 {
|
||||
return nil, fmt.Errorf("invalid content %q: must contain either list or value", config.Name)
|
||||
}
|
||||
if matcher.keyRe, err = regexp.Compile(matcher.Key); err != nil {
|
||||
return nil, fmt.Errorf("invalid regular expression on content %q key: %w", config.Name, err)
|
||||
}
|
||||
if matcher.Value != "" {
|
||||
if matcher.valueRe, err = regexp.Compile(matcher.Value); err != nil {
|
||||
return nil, fmt.Errorf("invalid regular expression on content %q value: %w", config.Name, err)
|
||||
}
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
case "size":
|
||||
var matcher = contentSizeLargerThan{name: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
case "status":
|
||||
var matcher = contentStatus{name: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown content matcher type %q", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type Domain struct {
|
||||
Name string `hcl:"name,label"`
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
func (config Domain) Matcher() (Request, error) {
|
||||
switch config.Type {
|
||||
case "list":
|
||||
var matcher = domainList{Title: config.Name}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matcher.list = netutil.NewDomainList(matcher.List...)
|
||||
return matcher, nil
|
||||
|
||||
case "adblock", "dnsmasq", "hosts", "detect", "domains":
|
||||
var matcher = DomainFile{
|
||||
Title: config.Name,
|
||||
Type: config.Type,
|
||||
}
|
||||
if err := gohcl.DecodeBody(config.Body, nil, &matcher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if matcher.Path == "" && matcher.From == "" {
|
||||
return nil, fmt.Errorf("matcher: domain %q must have either file or from configured", config.Name)
|
||||
}
|
||||
if err := matcher.Update(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return matcher, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown domain matcher type %q", config.Type)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type domainList struct {
|
||||
Title string `json:"title"`
|
||||
List []string `hcl:"list" json:"list"`
|
||||
list *netutil.DomainTree
|
||||
}
|
||||
|
||||
func (m domainList) Name() string {
|
||||
return m.Title
|
||||
}
|
||||
|
||||
func (m domainList) MatchesRequest(r *http.Request) bool {
|
||||
host := netutil.Host(r.URL.Host)
|
||||
log.Debug().Str("host", host).Msgf("match domain list (%d domains)", len(m.List))
|
||||
return m.list.Contains(host)
|
||||
}
|
||||
|
||||
type DomainFile struct {
|
||||
Title string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Path string `hcl:"path,optional" json:"path,omitempty"`
|
||||
From string `hcl:"from,optional" json:"from,omitempty"`
|
||||
Refresh time.Duration `hcl:"refresh,optional" json:"refresh"`
|
||||
}
|
||||
|
||||
func (m DomainFile) Name() string {
|
||||
return m.Title
|
||||
}
|
||||
|
||||
func (m DomainFile) MatchesRequest(_ *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *DomainFile) Update() (err error) {
|
||||
var data []byte
|
||||
if m.Path != "" {
|
||||
if data, err = os.ReadFile(m.Path); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
var response *http.Response
|
||||
if response, err = http.DefaultClient.Get(m.From); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = response.Body.Close() }()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("match: domain %q update failed: %s", m.name, response.Status)
|
||||
}
|
||||
if data, err = io.ReadAll(response.Body); err != nil {
|
||||
return
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case "hosts":
|
||||
}
|
||||
|
||||
_ = data
|
||||
return nil
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
Name string `hcl:"name,label"`
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
func (config *Network) Matcher(target bool) (Matcher, error) {
|
||||
switch config.Type {
|
||||
case "list":
|
||||
var (
|
||||
matcher = networkList{Title: config.Name}
|
||||
err error
|
||||
)
|
||||
if diag := gohcl.DecodeBody(config.Body, nil, &matcher); diag.HasErrors() {
|
||||
return nil, diag
|
||||
}
|
||||
if matcher.tree, err = netutil.NewNetworkTree(matcher.List...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &matcher, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown network matcher type %q", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type networkList struct {
|
||||
Title string `json:"name"`
|
||||
List []string `hcl:"list" json:"list"`
|
||||
tree *netutil.NetworkTree
|
||||
target bool
|
||||
}
|
||||
|
||||
func (m *networkList) Name() string {
|
||||
return m.Title
|
||||
}
|
||||
|
||||
func (m *networkList) MatchesIP(ip net.IP) bool {
|
||||
return m.tree.Contains(ip)
|
||||
}
|
||||
|
||||
func (m *networkList) MatchesRequest(r *http.Request) bool {
|
||||
var (
|
||||
host string
|
||||
err error
|
||||
)
|
||||
if m.target {
|
||||
host, _, err = net.SplitHostPort(r.URL.Host)
|
||||
} else {
|
||||
host, _, err = net.SplitHostPort(r.RemoteAddr)
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return m.MatchesIP(ip)
|
||||
}
|
45
proxy/match/match.go
Normal file
45
proxy/match/match.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package match
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Matchers map[string]map[string]Matcher
|
||||
|
||||
func (all Matchers) Get(kind, name string) (m Matcher, err error) {
|
||||
if typeMatchers, ok := all[kind]; ok {
|
||||
if m, ok = typeMatchers[name]; ok {
|
||||
return
|
||||
}
|
||||
return nil, fmt.Errorf("no %s matcher named %q found", kind, name)
|
||||
}
|
||||
return nil, fmt.Errorf("no %s matcher found", kind)
|
||||
}
|
||||
|
||||
type Matcher interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
type Updater interface {
|
||||
Update() error
|
||||
}
|
||||
|
||||
type IP interface {
|
||||
Matcher
|
||||
|
||||
MatchesIP(net.IP) bool
|
||||
}
|
||||
|
||||
type Request interface {
|
||||
Matcher
|
||||
|
||||
MatchesRequest(*http.Request) bool
|
||||
}
|
||||
|
||||
type Response interface {
|
||||
Matcher
|
||||
|
||||
MatchesResponse(*http.Response) bool
|
||||
}
|
11
proxy/match/util.go
Normal file
11
proxy/match/util.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package match
|
||||
|
||||
import "net"
|
||||
|
||||
func onlyHost(name string) string {
|
||||
host, _, err := net.SplitHostPort(name)
|
||||
if err != nil {
|
||||
return name
|
||||
}
|
||||
return host
|
||||
}
|
231
proxy/mitm/authority.go
Normal file
231
proxy/mitm/authority.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultValidity = 24 * time.Hour
|
||||
|
||||
type Authority interface {
|
||||
Certificate() *x509.Certificate
|
||||
TLSConfig(name string) *tls.Config
|
||||
}
|
||||
|
||||
type authority struct {
|
||||
pool *x509.CertPool
|
||||
cert *x509.Certificate
|
||||
key crypto.PrivateKey
|
||||
keyID []byte
|
||||
keyPool chan crypto.PrivateKey
|
||||
cache Cache
|
||||
}
|
||||
|
||||
func New(config *Config) (Authority, error) {
|
||||
cache, err := NewCache(config.Cache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caConfig := config.CA
|
||||
if caConfig == nil {
|
||||
caConfig = new(CAConfig)
|
||||
}
|
||||
|
||||
cert, key, err := cryptutil.LoadKeyPair(caConfig.Cert, caConfig.Key)
|
||||
if os.IsNotExist(err) {
|
||||
days := caConfig.Days
|
||||
if days == 0 {
|
||||
days = DefaultDays
|
||||
}
|
||||
if cert, key, err = cryptutil.GenerateKeyPair(caConfig.DN(), days, caConfig.KeyType, caConfig.Bits); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ContainsRune(caConfig.Cert, os.PathSeparator) {
|
||||
if err = cryptutil.SaveKeyPair(cert, key, caConfig.Cert, caConfig.Key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
keyConfig := config.Key
|
||||
if keyConfig == nil {
|
||||
keyConfig = &defaultKeyConfig
|
||||
}
|
||||
|
||||
keyPoolSize := defaultKeyConfig.Pool
|
||||
if keyConfig.Pool > 0 {
|
||||
keyPoolSize = keyConfig.Pool
|
||||
}
|
||||
keyPool := make(chan crypto.PrivateKey, keyPoolSize)
|
||||
if key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits); err != nil {
|
||||
return nil, fmt.Errorf("mitm: invalid key configuration: %w", err)
|
||||
} else {
|
||||
keyPool <- key
|
||||
}
|
||||
|
||||
go func(pool chan<- crypto.PrivateKey) {
|
||||
for {
|
||||
key, err := cryptutil.GeneratePrivateKey(keyConfig.Type, keyConfig.Bits)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msg("error generating private key")
|
||||
}
|
||||
pool <- key
|
||||
}
|
||||
}(keyPool)
|
||||
|
||||
return &authority{
|
||||
pool: pool,
|
||||
cert: cert,
|
||||
key: key,
|
||||
keyID: cryptutil.GenerateKeyID(cryptutil.PublicKey(key)),
|
||||
keyPool: keyPool,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ca *authority) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("ca", ca.cert.Subject.String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ca *authority) Certificate() *x509.Certificate {
|
||||
return ca.cert
|
||||
}
|
||||
|
||||
func (ca *authority) TLSConfig(name string) *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
log := ca.log()
|
||||
if hello.ServerName != "" {
|
||||
name = strings.ToLower(hello.ServerName)
|
||||
log.Debug().Msg("requesting certificate for server name (SNI)")
|
||||
} else {
|
||||
log.Debug().Msg("requesting certificate for hostname")
|
||||
}
|
||||
if cert, ok := ca.getCached(name); ok {
|
||||
log.Debug().
|
||||
Str("subject", cert.Leaf.Subject.String()).
|
||||
Str("serial", cert.Leaf.SerialNumber.String()).
|
||||
Time("valid", cert.Leaf.NotAfter).
|
||||
Msg("using cached certificate")
|
||||
return cert, nil
|
||||
}
|
||||
return ca.issueFor(name)
|
||||
},
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
}
|
||||
|
||||
func (ca *authority) getCached(name string) (cert *tls.Certificate, ok bool) {
|
||||
log := ca.log()
|
||||
|
||||
if cert = ca.cache.Certificate(name); cert == nil {
|
||||
if baseDomain(name) != name {
|
||||
cert = ca.cache.Certificate(baseDomain(name))
|
||||
}
|
||||
}
|
||||
if cert != nil {
|
||||
if _, err := cert.Leaf.Verify(x509.VerifyOptions{
|
||||
DNSName: name,
|
||||
Roots: ca.pool,
|
||||
}); err != nil {
|
||||
log.Debug().Err(err).Str("name", name).Msg("deleting invalid certificate from cache")
|
||||
} else {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ca *authority) issueFor(name string) (*tls.Certificate, error) {
|
||||
var (
|
||||
log = ca.log().With().Str("name", name).Logger()
|
||||
key crypto.PrivateKey
|
||||
)
|
||||
select {
|
||||
case key = <-ca.keyPool:
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, errors.New("mitm: timeout waiting for private key generator to catch up")
|
||||
}
|
||||
if key == nil {
|
||||
panic("key pool returned nil key")
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
name = strings.Join(part[1:], ".")
|
||||
log.Debug().Msgf("abbreviated name to %s (*.%s)", name, name)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{name, "*." + name},
|
||||
BasicConstraintsValid: true,
|
||||
NotBefore: now.Add(-DefaultValidity),
|
||||
NotAfter: now.Add(+DefaultValidity),
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, ca.cert, cryptutil.PublicKey(key), ca.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Str("serial", serialNumber.String()).Msg("generated certificate")
|
||||
out := &tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
//ca.cache[name] = out
|
||||
ca.cache.SaveCertificate(name, out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func containsValidCertificate(cert *tls.Certificate) bool {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if cert.Leaf == nil {
|
||||
var err error
|
||||
if cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
return !(cert.Leaf.NotBefore.Before(now) || cert.Leaf.NotAfter.After(now))
|
||||
}
|
233
proxy/mitm/cache.go
Normal file
233
proxy/mitm/cache.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"git.maze.io/maze/styx/internal/cryptutil"
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Certificate(name string) *tls.Certificate
|
||||
SaveCertificate(name string, cert *tls.Certificate) error
|
||||
RemoveCertificate(name string)
|
||||
}
|
||||
|
||||
func NewCache(config *CacheConfig) (Cache, error) {
|
||||
if config == nil {
|
||||
return NewCache(&CacheConfig{Type: "memory"})
|
||||
}
|
||||
switch config.Type {
|
||||
case "memory":
|
||||
var cacheConfig = new(MemoryCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMemoryCache(cacheConfig.Size), nil
|
||||
case "disk":
|
||||
var cacheConfig = new(DiskCacheConfig)
|
||||
if err := gohcl.DecodeBody(config.Body, nil, cacheConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDiskCache(cacheConfig.Path, time.Duration(cacheConfig.Expire*float64(time.Second)))
|
||||
default:
|
||||
return nil, fmt.Errorf("mitm: cache type %q is not supported", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type memoryCache struct {
|
||||
cache *expirable.LRU[string, *tls.Certificate]
|
||||
}
|
||||
|
||||
func NewMemoryCache(size int) Cache {
|
||||
return memoryCache{
|
||||
cache: expirable.NewLRU(size, func(key string, value *tls.Certificate) {
|
||||
log.Debug().Str("name", key).Msg("certificate evicted from cache")
|
||||
}, time.Hour*24),
|
||||
}
|
||||
}
|
||||
|
||||
func (c memoryCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
var ok bool
|
||||
if cert, ok = c.cache.Get(name); !ok {
|
||||
cert, _ = c.cache.Get(baseDomain(name))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c memoryCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
c.cache.Add(name, cert)
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c memoryCache) RemoveCertificate(name string) {
|
||||
c.cache.Remove(name)
|
||||
}
|
||||
|
||||
type diskCache string
|
||||
|
||||
func NewDiskCache(dir string, expire time.Duration) (Cache, error) {
|
||||
if !filepath.IsAbs(dir) {
|
||||
var err error
|
||||
if dir, err = filepath.Abs(dir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.Mode()&os.ModePerm|0o057 != 0 {
|
||||
if err := os.Chmod(dir, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if expire > 0 {
|
||||
go expireDiskCache(dir, expire)
|
||||
}
|
||||
|
||||
return diskCache(dir), nil
|
||||
}
|
||||
|
||||
func expireDiskCache(root string, expire time.Duration) {
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("disk cache expire loop starting")
|
||||
ticker := time.NewTicker(expire)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
now := <-ticker.C
|
||||
log.Debug().Str("path", root).Dur("expire", expire).Msg("expire disk cache")
|
||||
filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
// Remove the directory; this will fail if it's not empty, which is fine.
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.LoadCertificate(path)
|
||||
if err != nil {
|
||||
log.Debug().Str("path", path).Err(err).Msg("expire removing invalid certificate file")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
} else if cert.NotAfter.Before(now) {
|
||||
log.Debug().Str("path", path).Dur("expired", now.Sub(cert.NotAfter)).Msg("expire removing expired certificate")
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c diskCache) path(name string) string {
|
||||
part := dns.SplitDomainName(strings.ToLower(name))
|
||||
// x,com -> com,x
|
||||
// www,maze,io -> io,maze,www
|
||||
slices.Reverse(part)
|
||||
// com,x -> com,x,x.com
|
||||
// io,maze,www -> io,m,ma,maze,www.maze.io
|
||||
if len(part) > 2 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
} else if len(part) > 1 {
|
||||
if len(part[1]) > 1 {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
part[1][:2],
|
||||
name,
|
||||
}
|
||||
} else {
|
||||
part = []string{
|
||||
part[0],
|
||||
part[1][:1],
|
||||
name,
|
||||
}
|
||||
}
|
||||
}
|
||||
part[len(part)-1] += ".crt"
|
||||
return filepath.Join(append([]string{string(c)}, part...)...)
|
||||
}
|
||||
|
||||
func (c diskCache) Certificate(name string) (cert *tls.Certificate) {
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(name), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
if cert, key, err := cryptutil.LoadKeyPair(c.path(baseDomain(name)), ""); err == nil {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
Leaf: cert,
|
||||
PrivateKey: key,
|
||||
}
|
||||
}
|
||||
log.Debug().Str("path", string(c)).Str("name", name).Msg("cache miss")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) SaveCertificate(name string, cert *tls.Certificate) error {
|
||||
dir, name := filepath.Split(c.path(name))
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cryptutil.SaveKeyPair(cert.Leaf, cert.PrivateKey, filepath.Join(dir, name), ""); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Str("name", name).Msg("certificate added to cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c diskCache) RemoveCertificate(name string) {
|
||||
path := c.path(name)
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("certificate remove from cache failed")
|
||||
}
|
||||
_ = os.Remove(filepath.Dir(path))
|
||||
log.Debug().Str("name", name).Msg("certificate removed from cache")
|
||||
}
|
||||
|
||||
func baseDomain(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
if part := dns.SplitDomainName(name); len(part) > 2 {
|
||||
return strings.Join(part[1:], ".")
|
||||
}
|
||||
return name
|
||||
}
|
25
proxy/mitm/cache_test.go
Normal file
25
proxy/mitm/cache_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package mitm
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDiskCachePath(t *testing.T) {
|
||||
cache := diskCache("testdata")
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{"x.com", "testdata/com/x/x.com.crt"},
|
||||
{"feed.x.com", "testdata/com/x/x/feed.x.com.crt"},
|
||||
{"nu.nl", "testdata/nl/n/nu/nu.nl.crt"},
|
||||
{"maze.io", "testdata/io/m/ma/maze.io.crt"},
|
||||
{"lab.maze.io", "testdata/io/m/ma/maze/lab.maze.io.crt"},
|
||||
{"dev.lab.maze.io", "testdata/io/m/ma/maze/dev.lab.maze.io.crt"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.test, func(it *testing.T) {
|
||||
if v := cache.path(test.test); v != test.want {
|
||||
it.Errorf("expected %q to resolve to %q, got %q", test.test, test.want, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
89
proxy/mitm/config.go
Normal file
89
proxy/mitm/config.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package mitm
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCommonName = "Styx Certificate Authority"
|
||||
DefaultDays = 3
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
CA *CAConfig `hcl:"ca,block"`
|
||||
Key *KeyConfig `hcl:"key,block"`
|
||||
Cache *CacheConfig `hcl:"cache,block"`
|
||||
}
|
||||
|
||||
type CAConfig struct {
|
||||
Cert string `hcl:"cert"`
|
||||
Key string `hcl:"key,optional"`
|
||||
Days int `hcl:"days,optional"`
|
||||
KeyType string `hcl:"key_type,optional"`
|
||||
Bits int `hcl:"bits,optional"`
|
||||
Name string `hcl:"name,optional"`
|
||||
Country string `hcl:"country,optional"`
|
||||
Organization string `hcl:"organization,optional"`
|
||||
Unit string `hcl:"unit,optional"`
|
||||
Locality string `hcl:"locality,optional"`
|
||||
Province string `hcl:"province,optional"`
|
||||
Address []string `hcl:"address,optional"`
|
||||
PostalCode string `hcl:"postal_code,optional"`
|
||||
}
|
||||
|
||||
func (config CAConfig) DN() pkix.Name {
|
||||
var name = pkix.Name{
|
||||
CommonName: config.Name,
|
||||
StreetAddress: config.Address,
|
||||
}
|
||||
if config.Name == "" {
|
||||
name.CommonName = DefaultCommonName
|
||||
}
|
||||
if config.Country != "" {
|
||||
name.Country = append(name.Country, config.Country)
|
||||
}
|
||||
if config.Organization != "" {
|
||||
name.Organization = append(name.Organization, config.Organization)
|
||||
}
|
||||
if config.Unit != "" {
|
||||
name.OrganizationalUnit = append(name.OrganizationalUnit, config.Unit)
|
||||
}
|
||||
if config.Locality != "" {
|
||||
name.Locality = append(name.Locality, config.Locality)
|
||||
}
|
||||
if config.Province != "" {
|
||||
name.Province = append(name.Province, config.Province)
|
||||
}
|
||||
if config.PostalCode != "" {
|
||||
name.PostalCode = append(name.PostalCode, config.PostalCode)
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
type KeyConfig struct {
|
||||
Type string `hcl:"type,optional"`
|
||||
Bits int `hcl:"bits,optional"`
|
||||
Pool int `hcl:"pool,optional"`
|
||||
}
|
||||
|
||||
var defaultKeyConfig = KeyConfig{
|
||||
Type: "rsa",
|
||||
Bits: 2048,
|
||||
Pool: 5,
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
Type string `hcl:"type"`
|
||||
Body hcl.Body `hcl:",remain"`
|
||||
}
|
||||
|
||||
type MemoryCacheConfig struct {
|
||||
Size int `hcl:"size,optional"`
|
||||
}
|
||||
|
||||
type DiskCacheConfig struct {
|
||||
Path string `hcl:"path"`
|
||||
Expire float64 `hcl:"expire,optional"`
|
||||
}
|
53
proxy/policy/policy.go
Normal file
53
proxy/policy/policy.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.maze.io/maze/styx/proxy/match"
|
||||
)
|
||||
|
||||
// Policy contains rules that make up the policy.
|
||||
//
|
||||
// Some policy rules contain nested policies.
|
||||
type Policy struct {
|
||||
Rules []*rawRule `hcl:"on,block" json:"rules"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
Matchers match.Matchers `json:"matchers"` // Matchers for the policy
|
||||
|
||||
}
|
||||
|
||||
func (p *Policy) Configure(matchers match.Matchers) (err error) {
|
||||
for _, r := range p.Rules {
|
||||
if err = r.Configure(matchers); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
p.Matchers = matchers
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Policy) PermitIntercept(r *http.Request) *bool {
|
||||
if p != nil {
|
||||
for _, rule := range p.Rules {
|
||||
if rule, ok := rule.Rule.(InterceptRule); ok {
|
||||
if permit := rule.PermitIntercept(r); permit != nil {
|
||||
return permit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return p.Permit
|
||||
}
|
||||
|
||||
func (p *Policy) PermitRequest(r *http.Request) *bool {
|
||||
if p != nil {
|
||||
for _, rule := range p.Rules {
|
||||
if rule, ok := rule.Rule.(RequestRule); ok {
|
||||
if permit := rule.PermitRequest(r); permit != nil {
|
||||
return permit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return p.Permit
|
||||
}
|
139
proxy/policy/policy_test.go
Normal file
139
proxy/policy/policy_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/proxy/match"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type testInDomainList struct {
|
||||
t *testing.T
|
||||
list []string
|
||||
}
|
||||
|
||||
func (testInDomainList) Name() string { return "testInDomainList" }
|
||||
func (l testInDomainList) MatchesRequest(r *http.Request) bool {
|
||||
for _, domain := range l.list {
|
||||
if dns.IsSubDomain(domain, netutil.Host(r.URL.Host)) {
|
||||
l.t.Logf("domain %s contains %s", domain, r.URL.Host)
|
||||
return true
|
||||
}
|
||||
l.t.Logf("domain %s does not contain %s", domain, r.URL.Host)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func testInDomain(t *testing.T, domains ...string) match.Matcher {
|
||||
return &testInDomainList{t: t, list: domains}
|
||||
}
|
||||
|
||||
type testInNetworkList struct {
|
||||
t *testing.T
|
||||
list []*net.IPNet
|
||||
}
|
||||
|
||||
func (testInNetworkList) Name() string { return "testInNetworkList" }
|
||||
func (l testInNetworkList) MatchesIP(ip net.IP) bool {
|
||||
for _, ipnet := range l.list {
|
||||
if ipnet.Contains(ip) {
|
||||
l.t.Logf("network %s contains %s", ipnet, ip)
|
||||
return true
|
||||
}
|
||||
l.t.Logf("network %s does not contain %s", ipnet, ip)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func testInNetwork(t *testing.T, cidr string) match.Matcher {
|
||||
t.Helper()
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return testInNetworkList{t: t, list: []*net.IPNet{ipnet}}
|
||||
}
|
||||
|
||||
func TestPolicy(t *testing.T) {
|
||||
var (
|
||||
yes = true
|
||||
nope = false
|
||||
)
|
||||
p := &Policy{
|
||||
Rules: []*rawRule{
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
|
||||
isSource: []bool{true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInNetwork(t, "127.0.0.0/8")},
|
||||
isSource: []bool{false},
|
||||
},
|
||||
Permit: &yes,
|
||||
},
|
||||
},
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInDomain(t, "maze.io", "maze.engineering")},
|
||||
},
|
||||
Permit: &yes,
|
||||
},
|
||||
},
|
||||
{
|
||||
Rule: &requestRule{
|
||||
domainOrNetworkRule: domainOrNetworkRule{
|
||||
matchers: []match.Matcher{testInDomain(t, "google.com")},
|
||||
},
|
||||
Permit: &nope,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
r := &http.Request{
|
||||
URL: &url.URL{Scheme: "http", Host: "golang.org:80"},
|
||||
RemoteAddr: "127.0.0.1:1234",
|
||||
}
|
||||
if v := p.PermitRequest(r); v != nil {
|
||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
||||
}
|
||||
|
||||
p.Rules[0].Rule.(*requestRule).Permit = &yes
|
||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
||||
t.Errorf("expected request to return %t, %v", yes, v)
|
||||
}
|
||||
|
||||
r.RemoteAddr = "192.168.1.2:3456"
|
||||
if v := p.PermitRequest(r); v != nil {
|
||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
||||
}
|
||||
if v := p.PermitIntercept(r); v != nil {
|
||||
t.Errorf("expected request to return no verdict, got %t", *v)
|
||||
}
|
||||
|
||||
r.URL.Host = "maze.io"
|
||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
||||
t.Errorf("expected request to return %t, %v", yes, v)
|
||||
}
|
||||
|
||||
r.URL.Host = "google.com"
|
||||
if v := p.PermitRequest(r); v == nil || *v != nope {
|
||||
t.Errorf("expected request to return %t, %v", nope, v)
|
||||
}
|
||||
|
||||
r.URL.Host = "localhost:80"
|
||||
if v := p.PermitRequest(r); v == nil || *v != yes {
|
||||
t.Errorf("expected request to return %t, %v", yes, v)
|
||||
}
|
||||
}
|
368
proxy/policy/rule.go
Normal file
368
proxy/policy/rule.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/proxy/match"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/hashicorp/hcl/v2/gohcl"
|
||||
)
|
||||
|
||||
// Rule is a policy rule.
|
||||
type Rule interface {
|
||||
Configure(match.Matchers) error
|
||||
}
|
||||
|
||||
// InterceptRule can make policy rule decisions on intercept requests.
|
||||
type InterceptRule interface {
|
||||
PermitIntercept(r *http.Request) *bool
|
||||
}
|
||||
|
||||
// RequestRule can make policy rule decisions on HTTP CONNECT requests.
|
||||
type RequestRule interface {
|
||||
PermitRequest(r *http.Request) *bool
|
||||
}
|
||||
|
||||
type rawRule struct {
|
||||
Type string `hcl:"type,label" json:"type"`
|
||||
Body hcl.Body `hcl:",remain" json:"-"`
|
||||
Rule `json:"rule"`
|
||||
}
|
||||
|
||||
func (r *rawRule) Configure(matchers match.Matchers) (err error) {
|
||||
switch r.Type {
|
||||
case "intercept":
|
||||
r.Rule = new(interceptRule)
|
||||
case "request":
|
||||
r.Rule = new(requestRule)
|
||||
case "days":
|
||||
r.Rule = new(daysRule)
|
||||
case "time":
|
||||
r.Rule = new(timeRule)
|
||||
case "all":
|
||||
r.Rule = new(allRule)
|
||||
default:
|
||||
return fmt.Errorf("policy: invalid event type %q", r.Type)
|
||||
}
|
||||
|
||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rule); diag.HasErrors() {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Rule.Configure(matchers)
|
||||
}
|
||||
|
||||
type allRule struct {
|
||||
Rules []*rawRule `hcl:"on,block"`
|
||||
Permit *bool `hcl:"permit"`
|
||||
}
|
||||
|
||||
func (r *allRule) Configure(matchers match.Matchers) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
type domainOrNetworkRule struct {
|
||||
matchers []match.Matcher
|
||||
isSource []bool
|
||||
}
|
||||
|
||||
func (r *domainOrNetworkRule) configure(kind string, matchers match.Matchers, domains, sources, targets []string, v any, id *string) (err error) {
|
||||
var m match.Matcher
|
||||
for _, domain := range domains {
|
||||
if m, err = matchers.Get("domain", domain); err != nil {
|
||||
return fmt.Errorf("%s: unknown domain %q", kind, domain)
|
||||
}
|
||||
r.matchers = append(r.matchers, m)
|
||||
r.isSource = append(r.isSource, false)
|
||||
}
|
||||
for _, network := range sources {
|
||||
if m, err = matchers.Get("network", network); err != nil {
|
||||
return fmt.Errorf("%s: unknown source network %q", kind, network)
|
||||
}
|
||||
r.matchers = append(r.matchers, m)
|
||||
r.isSource = append(r.isSource, true)
|
||||
}
|
||||
for _, network := range targets {
|
||||
if m, err = matchers.Get("network", network); err != nil {
|
||||
return fmt.Errorf("%s: unknown target network %q", kind, network)
|
||||
}
|
||||
r.matchers = append(r.matchers, m)
|
||||
r.isSource = append(r.isSource, false)
|
||||
}
|
||||
if len(r.matchers) == 0 {
|
||||
return fmt.Errorf("%s: missing any of domain, source, target", kind)
|
||||
}
|
||||
if id != nil {
|
||||
*id = uuid.NewString()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *domainOrNetworkRule) matchesRequest(q *http.Request) bool {
|
||||
for i, m := range r.matchers {
|
||||
if m, ok := m.(match.Request); ok {
|
||||
if m.MatchesRequest(q) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if m, ok := m.(match.IP); ok {
|
||||
if r.isSource[i] {
|
||||
if m.MatchesIP(net.ParseIP(netutil.Host(q.RemoteAddr))) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
var (
|
||||
host = netutil.Host(q.URL.Host)
|
||||
ips []net.IP
|
||||
)
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
ips = append(ips, ip)
|
||||
} else {
|
||||
ips, _ = net.LookupIP(host)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if m.MatchesIP(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type interceptRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
|
||||
Source []string `hcl:"source,optional" json:"source,omitempty"`
|
||||
Target []string `hcl:"target,optional" json:"target,omitempty"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
domainOrNetworkRule `json:"-"`
|
||||
}
|
||||
|
||||
func (r *interceptRule) Configure(matchers match.Matchers) (err error) {
|
||||
return r.configure("intercept", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
|
||||
}
|
||||
|
||||
func (r *interceptRule) PermitIntercept(q *http.Request) *bool {
|
||||
if r.matchesRequest(q) {
|
||||
return r.Permit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type requestRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Domain []string `hcl:"domain,optional" json:"domain,omitempty"`
|
||||
Source []string `hcl:"source,optional" json:"source,omitempty"`
|
||||
Target []string `hcl:"target,optional" json:"target,omitempty"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
domainOrNetworkRule `json:"-"`
|
||||
}
|
||||
|
||||
func (r *requestRule) Configure(matchers match.Matchers) (err error) {
|
||||
return r.configure("request", matchers, r.Domain, r.Source, r.Target, r, &r.ID)
|
||||
}
|
||||
|
||||
func (r *requestRule) PermitRequest(q *http.Request) *bool {
|
||||
if r.matchesRequest(q) {
|
||||
return r.Permit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type timeRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Time []string `hcl:"time" json:"time"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
Body hcl.Body `hcl:",remain" json:"-"`
|
||||
Rules *Policy `json:"rules"`
|
||||
Start Time `json:"start"`
|
||||
End Time `json:"end"`
|
||||
}
|
||||
|
||||
func (r *timeRule) isActive() bool {
|
||||
if r == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
now := Now()
|
||||
if r.Start.After(r.End) { // ie: 18:00-06:00
|
||||
return now.After(r.Start) || now.Before(r.End)
|
||||
}
|
||||
return now.After(r.Start) && now.Before(r.End)
|
||||
}
|
||||
|
||||
func (r *timeRule) Configure(matchers match.Matchers) (err error) {
|
||||
if len(r.Time) != 2 {
|
||||
return fmt.Errorf("invalid time %s, need [start, stop]", r.Time)
|
||||
}
|
||||
if r.Start, err = ParseTime(r.Time[0]); err != nil {
|
||||
return fmt.Errorf("invalid start %q: %w", r.Time[0], err)
|
||||
}
|
||||
if r.End, err = ParseTime(r.Time[1]); err != nil {
|
||||
return fmt.Errorf("invalid end %q: %w", r.Time[1], err)
|
||||
}
|
||||
|
||||
r.Rules = new(Policy)
|
||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
|
||||
return diag
|
||||
}
|
||||
|
||||
if err = r.Rules.Configure(matchers); err != nil {
|
||||
return
|
||||
}
|
||||
r.Rules.Matchers = nil
|
||||
|
||||
if r.ID == "" {
|
||||
r.ID = uuid.NewString()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *timeRule) PermitIntercept(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitIntercept(q)
|
||||
}
|
||||
|
||||
func (r *timeRule) PermitRequest(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitRequest(q)
|
||||
}
|
||||
|
||||
type daysRule struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Days string `hcl:"days" json:"days"`
|
||||
Permit *bool `hcl:"permit" json:"permit"`
|
||||
Body hcl.Body `hcl:",remain" json:"-"`
|
||||
Rules *Policy `json:"rules"`
|
||||
cond []onCond
|
||||
}
|
||||
|
||||
func (r *daysRule) isActive() bool {
|
||||
if r == nil || len(r.cond) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, cond := range r.cond {
|
||||
if cond(now) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *daysRule) Configure(matchers match.Matchers) (err error) {
|
||||
if r.cond, err = parseOnCond(r.Days); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.Rules = new(Policy)
|
||||
if diag := gohcl.DecodeBody(r.Body, nil, r.Rules); diag.HasErrors() {
|
||||
return diag
|
||||
}
|
||||
if err = r.Rules.Configure(matchers); err != nil {
|
||||
return
|
||||
}
|
||||
r.Rules.Matchers = nil
|
||||
|
||||
if r.ID == "" {
|
||||
r.ID = uuid.NewString()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *daysRule) PermitIntercept(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitIntercept(q)
|
||||
}
|
||||
|
||||
func (r *daysRule) PermitRequest(q *http.Request) *bool {
|
||||
if !r.isActive() {
|
||||
return nil
|
||||
}
|
||||
return r.Rules.PermitRequest(q)
|
||||
}
|
||||
|
||||
type onCond func(time.Time) bool
|
||||
|
||||
var weekdays = map[string]time.Weekday{
|
||||
"sun": time.Sunday,
|
||||
"mon": time.Monday,
|
||||
"tue": time.Tuesday,
|
||||
"wed": time.Wednesday,
|
||||
"thu": time.Thursday,
|
||||
"fri": time.Friday,
|
||||
"sat": time.Saturday,
|
||||
}
|
||||
|
||||
func parseOnCond(when string) (conds []onCond, err error) {
|
||||
for _, spec := range strings.Split(when, ",") {
|
||||
spec = strings.ToLower(strings.TrimSpace(spec))
|
||||
if d, ok := weekdays[spec]; ok {
|
||||
conds = append(conds, onWeekday(d))
|
||||
} else if spec == "weekend" || spec == "weekends" {
|
||||
conds = append(conds, onWeekend)
|
||||
} else if spec == "workday" || spec == "workdays" {
|
||||
conds = append(conds, onWorkday)
|
||||
} else if strings.ContainsRune(spec, '-') {
|
||||
var (
|
||||
part = strings.SplitN(spec, "-", 2)
|
||||
from, upto time.Weekday
|
||||
ok bool
|
||||
)
|
||||
if from, ok = weekdays[part[0]]; !ok {
|
||||
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[0])
|
||||
}
|
||||
if upto, ok = weekdays[part[1]]; !ok {
|
||||
return nil, fmt.Errorf("on %q: invalid weekday %q", spec, part[1])
|
||||
}
|
||||
if from < upto {
|
||||
for d := from; d < upto; d++ {
|
||||
conds = append(conds, onWeekday(d))
|
||||
}
|
||||
} else {
|
||||
for d := time.Sunday; d < from; d++ {
|
||||
conds = append(conds, onWeekday(d))
|
||||
}
|
||||
for d := upto; d <= time.Saturday; d++ {
|
||||
conds = append(conds, onWeekday(d))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("on %q: invalid condition", spec)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func onWeekday(weekday time.Weekday) onCond {
|
||||
return func(t time.Time) bool {
|
||||
return t.Weekday() == weekday
|
||||
}
|
||||
}
|
||||
|
||||
func onWeekend(t time.Time) bool {
|
||||
d := t.Weekday()
|
||||
return d == time.Saturday || d == time.Sunday
|
||||
}
|
||||
|
||||
func onWorkday(t time.Time) bool {
|
||||
d := t.Weekday()
|
||||
return !(d == time.Saturday || d == time.Sunday)
|
||||
}
|
53
proxy/policy/time.go
Normal file
53
proxy/policy/time.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Time struct {
|
||||
Hour int
|
||||
Minute int
|
||||
Second int
|
||||
}
|
||||
|
||||
func (t Time) Eq(other Time) bool {
|
||||
return t.Hour == other.Hour && t.Minute == other.Minute && t.Second == other.Second
|
||||
}
|
||||
|
||||
func (t Time) After(other Time) bool {
|
||||
return t.Seconds() > other.Seconds()
|
||||
}
|
||||
|
||||
func (t Time) Before(other Time) bool {
|
||||
return t.Seconds() < other.Seconds()
|
||||
}
|
||||
|
||||
func (t Time) Seconds() int {
|
||||
return t.Hour*3600 + t.Minute*60 + t.Second
|
||||
}
|
||||
|
||||
func (t Time) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf(`"%02d:%02d:%02d"`, t.Hour, t.Minute, t.Second)), nil
|
||||
}
|
||||
|
||||
var timeFormats = []string{
|
||||
time.TimeOnly,
|
||||
"15:04",
|
||||
time.Kitchen,
|
||||
}
|
||||
|
||||
func Now() Time {
|
||||
now := time.Now()
|
||||
return Time{now.Hour(), now.Minute(), now.Second()}
|
||||
}
|
||||
|
||||
func ParseTime(s string) (t Time, err error) {
|
||||
var tt time.Time
|
||||
for _, layout := range timeFormats {
|
||||
if tt, err = time.Parse(layout, s); err == nil {
|
||||
return Time{tt.Hour(), tt.Minute(), tt.Second()}, nil
|
||||
}
|
||||
}
|
||||
return Time{}, fmt.Errorf("time: invalid time %q", s)
|
||||
}
|
616
proxy/proxy.go
Normal file
616
proxy/proxy.go
Normal file
@@ -0,0 +1,616 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"git.maze.io/maze/styx/proxy/mitm"
|
||||
"git.maze.io/maze/styx/proxy/policy"
|
||||
"git.maze.io/maze/styx/proxy/resolver"
|
||||
"git.maze.io/maze/styx/proxy/stats"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultListenAddr = ":3128"
|
||||
DefaultBindAddr = ""
|
||||
DefaultDialTimeout = 30 * time.Second
|
||||
DefaultKeepAlivePeriod = 1 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
HeaderConnection = "Connection"
|
||||
HeaderContentLength = "Content-Length"
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderUpgrade = "Upgrade"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("proxy: shutdown")
|
||||
ErrClientCert = errors.New("tls: client certificate requested")
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
addr *net.TCPAddr
|
||||
bind *net.TCPAddr
|
||||
resolver resolver.Resolver
|
||||
transport *http.Transport
|
||||
dial func(network, address string) (net.Conn, error)
|
||||
config *Config
|
||||
authority mitm.Authority
|
||||
policy *policy.Policy
|
||||
admin *Admin
|
||||
stats *stats.Stats
|
||||
closed chan struct{}
|
||||
onConnect ConnectHandler
|
||||
onRequest RequestHandler
|
||||
onResponse ResponseHandler
|
||||
onError ErrorHandler
|
||||
}
|
||||
|
||||
func New(config *Config, ca mitm.Authority) (*Proxy, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("proxy: config can't be nil")
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
transport: newTransport(),
|
||||
config: config,
|
||||
resolver: resolver.Default,
|
||||
authority: ca,
|
||||
policy: config.Policy,
|
||||
closed: make(chan struct{}),
|
||||
onConnect: config.ConnectHandler,
|
||||
onRequest: config.RequestHandler,
|
||||
onResponse: config.ResponseHandler,
|
||||
onError: config.ErrorHandler,
|
||||
}
|
||||
|
||||
var err error
|
||||
if config.Listen == "" {
|
||||
p.addr, err = net.ResolveTCPAddr("tcp", DefaultBindAddr)
|
||||
} else {
|
||||
p.addr, err = net.ResolveTCPAddr("tcp", config.Listen)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: invalid listen addres: %w", err)
|
||||
}
|
||||
if config.Bind != "" {
|
||||
if p.bind, err = net.ResolveTCPAddr("tcp", config.Bind+":0"); err != nil {
|
||||
return nil, fmt.Errorf("proxy: invalid bind address: %w", err)
|
||||
}
|
||||
} else if config.Interface != "" {
|
||||
if err = resolveInterfaceAddr(config.Interface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if p.bind != nil {
|
||||
/* FIXME
|
||||
var c *net.TCPConn
|
||||
if c, err = net.DialTCP("tcp", p.bind, p.bind); err != nil && errors.Is(err, syscall.EADDRNOTAVAIL) {
|
||||
return nil, fmt.Errorf("proxy: invalid bind address: %w", syscall.EADDRNOTAVAIL)
|
||||
} else if c != nil {
|
||||
_ = c.Close()
|
||||
}
|
||||
*/
|
||||
}
|
||||
if config.Resolver != nil {
|
||||
p.resolver = config.Resolver
|
||||
}
|
||||
|
||||
dialTimeout := DefaultDialTimeout
|
||||
if config.DialTimeout > 0 {
|
||||
dialTimeout = config.DialTimeout
|
||||
}
|
||||
p.dial = (&net.Dialer{
|
||||
Timeout: dialTimeout,
|
||||
KeepAlive: dialTimeout,
|
||||
LocalAddr: p.bind,
|
||||
}).Dial
|
||||
|
||||
p.admin = NewAdmin(p)
|
||||
|
||||
if p.stats, err = stats.New(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Close() error {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return ErrClosed
|
||||
default:
|
||||
close(p.closed)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Start() error {
|
||||
l, err := net.ListenTCP("tcp", p.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go p.Serve(l)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) Serve(listener net.Listener) error {
|
||||
defer func() { _ = listener.Close() }()
|
||||
|
||||
log.Info().Str("addr", listener.Addr().String()).Msg("proxy server listening")
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
c, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
||||
ctx := newContext(c, rw, nil)
|
||||
|
||||
if c, ok := c.(*net.TCPConn); ok {
|
||||
_ = c.SetKeepAlive(true)
|
||||
_ = c.SetKeepAlivePeriod(DefaultKeepAlivePeriod)
|
||||
}
|
||||
|
||||
go p.handle(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handle(ctx *Context) {
|
||||
logger := ctx.log()
|
||||
defer log.OnCloseError(logger.Debug(), ctx.conn)
|
||||
logger.Info().Str("client", ctx.RemoteAddr().String()).Msg("new client connection")
|
||||
|
||||
last := int64(0)
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
|
||||
default:
|
||||
ses, err := p.handleRequest(ctx)
|
||||
if ses != nil {
|
||||
log := ses.log()
|
||||
log.Info().
|
||||
Str("method", ses.request.Method).
|
||||
Str("url", ses.request.URL.String()).
|
||||
Str("status", ses.response.Status).
|
||||
Int64("size", ctx.conn.bytes-last).
|
||||
Msg("handled request")
|
||||
|
||||
p.stats.AddLog(&stats.Log{
|
||||
ClientIP: netutil.Host(ses.request.RemoteAddr),
|
||||
Request: stats.FromRequest(ses.request),
|
||||
Response: stats.FromResponse(ses.response).SetSize(ctx.conn.bytes - last),
|
||||
})
|
||||
|
||||
last = ctx.conn.bytes
|
||||
}
|
||||
if err != nil && !isClosing(err) || (ses != nil && ses.response != nil && ses.response.Close) {
|
||||
event := logger.Debug()
|
||||
if ctx.conn.bytes > 0 {
|
||||
event = event.Int64("size", ctx.conn.bytes)
|
||||
}
|
||||
event.Msg("closing client connection")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handleRequest(ctx *Context) (ses *Session, err error) {
|
||||
logger := ctx.log()
|
||||
|
||||
var request *http.Request
|
||||
if request, err = p.readRequest(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ses = newSession(ctx, request)
|
||||
p.cleanRequest(ses, request)
|
||||
|
||||
logger.Debug().Str("method", request.Method).Str("url", request.URL.String()).Msg("handle request")
|
||||
|
||||
if p.onRequest != nil {
|
||||
newRequest, newResponse := p.onRequest.HandleRequest(ses)
|
||||
if newRequest != nil {
|
||||
logger.Debug().Str("method", newRequest.Method).Str("url", newRequest.URL.String()).Msg("request override")
|
||||
ses.request = newRequest
|
||||
}
|
||||
if newResponse != nil {
|
||||
logger.Debug().Str("status", newResponse.Status).Msg("response override")
|
||||
ses.response = newResponse
|
||||
}
|
||||
}
|
||||
|
||||
if ses.response == nil {
|
||||
// WebSocket request
|
||||
if ses.request.Header.Get(HeaderUpgrade) == "websocket" {
|
||||
return ses, p.handleTunnel(ses)
|
||||
}
|
||||
|
||||
cleanHopByHopHeaders(ses.request.Header)
|
||||
|
||||
// Proxy CONNECT request
|
||||
if ses.request.Method == http.MethodConnect {
|
||||
return p.handleConnect(ses)
|
||||
}
|
||||
|
||||
if netutil.Port(ses.request.URL.Host) == p.addr.Port {
|
||||
// Plain API request
|
||||
ses.request.URL.Host = ses.request.Host
|
||||
return ses, p.admin.handleRequest(ses)
|
||||
|
||||
} else if ses.response, err = p.transport.RoundTrip(ses.request); err != nil {
|
||||
// Plain HTTP request
|
||||
if p.config.ErrorHandler != nil {
|
||||
p.config.ErrorHandler.HandleError(ses, err)
|
||||
}
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
}
|
||||
|
||||
logger.Debug().Str("status", ses.response.Status).Msg("received response")
|
||||
cleanHopByHopHeaders(ses.response.Header)
|
||||
}
|
||||
|
||||
ses.response.Close = true
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
return ses, p.writeResponse(ses)
|
||||
}
|
||||
|
||||
func (p *Proxy) handleConnect(ses *Session) (next *Session, err error) {
|
||||
next = ses
|
||||
|
||||
logger := ses.log()
|
||||
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
|
||||
|
||||
var c net.Conn
|
||||
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
|
||||
logger.Error().Err(err).Msg("connect failed")
|
||||
if p.onError != nil {
|
||||
p.onError.HandleError(ses, err)
|
||||
}
|
||||
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
_ = p.writeResponse(ses)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if p.onError != nil {
|
||||
p.onError.HandleError(ses, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if p.canIntercept(ses.request) {
|
||||
logger.Debug().Msg("intercepting connection")
|
||||
ses.response = NewResponse(http.StatusOK, nil, ses.request)
|
||||
err = p.writeResponse(ses)
|
||||
log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Peek first byte
|
||||
b := make([]byte, 1)
|
||||
if _, err = io.ReadFull(ses.ctx.rw, b); err != nil {
|
||||
logger.Error().Err(err).Msg("error peeking CONNECT byte")
|
||||
return
|
||||
}
|
||||
|
||||
// Drain buffered bytes
|
||||
b = append(b, make([]byte, ses.ctx.rw.Reader.Buffered())...)
|
||||
ses.ctx.rw.Reader.Read(b[1:])
|
||||
|
||||
r := &connReader{
|
||||
Conn: ses.ctx.conn,
|
||||
Reader: io.MultiReader(bytes.NewBuffer(b), ses.ctx.conn),
|
||||
}
|
||||
if b[0] == 22 { // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1
|
||||
secure := tls.Server(r, p.authority.TLSConfig(ses.request.URL.Host))
|
||||
if err = secure.Handshake(); err != nil {
|
||||
logger.Error().Err(err).Msg("error intercepting TLS connection: client handshake failed")
|
||||
return
|
||||
}
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(secure), bufio.NewWriter(secure))
|
||||
ctx := newContext(secure, rw, ses)
|
||||
return p.handleRequest(ctx)
|
||||
}
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(r), bufio.NewWriter(r))
|
||||
ctx := newContext(r, rw, ses)
|
||||
return p.handleRequest(ctx)
|
||||
}
|
||||
|
||||
ses.response = NewResponse(http.StatusOK, nil, ses.request)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
ses.response.ContentLength = -1
|
||||
if err = p.writeResponse(ses); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug().Msg("established CONNECT tunnel, proxying traffic")
|
||||
var wait sync.WaitGroup
|
||||
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
|
||||
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
|
||||
wait.Wait()
|
||||
logger.Debug().Msg("closed CONNECT tunnel")
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Proxy) handleTunnel(ses *Session) (err error) {
|
||||
logger := ses.log()
|
||||
logger.Debug().Msgf("connecting to %s", ses.request.URL.Host)
|
||||
|
||||
var c net.Conn
|
||||
if c, err = p.connect(ses, "tcp", ses.request.URL.Host); err != nil {
|
||||
logger.Error().Err(err).Msg("connect failed")
|
||||
if p.onError != nil {
|
||||
p.onError.HandleError(ses, err)
|
||||
}
|
||||
|
||||
ses.response = ErrorResponse(ses.request, err)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
_ = p.writeResponse(ses)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer log.OnCloseError(logger.Debug(), c)
|
||||
|
||||
if ses.ctx.IsTLS() {
|
||||
// Open a TLS client connection
|
||||
secure := tls.Client(c, &tls.Config{
|
||||
ServerName: ses.request.URL.Host,
|
||||
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, ErrClientCert
|
||||
},
|
||||
})
|
||||
if err = secure.Handshake(); err != nil {
|
||||
logger.Error().Err(err).Msg("TLS handshake failed")
|
||||
return
|
||||
}
|
||||
c = secure
|
||||
}
|
||||
|
||||
if err = ses.request.Write(c); err != nil {
|
||||
logger.Error().Err(err).Msg("failed to write request")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug().Msg("established tunnel, proxying traffic")
|
||||
var wait sync.WaitGroup
|
||||
wait.Go(func() { copyStream(ses, c, ses.ctx.conn) })
|
||||
wait.Go(func() { copyStream(ses, ses.ctx.conn, c) })
|
||||
wait.Wait()
|
||||
logger.Debug().Msg("closed tunnel")
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Proxy) canIntercept(request *http.Request) bool {
|
||||
if permit := p.policy.PermitIntercept(request); permit != nil {
|
||||
return *permit
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/*
|
||||
func (p *Proxy) handleAPIRequest(ses *Session) error {
|
||||
if ses.request.URL.Path == "/ca.crt" && p.authority != nil {
|
||||
b := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: p.authority.Certificate().Raw,
|
||||
})
|
||||
|
||||
ses.response = NewResponse(http.StatusOK, bytes.NewReader(b), ses.request)
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
|
||||
ses.response.Close = true
|
||||
ses.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
|
||||
ses.response.ContentLength = int64(len(b))
|
||||
return p.writeResponse(ses)
|
||||
}
|
||||
|
||||
ses.response = ErrorResponse(ses.request, errors.New("invalid API endpoint"))
|
||||
defer log.OnCloseError(logger.Debug(), ses.response.Body)
|
||||
ses.response.Close = true
|
||||
return p.writeResponse(ses)
|
||||
}
|
||||
*/
|
||||
|
||||
func (p *Proxy) readRequest(ctx *Context) (request *http.Request, err error) {
|
||||
var (
|
||||
done = make(chan *http.Request, 1)
|
||||
errs = make(chan error, 1)
|
||||
)
|
||||
|
||||
go func() {
|
||||
r, err := http.ReadRequest(ctx.rw.Reader)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
} else {
|
||||
done <- r
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-p.closed:
|
||||
return nil, ErrClosed
|
||||
case request = <-done:
|
||||
return
|
||||
case err = <-errs:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) cleanRequest(ses *Session, request *http.Request) {
|
||||
if request.URL.Host == "" {
|
||||
request.URL.Host = request.Host
|
||||
}
|
||||
|
||||
// Ensure proper URL scheme
|
||||
if !strings.HasPrefix(request.URL.Scheme, "http") {
|
||||
request.URL.Scheme = "http"
|
||||
}
|
||||
if ses.ctx.IsTLS() {
|
||||
state := ses.ctx.conn.Conn.(*tls.Conn).ConnectionState()
|
||||
request.TLS = &state
|
||||
request.URL.Scheme = "https"
|
||||
}
|
||||
|
||||
// Ensure proper RemoteAddr
|
||||
request.RemoteAddr = ses.ctx.RemoteAddr().String()
|
||||
|
||||
// Ensure proper encoding
|
||||
if request.Header.Get(HeaderAcceptEncoding) != "" {
|
||||
// We only support gzip
|
||||
request.Header.Set(HeaderAcceptEncoding, "gzip")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) writeResponse(ses *Session) (err error) {
|
||||
log := ses.log()
|
||||
|
||||
if p.onResponse != nil {
|
||||
response := p.onResponse.HandleResponse(ses)
|
||||
if response != nil {
|
||||
log.Debug().Str("status", response.Status).Msg("response override")
|
||||
ses.response = response
|
||||
}
|
||||
}
|
||||
|
||||
if err = ses.response.Write(ses.ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error writing response back to client")
|
||||
} else if err = ses.ctx.Flush(); err != nil {
|
||||
log.Error().Err(err).Msg("error flushing response back to client")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Proxy) connect(ses *Session, network, address string) (c net.Conn, err error) {
|
||||
log := ses.log()
|
||||
log.Debug().Msgf("connect to %s://%s", network, address)
|
||||
|
||||
if p.onConnect != nil {
|
||||
if c = p.onConnect.HandleConnect(ses, network, address); c != nil {
|
||||
log.Debug().Msg("connect override")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var host, port string
|
||||
if host, port, err = net.SplitHostPort(address); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var hosts []string
|
||||
if hosts, err = p.resolver.Lookup(context.Background(), host); err != nil {
|
||||
log.Warn().Err(err).Msg("connect failed: DNS lookup error")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Str("address", hosts[0]).Msg("connect resolved address")
|
||||
return p.dial(network, net.JoinHostPort(hosts[0], port))
|
||||
}
|
||||
|
||||
var hopByHopHeaders = []string{
|
||||
HeaderConnection,
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Proxy-Connection", // Non-standard, but required for HTTP/2.
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
HeaderUpgrade,
|
||||
}
|
||||
|
||||
func cleanHopByHopHeaders(header http.Header) {
|
||||
// Additional hop-by-hop headers may be specified in `Connection` headers.
|
||||
// http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1
|
||||
for _, values := range header[HeaderConnection] {
|
||||
for _, key := range strings.Split(values, ",") {
|
||||
header.Del(key)
|
||||
}
|
||||
}
|
||||
for _, key := range hopByHopHeaders {
|
||||
header.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// copyStream copies data from reader to writer
|
||||
func copyStream(ses *Session, w io.Writer, r io.Reader) {
|
||||
log := ses.log()
|
||||
if _, err := io.Copy(w, r); err != nil && !isClosing(err) {
|
||||
log.Error().Err(err).Msg("failed CONNECT tunnel")
|
||||
} else {
|
||||
log.Debug().Msg("finished copying CONNECT tunnel")
|
||||
}
|
||||
}
|
||||
|
||||
func isClosing(err error) bool {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, syscall.ECONNRESET) || err == ErrClosed {
|
||||
return true
|
||||
}
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
return true
|
||||
}
|
||||
// log.Debug().Msgf("not a closing error %T: %#+v", err, err)
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveInterfaceAddr(name string) (err error) {
|
||||
var iface *net.Interface
|
||||
if iface, err = net.InterfaceByName(name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var addrs []net.Addr
|
||||
if addrs, err = iface.Addrs(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
if addr, ok := addr.(*net.IPNet); ok && !addr.IP.IsUnspecified() {
|
||||
log.Warn().Msgf("addr %T: %s", addr, addr)
|
||||
}
|
||||
}
|
||||
return errors.New("nope; TODO")
|
||||
}
|
148
proxy/resolver/resolver.go
Normal file
148
proxy/resolver/resolver.go
Normal file
@@ -0,0 +1,148 @@
|
||||
// Package resolver implements a caching DNS resolver
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/netutil"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultSize = 1024
|
||||
DefaultTTL = 5 * time.Minute
|
||||
DefaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfig are the defaults for the Default resolver.
|
||||
DefaultConfig = Config{
|
||||
Size: DefaultSize,
|
||||
TTL: DefaultTTL.Seconds(),
|
||||
Timeout: DefaultTimeout.Seconds(),
|
||||
}
|
||||
|
||||
// Default resolver.
|
||||
Default = New(DefaultConfig)
|
||||
)
|
||||
|
||||
type Resolver interface {
|
||||
// Lookup returns resolved IPs for given hostname/ips.
|
||||
Lookup(context.Context, string) ([]string, error)
|
||||
}
|
||||
|
||||
type netResolver struct {
|
||||
resolver *net.Resolver
|
||||
timeout time.Duration
|
||||
noIPv6 bool
|
||||
cache *expirable.LRU[string, []string]
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Size is our cache size in number of entries.
|
||||
Size int `hcl:"size,optional"`
|
||||
|
||||
// TTL is the cache time to live in seconds.
|
||||
TTL float64 `hcl:"ttl,optional"`
|
||||
|
||||
// Timeout is the cache timeout in seconds.
|
||||
Timeout float64 `hcl:"timeout,optional"`
|
||||
|
||||
// Server are alternative DNS servers.
|
||||
Server []string `hcl:"server,optional"`
|
||||
|
||||
// NoIPv6 disables IPv6 DNS resolution.
|
||||
NoIPv6 bool `hcl:"noipv6,optional"`
|
||||
}
|
||||
|
||||
func New(config Config) Resolver {
|
||||
var (
|
||||
size = config.Size
|
||||
ttl = time.Duration(float64(time.Second) * config.TTL)
|
||||
timeout = time.Duration(float64(time.Second) * config.Timeout)
|
||||
)
|
||||
if size <= 0 {
|
||||
size = DefaultSize
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultTTL
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 0
|
||||
}
|
||||
|
||||
var resolver = new(net.Resolver)
|
||||
if len(config.Server) > 0 {
|
||||
var dialer net.Dialer
|
||||
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
server := netutil.EnsurePort(config.Server[rand.IntN(len(config.Server))], "53")
|
||||
return dialer.DialContext(ctx, network, server)
|
||||
}
|
||||
}
|
||||
|
||||
return &netResolver{
|
||||
resolver: resolver,
|
||||
timeout: timeout,
|
||||
noIPv6: config.NoIPv6,
|
||||
cache: expirable.NewLRU[string, []string](size, nil, ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *netResolver) Lookup(ctx context.Context, host string) ([]string, error) {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
if hosts, ok := r.cache.Get(host); ok {
|
||||
rand.Shuffle(len(hosts), func(i, j int) {
|
||||
hosts[i], hosts[j] = hosts[j], hosts[i]
|
||||
})
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
hosts, err := r.lookup(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.cache.Add(host, hosts)
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func (r *netResolver) lookup(ctx context.Context, host string) ([]string, error) {
|
||||
if r.timeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, r.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if net.ParseIP(host) == nil {
|
||||
addrs, err := r.resolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.noIPv6 {
|
||||
var addrs4 []string
|
||||
for _, addr := range addrs {
|
||||
if net.ParseIP(addr).To4() != nil {
|
||||
addrs4 = append(addrs4, addr)
|
||||
}
|
||||
}
|
||||
return addrs4, nil
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
addrs, err := r.resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hosts := make([]string, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
if !r.noIPv6 || addr.IP.To4() != nil {
|
||||
hosts[i] = addr.IP.String()
|
||||
}
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
78
proxy/response.go
Normal file
78
proxy/response.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
func NewResponse(code int, body io.Reader, request *http.Request) *http.Response {
|
||||
if body == nil {
|
||||
body = new(bytes.Buffer)
|
||||
}
|
||||
|
||||
rc, ok := body.(io.ReadCloser)
|
||||
if !ok {
|
||||
rc = io.NopCloser(body)
|
||||
}
|
||||
|
||||
response := &http.Response{
|
||||
Status: strconv.Itoa(code) + " " + http.StatusText(code),
|
||||
StatusCode: code,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Body: rc,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
if request != nil {
|
||||
response.Close = request.Close
|
||||
response.Proto = request.Proto
|
||||
response.ProtoMajor = request.ProtoMajor
|
||||
response.ProtoMinor = request.ProtoMinor
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
type withLen interface {
|
||||
Len() int
|
||||
}
|
||||
|
||||
type withSize interface {
|
||||
Size() int64
|
||||
}
|
||||
|
||||
func NewJSONResponse(code int, body io.Reader, request *http.Request) *http.Response {
|
||||
response := NewResponse(code, body, request)
|
||||
response.Header.Set(HeaderContentType, "application/json")
|
||||
if s, ok := body.(withLen); ok {
|
||||
response.Header.Set(HeaderContentLength, strconv.Itoa(s.Len()))
|
||||
} else if s, ok := body.(withSize); ok {
|
||||
response.Header.Set(HeaderContentLength, strconv.FormatInt(s.Size(), 10))
|
||||
} else {
|
||||
log.Trace().Str("type", fmt.Sprintf("%T", body)).Msg("can't detemine body size")
|
||||
}
|
||||
response.Close = true
|
||||
return response
|
||||
}
|
||||
|
||||
func ErrorResponse(request *http.Request, err error) *http.Response {
|
||||
response := NewResponse(http.StatusBadGateway, nil, request)
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
response.StatusCode = http.StatusNotFound
|
||||
case os.IsPermission(err):
|
||||
response.StatusCode = http.StatusForbidden
|
||||
}
|
||||
response.Status = http.StatusText(response.StatusCode)
|
||||
response.Close = true
|
||||
return response
|
||||
}
|
151
proxy/session.go
Normal file
151
proxy/session.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
)
|
||||
|
||||
var seed = rand.NewSource(time.Now().UnixNano())
|
||||
|
||||
type Context struct {
|
||||
id int64
|
||||
conn *wrappedConn
|
||||
rw *bufio.ReadWriter
|
||||
parent *Session
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
func newContext(conn net.Conn, rw *bufio.ReadWriter, parent *Session) *Context {
|
||||
if wrapped, ok := conn.(*wrappedConn); ok {
|
||||
conn = wrapped.Conn
|
||||
}
|
||||
|
||||
ctx := &Context{
|
||||
id: seed.Int63(),
|
||||
conn: &wrappedConn{Conn: conn},
|
||||
rw: rw,
|
||||
parent: parent,
|
||||
data: make(map[string]any),
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (ctx *Context) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("context", ctx.ID()).
|
||||
Str("addr", ctx.RemoteAddr().String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ctx *Context) ID() string {
|
||||
var b [8]byte
|
||||
binary.BigEndian.PutUint64(b[:], uint64(ctx.id))
|
||||
if ctx.parent != nil {
|
||||
return ctx.parent.ID() + "-" + hex.EncodeToString(b[:])
|
||||
}
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (ctx *Context) IsTLS() bool {
|
||||
_, ok := ctx.conn.Conn.(*tls.Conn)
|
||||
return ok && ctx.parent != nil
|
||||
}
|
||||
|
||||
func (ctx *Context) RemoteAddr() net.Addr {
|
||||
if ctx.parent != nil {
|
||||
return ctx.parent.ctx.RemoteAddr()
|
||||
}
|
||||
return ctx.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (ctx *Context) SetDeadline(t time.Time) error {
|
||||
if ctx.parent != nil {
|
||||
return ctx.parent.ctx.SetDeadline(t)
|
||||
}
|
||||
return ctx.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (ctx *Context) Set(key string, value any) {
|
||||
ctx.data[key] = value
|
||||
}
|
||||
|
||||
func (ctx *Context) Get(key string) (value any, ok bool) {
|
||||
value, ok = ctx.data[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (ctx *Context) Flush() error {
|
||||
return ctx.rw.Flush()
|
||||
}
|
||||
|
||||
func (ctx *Context) Write(p []byte) (n int, err error) {
|
||||
if n, err = ctx.rw.Write(p); n > 0 {
|
||||
atomic.AddInt64(&ctx.conn.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
id int64
|
||||
ctx *Context
|
||||
request *http.Request
|
||||
response *http.Response
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
func newSession(ctx *Context, request *http.Request) *Session {
|
||||
return &Session{
|
||||
id: seed.Int63(),
|
||||
ctx: ctx,
|
||||
request: request,
|
||||
data: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (ses *Session) log() log.Logger {
|
||||
return log.Console.With().
|
||||
Str("context", ses.ctx.ID()).
|
||||
Str("session", ses.ID()).
|
||||
Str("addr", ses.ctx.RemoteAddr().String()).
|
||||
Logger()
|
||||
}
|
||||
|
||||
func (ses *Session) ID() string {
|
||||
var b [8]byte
|
||||
binary.BigEndian.PutUint64(b[:], uint64(ses.id))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (ses *Session) Context() *Context {
|
||||
return ses.ctx
|
||||
}
|
||||
|
||||
func (ses *Session) Request() *http.Request {
|
||||
return ses.request
|
||||
}
|
||||
|
||||
func (ses *Session) Response() *http.Response {
|
||||
return ses.response
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Write(p []byte) (n int, err error) {
|
||||
if n, err = c.Conn.Write(p); n > 0 {
|
||||
atomic.AddInt64(&c.bytes, int64(n))
|
||||
}
|
||||
return
|
||||
}
|
225
proxy/stats/stats.go
Normal file
225
proxy/stats/stats.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.maze.io/maze/styx/internal/log"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type Stats struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func New() (*Stats, error) {
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := filepath.Join(u.HomeDir, ".styx", "stats.db")
|
||||
if err = os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, table := range []string{
|
||||
createLog,
|
||||
createDomainStat,
|
||||
createStatusStat,
|
||||
} {
|
||||
if _, err = db.Exec(table); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Stats{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *Stats) AddLog(entry *Log) error {
|
||||
var (
|
||||
request []byte
|
||||
response []byte
|
||||
err error
|
||||
)
|
||||
if request, err = json.Marshal(entry.Request); err != nil {
|
||||
return err
|
||||
}
|
||||
if response, err = json.Marshal(entry.Response); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt, err := tx.Prepare("insert into styx_log(client_ip, request, response) values(?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err = stmt.Exec(entry.ClientIP, request, response); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Stats) QueryLog(offset, limit int) ([]*Log, error) {
|
||||
if limit == 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
rows, err := s.db.Query("select dt, client_ip, request, response from styx_log limit ?, ?", offset, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []*Log
|
||||
for rows.Next() {
|
||||
var entry = new(Log)
|
||||
if err = rows.Scan(&entry.Time, &entry.ClientIP, &entry.Request, &entry.Response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logs = append(logs, entry)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
Code int `json:"code"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
var timeZero time.Time
|
||||
|
||||
func (s *Stats) QueryStatus(since time.Time) ([]*Status, error) {
|
||||
if since.Equal(timeZero) {
|
||||
since = time.Now().Add(-24 * time.Hour)
|
||||
}
|
||||
|
||||
rows, err := s.db.Query("select response->'status', count(*) from styx_log where dt >= ? group by response->'status' order by response->'status'", since)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var stats []*Status
|
||||
for rows.Next() {
|
||||
var entry = new(Status)
|
||||
if err = rows.Scan(&entry.Code, &entry.Count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats = append(stats, entry)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
const createLog = `CREATE TABLE IF NOT EXISTS styx_log (
|
||||
id INT PRIMARY KEY,
|
||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
client_ip TEXT NOT NULL,
|
||||
request JSONB NOT NULL,
|
||||
response JSONB NOT NULL
|
||||
);`
|
||||
|
||||
type Log struct {
|
||||
Time time.Time `json:"time"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
Request *Request `json:"request"`
|
||||
Response *Response `json:"response"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
URL string `json:"url"`
|
||||
Host string `json:"host"`
|
||||
Method string `json:"method"`
|
||||
Proto string `json:"proto"`
|
||||
Header http.Header `json:"header"`
|
||||
}
|
||||
|
||||
func (r *Request) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), r)
|
||||
case []byte:
|
||||
return json.Unmarshal(v, r)
|
||||
default:
|
||||
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan request unknown type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Request) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(r)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
func FromRequest(r *http.Request) *Request {
|
||||
return &Request{
|
||||
URL: r.URL.String(),
|
||||
Host: r.Host,
|
||||
Method: r.Method,
|
||||
Proto: r.Proto,
|
||||
Header: r.Header,
|
||||
}
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Status int `json:"status"`
|
||||
Size int64 `json:"size"`
|
||||
Header http.Header `json:"header"`
|
||||
}
|
||||
|
||||
func (r *Response) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), r)
|
||||
case []byte:
|
||||
return json.Unmarshal(v, r)
|
||||
default:
|
||||
log.Error().Str("type", fmt.Sprintf("%T", value)).Msg("scan response unknown type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Response) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(r)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
func (r *Response) SetSize(size int64) *Response {
|
||||
r.Size = size
|
||||
return r
|
||||
}
|
||||
|
||||
func FromResponse(r *http.Response) *Response {
|
||||
return &Response{
|
||||
Status: r.StatusCode,
|
||||
Header: r.Header,
|
||||
}
|
||||
}
|
||||
|
||||
const createStatusStat = `CREATE TABLE IF NOT EXISTS styx_stat_status (
|
||||
id INT PRIMARY KEY,
|
||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
status INT NOT NULL
|
||||
);`
|
||||
|
||||
const createDomainStat = `CREATE TABLE IF NOT EXISTS styx_stat_domain (
|
||||
id INT PRIMARY KEY,
|
||||
dt DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
domain TEXT NOT NULL
|
||||
);`
|
16
proxy/util.go
Normal file
16
proxy/util.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// connReader is a net.Conn with a separate reader.
|
||||
type connReader struct {
|
||||
net.Conn
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (c connReader) Read(p []byte) (int, error) {
|
||||
return c.Reader.Read(p)
|
||||
}
|
Reference in New Issue
Block a user