Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wait): tls strategy #2896

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/features/wait/tls.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# TLS Strategy

TLS Strategy waits for one or more files to exist in the container and uses them
and other details to construct a `tls.Config` which can be used to create secure
connections.

It supports:

- x509 PEM Certificate loaded from a certificate / key file pair.
- Root Certificate Authorities aka RootCAs loaded from PEM encoded files.
- Server name.
- Startup timeout to be used in seconds, default is 60 seconds.
- Poll interval to be used in milliseconds, default is 100 milliseconds.

## Waiting for certificate pair to exist and construct a tls.Config

<!--codeinclude-->
[Waiting for certificate pair to exist and construct a tls.Config](../../../wait/tls_test.go) inside_block:waitForTLSCert
<!--/codeinclude-->
2 changes: 1 addition & 1 deletion wait/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

const testFilename = "/tmp/file"

var anyContext = mock.AnythingOfType("*context.timerCtx")
var anyContext = mock.MatchedBy(func(_ context.Context) bool { return true })

// newRunningTarget creates a new mockStrategyTarget that is running.
func newRunningTarget() *mockStrategyTarget {
Expand Down
39 changes: 11 additions & 28 deletions wait/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
_ "embed"
"fmt"
"io"
"log"
Expand All @@ -23,6 +24,9 @@ import (
"github.com/testcontainers/testcontainers-go/wait"
)

//go:embed testdata/root.pem
var caBytes []byte

// https://github.com/testcontainers/testcontainers-go/issues/183
func ExampleHTTPStrategy() {
// waitForHTTPWithDefaultPort {
Expand Down Expand Up @@ -80,7 +84,7 @@ func ExampleHTTPStrategy_WithHeaders() {
tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
req := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: "testdata",
Context: "testdata/http",
},
ExposedPorts: []string{"6443/tcp"},
WaitingFor: wait.ForHTTP("/headers").
Expand Down Expand Up @@ -227,20 +231,13 @@ func ExampleHTTPStrategy_WithBasicAuth() {
}

func TestHTTPStrategyWaitUntilReady(t *testing.T) {
workdir, err := os.Getwd()
require.NoError(t, err)

capath := filepath.Join(workdir, "testdata", "root.pem")
cafile, err := os.ReadFile(capath)
require.NoError(t, err)

certpool := x509.NewCertPool()
require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid")
require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid")

tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
dockerReq := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: filepath.Join(workdir, "testdata"),
Context: "testdata/http",
},
ExposedPorts: []string{"6443/tcp"},
WaitingFor: wait.NewHTTPStrategy("/auth-ping").WithTLS(true, tlsconfig).
Expand Down Expand Up @@ -288,20 +285,13 @@ func TestHTTPStrategyWaitUntilReady(t *testing.T) {
}

func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) {
workdir, err := os.Getwd()
require.NoError(t, err)

capath := filepath.Join(workdir, "testdata", "root.pem")
cafile, err := os.ReadFile(capath)
require.NoError(t, err)

certpool := x509.NewCertPool()
require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid")
require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid")

tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
dockerReq := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: filepath.Join(workdir, "testdata"),
Context: "testdata/http",
},

ExposedPorts: []string{"6443/tcp"},
Expand Down Expand Up @@ -348,22 +338,15 @@ func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) {
}

