Files
styx/internal/cryptutil/x509.go
2025-10-01 15:37:55 +02:00

300 lines
7.9 KiB
Go

package cryptutil
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"os"
"path/filepath"
"strings"
"time"
"git.maze.io/maze/styx/logger"
"github.com/rs/zerolog/log"
)
// Supported key types.
const (
TypeRSA = "rsa"
TypeECDSA = "ecdsa"
TypeED25519 = "ed25519"
)
// Supported PEM block types.
const (
pemTypeCert = "CERTIFICATE"
pemTypeRSA = "RSA PRIVATE KEY"
pemTypeECDSA = "EC PRIVATE KEY"
pemTypeAny = "PRIVATE KEY"
)
// DecodeRoots loads all PEM encoded certificates from b.
func DecodeRoots(b []byte) (*x509.CertPool, error) {
var (
pool = x509.NewCertPool()
rest = b
block *pem.Block
)
for {
if block, rest = pem.Decode(rest); block == nil {
break
} else if block.Type == pemTypeCert {
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
} else if cert.IsCA {
pool.AddCert(cert)
}
}
}
return pool, nil
}
// LoadRoots loads a certificate authority bundle.
func LoadRoots(roots string) (*x509.CertPool, error) {
if strings.Contains(roots, "-----BEGIN CERTIFICATE") {
logger.StandardLog.Trace("Parsing X.509 certificates")
return DecodeRoots([]byte(roots))
}
var b []byte
i, err := os.Stat(roots)
if err != nil {
return nil, err
} else if i.IsDir() {
logger.StandardLog.Value("path", roots).Trace("Loading X.509 certificates from *.crt *.pem")
for _, ext := range []string{"*.crt", "*.pem"} {
var files []string
if files, err = filepath.Glob(filepath.Join(roots, ext)); err != nil {
return nil, err
}
for _, file := range files {
var v []byte
if v, err = os.ReadFile(file); err != nil {
return nil, err
}
b = append(b, v...)
}
}
} else {
logger.StandardLog.Value("path", roots).Trace("Loading X.509 certificates")
if b, err = os.ReadFile(roots); err != nil {
return nil, err
}
}
return DecodeRoots(b)
}
// LoadKeyPair loads a certificate and private key, certdata and keydata can be a PEM encoded block or a file.
//
// If [keydata] is empty, then the private key is assumed to be contained in [certdata].
func LoadKeyPair(certdata, keydata string) (cert *x509.Certificate, key crypto.PrivateKey, err error) {
if keydata == "" {
keydata = certdata
}
if strings.Contains(certdata, "-----BEGIN "+pemTypeCert) {
logger.StandardLog.Trace("Parsing X.509 certificate")
if cert, err = decodePEMBCertificate([]byte(certdata)); err != nil {
return
}
} else {
logger.StandardLog.Value("name", certdata).Trace("Loading X.509 certificate")
if cert, err = LoadCertificate(certdata); err != nil {
return
}
}
if strings.Contains(keydata, pemTypeAny+"-----") {
logger.StandardLog.Trace("Parsing private key")
if key, err = decodePEMPrivateKey([]byte(keydata)); err != nil {
return
}
} else if key, err = LoadPrivateKey(keydata); err != nil {
logger.StandardLog.Value("name", keydata).Trace("Loading private key")
return
}
return
}
// SaveKeyPair saves a certificate and private key in PEM encoding.
//
// If [keyFile] is empty, then the private key is stored in [certFile] alongside the certificate.
//
// Attempts are made to use secure file modes for files that contains private keys.
func SaveKeyPair(cert *x509.Certificate, key crypto.PrivateKey, certFile, keyFile string) (err error) {
var (
keyDER []byte
keyPEMType = pemTypeAny
)
switch key := key.(type) {
case *ecdsa.PrivateKey:
if keyDER, err = x509.MarshalECPrivateKey(key); err != nil {
return
}
keyPEMType = pemTypeECDSA
case ed25519.PrivateKey:
if keyDER, err = x509.MarshalPKCS8PrivateKey(key); err != nil {
return
}
case *rsa.PrivateKey:
keyDER = x509.MarshalPKCS1PrivateKey(key)
keyPEMType = pemTypeRSA
default:
return fmt.Errorf("mitm: don't know how to marshal %T", key)
}
var certf, keyf *os.File
if certf, err = os.OpenFile(certFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644); err != nil {
return
}
defer func() { _ = certf.Close() }()
if filepath.Clean(certFile) == filepath.Clean(keyFile) || keyFile == "" {
if err = certf.Chmod(0o600); err != nil {
return
}
keyf, keyFile = certf, certFile
} else {
if keyf, err = os.OpenFile(keyFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o600); err != nil {
return
}
defer func() { _ = keyf.Close() }()
}
log.Debug().Str("file", certFile).Msg("saving X.509 certificate")
if err = pem.Encode(certf, &pem.Block{Type: pemTypeCert, Bytes: cert.Raw}); err != nil {
return
}
log.Debug().Str("fiile", keyFile).Msg("saving private key")
if err = pem.Encode(keyf, &pem.Block{Type: keyPEMType, Bytes: keyDER}); err != nil {
return
}
return
}
// GenerateKeyPair generates a private key and self-signed certificate.
func GenerateKeyPair(name pkix.Name, days int, keyType string, keyBits int) (cert *x509.Certificate, key crypto.PrivateKey, err error) {
if key, err = GeneratePrivateKey(keyType, keyBits); err != nil {
return
}
if cert, err = GenerateCertificateAuthority(name, days, key); err != nil {
return
}
return
}
func GenerateCertificateAuthority(name pkix.Name, days int, key crypto.PrivateKey) (cert *x509.Certificate, err error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err)
}
keyUsage := x509.KeyUsageCertSign
if _, ok := key.(*rsa.PrivateKey); ok {
keyUsage |= x509.KeyUsageDigitalSignature
}
notBefore := roundToDay(time.Now())
notAfter := notBefore.Add(time.Duration(days) * 24 * time.Hour)
template := &x509.Certificate{
Subject: name,
SerialNumber: serialNumber,
KeyUsage: keyUsage,
SubjectKeyId: GenerateKeyID(key),
IsCA: true,
BasicConstraintsValid: true,
NotBefore: notBefore,
NotAfter: notAfter,
}
log.Info().
Str("name", name.CommonName).
Int("days", days).
Str("key", keyType(key)).
Str("serial", serialNumber.String()).
Msg("generating self-signed CA certificate")
var der []byte
if der, err = x509.CreateCertificate(rand.Reader, template, template, PublicKey(key), key); err != nil {
return
}
return x509.ParseCertificate(der)
}
func GeneratePrivateKey(kind string, bits int) (key crypto.PrivateKey, err error) {
switch strings.ToLower(kind) {
case TypeRSA, "":
if bits == 0 {
bits = 2048
}
log.Trace().Int("bits", bits).Str("type", TypeRSA).Msg("generating private key")
return rsa.GenerateKey(rand.Reader, bits)
case TypeECDSA, "ec", "ecc":
if bits == 0 {
bits = 256
}
var curve elliptic.Curve
switch bits {
case 224:
curve = elliptic.P224()
case 256:
curve = elliptic.P256()
case 384:
curve = elliptic.P384()
case 521:
curve = elliptic.P521()
default:
return nil, fmt.Errorf("mitm: elliptic curve %d bits not supported", bits)
}
log.Trace().Int("bits", bits).Str("type", TypeECDSA).Msg("generating private key")
return ecdsa.GenerateKey(curve, rand.Reader)
case TypeED25519:
log.Trace().Str("type", TypeED25519).Msg("generating ED25519 private key")
_, key, err = ed25519.GenerateKey(rand.Reader)
return
default:
return nil, fmt.Errorf("mitm: don't know how to generate %s private key", kind)
}
}
func decodePEMBCertificate(b []byte) (cert *x509.Certificate, err error) {
var (
rest = b
block *pem.Block
)
for {
if block, rest = pem.Decode(rest); block == nil {
return nil, errors.New("mitm: no CERTIFICATE PEM block could be decoded")
} else if block.Type == "CERTIFICATE" {
return x509.ParseCertificate(block.Bytes)
}
}
}
func LoadCertificate(name string) (*x509.Certificate, error) {
b, err := os.ReadFile(name)
if err != nil {
return nil, err
}
return decodePEMBCertificate(b)
}
func roundToDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}