package cryptutil import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rsa" "crypto/sha1" "crypto/x509" "encoding/pem" "errors" "fmt" "os" ) type publicKeyer interface { Public() any } // PublicKey returns the public part of a crypto.PrivateKey. func PublicKey(key crypto.PrivateKey) any { switch key := key.(type) { case ed25519.PublicKey: return key case ed25519.PrivateKey: return key.Public() case ecdsa.PublicKey: return &key case *ecdsa.PublicKey: return key case *ecdsa.PrivateKey: return &key.PublicKey case rsa.PublicKey: return &key case *rsa.PublicKey: return key case *rsa.PrivateKey: return &key.PublicKey default: if p, ok := key.(publicKeyer); ok { return p.Public() } panic(fmt.Sprintf("don't know how to extract a public key from %T", key)) } } // LoadPrivateKey loads a private key from disk. func LoadPrivateKey(name string) (crypto.PrivateKey, error) { b, err := os.ReadFile(name) if err != nil { return nil, err } return decodePEMPrivateKey(b) } func decodePEMPrivateKey(b []byte) (key crypto.PrivateKey, err error) { var ( rest = b block *pem.Block ) for { if block, rest = pem.Decode(rest); block == nil { return nil, errors.New("mitm: no private key PEM block could be decoded") } switch block.Type { case "EC PRIVATE KEY": return x509.ParseECPrivateKey(block.Bytes) case "RSA PRIVATE KEY": return x509.ParsePKCS1PrivateKey(block.Bytes) case "PRIVATE KEY": return x509.ParsePKCS8PrivateKey(block.Bytes) } } } // GenerateKeyID generates the PKIX public key ID. func GenerateKeyID(key crypto.PublicKey) []byte { b, err := x509.MarshalPKIXPublicKey(key) if err != nil { return nil } return sha1.New().Sum(b) } func keyType(key any) string { switch key := key.(type) { case ed25519.PrivateKey: return "ed25519" case *ecdsa.PrivateKey: return "ecdsa (" + curveType(key.Curve) + ")" case *rsa.PrivateKey: return "rsa" default: return fmt.Sprintf("%T", key) } } func curveType(c elliptic.Curve) string { switch c { case elliptic.P224(): return "p224" case elliptic.P256(): return "p256" case elliptic.P384(): return "p384" case elliptic.P521(): return "p521" default: return fmt.Sprintf("%T", c) } }