func TestHTTPStrategyWaitUntilReadyNoBasicAuth(t *testing.T) {
workdir, err := os.Getwd()
require.NoError(t, err)

capath := filepath.Join(workdir, "testdata", "root.pem")
cafile, err := os.ReadFile(capath)
require.NoError(t, err)

certpool := x509.NewCertPool()
require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid")
require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid")

// waitForHTTPStatusCode {
tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"}
var i int
dockerReq := testcontainers.ContainerRequest{
FromDockerfile: testcontainers.FromDockerfile{
Context: filepath.Join(workdir, "testdata"),
Context: "testdata/http",
},
ExposedPorts: []string{"6443/tcp"},
WaitingFor: wait.NewHTTPStrategy("/ping").WithTLS(true, tlsconfig).
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
5 changes: 5 additions & 0 deletions wait/testdata/http/tls-key.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIM8HuDwcZyVqBBy2C6db6zNb/dAJ69bq5ejAEz7qGOIQoAoGCCqGSM49
AwEHoUQDQgAEBL2ioRmfTc70WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG
0735iy9Fz16PX4vqnLMiM/ZupugAhB//yA==
-----END EC PRIVATE KEY-----
12 changes: 12 additions & 0 deletions wait/testdata/http/tls.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBxTCCAWugAwIBAgIUWBLNpiF1o4r+5ZXwawzPOfBM1F8wCgYIKoZIzj0EAwIw
ADAeFw0yMDA4MTkxMzM4MDBaFw0zMDA4MTcxMzM4MDBaMBkxFzAVBgNVBAMTDnRl
c3Rjb250YWluZXJzMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBL2ioRmfTc70
WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG0735iy9Fz16PX4vqnLMiM/Zu
pugAhB//yKOBqTCBpjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH
AwEwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUTMdz5PIZ+Gix4jYUzRIHfByrW+Yw
HwYDVR0jBBgwFoAUFdfV6PSYUlHs+lSQNouRwSfR2ZgwMQYDVR0RBCowKIIVdGVz
dGNvbnRhaW5lci5nby50ZXN0gglsb2NhbGhvc3SHBH8AAAEwCgYIKoZIzj0EAwID
SAAwRQIhAJznPNumi2Plf0GsP9DpC+8WukT/jUhnhcDWCfZ6Ini2AiBLhnhFebZX
XWfSsdSNxIo20OWvy6z3wqdybZtRUfdU+g==
-----END CERTIFICATE-----
167 changes: 167 additions & 0 deletions wait/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package wait

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"time"
)

// Validate we implement interface.
var _ Strategy = (*TLSStrategy)(nil)

// TLSStrategy is a strategy for handling TLS.
type TLSStrategy struct {
// General Settings.
timeout *time.Duration
pollInterval time.Duration

// Custom Settings.
certFiles *x509KeyPair
rootFiles []string

// State.
tlsConfig *tls.Config
}

// x509KeyPair is a pair of certificate and key files.
type x509KeyPair struct {
certPEMFile string
keyPEMFile string
}

// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config]
// constructed from PEM formatted certificate key file pair in the container.
func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy {
Copy link
Member

@mdelapenya mdelapenya Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: do you think it would be interesting to provide a way for consumers to forget about generating certs and the library build them on the fly? Something like wait.ForTLSTestCert? The library would generate the file under the hood.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep love that idea, something for a follow up PR.

return &TLSStrategy{
certFiles: &x509KeyPair{
certPEMFile: certPEMFile,
keyPEMFile: keyPEMFile,
},
tlsConfig: &tls.Config{},
pollInterval: defaultPollInterval(),
}
}

// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config]
// using the given PEM formatted files from the container.
func ForTLSRootCAs(pemFiles ...string) *TLSStrategy {
return &TLSStrategy{
rootFiles: pemFiles,
tlsConfig: &tls.Config{},
pollInterval: defaultPollInterval(),
}
}

// WithRootCAs sets the root CAs for the [tls.Config] using the given files from
// the container.
func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy {
ws.rootFiles = files
return ws
}

// WithCert sets the [tls.Config] Certificates using the given files from the container.
func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy {
ws.certFiles = &x509KeyPair{
certPEMFile: certPEMFile,
keyPEMFile: keyPEMFile,
}
return ws
}

// WithServerName sets the server for the [tls.Config].
func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy {
ws.tlsConfig.ServerName = serverName
return ws
}

// WithStartupTimeout can be used to change the default startup timeout.
func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy {
ws.timeout = &startupTimeout
return ws
}

// WithPollInterval can be used to override the default polling interval of 100 milliseconds.
func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy {
ws.pollInterval = pollInterval
return ws
}

// TLSConfig returns the TLS config once the strategy is ready.
// If the strategy is nil, it returns nil.
func (ws *TLSStrategy) TLSConfig() *tls.Config {
if ws == nil {
return nil
}

return ws.tlsConfig
}

// WaitUntilReady implements the [Strategy] interface.
// It waits for the CA, client cert and key files to be available in the container and
// uses them to setup the TLS config.
func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error {
size := len(ws.rootFiles)
if ws.certFiles != nil {
size += 2
}
strategies := make([]Strategy, 0, size)
for _, file := range ws.rootFiles {
strategies = append(strategies,
ForFile(file).WithMatcher(func(r io.Reader) error {
buf, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("read CA cert file %q: %w", file, err)
}

if ws.tlsConfig.RootCAs == nil {
ws.tlsConfig.RootCAs = x509.NewCertPool()
}

if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) {
return fmt.Errorf("invalid CA cert file %q", file)
}

return nil
}).WithPollInterval(ws.pollInterval),
)
}

if ws.certFiles != nil {
var certPEMBlock []byte
strategies = append(strategies,
ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error {
var err error
if certPEMBlock, err = io.ReadAll(r); err != nil {
return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err)
}

return nil
}).WithPollInterval(ws.pollInterval),
ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error {
keyPEMBlock, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err)
}

cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err)
}

ws.tlsConfig.Certificates = []tls.Certificate{cert}

return nil
}).WithPollInterval(ws.pollInterval),
)
}

strategy := ForAll(strategies...)
if ws.timeout != nil {
strategy.WithStartupTimeout(*ws.timeout)
}

return strategy.WaitUntilReady(ctx, target)
}
Loading