Skip to content

Commit

Permalink
CHERRY-PICK: Refactoring: never assign unacceptable TLS versions
Browse files Browse the repository at this point in the history
This commit makes security linting easier by never setting a TLS version
outside v1.2 or v1.3, even in case of an unacceptable user input.

Upstream PR: kubernetes-sigs#2037
(cherry picked from commit 27526d5)
  • Loading branch information
pierreprinetti authored and mdbooth committed Apr 25, 2024
1 parent 9753c5c commit 49c1705
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
36 changes: 13 additions & 23 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,19 @@ func concurrency(c int) controller.Options {
func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) {
var tlsOptions []func(config *tls.Config)

tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion)
if err != nil {
return nil, err
}

tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion)
if err != nil {
return nil, err
// To make a static analyzer happy, this block ensures there is no code
// path that sets a TLS version outside the acceptable values, even in
// case of unexpected user input.
var tlsMinVersion, tlsMaxVersion uint16
for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} {
switch option {
case TLSVersion12:
*version = tls.VersionTLS12
case TLSVersion13:
*version = tls.VersionTLS13
default:
return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", "))
}
}

if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion {
Expand Down Expand Up @@ -421,18 +426,3 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error)

return tlsOptions, nil
}

// GetTLSVersion returns the corresponding tls.Version or error.
func GetTLSVersion(version string) (uint16, error) {
var v uint16

switch version {
case TLSVersion12:
v = tls.VersionTLS12
case TLSVersion13:
v = tls.VersionTLS13
default:
return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", "))
}
return v, nil
}
55 changes: 45 additions & 10 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"bytes"
"crypto/tls"
"testing"

. "github.com/onsi/gomega"
Expand Down Expand Up @@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) {
klog.SetOutput(bufWriter)
klog.LogToStderr(false) // this is important, because klog by default logs to stderr only
_, err := GetTLSOptionOverrideFuncs(tlsMockOptions)
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
g.Expect(err).Should(BeNil())
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
})
}

func TestGetTLSVersion(t *testing.T) {
t.Run("should error out when incorrect tls version passed", func(t *testing.T) {
func TestGetTLSOverrideFuncs(t *testing.T) {
t.Run("should error out when incorrect min tls version passed", func(t *testing.T) {
g := NewWithT(t)
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS11",
TLSMaxVersion: "TLS12",
})
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
})
t.Run("should error out when incorrect max tls version passed", func(t *testing.T) {
g := NewWithT(t)
tlsVersion := "TLS11"
_, err := GetTLSVersion(tlsVersion)
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS12",
TLSMaxVersion: "TLS11",
})
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
})
t.Run("should pass and output correct tls version", func(t *testing.T) {
const VersionTLS12 uint16 = 771
t.Run("should apply the requested TLS versions", func(t *testing.T) {
g := NewWithT(t)
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS12",
TLSMaxVersion: "TLS13",
})

var tlsConfig tls.Config
for _, apply := range tlsOptionOverrides {
apply(&tlsConfig)
}

g.Expect(err).Should(BeNil())
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
})
t.Run("should apply the requested non-default TLS versions", func(t *testing.T) {
g := NewWithT(t)
tlsVersion := "TLS12"
version, err := GetTLSVersion(tlsVersion)
g.Expect(version).To(Equal(VersionTLS12))
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS13",
TLSMaxVersion: "TLS13",
})

var tlsConfig tls.Config
for _, apply := range tlsOptionOverrides {
apply(&tlsConfig)
}

g.Expect(err).Should(BeNil())
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
})
}

Expand Down

0 comments on commit 49c1705

Please sign in to comment.