Files
styx/policy/handler.go
2025-10-08 20:57:13 +02:00

143 lines
4.0 KiB
Go

package policy
import (
"bufio"
"crypto/tls"
"net"
"net/http"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
proxy "git.maze.io/maze/styx/proxy"
)
func NewRequestHandler(policy *Policy) proxy.RequestHandler {
log := logger.StandardLog.Value("policy", policy.name)
return proxy.RequestHandlerFunc(func(ctx proxy.Context) (*http.Request, *http.Response) {
input := NewInputFromRequest(ctx, ctx.Request())
input.logValues(log).Trace("Running request handler")
result, err := policy.Query(input, proxy.PolicyQueryOptions(ctx)...)
if err != nil {
log.Err(err).Error("Error evaulating policy")
return nil, nil
}
r, err := result.Response(ctx)
if err != nil {
log.Err(err).Error("Error generating response")
return nil, nil
}
log.Debug("Replacing HTTP response from policy")
return nil, r
})
}
func NewDialHandler(policy *Policy) proxy.DialHandler {
log := logger.StandardLog.Value("policy", policy.name)
return proxy.DialHandlerFunc(func(ctx proxy.Context, req *http.Request) (net.Conn, error) {
input := NewInputFromRequest(ctx, req)
input.logValues(log).Trace("Running dial handler")
result, err := policy.Query(input, proxy.PolicyQueryOptions(ctx)...)
if err != nil {
log.Err(err).Error("Error evaulating policy")
return nil, nil
}
r, err := result.Response(ctx)
if err != nil {
log.Err(err).Error("Error generating response")
return nil, nil
}
if r == nil {
return nil, nil
}
// Create a fake loopback connection
pipe := netutil.NewLoopback()
go func(c net.Conn) {
defer func() { _ = c.Close() }()
if req.URL.Scheme == "https" || req.URL.Scheme == "wss" || netutil.Port(req.URL.Host) == 443 {
c = maybeUpgradeToTLS(c, ctx, req, log)
}
br := bufio.NewReader(c)
if _, err := http.ReadRequest(br); err != nil {
log.Err(err).Warn("Malformed HTTP request in MITM connection")
}
_ = r.Write(c)
}(pipe.Server)
return pipe.Client, nil
})
}
func maybeUpgradeToTLS(c net.Conn, ctx proxy.Context, req *http.Request, log logger.Structured) net.Conn {
var ca ca.CertificateAuthority
if caCtx, ok := ctx.(proxy.WithCertificateAuthority); ok {
ca = caCtx.CertificateAuthority()
}
if ca == nil {
return c
}
secure := tls.Server(c, &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Values(logger.Values{
"cn": req.URL.Host,
"names": hello.ServerName,
}).Debug("Requesting certificate from CA")
return ca.GetCertificate(netutil.Host(req.URL.Host), []string{hello.ServerName}, nil)
},
NextProtos: []string{"http/1.1"},
})
if err := secure.Handshake(); err != nil {
log.Err(err).Warn("Failed to pretend secure HTTP")
return c
}
return secure
}
func NewForwardHandler(p *Policy) proxy.ForwardHandler {
log := logger.StandardLog.Value("policy", p.name)
return proxy.ForwardHandlerFunc(func(ctx proxy.Context, req *http.Request) (*http.Response, error) {
input := NewInputFromRequest(ctx, req)
input.logValues(log).Trace("Running forward handler")
result, err := p.Query(input)
if err != nil {
log.Err(err).Error("Error evaulating policy")
return nil, nil
}
r, err := result.Response(ctx)
if err != nil {
log.Err(err).Error("Error generating response")
return nil, err
}
if r != nil {
log.Debug("Replacing HTTP response from policy")
}
return r, nil
})
}
func NewResponseHandler(p *Policy) proxy.ResponseHandler {
log := logger.StandardLog.Value("policy", p.name)
return proxy.ResponseHandlerFunc(func(ctx proxy.Context) *http.Response {
input := NewInputFromResponse(ctx, ctx.Response())
input.logValues(log).Trace("Running response handler")
result, err := p.Query(input)
if err != nil {
log.Err(err).Error("Error evaulating policy")
return nil
}
r, err := result.Response(ctx)
if err != nil {
log.Err(err).Error("Error generating response")
return nil
}
if r != nil {
log.Debug("Replacing HTTP response from policy")
}
return r
})
}