package cryptutil import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "errors" "fmt" "math/big" "os" "path/filepath" "strings" "time" "git.maze.io/maze/styx/logger" "github.com/rs/zerolog/log" ) // Supported key types. const ( TypeRSA = "rsa" TypeECDSA = "ecdsa" TypeED25519 = "ed25519" ) // Supported PEM block types. const ( pemTypeCert = "CERTIFICATE" pemTypeRSA = "RSA PRIVATE KEY" pemTypeECDSA = "EC PRIVATE KEY" pemTypeAny = "PRIVATE KEY" ) // DecodeRoots loads all PEM encoded certificates from b. func DecodeRoots(b []byte) (*x509.CertPool, error) { var ( pool = x509.NewCertPool() rest = b block *pem.Block ) for { if block, rest = pem.Decode(rest); block == nil { break } else if block.Type == pemTypeCert { cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, err } else if cert.IsCA { pool.AddCert(cert) } } } return pool, nil } // LoadRoots loads a certificate authority bundle. func LoadRoots(roots string) (*x509.CertPool, error) { if strings.Contains(roots, "-----BEGIN CERTIFICATE") { logger.StandardLog.Trace("Parsing X.509 certificates") return DecodeRoots([]byte(roots)) } var b []byte i, err := os.Stat(roots) if err != nil { return nil, err } else if i.IsDir() { logger.StandardLog.Value("path", roots).Trace("Loading X.509 certificates from *.crt *.pem") for _, ext := range []string{"*.crt", "*.pem"} { var files []string if files, err = filepath.Glob(filepath.Join(roots, ext)); err != nil { return nil, err } for _, file := range files { var v []byte if v, err = os.ReadFile(file); err != nil { return nil, err } b = append(b, v...) } } } else { logger.StandardLog.Value("path", roots).Trace("Loading X.509 certificates") if b, err = os.ReadFile(roots); err != nil { return nil, err } } return DecodeRoots(b) } // LoadKeyPair loads a certificate and private key, certdata and keydata can be a PEM encoded block or a file. // // If [keydata] is empty, then the private key is assumed to be contained in [certdata]. func LoadKeyPair(certdata, keydata string) (cert *x509.Certificate, key crypto.PrivateKey, err error) { if keydata == "" { keydata = certdata } if strings.Contains(certdata, "-----BEGIN "+pemTypeCert) { logger.StandardLog.Trace("Parsing X.509 certificate") if cert, err = decodePEMBCertificate([]byte(certdata)); err != nil { return } } else { logger.StandardLog.Value("name", certdata).Trace("Loading X.509 certificate") if cert, err = LoadCertificate(certdata); err != nil { return } } if strings.Contains(keydata, pemTypeAny+"-----") { logger.StandardLog.Trace("Parsing private key") if key, err = decodePEMPrivateKey([]byte(keydata)); err != nil { return } } else if key, err = LoadPrivateKey(keydata); err != nil { logger.StandardLog.Value("name", keydata).Trace("Loading private key") return } return } // SaveKeyPair saves a certificate and private key in PEM encoding. // // If [keyFile] is empty, then the private key is stored in [certFile] alongside the certificate. // // Attempts are made to use secure file modes for files that contains private keys. func SaveKeyPair(cert *x509.Certificate, key crypto.PrivateKey, certFile, keyFile string) (err error) { var ( keyDER []byte keyPEMType = pemTypeAny ) switch key := key.(type) { case *ecdsa.PrivateKey: if keyDER, err = x509.MarshalECPrivateKey(key); err != nil { return } keyPEMType = pemTypeECDSA case ed25519.PrivateKey: if keyDER, err = x509.MarshalPKCS8PrivateKey(key); err != nil { return } case *rsa.PrivateKey: keyDER = x509.MarshalPKCS1PrivateKey(key) keyPEMType = pemTypeRSA default: return fmt.Errorf("mitm: don't know how to marshal %T", key) } var certf, keyf *os.File if certf, err = os.OpenFile(certFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644); err != nil { return } defer func() { _ = certf.Close() }() if filepath.Clean(certFile) == filepath.Clean(keyFile) || keyFile == "" { if err = certf.Chmod(0o600); err != nil { return } keyf, keyFile = certf, certFile } else { if keyf, err = os.OpenFile(keyFile, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o600); err != nil { return } defer func() { _ = keyf.Close() }() } log.Debug().Str("file", certFile).Msg("saving X.509 certificate") if err = pem.Encode(certf, &pem.Block{Type: pemTypeCert, Bytes: cert.Raw}); err != nil { return } log.Debug().Str("fiile", keyFile).Msg("saving private key") if err = pem.Encode(keyf, &pem.Block{Type: keyPEMType, Bytes: keyDER}); err != nil { return } return } // GenerateKeyPair generates a private key and self-signed certificate. func GenerateKeyPair(name pkix.Name, days int, keyType string, keyBits int) (cert *x509.Certificate, key crypto.PrivateKey, err error) { if key, err = GeneratePrivateKey(keyType, keyBits); err != nil { return } if cert, err = GenerateCertificateAuthority(name, days, key); err != nil { return } return } func GenerateCertificateAuthority(name pkix.Name, days int, key crypto.PrivateKey) (cert *x509.Certificate, err error) { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return nil, fmt.Errorf("mtim: failed to generate serial number: %w", err) } keyUsage := x509.KeyUsageCertSign if _, ok := key.(*rsa.PrivateKey); ok { keyUsage |= x509.KeyUsageDigitalSignature } notBefore := roundToDay(time.Now()) notAfter := notBefore.Add(time.Duration(days) * 24 * time.Hour) template := &x509.Certificate{ Subject: name, SerialNumber: serialNumber, KeyUsage: keyUsage, SubjectKeyId: GenerateKeyID(key), IsCA: true, BasicConstraintsValid: true, NotBefore: notBefore, NotAfter: notAfter, } log.Info(). Str("name", name.CommonName). Int("days", days). Str("key", keyType(key)). Str("serial", serialNumber.String()). Msg("generating self-signed CA certificate") var der []byte if der, err = x509.CreateCertificate(rand.Reader, template, template, PublicKey(key), key); err != nil { return } return x509.ParseCertificate(der) } func GeneratePrivateKey(kind string, bits int) (key crypto.PrivateKey, err error) { switch strings.ToLower(kind) { case TypeRSA, "": if bits == 0 { bits = 2048 } log.Trace().Int("bits", bits).Str("type", TypeRSA).Msg("generating private key") return rsa.GenerateKey(rand.Reader, bits) case TypeECDSA, "ec", "ecc": if bits == 0 { bits = 256 } var curve elliptic.Curve switch bits { case 224: curve = elliptic.P224() case 256: curve = elliptic.P256() case 384: curve = elliptic.P384() case 521: curve = elliptic.P521() default: return nil, fmt.Errorf("mitm: elliptic curve %d bits not supported", bits) } log.Trace().Int("bits", bits).Str("type", TypeECDSA).Msg("generating private key") return ecdsa.GenerateKey(curve, rand.Reader) case TypeED25519: log.Trace().Str("type", TypeED25519).Msg("generating ED25519 private key") _, key, err = ed25519.GenerateKey(rand.Reader) return default: return nil, fmt.Errorf("mitm: don't know how to generate %s private key", kind) } } func decodePEMBCertificate(b []byte) (cert *x509.Certificate, err error) { var ( rest = b block *pem.Block ) for { if block, rest = pem.Decode(rest); block == nil { return nil, errors.New("mitm: no CERTIFICATE PEM block could be decoded") } else if block.Type == "CERTIFICATE" { return x509.ParseCertificate(block.Bytes) } } } func LoadCertificate(name string) (*x509.Certificate, error) { b, err := os.ReadFile(name) if err != nil { return nil, err } return decodePEMBCertificate(b) } func roundToDay(t time.Time) time.Time { return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) }