package cryptutil import ( "bufio" "bytes" "crypto/tls" "crypto/x509" "encoding/pem" "io" "net" "os" "slices" "strings" "git.maze.io/maze/styx/internal/netutil" "git.maze.io/maze/styx/internal/sliceutil" "git.maze.io/maze/styx/logger" ) var ( supportedCipherSuites = tls.CipherSuites() supportedCipherSuite = make(map[uint16]bool) supportedVersions = []uint16{ tls.VersionTLS13, tls.VersionTLS12, tls.VersionTLS11, tls.VersionTLS10, } ) func init() { for _, suite := range supportedCipherSuites { supportedCipherSuite[suite.ID] = true } } func DecodeTLSCertificate(b []byte) (tls.Certificate, error) { var ( cert tls.Certificate chain []*x509.Certificate rest = b block *pem.Block err error ) for { if block, rest = pem.Decode(rest); block == nil { break } switch block.Type { case "CERTIFICATE": cert.Certificate = append(cert.Certificate, block.Bytes) if x509Cert, err := x509.ParseCertificate(block.Bytes); err != nil { return tls.Certificate{}, err } else { chain = append(chain, x509Cert) cert.Leaf = x509Cert } case "PRIVATE KEY": if cert.PrivateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { return tls.Certificate{}, err } case "RSA PRIVATE KEY": if cert.PrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { return tls.Certificate{}, err } case "EC PRIVATE KEY": if cert.PrivateKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil { return tls.Certificate{}, err } } } return cert, nil } func LoadTLSCertificate(certFile, keyFile string) (tls.Certificate, error) { var ( b []byte err error ) if strings.Contains(certFile, "-----BEGIN") { logger.StandardLog.Trace("Loading X.509 certificate") b = []byte(certFile) } else { logger.StandardLog.Value("name", certFile).Trace("Loading X.509 certificate") if b, err = os.ReadFile(certFile); err != nil { return tls.Certificate{}, err } } if strings.Contains(keyFile, "-----BEGIN") { logger.StandardLog.Trace("Loading private key") b = append(b, []byte(keyFile)...) } else if keyFile != "" { logger.StandardLog.Value("name", keyFile).Trace("Loading private key") var k []byte if k, err = os.ReadFile(keyFile); err != nil { return tls.Certificate{}, err } b = append(b, k...) } return DecodeTLSCertificate(b) } // CheckTLSBuffer is like [CheckTLSHandshake] but restores the original buffered reader. func CheckTLSBuffer(r *bufio.Reader) (bool, error) { b, err := r.ReadByte() if err != nil { return false, err } if err = r.UnreadByte(); err != nil { return false, err } return b == 0x16, nil } // CheckTLSHandshake checks if the next byte available in r looks like a TLS handshake. func CheckTLSHandshake(r io.Reader) (bool, error) { // Peek first byte received in tunneled connection, client initiates the TLS connection or plain HTTP request b := make([]byte, 1) if _, err := io.ReadFull(r, b); err != nil { return false, err } // TLS handshake: https://tools.ietf.org/html/rfc5246#section-6.2.1 return b[0] == 0x16, nil } // SniffClientHello uses ReadClientHello to sniff the TLS handshake and returns a new [net.Conn] that // contains the original byte sequences. func SniffClientHello(c net.Conn) (net.Conn, *tls.ClientHelloInfo, error) { b := new(bytes.Buffer) h, err := ReadClientHello(io.TeeReader(c, b)) return netutil.ReaderConn{ Conn: c, Reader: io.MultiReader(b, c), }, h, err } // ReadClientHello reads a TLS client hello message from the TLS handshake. func ReadClientHello(r io.Reader) (*tls.ClientHelloInfo, error) { var hello *tls.ClientHelloInfo err := tls.Server(netutil.ReadOnlyConn{Reader: r}, &tls.Config{ GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { hello = new(tls.ClientHelloInfo) *hello = *clientHello return nil, nil }, }).Handshake() if hello == nil { return nil, err } return hello, nil } // IsSupportedCipherSuite checks if Go can support the cipher suite. func IsSupportedCipherSuite(id uint16) bool { return supportedCipherSuite[id] } // IsSupportedVersion checks if Go can support the TLS version. func IsSupportedVersion(version uint16) bool { return slices.Contains(supportedVersions, version) } // OnlySecureCipherSuites removes any cipher suite that isn't supported by Go. func OnlySecureCipherSuites(ids []uint16) []uint16 { return sliceutil.Filter(ids, IsSupportedCipherSuite) }