Skip to content

Commit

Permalink
feat: tls strategy and strategy walk
Browse files Browse the repository at this point in the history
Extract TLS certificate wait strategy into a dedicated wait type so it
can be reused.

Implement walk method which can be used to identify wait strategies in
a request chain.

Use embed to simplify wait test loading of certs.
  • Loading branch information
stevenh committed Nov 25, 2024
1 parent 892cb08 commit 95f93fa
Show file tree
Hide file tree
Showing 14 changed files with 521 additions and 29 deletions.
19 changes: 19 additions & 0 deletions docs/features/wait/tls.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# TLS Strategy

TLS Strategy waits for one or more files to exist in the container and uses them
and other details to construct a `tls.Config` which can be used to create secure
connections.

It supports:

- x509 PEM Certificate loaded from a certificate / key file pair.
- Root Certificate Authorities aka RootCAs loaded from PEM encoded files.
- Server name.
- Startup timeout to be used in seconds, default is 60 seconds.
- Poll interval to be used in milliseconds, default is 100 milliseconds.

## Waiting for certificate pair to exist and construct a tls.Config

<!--codeinclude-->
[Waiting for certificate pair to exist and construct a tls.Config](../../../wait/tls_test.go) inside_block:waitForTLSCert
<!--/codeinclude-->
2 changes: 1 addition & 1 deletion wait/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

const testFilename = "/tmp/file"

var anyContext = mock.AnythingOfType("*context.timerCtx")
var anyContext = mock.MatchedBy(func(_ context.Context) bool { return true })

// newRunningTarget creates a new mockStrategyTarget that is running.
func newRunningTarget() *mockStrategyTarget {
Expand Down
39 changes: 11 additions & 28 deletions wait/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
_ "embed"
"fmt"
"io"
"log"
Expand All @@ -23,6 +24,9 @@ import (
"github.com/testcontainers/testcontainers-go/wait"
)

//go:embed testdata/root.pem
var caBytes []byte

// https://github.com/testcontainers/testcontainers-go/issues/183
func ExampleHTTPStrategy() {
// waitForHTTPWithDefaultPort {
Expand Down Expand Up @@ -80,7 +84,7 @@ func ExampleHTTPStrategy_WithHeaders() {
tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
req := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: "testdata",
Context: "testdata/http",
},
ExposedPorts: []string{"6443/tcp"},
WaitingFor: wait.ForHTTP("/headers").
Expand Down Expand Up @@ -227,20 +231,13 @@ func ExampleHTTPStrategy_WithBasicAuth() {
}

func TestHTTPStrategyWaitUntilReady(t *testing.T) {
workdir, err := os.Getwd()
require.NoError(t, err)

capath := filepath.Join(workdir, "testdata", "root.pem")
cafile, err := os.ReadFile(capath)
require.NoError(t, err)

certpool := x509.NewCertPool()
require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid")
require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid")

tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
dockerReq := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: filepath.Join(workdir, "testdata"),
Context: "testdata/http",
},
ExposedPorts: []string{"6443/tcp"},
WaitingFor: wait.NewHTTPStrategy("/auth-ping").WithTLS(true, tlsconfig).
Expand Down Expand Up @@ -288,20 +285,13 @@ func TestHTTPStrategyWaitUntilReady(t *testing.T) {
}

func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) {
workdir, err := os.Getwd()
require.NoError(t, err)

capath := filepath.Join(workdir, "testdata", "root.pem")
cafile, err := os.ReadFile(capath)
require.NoError(t, err)

certpool := x509.NewCertPool()
require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid")
require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid")

tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
dockerReq := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: filepath.Join(workdir, "testdata"),
Context: "testdata/http",
},

ExposedPorts: []string{"6443/tcp"},
Expand Down Expand Up @@ -348,22 +338,15 @@ func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) {
}

