395 lines
8.8 KiB
Go
395 lines
8.8 KiB
Go
package policy
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
|
|
"git.maze.io/maze/styx/internal/netutil"
|
|
"git.maze.io/maze/styx/logger"
|
|
)
|
|
|
|
// Input represents the input to the policy query.
|
|
type Input struct {
|
|
Client *Client `json:"client"`
|
|
TLS *TLS `json:"tls"`
|
|
Request *Request `json:"request"`
|
|
Response *Response `json:"response"`
|
|
}
|
|
|
|
func (i *Input) logValues(log logger.Structured) logger.Structured {
|
|
log = i.Client.logValues(log)
|
|
log = i.TLS.logValues(log)
|
|
log = i.Request.logValues(log)
|
|
log = i.Response.logValues(log)
|
|
return log
|
|
}
|
|
|
|
func NewInputFromConn(c net.Conn) *Input {
|
|
if c == nil {
|
|
return new(Input)
|
|
}
|
|
return &Input{
|
|
Client: NewClientFromConn(c),
|
|
TLS: NewTLSFromConn(c),
|
|
}
|
|
}
|
|
|
|
func NewInputFromRequest(c net.Conn, r *http.Request) *Input {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
input := NewInputFromConn(c)
|
|
input.Request = NewRequest(r)
|
|
return input
|
|
}
|
|
|
|
func NewInputFromResponse(c net.Conn, r *http.Response) *Input {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
input := NewInputFromConn(c)
|
|
input.Response = NewResponse(r)
|
|
return input
|
|
}
|
|
|
|
type Client struct {
|
|
Network string `json:"network"`
|
|
IP string `json:"ip"`
|
|
Port int `json:"int"`
|
|
}
|
|
|
|
func (i *Client) logValues(log logger.Structured) logger.Structured {
|
|
if i != nil {
|
|
log = log.Values(logger.Values{
|
|
"client_network": i.Network,
|
|
"client_ip": i.IP,
|
|
"client_port": i.Port,
|
|
})
|
|
}
|
|
return log
|
|
}
|
|
|
|
func NewClient(network, address string) *Client {
|
|
if host, port, err := net.SplitHostPort(address); err == nil {
|
|
p, _ := net.LookupPort(network, port)
|
|
return &Client{
|
|
Network: network,
|
|
IP: host,
|
|
Port: p,
|
|
}
|
|
}
|
|
return &Client{
|
|
Network: network,
|
|
IP: address,
|
|
}
|
|
}
|
|
|
|
func NewClientFromConn(c net.Conn) *Client {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
return NewClientFromAddr(c.RemoteAddr())
|
|
}
|
|
|
|
func NewClientFromAddr(addr net.Addr) *Client {
|
|
switch addr := addr.(type) {
|
|
case *net.TCPAddr:
|
|
return &Client{
|
|
Network: addr.Network(),
|
|
IP: addr.IP.String(),
|
|
Port: addr.Port,
|
|
}
|
|
case *net.UDPAddr:
|
|
return &Client{
|
|
Network: addr.Network(),
|
|
IP: addr.IP.String(),
|
|
Port: addr.Port,
|
|
}
|
|
case *net.IPAddr:
|
|
return &Client{
|
|
Network: addr.Network(),
|
|
IP: addr.IP.String(),
|
|
}
|
|
default:
|
|
if host, port, err := net.SplitHostPort(addr.String()); err == nil {
|
|
return &Client{
|
|
Network: addr.Network(),
|
|
IP: host,
|
|
Port: func() int { p, _ := net.LookupPort(addr.Network(), port); return p }(),
|
|
}
|
|
}
|
|
return &Client{
|
|
Network: addr.Network(),
|
|
IP: addr.String(),
|
|
}
|
|
}
|
|
}
|
|
|
|
type TLS struct {
|
|
Version string `json:"version"`
|
|
CipherSuite string `json:"cipher_suite"`
|
|
ServerName string `json:"server_name"`
|
|
Certificates []*Certificate `json:"certificates"`
|
|
}
|
|
|
|
func (i *TLS) logValues(log logger.Structured) logger.Structured {
|
|
if i != nil {
|
|
cns := make([]string, len(i.Certificates))
|
|
for j, cert := range i.Certificates {
|
|
cns[j] = cert.Subject.CommonName
|
|
}
|
|
log = log.Values(logger.Values{
|
|
"tls_version": i.Version,
|
|
"tls_cipher": i.CipherSuite,
|
|
"tls_server_name": i.ServerName,
|
|
"tls_certificates": cns,
|
|
})
|
|
}
|
|
return log
|
|
}
|
|
|
|
func NewTLS(state *tls.ConnectionState) *TLS {
|
|
if state == nil {
|
|
return nil
|
|
}
|
|
tls := &TLS{
|
|
Version: tls.VersionName(state.Version),
|
|
CipherSuite: tls.CipherSuiteName(state.CipherSuite),
|
|
ServerName: state.ServerName,
|
|
}
|
|
for _, cert := range state.PeerCertificates {
|
|
if cert := NewCertificate(cert); cert != nil {
|
|
tls.Certificates = append(tls.Certificates, cert)
|
|
}
|
|
}
|
|
return tls
|
|
}
|
|
|
|
type tlsConnectionStater interface {
|
|
ConnectionState() tls.ConnectionState
|
|
}
|
|
|
|
func NewTLSFromConn(c net.Conn) *TLS {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
if s, ok := c.(tlsConnectionStater); ok {
|
|
cs := s.ConnectionState()
|
|
return NewTLS(&cs)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Certificate struct {
|
|
SerialNumber string `json:"serial_number"`
|
|
Subject PKIXName `json:"subject"`
|
|
Issuer PKIXName `json:"issuer"`
|
|
NotBefore int64 `json:"not_before"`
|
|
NotAfter int64 `json:"not_after"`
|
|
}
|
|
|
|
func NewCertificate(cert *x509.Certificate) *Certificate {
|
|
if cert == nil {
|
|
return nil
|
|
}
|
|
return &Certificate{
|
|
SerialNumber: cert.SerialNumber.String(),
|
|
Subject: MakePKIXName(cert.Subject),
|
|
Issuer: MakePKIXName(cert.Issuer),
|
|
NotBefore: cert.NotBefore.UnixNano(),
|
|
NotAfter: cert.NotAfter.UnixNano(),
|
|
}
|
|
}
|
|
|
|
type PKIXName struct {
|
|
CommonName string `json:"cn,omitempty"`
|
|
Country string `json:"country,omitempty"`
|
|
Organization string `json:"organization,omitempty"`
|
|
OrganizationalUnit string `json:"ou,omitempty"`
|
|
Locality string `json:"locality,omitempty"`
|
|
Province string `json:"province,omitempty"`
|
|
StreetAddress string `json:"address,omitempty"`
|
|
PostalCode string `json:"postalcode,omitempty"`
|
|
}
|
|
|
|
func MakePKIXName(name pkix.Name) PKIXName {
|
|
return PKIXName{
|
|
CommonName: name.CommonName,
|
|
Country: pick(name.Country...),
|
|
Organization: pick(name.Organization...),
|
|
OrganizationalUnit: pick(name.OrganizationalUnit...),
|
|
Locality: pick(name.Locality...),
|
|
Province: pick(name.Province...),
|
|
StreetAddress: pick(name.StreetAddress...),
|
|
PostalCode: pick(name.PostalCode...),
|
|
}
|
|
}
|
|
|
|
// Request represents an HTTP request.
|
|
type Request struct {
|
|
Method string `json:"method"`
|
|
URL *URL `json:"url"`
|
|
Proto string `json:"proto"`
|
|
Header map[string]string `json:"header"`
|
|
Host string `json:"host"`
|
|
Port int `json:"port"`
|
|
RequestURI string `json:"request_uri"`
|
|
}
|
|
|
|
func (i *Request) logValues(log logger.Structured) logger.Structured {
|
|
if i != nil {
|
|
log = log.Values(logger.Values{
|
|
"request_method": i.Method,
|
|
"request_url": i.URL.String(),
|
|
"request_proto": i.Proto,
|
|
"request_header": i.Header,
|
|
"request_host": i.Host,
|
|
"request_port": i.Port,
|
|
})
|
|
}
|
|
return log
|
|
}
|
|
|
|
func NewRequest(r *http.Request) *Request {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
|
|
header := make(map[string]string)
|
|
for key := range r.Header {
|
|
header[key] = r.Header.Get(key)
|
|
}
|
|
|
|
host, portName, err := net.SplitHostPort(r.URL.Host)
|
|
if err != nil {
|
|
host = netutil.Host(r.URL.Host)
|
|
portName = "80"
|
|
if r.URL.Scheme == "https" || r.URL.Scheme == "wss" || r.TLS != nil {
|
|
portName = "443"
|
|
}
|
|
}
|
|
|
|
var port int
|
|
if port, err = strconv.Atoi(portName); err != nil {
|
|
port, _ = net.LookupPort("tcp", portName)
|
|
}
|
|
|
|
return &Request{
|
|
Method: r.Method,
|
|
URL: NewURL(r.URL),
|
|
Proto: r.Proto,
|
|
Header: header,
|
|
Host: host,
|
|
Port: port,
|
|
RequestURI: r.RequestURI,
|
|
}
|
|
}
|
|
|
|
// Response represents an HTTP response.
|
|
type Response struct {
|
|
Status string `json:"status"`
|
|
StatusCode int `json:"status_code"`
|
|
Proto string `json:"proto"`
|
|
Header map[string]string `json:"header"`
|
|
ContentLength int64 `json:"content_length"`
|
|
Close bool `json:"close"`
|
|
Request *Request `json:"request"`
|
|
TLS *TLS `json:"tls"`
|
|
}
|
|
|
|
func (i *Response) logValues(log logger.Structured) logger.Structured {
|
|
if i != nil {
|
|
log = log.Values(logger.Values{
|
|
"response_status": i.StatusCode,
|
|
"response_proto": i.Proto,
|
|
"response_header": i.Header,
|
|
"response_close": i.Close,
|
|
"response_tls": i.TLS != nil,
|
|
})
|
|
}
|
|
return log
|
|
}
|
|
|
|
func NewResponse(r *http.Response) *Response {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
header := make(map[string]string)
|
|
for key := range r.Header {
|
|
header[key] = r.Header.Get(key)
|
|
}
|
|
return &Response{
|
|
Status: r.Status,
|
|
StatusCode: r.StatusCode,
|
|
Proto: r.Proto,
|
|
Header: header,
|
|
ContentLength: r.ContentLength,
|
|
Close: r.Close,
|
|
Request: NewRequest(r.Request),
|
|
TLS: NewTLS(r.TLS),
|
|
}
|
|
}
|
|
|
|
type URL struct {
|
|
Scheme string `json:"scheme"`
|
|
Host string `json:"host"`
|
|
Path string `json:"path"`
|
|
Query map[string]string `json:"query"`
|
|
}
|
|
|
|
func (i *URL) String() string {
|
|
if i == nil {
|
|
return "<nil>"
|
|
}
|
|
|
|
s := fmt.Sprintf("%s://%s%s", i.Scheme, i.Host, i.Path)
|
|
if len(i.Query) > 0 {
|
|
s += "?"
|
|
for k, v := range i.Query {
|
|
s += k + "=" + url.QueryEscape(v)
|
|
}
|
|
}
|
|
return s
|
|
}
|
|
|
|
func ParseURL(rawurl string) (*URL, error) {
|
|
parsed, err := url.Parse(rawurl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewURL(parsed), nil
|
|
}
|
|
|
|
func NewURL(url *url.URL) *URL {
|
|
if url == nil {
|
|
return nil
|
|
}
|
|
query := make(map[string]string)
|
|
for key, values := range url.Query() {
|
|
if len(values) > 0 {
|
|
query[key] = values[0]
|
|
}
|
|
}
|
|
return &URL{
|
|
Scheme: url.Scheme,
|
|
Host: url.Host,
|
|
Path: url.Path,
|
|
Query: query,
|
|
}
|
|
}
|
|
|
|
func pick(values ...string) string {
|
|
for _, v := range values {
|
|
if v != "" {
|
|
return v
|
|
}
|
|
}
|
|
return ""
|
|
}
|