diff --git a/pkg/crypto/missing_san.go b/pkg/crypto/missing_san.go new file mode 100644 index 0000000000..bf17008a7f --- /dev/null +++ b/pkg/crypto/missing_san.go @@ -0,0 +1,40 @@ +package crypto + +import ( + "crypto/x509" + "errors" + "strings" +) + +// CertHasSAN returns true if the given certificate includes a SAN field, else false. +func CertHasSAN(c *x509.Certificate) bool { + if c == nil { + return false + } + + sanOID := []int{2, 5, 29, 17} + + for i := range c.Extensions { + if c.Extensions[i].Id.Equal(sanOID) { + return true + } + } + return false +} + +// IsHostnameError returns true if the error indicates a host name error about legacy CN fields +// else false as a result of `x509.Certificate#VerifyHostname`. +// +// For Golang <1.17: If GODEBUG=x509ignoreCN=0 is set this will always return false. +// In this case, use `crypto.CertHasSAN` to assert validity of the certificate directly. +// +// See https://github.com/golang/go/blob/go1.16.12/src/crypto/x509/verify.go#L119 +func IsHostnameError(err error) bool { + if err != nil && + errors.As(err, &x509.HostnameError{}) && + strings.Contains(err.Error(), "x509: certificate relies on legacy Common Name field") { + return true + } + + return false +} diff --git a/pkg/crypto/missing_san_test.go b/pkg/crypto/missing_san_test.go new file mode 100644 index 0000000000..b816b03d93 --- /dev/null +++ b/pkg/crypto/missing_san_test.go @@ -0,0 +1,104 @@ +package crypto + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "math/big" + "testing" +) + +func TestCertHasSAN(t *testing.T) { + for _, tc := range []struct { + name string + cert *x509.Certificate + expectedHasSAN bool + }{ + { + name: "nil cert", + expectedHasSAN: false, + }, + { + name: "sole identifier", + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + {Id: asn1.ObjectIdentifier{2, 5, 29, 17}}, + }, + }, + expectedHasSAN: true, + }, + { + name: "last identifier", + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + {Id: asn1.ObjectIdentifier{1}}, + {Id: asn1.ObjectIdentifier{2, 5, 29, 17}}, + }, + }, + expectedHasSAN: true, + }, + { + name: "first identifier", + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + {Id: asn1.ObjectIdentifier{2, 5, 29, 17}}, + {Id: asn1.ObjectIdentifier{1}}, + }, + }, + expectedHasSAN: true, + }, + { + name: "no identifier", + cert: &x509.Certificate{ + Extensions: []pkix.Extension{ + {Id: asn1.ObjectIdentifier{1}}, + {Id: asn1.ObjectIdentifier{2}}, + }, + }, + expectedHasSAN: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if got := CertHasSAN(tc.cert); got != tc.expectedHasSAN { + t.Errorf("expected result %t, got %t", tc.expectedHasSAN, got) + } + }) + } +} + +func TestIsHostnameError(t *testing.T) { + for _, tc := range []struct { + name string + err error + expected bool + }{ + { + name: "invalid hostname error", + err: x509.HostnameError{ + Certificate: &x509.Certificate{ + Subject: pkix.Name{CommonName: "foo.bar"}, + SerialNumber: big.NewInt(1), + }, + Host: "foo.bar", + }, + expected: true, + }, + { + name: "other error", + err: errors.New("boom"), + expected: false, + }, + { + name: "nil error", + err: nil, + expected: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if got := IsHostnameError(tc.err); got != tc.expected { + t.Errorf("expected %t, got %t", tc.expected, got) + } + }) + } +} diff --git a/pkg/transport/missing_san_roundtripper.go b/pkg/transport/missing_san_roundtripper.go new file mode 100644 index 0000000000..8e6fad4010 --- /dev/null +++ b/pkg/transport/missing_san_roundtripper.go @@ -0,0 +1,43 @@ +package transport + +import ( + "net/http" + + "github.com/openshift/library-go/pkg/crypto" +) + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return rt(r) +} + +type counter interface { + Inc() +} + +// NewMissingSANRoundTripper returns a round tripper that increases the given counter +// if the wrapped roundtripper returns an error indicating usage of legacy CN fields +// or if the wrapped roundtripper response leaf certificate includes no SAN field. +// +// The counter is compatible both with native Prometheus and k8s metrics types. +func NewMissingSANRoundTripper(rt http.RoundTripper, c counter) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + resp, err = rt.RoundTrip(req) + if crypto.IsHostnameError(err) { + c.Inc() + return + } + + if resp == nil || resp.TLS == nil || len(resp.TLS.PeerCertificates) == 0 { + return + } + + // The first element is the leaf certificate. + if !crypto.CertHasSAN(resp.TLS.PeerCertificates[0]) { + c.Inc() + } + + return + }) +} diff --git a/pkg/transport/missing_san_roundtripper_test.go b/pkg/transport/missing_san_roundtripper_test.go new file mode 100644 index 0000000000..465324f151 --- /dev/null +++ b/pkg/transport/missing_san_roundtripper_test.go @@ -0,0 +1,128 @@ +package transport + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net/http" + "testing" +) + +type mockCounter int + +func (c *mockCounter) Inc() { + *c = *c + 1 +} + +func TestMissingSANRoundTripper(t *testing.T) { + for _, tc := range []struct { + name string + + resp *http.Response + respErr error + + expectedCounts int + expectedErr string + }{ + { + name: "non tls response", + resp: &http.Response{}, + }, + { + name: "valid cert", + resp: &http.Response{ + TLS: &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + newCert(t, &x509.Certificate{ + Subject: pkix.Name{CommonName: "foo.bar"}, + SerialNumber: big.NewInt(1), + DNSNames: []string{"foo.bar"}, + })}, + }, + }, + expectedCounts: 0, + }, + { + name: "go 1.16: legacy cert verification error", + respErr: x509.HostnameError{ + Certificate: newCert(t, &x509.Certificate{ + Subject: pkix.Name{CommonName: "foo.bar"}, + SerialNumber: big.NewInt(1), + }), + Host: "foo.bar", + }, + expectedErr: "x509: certificate relies on legacy Common Name field, use SANs or temporarily enable Common Name matching with GODEBUG=x509ignoreCN=0", + expectedCounts: 1, + }, + { + name: "go 1.16: invalid cert", + resp: &http.Response{ + TLS: &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + newCert(t, &x509.Certificate{ + Subject: pkix.Name{CommonName: "foo.bar"}, + SerialNumber: big.NewInt(1), + })}, + }, + }, + expectedCounts: 1, + }, + { + name: "invalid hostname", + respErr: x509.HostnameError{ + Certificate: newCert(t, &x509.Certificate{ + Subject: pkix.Name{CommonName: "foo.bar"}, + SerialNumber: big.NewInt(1), + }), + Host: "some.host", + }, + expectedErr: "x509: certificate is not valid for any names, but wanted to match some.host", + expectedCounts: 0, + }, + } { + t.Run(tc.name, func(t *testing.T) { + rt := roundTripperFunc(func(_ *http.Request) (*http.Response, error) { + return tc.resp, tc.respErr + }) + + var ( + cnt mockCounter + gotErr string + ) + _, err := NewMissingSANRoundTripper(rt, &cnt).RoundTrip(nil) + if err != nil { + gotErr = err.Error() + } + + if tc.expectedErr != gotErr { + t.Errorf("expected error %q, got %q", tc.expectedErr, gotErr) + } + + if tc.expectedCounts != int(cnt) { + t.Errorf("expected %v counts, got %v", tc.expectedCounts, int(cnt)) + } + }) + } +} + +func newCert(t *testing.T, template *x509.Certificate) *x509.Certificate { + pk, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pk.PublicKey, pk) + if err != nil { + t.Fatal(err) + } + + certs, err := x509.ParseCertificates(certBytes) + if err != nil { + t.Fatal(err) + } + + return certs[0] +}