func TestHTTPStrategyWaitUntilReadyNoBasicAuth(t *testing.T) {
workdir, err := os.Getwd()
require.NoError(t, err)

capath := filepath.Join(workdir, "testdata", "root.pem")
cafile, err := os.ReadFile(capath)
require.NoError(t, err)

certpool := x509.NewCertPool()
require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid")
require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid")

// waitForHTTPStatusCode {
tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
var i int
dockerReq := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: filepath.Join(workdir, "testdata"),
Context: "testdata/http",
},
ExposedPorts: []string{"6443/tcp"},
WaitingFor: wait.NewHTTPStrategy("/ping").WithTLS(true, tlsconfig).
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
5 changes: 5 additions & 0 deletions wait/testdata/http/tls-key.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIM8HuDwcZyVqBBy2C6db6zNb/dAJ69bq5ejAEz7qGOIQoAoGCCqGSM49
AwEHoUQDQgAEBL2ioRmfTc70WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG
0735iy9Fz16PX4vqnLMiM/ZupugAhB//yA==
-----END EC PRIVATE KEY-----
12 changes: 12 additions & 0 deletions wait/testdata/http/tls.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBxTCCAWugAwIBAgIUWBLNpiF1o4r+5ZXwawzPOfBM1F8wCgYIKoZIzj0EAwIw
ADAeFw0yMDA4MTkxMzM4MDBaFw0zMDA4MTcxMzM4MDBaMBkxFzAVBgNVBAMTDnRl
c3Rjb250YWluZXJzMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBL2ioRmfTc70
WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG0735iy9Fz16PX4vqnLMiM/Zu
pugAhB//yKOBqTCBpjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH
AwEwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUTMdz5PIZ+Gix4jYUzRIHfByrW+Yw
HwYDVR0jBBgwFoAUFdfV6PSYUlHs+lSQNouRwSfR2ZgwMQYDVR0RBCowKIIVdGVz
dGNvbnRhaW5lci5nby50ZXN0gglsb2NhbGhvc3SHBH8AAAEwCgYIKoZIzj0EAwID
SAAwRQIhAJznPNumi2Plf0GsP9DpC+8WukT/jUhnhcDWCfZ6Ini2AiBLhnhFebZX
XWfSsdSNxIo20OWvy6z3wqdybZtRUfdU+g==
-----END CERTIFICATE-----
167 changes: 167 additions & 0 deletions wait/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package wait

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"time"
)

// Validate we implement interface.
var _ Strategy = (*TLSStrategy)(nil)

// TLSStrategy is a strategy for handling TLS.
type TLSStrategy struct {
// General Settings.
timeout *time.Duration
pollInterval time.Duration

// Custom Settings.
certFiles *x509KeyPair
rootFiles []string

// State.
tlsConfig *tls.Config
}

// x509KeyPair is a pair of certificate and key files.
type x509KeyPair struct {
certPEMFile string
keyPEMFile string
}

// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config]
// constructed from PEM formatted certificate key file pair in the container.
func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy {
return &TLSStrategy{
certFiles: &x509KeyPair{
certPEMFile: certPEMFile,
keyPEMFile: keyPEMFile,
},
tlsConfig: &tls.Config{},
pollInterval: defaultPollInterval(),
}
}

// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config]
// using the given PEM formatted files from the container.
func ForTLSRootCAs(pemFiles ...string) *TLSStrategy {
return &TLSStrategy{
rootFiles: pemFiles,
tlsConfig: &tls.Config{},
pollInterval: defaultPollInterval(),
}
}

// WithRootCAs sets the root CAs for the [tls.Config] using the given files from
// the container.
func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy {
ws.rootFiles = files
return ws
}

// WithCert sets the [tls.Config] Certificates using the given files from the container.
func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy {
ws.certFiles = &x509KeyPair{
certPEMFile: certPEMFile,
keyPEMFile: keyPEMFile,
}
return ws
}

// WithServerName sets the server for the [tls.Config].
func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy {
ws.tlsConfig.ServerName = serverName
return ws
}

// WithStartupTimeout can be used to change the default startup timeout.
func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy {
ws.timeout = &startupTimeout
return ws
}

// WithPollInterval can be used to override the default polling interval of 100 milliseconds.
func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy {
ws.pollInterval = pollInterval
return ws
}

// TLSConfig returns the TLS config once the strategy is ready.
// If the strategy is nil, it returns nil.
func (ws *TLSStrategy) TLSConfig() *tls.Config {
if ws == nil {
return nil
}

return ws.tlsConfig
}

// WaitUntilReady implements the [Strategy] interface.
// It waits for the CA, client cert and key files to be available in the container and
// uses them to setup the TLS config.
func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error {
size := len(ws.rootFiles)
if ws.certFiles != nil {
size += 2
}
strategies := make([]Strategy, 0, size)
for _, file := range ws.rootFiles {
strategies = append(strategies,
ForFile(file).WithMatcher(func(r io.Reader) error {
buf, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("read CA cert file %q: %w", file, err)
}

if ws.tlsConfig.RootCAs == nil {
ws.tlsConfig.RootCAs = x509.NewCertPool()
}

if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) {
return fmt.Errorf("invalid CA cert file %q", file)
}

return nil
}).WithPollInterval(ws.pollInterval),
)
}

if ws.certFiles != nil {
var certPEMBlock []byte
strategies = append(strategies,
ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error {
var err error
if certPEMBlock, err = io.ReadAll(r); err != nil {
return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err)
}

return nil
}).WithPollInterval(ws.pollInterval),
ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error {
keyPEMBlock, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err)
}

cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err)
}

ws.tlsConfig.Certificates = []tls.Certificate{cert}

return nil
}).WithPollInterval(ws.pollInterval),
)
}

strategy := ForAll(strategies...)
if ws.timeout != nil {
strategy.WithStartupTimeout(*ws.timeout)
}

return strategy.WaitUntilReady(ctx, target)
}
Loading

0 comments on commit 95f93fa

Please sign in to comment.