Files
styx/internal/cryptutil/key.go
2025-09-26 08:49:53 +02:00

115 lines
2.2 KiB
Go

package cryptutil
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
)
type publicKeyer interface {
Public() any
}
// PublicKey returns the public part of a crypto.PrivateKey.
func PublicKey(key crypto.PrivateKey) any {
switch key := key.(type) {
case ed25519.PublicKey:
return key
case ed25519.PrivateKey:
return key.Public()
case ecdsa.PublicKey:
return &key
case *ecdsa.PublicKey:
return key
case *ecdsa.PrivateKey:
return &key.PublicKey
case rsa.PublicKey:
return &key
case *rsa.PublicKey:
return key
case *rsa.PrivateKey:
return &key.PublicKey
default:
if p, ok := key.(publicKeyer); ok {
return p.Public()
}
panic(fmt.Sprintf("don't know how to extract a public key from %T", key))
}
}
// LoadPrivateKey loads a private key from disk.
func LoadPrivateKey(name string) (crypto.PrivateKey, error) {
b, err := os.ReadFile(name)
if err != nil {
return nil, err
}
return decodePEMPrivateKey(b)
}
func decodePEMPrivateKey(b []byte) (key crypto.PrivateKey, err error) {
var (
rest = b
block *pem.Block
)
for {
if block, rest = pem.Decode(rest); block == nil {
return nil, errors.New("mitm: no private key PEM block could be decoded")
}
switch block.Type {
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(block.Bytes)
}
}
}
// GenerateKeyID generates the PKIX public key ID.
func GenerateKeyID(key crypto.PublicKey) []byte {
b, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return nil
}
return sha1.New().Sum(b)
}
func keyType(key any) string {
switch key := key.(type) {
case ed25519.PrivateKey:
return "ed25519"
case *ecdsa.PrivateKey:
return "ecdsa (" + curveType(key.Curve) + ")"
case *rsa.PrivateKey:
return "rsa"
default:
return fmt.Sprintf("%T", key)
}
}
func curveType(c elliptic.Curve) string {
switch c {
case elliptic.P224():
return "p224"
case elliptic.P256():
return "p256"
case elliptic.P384():
return "p384"
case elliptic.P521():
return "p521"
default:
return fmt.Sprintf("%T", c)
}
}