Files
styx/cmd/styx/config.go
2025-10-06 22:25:23 +02:00

274 lines
6.7 KiB
Go

package main
import (
"crypto/tls"
"fmt"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/hashicorp/hcl/v2/hclsimple"
"git.maze.io/maze/styx/ca"
"git.maze.io/maze/styx/dataset"
"git.maze.io/maze/styx/internal/cryptutil"
"git.maze.io/maze/styx/logger"
"git.maze.io/maze/styx/policy"
"git.maze.io/maze/styx/proxy"
)
type Config struct {
Proxy ProxyConfig `hcl:"proxy,block"`
Policy []PolicyConfig `hcl:"policy,block"`
CA *CAConfig `hcl:"ca,block"`
Data DataConfig `hcl:"data,block"`
}
func (c Config) Proxies(log logger.Structured) ([]*proxy.Proxy, error) {
policies := make(map[string]*policy.Policy)
for _, pc := range c.Policy {
p, err := policy.New(pc.Path, pc.Package)
if err != nil {
return nil, fmt.Errorf("policy %s: %w", pc.Name, err)
}
policies[pc.Name] = p
}
var (
onRequest []proxy.RequestHandler
onDial []proxy.DialHandler
onForward []proxy.ForwardHandler
onResponse []proxy.ResponseHandler
)
for _, name := range c.Proxy.On.Request {
log.Value("policy", name).Debug("Resolving request policy")
p, ok := policies[name]
if !ok {
return nil, fmt.Errorf("on request: no policy named %q", name)
}
onRequest = append(onRequest, policy.NewRequestHandler(p))
}
for _, name := range c.Proxy.On.Dial {
log.Value("policy", name).Debug("Resolving dial policy")
p, ok := policies[name]
if !ok {
return nil, fmt.Errorf("on dial: no policy named %q", name)
}
onDial = append(onDial, policy.NewDialHandler(p))
}
for _, name := range c.Proxy.On.Forward {
log.Value("policy", name).Debug("Resolving forward policy")
p, ok := policies[name]
if !ok {
return nil, fmt.Errorf("on forward: no policy named %q", name)
}
onForward = append(onForward, policy.NewForwardHandler(p))
}
for _, name := range c.Proxy.On.Response {
log.Value("policy", name).Debug("Resolving response policy")
p, ok := policies[name]
if !ok {
return nil, fmt.Errorf("on response: no policy named %q", name)
}
onResponse = append(onResponse, policy.NewResponseHandler(p))
}
var proxies []*proxy.Proxy
for _, pc := range c.Proxy.Port {
log.Value("port", pc.Listen).Debug("Configuring proxy port")
p, err := pc.Proxy()
if err != nil {
return nil, err
}
p.OnRequest = append(p.OnRequest, onRequest...)
p.OnDial = append(p.OnDial, onDial...)
p.OnForward = append(p.OnForward, onForward...)
p.OnResponse = append(p.OnResponse, onResponse...)
proxies = append(proxies, p)
}
return proxies, nil
}
type ProxyConfig struct {
Port []PortConfig `hcl:"port,block"`
Upstream []string `hcl:"upstream"`
On ProxyPolicyConfig `hcl:"on,block"`
}
type PortConfig struct {
Listen string `hcl:"port,label"`
TLS *PortTLSConfig `hcl:"tls,block"`
Transparent int `hcl:"transparent,optional"`
Name string `hcl:"name,optional"`
}
type PortTLSConfig struct {
Cert string `hcl:"cert"`
Key string `hcl:"key,optional"`
CA string `hcl:"ca,optional"`
}
func (c PortConfig) Proxy() (*proxy.Proxy, error) {
p := proxy.New()
if c.Transparent > 0 {
p.OnConnect = append(p.OnConnect, proxy.Transparent(c.Transparent))
} else if c.TLS != nil {
cert, err := cryptutil.LoadTLSCertificate(c.TLS.Cert, c.TLS.Key)
if err != nil {
return nil, err
}
config := new(tls.Config)
config.Certificates = []tls.Certificate{cert}
if c.TLS.CA != "" {
roots, err := cryptutil.LoadRoots(c.TLS.CA)
if err != nil {
return nil, err
}
config.RootCAs = roots
}
p.OnConnect = append(p.OnConnect, proxy.TLS(config))
}
return p, nil
}
type ProxyPolicyConfig struct {
Intercept []string `hcl:"intercept,optional"`
Request []string `hcl:"request,optional"`
Dial []string `hcl:"dial,optional"`
Forward []string `hcl:"forward,optional"`
Response []string `hcl:"response,optional"`
}
type PolicyConfig struct {
Name string `hcl:"name,label"`
Path string `hcl:"path"`
Package string `hcl:"package,optional"`
}
type CAConfig struct {
Cert string `hcl:"cert"`
Key string `hcl:"key,optional"`
}
func (c CAConfig) CertificateAuthority() (ca.CertificateAuthority, error) {
return ca.Open(c.Cert, c.Key)
}
type DataConfig struct {
Path string `hcl:"path,optional"`
Storage DataStorageConfig `hcl:"storage,block"`
Domains []DomainDataConfig `hcl:"domain,block"`
Networks []NetworkDataConfig `hcl:"network,block"`
}
func (c DataConfig) Configure() error {
for _, dc := range c.Domains {
if err := dc.Configure(); err != nil {
return fmt.Errorf("error setting up domain data %q: %w", dc.Name, err)
}
}
for _, nc := range c.Networks {
if err := nc.Configure(); err != nil {
return fmt.Errorf("error setting up network data %q: %w", nc.Name, err)
}
}
return nil
}
func (c DataConfig) OpenStorage() (dataset.Storage, error) {
switch c.Storage.Type {
case "", "bolt", "boltdb":
var config struct {
Path string `hcl:"path"`
}
if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
return nil, diag
}
//return dataset.OpenBolt(config.Path)
return dataset.OpenBStore(config.Path)
/*
case "sqlite", "sqlite3":
var config struct {
Path string `hcl:"path"`
}
if diag := gohcl.DecodeBody(c.Storage.Body, nil, &config); diag.HasErrors() {
return nil, diag
}
return dataset.OpenSQLite(config.Path)
*/
default:
return nil, fmt.Errorf("storage: no %q driver", c.Storage.Type)
}
}
type DataStorageConfig struct {
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
type DomainDataConfig struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
func (c DomainDataConfig) Configure() error {
switch c.Type {
case "", "list":
var justTheList struct {
List []string `hcl:"list"`
}
if diag := gohcl.DecodeBody(c.Body, nil, &justTheList); diag.HasErrors() {
return diag
}
dataset.Domains[c.Name] = dataset.NewDomainList(justTheList.List...)
default:
return fmt.Errorf("unknown type %q", c.Type)
}
return nil
}
type NetworkDataConfig struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
Body hcl.Body `hcl:",remain"`
}
func (c NetworkDataConfig) Configure() error {
switch c.Type {
case "", "list":
var justTheList struct {
List []string `hcl:"list"`
}
if diag := gohcl.DecodeBody(c.Body, nil, &justTheList); diag.HasErrors() {
return diag
}
list, err := dataset.NewNetworkTree(justTheList.List...)
if err != nil {
return err
}
dataset.Networks[c.Name] = list
default:
return fmt.Errorf("unknown type %q", c.Type)
}
return nil
}
func Load(name string) (*Config, error) {
config := new(Config)
if err := hclsimple.DecodeFile(name, nil, config); err != nil {
return nil, err
}
return config, nil
}