package policy import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" "fmt" "net" "net/http" "net/url" "strconv" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/logger" proxy "git.maze.io/maze/styx/proxy" ) // Input represents the input to the policy query. type Input struct { Context map[string]any `json:"context"` Client *Client `json:"client"` Groups []*Group `json:"groups"` TLS *TLS `json:"tls"` Request *Request `json:"request"` Response *Response `json:"response"` } func (i *Input) logValues(log logger.Structured) logger.Structured { if i.Context != nil { log = log.Values(i.Context) } log = i.Client.logValues(log) log = i.TLS.logValues(log) log = i.Request.logValues(log) log = i.Response.logValues(log) return log } func NewInputFromConn(c net.Conn) *Input { if c == nil { return new(Input) } input := &Input{ Context: make(map[string]any), Client: NewClientFromConn(c), TLS: NewTLSFromConn(c), } /* if wcl, ok := c.(dataset.WithClient); ok { client, err := wcl.Client() if err == nil { input.Context["client_id"] = client.ID input.Context["client_description"] = client.Description input.Context["groups"] = client.Groups } } */ if ctx, ok := c.(proxy.Context); ok { input.Context["local"] = NewClientFromAddr(ctx.LocalAddr()) input.Context["bytes_rx"] = ctx.BytesRead() input.Context["bytes_tx"] = ctx.BytesSent() } return input } func NewInputFromRequest(c net.Conn, r *http.Request) *Input { if r == nil { return nil } input := NewInputFromConn(c) input.Request = NewRequest(r) return input } func NewInputFromResponse(c net.Conn, r *http.Response) *Input { if r == nil { return nil } input := NewInputFromConn(c) input.Response = NewResponse(r) return input } type Client struct { Network string `json:"network"` IP string `json:"ip"` Port int `json:"int"` } func (i *Client) logValues(log logger.Structured) logger.Structured { if i != nil { log = log.Values(logger.Values{ "client_network": i.Network, "client_ip": i.IP, "client_port": i.Port, }) } return log } func NewClient(network, address string) *Client { if host, port, err := net.SplitHostPort(address); err == nil { p, _ := net.LookupPort(network, port) return &Client{ Network: network, IP: host, Port: p, } } return &Client{ Network: network, IP: address, } } func NewClientFromConn(c net.Conn) *Client { if c == nil { return nil } return NewClientFromAddr(c.RemoteAddr()) } func NewClientFromAddr(addr net.Addr) *Client { switch addr := addr.(type) { case *net.TCPAddr: return &Client{ Network: addr.Network(), IP: addr.IP.String(), Port: addr.Port, } case *net.UDPAddr: return &Client{ Network: addr.Network(), IP: addr.IP.String(), Port: addr.Port, } case *net.IPAddr: return &Client{ Network: addr.Network(), IP: addr.IP.String(), } default: if host, port, err := net.SplitHostPort(addr.String()); err == nil { return &Client{ Network: addr.Network(), IP: host, Port: func() int { p, _ := net.LookupPort(addr.Network(), port); return p }(), } } return &Client{ Network: addr.Network(), IP: addr.String(), } } } type Group struct { Name string `json:"name"` } type TLS struct { Version string `json:"version"` CipherSuite string `json:"cipher_suite"` ServerName string `json:"server_name"` Certificates []*Certificate `json:"certificates"` } func (i *TLS) logValues(log logger.Structured) logger.Structured { if i != nil { cns := make([]string, len(i.Certificates)) for j, cert := range i.Certificates { cns[j] = cert.Subject.CommonName } log = log.Values(logger.Values{ "tls_version": i.Version, "tls_cipher": i.CipherSuite, "tls_server_name": i.ServerName, "tls_certificates": cns, }) } return log } func NewTLS(state *tls.ConnectionState) *TLS { if state == nil { return nil } tls := &TLS{ Version: tls.VersionName(state.Version), CipherSuite: tls.CipherSuiteName(state.CipherSuite), ServerName: state.ServerName, } for _, cert := range state.PeerCertificates { if cert := NewCertificate(cert); cert != nil { tls.Certificates = append(tls.Certificates, cert) } } return tls } type tlsConnectionStater interface { ConnectionState() tls.ConnectionState } func NewTLSFromConn(c net.Conn) *TLS { if c == nil { return nil } if s, ok := c.(tlsConnectionStater); ok { cs := s.ConnectionState() return NewTLS(&cs) } return nil } type Certificate struct { SerialNumber string `json:"serial_number"` Subject PKIXName `json:"subject"` Issuer PKIXName `json:"issuer"` NotBefore int64 `json:"not_before"` NotAfter int64 `json:"not_after"` } func NewCertificate(cert *x509.Certificate) *Certificate { if cert == nil { return nil } return &Certificate{ SerialNumber: cert.SerialNumber.String(), Subject: MakePKIXName(cert.Subject), Issuer: MakePKIXName(cert.Issuer), NotBefore: cert.NotBefore.UnixNano(), NotAfter: cert.NotAfter.UnixNano(), } } type PKIXName struct { CommonName string `json:"cn,omitempty"` Country string `json:"country,omitempty"` Organization string `json:"organization,omitempty"` OrganizationalUnit string `json:"ou,omitempty"` Locality string `json:"locality,omitempty"` Province string `json:"province,omitempty"` StreetAddress string `json:"address,omitempty"` PostalCode string `json:"postalcode,omitempty"` } func MakePKIXName(name pkix.Name) PKIXName { return PKIXName{ CommonName: name.CommonName, Country: pick(name.Country...), Organization: pick(name.Organization...), OrganizationalUnit: pick(name.OrganizationalUnit...), Locality: pick(name.Locality...), Province: pick(name.Province...), StreetAddress: pick(name.StreetAddress...), PostalCode: pick(name.PostalCode...), } } // Request represents an HTTP request. type Request struct { Method string `json:"method"` URL *URL `json:"url"` Proto string `json:"proto"` Header map[string]string `json:"header"` Host string `json:"host"` Port int `json:"port"` RequestURI string `json:"request_uri"` } func (i *Request) logValues(log logger.Structured) logger.Structured { if i != nil { log = log.Values(logger.Values{ "request_method": i.Method, "request_url": i.URL.String(), "request_proto": i.Proto, "request_header": i.Header, "request_host": i.Host, "request_port": i.Port, }) } return log } func NewRequest(r *http.Request) *Request { if r == nil { return nil } header := make(map[string]string) for key := range r.Header { header[key] = r.Header.Get(key) } host, portName, err := net.SplitHostPort(r.URL.Host) if err != nil { host = netutil.Host(r.URL.Host) portName = "80" if r.URL.Scheme == "https" || r.URL.Scheme == "wss" || r.TLS != nil { portName = "443" } } var port int if port, err = strconv.Atoi(portName); err != nil { port, _ = net.LookupPort("tcp", portName) } return &Request{ Method: r.Method, URL: NewURL(r.URL), Proto: r.Proto, Header: header, Host: host, Port: port, RequestURI: r.RequestURI, } } // Response represents an HTTP response. type Response struct { Status string `json:"status"` StatusCode int `json:"status_code"` Proto string `json:"proto"` Header map[string]string `json:"header"` ContentLength int64 `json:"content_length"` Close bool `json:"close"` Request *Request `json:"request"` TLS *TLS `json:"tls"` } func (i *Response) logValues(log logger.Structured) logger.Structured { if i != nil { log = log.Values(logger.Values{ "response_status": i.StatusCode, "response_proto": i.Proto, "response_header": i.Header, "response_close": i.Close, "response_tls": i.TLS != nil, }) } return log } func NewResponse(r *http.Response) *Response { if r == nil { return nil } header := make(map[string]string) for key := range r.Header { header[key] = r.Header.Get(key) } return &Response{ Status: r.Status, StatusCode: r.StatusCode, Proto: r.Proto, Header: header, ContentLength: r.ContentLength, Close: r.Close, Request: NewRequest(r.Request), TLS: NewTLS(r.TLS), } } type URL struct { Scheme string `json:"scheme"` Host string `json:"host"` Path string `json:"path"` Query map[string]string `json:"query"` } func (i *URL) String() string { if i == nil { return "" } s := fmt.Sprintf("%s://%s%s", i.Scheme, i.Host, i.Path) if len(i.Query) > 0 { s += "?" for k, v := range i.Query { s += k + "=" + url.QueryEscape(v) } } return s } func ParseURL(rawurl string) (*URL, error) { parsed, err := url.Parse(rawurl) if err != nil { return nil, err } return NewURL(parsed), nil } func NewURL(url *url.URL) *URL { if url == nil { return nil } query := make(map[string]string) for key, values := range url.Query() { if len(values) > 0 { query[key] = values[0] } } return &URL{ Scheme: url.Scheme, Host: url.Host, Path: url.Path, Query: query, } } func pick(values ...string) string { for _, v := range values { if v != "" { return v } } return "" }