Files
styx/policy/input.go
2025-10-08 20:57:13 +02:00

426 lines
9.6 KiB
Go

package policy
import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"git.maze.io/maze/styx/internal/netutil"
"git.maze.io/maze/styx/logger"
proxy "git.maze.io/maze/styx/proxy"
)
// Input represents the input to the policy query.
type Input struct {
Context map[string]any `json:"context"`
Client *Client `json:"client"`
Groups []*Group `json:"groups"`
TLS *TLS `json:"tls"`
Request *Request `json:"request"`
Response *Response `json:"response"`
}
func (i *Input) logValues(log logger.Structured) logger.Structured {
if i.Context != nil {
log = log.Values(i.Context)
}
log = i.Client.logValues(log)
log = i.TLS.logValues(log)
log = i.Request.logValues(log)
log = i.Response.logValues(log)
return log
}
func NewInputFromConn(c net.Conn) *Input {
if c == nil {
return new(Input)
}
input := &Input{
Context: make(map[string]any),
Client: NewClientFromConn(c),
TLS: NewTLSFromConn(c),
}
/*
if wcl, ok := c.(dataset.WithClient); ok {
client, err := wcl.Client()
if err == nil {
input.Context["client_id"] = client.ID
input.Context["client_description"] = client.Description
input.Context["groups"] = client.Groups
}
}
*/
if ctx, ok := c.(proxy.Context); ok {
input.Context["local"] = NewClientFromAddr(ctx.LocalAddr())
input.Context["bytes_rx"] = ctx.BytesRead()
input.Context["bytes_tx"] = ctx.BytesSent()
}
return input
}
func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
if r == nil {
return nil
}
input := NewInputFromConn(c)
input.Request = NewRequest(r)
return input
}
func NewInputFromResponse(c net.Conn, r *http.Response) *Input {
if r == nil {
return nil
}
input := NewInputFromConn(c)
input.Response = NewResponse(r)
return input
}
type Client struct {
Network string `json:"network"`
IP string `json:"ip"`
Port int `json:"int"`
}
func (i *Client) logValues(log logger.Structured) logger.Structured {
if i != nil {
log = log.Values(logger.Values{
"client_network": i.Network,
"client_ip": i.IP,
"client_port": i.Port,
})
}
return log
}
func NewClient(network, address string) *Client {
if host, port, err := net.SplitHostPort(address); err == nil {
p, _ := net.LookupPort(network, port)
return &Client{
Network: network,
IP: host,
Port: p,
}
}
return &Client{
Network: network,
IP: address,
}
}
func NewClientFromConn(c net.Conn) *Client {
if c == nil {
return nil
}
return NewClientFromAddr(c.RemoteAddr())
}
func NewClientFromAddr(addr net.Addr) *Client {
switch addr := addr.(type) {
case *net.TCPAddr:
return &Client{
Network: addr.Network(),
IP: addr.IP.String(),
Port: addr.Port,
}
case *net.UDPAddr:
return &Client{
Network: addr.Network(),
IP: addr.IP.String(),
Port: addr.Port,
}
case *net.IPAddr:
return &Client{
Network: addr.Network(),
IP: addr.IP.String(),
}
default:
if host, port, err := net.SplitHostPort(addr.String()); err == nil {
return &Client{
Network: addr.Network(),
IP: host,
Port: func() int { p, _ := net.LookupPort(addr.Network(), port); return p }(),
}
}
return &Client{
Network: addr.Network(),
IP: addr.String(),
}
}
}
type Group struct {
Name string `json:"name"`
}
type TLS struct {
Version string `json:"version"`
CipherSuite string `json:"cipher_suite"`
ServerName string `json:"server_name"`
Certificates []*Certificate `json:"certificates"`
}
func (i *TLS) logValues(log logger.Structured) logger.Structured {
if i != nil {
cns := make([]string, len(i.Certificates))
for j, cert := range i.Certificates {
cns[j] = cert.Subject.CommonName
}
log = log.Values(logger.Values{
"tls_version": i.Version,
"tls_cipher": i.CipherSuite,
"tls_server_name": i.ServerName,
"tls_certificates": cns,
})
}
return log
}
func NewTLS(state *tls.ConnectionState) *TLS {
if state == nil {
return nil
}
tls := &TLS{
Version: tls.VersionName(state.Version),
CipherSuite: tls.CipherSuiteName(state.CipherSuite),
ServerName: state.ServerName,
}
for _, cert := range state.PeerCertificates {
if cert := NewCertificate(cert); cert != nil {
tls.Certificates = append(tls.Certificates, cert)
}
}
return tls
}
type tlsConnectionStater interface {
ConnectionState() tls.ConnectionState
}
func NewTLSFromConn(c net.Conn) *TLS {
if c == nil {
return nil
}
if s, ok := c.(tlsConnectionStater); ok {
cs := s.ConnectionState()
return NewTLS(&cs)
}
return nil
}
type Certificate struct {
SerialNumber string `json:"serial_number"`
Subject PKIXName `json:"subject"`
Issuer PKIXName `json:"issuer"`
NotBefore int64 `json:"not_before"`
NotAfter int64 `json:"not_after"`
}
func NewCertificate(cert *x509.Certificate) *Certificate {
if cert == nil {
return nil
}
return &Certificate{
SerialNumber: cert.SerialNumber.String(),
Subject: MakePKIXName(cert.Subject),
Issuer: MakePKIXName(cert.Issuer),
NotBefore: cert.NotBefore.UnixNano(),
NotAfter: cert.NotAfter.UnixNano(),
}
}
type PKIXName struct {
CommonName string `json:"cn,omitempty"`
Country string `json:"country,omitempty"`
Organization string `json:"organization,omitempty"`
OrganizationalUnit string `json:"ou,omitempty"`
Locality string `json:"locality,omitempty"`
Province string `json:"province,omitempty"`
StreetAddress string `json:"address,omitempty"`
PostalCode string `json:"postalcode,omitempty"`
}
func MakePKIXName(name pkix.Name) PKIXName {
return PKIXName{
CommonName: name.CommonName,
Country: pick(name.Country...),
Organization: pick(name.Organization...),
OrganizationalUnit: pick(name.OrganizationalUnit...),
Locality: pick(name.Locality...),
Province: pick(name.Province...),
StreetAddress: pick(name.StreetAddress...),
PostalCode: pick(name.PostalCode...),
}
}
// Request represents an HTTP request.
type Request struct {
Method string `json:"method"`
URL *URL `json:"url"`
Proto string `json:"proto"`
Header map[string]string `json:"header"`
Host string `json:"host"`
Port int `json:"port"`
RequestURI string `json:"request_uri"`
}
func (i *Request) logValues(log logger.Structured) logger.Structured {
if i != nil {
log = log.Values(logger.Values{
"request_method": i.Method,
"request_url": i.URL.String(),
"request_proto": i.Proto,
"request_header": i.Header,
"request_host": i.Host,
"request_port": i.Port,
})
}
return log
}
func NewRequest(r *http.Request) *Request {
if r == nil {
return nil
}
header := make(map[string]string)
for key := range r.Header {
header[key] = r.Header.Get(key)
}
host, portName, err := net.SplitHostPort(r.URL.Host)
if err != nil {
host = netutil.Host(r.URL.Host)
portName = "80"
if r.URL.Scheme == "https" || r.URL.Scheme == "wss" || r.TLS != nil {
portName = "443"
}
}
var port int
if port, err = strconv.Atoi(portName); err != nil {
port, _ = net.LookupPort("tcp", portName)
}
return &Request{
Method: r.Method,
URL: NewURL(r.URL),
Proto: r.Proto,
Header: header,
Host: host,
Port: port,
RequestURI: r.RequestURI,
}
}
// Response represents an HTTP response.
type Response struct {
Status string `json:"status"`
StatusCode int `json:"status_code"`
Proto string `json:"proto"`
Header map[string]string `json:"header"`
ContentLength int64 `json:"content_length"`
Close bool `json:"close"`
Request *Request `json:"request"`
TLS *TLS `json:"tls"`
}
func (i *Response) logValues(log logger.Structured) logger.Structured {
if i != nil {
log = log.Values(logger.Values{
"response_status": i.StatusCode,
"response_proto": i.Proto,
"response_header": i.Header,
"response_close": i.Close,
"response_tls": i.TLS != nil,
})
}
return log
}
func NewResponse(r *http.Response) *Response {
if r == nil {
return nil
}
header := make(map[string]string)
for key := range r.Header {
header[key] = r.Header.Get(key)
}
return &Response{
Status: r.Status,
StatusCode: r.StatusCode,
Proto: r.Proto,
Header: header,
ContentLength: r.ContentLength,
Close: r.Close,
Request: NewRequest(r.Request),
TLS: NewTLS(r.TLS),
}
}
type URL struct {
Scheme string `json:"scheme"`
Host string `json:"host"`
Path string `json:"path"`
Query map[string]string `json:"query"`
}
func (i *URL) String() string {
if i == nil {
return "<nil>"
}
s := fmt.Sprintf("%s://%s%s", i.Scheme, i.Host, i.Path)
if len(i.Query) > 0 {
s += "?"
for k, v := range i.Query {
s += k + "=" + url.QueryEscape(v)
}
}
return s
}
func ParseURL(rawurl string) (*URL, error) {
parsed, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
return NewURL(parsed), nil
}
func NewURL(url *url.URL) *URL {
if url == nil {
return nil
}
query := make(map[string]string)
for key, values := range url.Query() {
if len(values) > 0 {
query[key] = values[0]
}
}
return &URL{
Scheme: url.Scheme,
Host: url.Host,
Path: url.Path,
Query: query,
}
}
func pick(values ...string) string {
for _, v := range values {
if v != "" {
return v
}
}
return ""
}