From 49c17051bcb8eae4b6fada28111ac4c47e4a4067 Mon Sep 17 00:00:00 2001 From: Pierre Prinetti Date: Thu, 25 Apr 2024 15:14:29 +0200 Subject: [PATCH] CHERRY-PICK: Refactoring: never assign unacceptable TLS versions This commit makes security linting easier by never setting a TLS version outside v1.2 or v1.3, even in case of an unacceptable user input. Upstream PR: https://github.com/kubernetes-sigs/cluster-api-provider-openstack/pull/2037 (cherry picked from commit 27526d5f37d843c6b9ba15d302ec026c0e7da227) --- main.go | 36 +++++++++++++--------------------- main_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/main.go b/main.go index c170553eeb..4cb02726bc 100644 --- a/main.go +++ b/main.go @@ -372,14 +372,19 @@ func concurrency(c int) controller.Options { func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) { var tlsOptions []func(config *tls.Config) - tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion) - if err != nil { - return nil, err - } - - tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion) - if err != nil { - return nil, err + // To make a static analyzer happy, this block ensures there is no code + // path that sets a TLS version outside the acceptable values, even in + // case of unexpected user input. + var tlsMinVersion, tlsMaxVersion uint16 + for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} { + switch option { + case TLSVersion12: + *version = tls.VersionTLS12 + case TLSVersion13: + *version = tls.VersionTLS13 + default: + return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", ")) + } } if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion { @@ -421,18 +426,3 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) return tlsOptions, nil } - -// GetTLSVersion returns the corresponding tls.Version or error. -func GetTLSVersion(version string) (uint16, error) { - var v uint16 - - switch version { - case TLSVersion12: - v = tls.VersionTLS12 - case TLSVersion13: - v = tls.VersionTLS13 - default: - return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", ")) - } - return v, nil -} diff --git a/main_test.go b/main_test.go index b4139280c8..fd1281f179 100644 --- a/main_test.go +++ b/main_test.go @@ -18,6 +18,7 @@ package main import ( "bytes" + "crypto/tls" "testing" . "github.com/onsi/gomega" @@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) { klog.SetOutput(bufWriter) klog.LogToStderr(false) // this is important, because klog by default logs to stderr only _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) - g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers")) g.Expect(err).Should(BeNil()) + g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers")) }) } -func TestGetTLSVersion(t *testing.T) { - t.Run("should error out when incorrect tls version passed", func(t *testing.T) { +func TestGetTLSOverrideFuncs(t *testing.T) { + t.Run("should error out when incorrect min tls version passed", func(t *testing.T) { + g := NewWithT(t) + _, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS11", + TLSMaxVersion: "TLS12", + }) + g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)")) + }) + t.Run("should error out when incorrect max tls version passed", func(t *testing.T) { g := NewWithT(t) - tlsVersion := "TLS11" - _, err := GetTLSVersion(tlsVersion) + _, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS12", + TLSMaxVersion: "TLS11", + }) g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)")) }) - t.Run("should pass and output correct tls version", func(t *testing.T) { - const VersionTLS12 uint16 = 771 + t.Run("should apply the requested TLS versions", func(t *testing.T) { + g := NewWithT(t) + tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS12", + TLSMaxVersion: "TLS13", + }) + + var tlsConfig tls.Config + for _, apply := range tlsOptionOverrides { + apply(&tlsConfig) + } + + g.Expect(err).Should(BeNil()) + g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12))) + g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13))) + }) + t.Run("should apply the requested non-default TLS versions", func(t *testing.T) { g := NewWithT(t) - tlsVersion := "TLS12" - version, err := GetTLSVersion(tlsVersion) - g.Expect(version).To(Equal(VersionTLS12)) + tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS13", + TLSMaxVersion: "TLS13", + }) + + var tlsConfig tls.Config + for _, apply := range tlsOptionOverrides { + apply(&tlsConfig) + } + g.Expect(err).Should(BeNil()) + g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13))) + g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13))) }) }