Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Hot reloading of TLS key stores #1230

Merged
merged 18 commits into from
Mar 11, 2024
Merged
6 changes: 5 additions & 1 deletion internal/cache/redis/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/dadrus/heimdall/internal/heimdall"
"github.com/dadrus/heimdall/internal/watcher"
"github.com/dadrus/heimdall/internal/x/errorchain"
"github.com/dadrus/heimdall/internal/x/tlsx"
)

// for test purposes only.
Expand Down Expand Up @@ -144,7 +145,10 @@ func (c baseConfig) clientOptions(cw watcher.Watcher) (rueidis.ClientOption, err
)

if !c.TLS.Disabled {
tlsCfg, err = c.TLS.TLSConfig()
tlsCfg, err = tlsx.ToTLSConfig(&c.TLS.TLS,
tlsx.WithClientAuthentication(len(c.TLS.KeyStore.Path) != 0),
tlsx.WithSecretsWatcher(cw),
)
if err != nil {
return rueidis.ClientOption{}, errorchain.NewWithMessage(heimdall.ErrInternal,
"failed creating tls configuration for Redis client").CausedBy(err)
Expand Down
4 changes: 3 additions & 1 deletion internal/cache/redis/standalone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,11 @@ func TestNewStandaloneCache(t *testing.T) {
},
{
uc: "successful cache creation with mutual TLS",
config: func(t *testing.T, _ *mocks.WatcherMock) []byte {
config: func(t *testing.T, wm *mocks.WatcherMock) []byte {
t.Helper()

wm.EXPECT().Add(mock.Anything, mock.Anything).Return(nil)

rootCertPool = x509.NewCertPool()
rootCertPool.AddCert(cert)

Expand Down
49 changes: 0 additions & 49 deletions internal/config/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ package config

import (
"crypto/tls"

"github.com/dadrus/heimdall/internal/heimdall"
"github.com/dadrus/heimdall/internal/keystore"
"github.com/dadrus/heimdall/internal/x/errorchain"
)

type TLSCipherSuites []uint16
Expand Down Expand Up @@ -66,48 +62,3 @@ type TLS struct {
CipherSuites TLSCipherSuites `koanf:"cipher_suites" mapstructure:"cipher_suites"`
MinVersion TLSMinVersion `koanf:"min_version" mapstructure:"min_version"`
}

func (t *TLS) TLSConfig() (*tls.Config, error) {
var eeCerts []tls.Certificate

if len(t.KeyStore.Path) != 0 { //nolint:nestif
ks, err := keystore.NewKeyStoreFromPEMFile(t.KeyStore.Path, t.KeyStore.Password)
if err != nil {
return nil, errorchain.NewWithMessage(heimdall.ErrInternal, "failed loading keystore").
CausedBy(err)
}

var entry *keystore.Entry

if len(t.KeyID) != 0 {
if entry, err = ks.GetKey(t.KeyID); err != nil {
return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration,
"failed retrieving key from key store").CausedBy(err)
}
} else {
entry = ks.Entries()[0]
}

cert, err := entry.TLSCertificate()
if err != nil {
return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration,
"key store entry is not suitable for TLS").CausedBy(err)
}

eeCerts = []tls.Certificate{cert}
}

// nolint:gosec
// configuration ensures, TLS versions below 1.2 are not possible
cfg := &tls.Config{
Certificates: eeCerts,
MinVersion: t.MinVersion.OrDefault(),
NextProtos: []string{"h2", "http/1.1"},
}

if cfg.MinVersion != tls.VersionTLS13 {
cfg.CipherSuites = t.CipherSuites.OrDefault()
}

return cfg, nil
}
164 changes: 0 additions & 164 deletions internal/config/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,10 @@
package config

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dadrus/heimdall/internal/heimdall"
"github.com/dadrus/heimdall/internal/x/pkix/pemx"
"github.com/dadrus/heimdall/internal/x/testsupport"
)

func TestTLSMinVersionOrDefault(t *testing.T) {
Expand Down Expand Up @@ -86,153 +72,3 @@ func TestTLSCipherSuitesOrDefault(t *testing.T) {
})
}
}

func TestTLSConfig(t *testing.T) {
t.Parallel()

testDir := t.TempDir()

privKey1, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
require.NoError(t, err)

privKey2, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
require.NoError(t, err)

cert, err := testsupport.NewCertificateBuilder(testsupport.WithValidity(time.Now(), 10*time.Hour),
testsupport.WithSerialNumber(big.NewInt(1)),
testsupport.WithSubject(pkix.Name{
CommonName: "test cert",
Organization: []string{"Test"},
Country: []string{"EU"},
}),
testsupport.WithSubjectPubKey(&privKey1.PublicKey, x509.ECDSAWithSHA384),
testsupport.WithSelfSigned(),
testsupport.WithSignaturePrivKey(privKey1)).
Build()
require.NoError(t, err)

pemBytes, err := pemx.BuildPEM(
pemx.WithECDSAPrivateKey(privKey1, pemx.WithHeader("X-Key-ID", "key1")),
pemx.WithX509Certificate(cert),
pemx.WithECDSAPrivateKey(privKey2, pemx.WithHeader("X-Key-ID", "key2")),
)
require.NoError(t, err)

pemFile, err := os.Create(filepath.Join(testDir, "keystore.pem"))
require.NoError(t, err)

_, err = pemFile.Write(pemBytes)
require.NoError(t, err)

for _, tc := range []struct {
uc string
conf TLS
assert func(t *testing.T, err error, conf *tls.Config)
}{
{
uc: "empty config",
assert: func(t *testing.T, err error, conf *tls.Config) {
t.Helper()

require.NoError(t, err)
require.NotNil(t, conf)

assert.Empty(t, conf.Certificates)
assert.Equal(t, uint16(tls.VersionTLS13), conf.MinVersion)
assert.Len(t, conf.NextProtos, 2)
assert.Contains(t, conf.NextProtos, "h2")
assert.Contains(t, conf.NextProtos, "http/1.1")
},
},
{
uc: "fails due to not existent key store for TLS usage",
conf: TLS{KeyStore: KeyStore{Path: "/no/such/file"}},
assert: func(t *testing.T, err error, _ *tls.Config) {
t.Helper()

require.Error(t, err)
require.ErrorIs(t, err, heimdall.ErrInternal)
assert.Contains(t, err.Error(), "failed loading")
},
},
{
uc: "fails due to not existent key for the given key id for TLS usage",
conf: TLS{
KeyStore: KeyStore{Path: pemFile.Name()},
KeyID: "foo",
MinVersion: tls.VersionTLS12,
},
assert: func(t *testing.T, err error, _ *tls.Config) {
t.Helper()

require.Error(t, err)
require.ErrorIs(t, err, heimdall.ErrConfiguration)
assert.Contains(t, err.Error(), "no such key")
},
},
{
uc: "fails due to not present certificates for the given key id",
conf: TLS{
KeyStore: KeyStore{Path: pemFile.Name()},
KeyID: "key2",
MinVersion: tls.VersionTLS12,
},
assert: func(t *testing.T, err error, _ *tls.Config) {
t.Helper()

require.Error(t, err)
require.ErrorIs(t, err, heimdall.ErrConfiguration)
assert.Contains(t, err.Error(), "no certificate present")
},
},
{
uc: "successful with default key",
conf: TLS{
KeyStore: KeyStore{Path: pemFile.Name()},
MinVersion: tls.VersionTLS12,
},
assert: func(t *testing.T, err error, conf *tls.Config) {
t.Helper()

require.NoError(t, err)
require.NotNil(t, conf)

assert.Len(t, conf.Certificates, 1)
assert.Equal(t, cert, conf.Certificates[0].Leaf)
assert.Equal(t, uint16(tls.VersionTLS12), conf.MinVersion)
assert.Len(t, conf.NextProtos, 2)
assert.Contains(t, conf.NextProtos, "h2")
assert.Contains(t, conf.NextProtos, "http/1.1")
},
},
{
uc: "successful with specified key id",
conf: TLS{
KeyStore: KeyStore{Path: pemFile.Name()},
KeyID: "key1",
MinVersion: tls.VersionTLS12,
},
assert: func(t *testing.T, err error, conf *tls.Config) {
t.Helper()

require.NoError(t, err)
require.NotNil(t, conf)

assert.Len(t, conf.Certificates, 1)
assert.Equal(t, cert, conf.Certificates[0].Leaf)
assert.Equal(t, uint16(tls.VersionTLS12), conf.MinVersion)
assert.Len(t, conf.NextProtos, 2)
assert.Contains(t, conf.NextProtos, "h2")
assert.Contains(t, conf.NextProtos, "http/1.1")
},
},
} {
t.Run(tc.uc, func(t *testing.T) {
// WHEN
conf, err := tc.conf.TLSConfig()

// THEN
tc.assert(t, err, conf)
})
}
}
3 changes: 3 additions & 0 deletions internal/handler/decision/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/dadrus/heimdall/internal/handler/fxlcm"
"github.com/dadrus/heimdall/internal/heimdall"
"github.com/dadrus/heimdall/internal/rules/rule"
"github.com/dadrus/heimdall/internal/watcher"
)

var Module = fx.Invoke( // nolint: gochecknoglobals
Expand All @@ -43,6 +44,7 @@ func newLifecycleManager(
cch cache.Cache,
exec rule.Executor,
signer heimdall.JWTSigner,
cw watcher.Watcher,
) *fxlcm.LifecycleManager {
cfg := conf.Serve.Decision

Expand All @@ -52,5 +54,6 @@ func newLifecycleManager(
Server: newService(conf, cch, logger, exec, signer),
Logger: logger,
TLSConf: cfg.TLS,
FileWatcher: cw,
}
}
2 changes: 1 addition & 1 deletion internal/handler/decision/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ func TestHandleDecisionEndpointRequest(t *testing.T) {
srvConf.Host = "127.0.0.1"
srvConf.Port = port

listener, err := listener.New("tcp", srvConf.Address(), srvConf.TLS)
listener, err := listener.New("tcp", srvConf.Address(), srvConf.TLS, nil)
require.NoError(t, err)

conf := &config.Configuration{Serve: config.ServeConfig{Decision: srvConf}}
Expand Down
7 changes: 5 additions & 2 deletions internal/handler/envoyextauth/grpcv3/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/dadrus/heimdall/internal/handler/fxlcm"
"github.com/dadrus/heimdall/internal/heimdall"
"github.com/dadrus/heimdall/internal/rules/rule"
"github.com/dadrus/heimdall/internal/watcher"
)

var Module = fx.Invoke( // nolint: gochecknoglobals
Expand All @@ -43,6 +44,7 @@ func newLifecycleManager(
exec rule.Executor,
signer heimdall.JWTSigner,
cch cache.Cache,
cw watcher.Watcher,
) *fxlcm.LifecycleManager {
cfg := conf.Serve.Decision

Expand All @@ -52,7 +54,8 @@ func newLifecycleManager(
Server: &adapter{
s: newService(conf, cch, logger, exec, signer),
},
Logger: logger,
TLSConf: cfg.TLS,
Logger: logger,
TLSConf: cfg.TLS,
FileWatcher: cw,
}
}
4 changes: 3 additions & 1 deletion internal/handler/fxlcm/lifecycle_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/dadrus/heimdall/internal/config"
"github.com/dadrus/heimdall/internal/handler/listener"
"github.com/dadrus/heimdall/internal/watcher"
)

//go:generate mockery --name Server --structname ServerMock
Expand All @@ -41,10 +42,11 @@ type LifecycleManager struct {
Server Server
Logger zerolog.Logger
TLSConf *config.TLS
FileWatcher watcher.Watcher
}

func (m *LifecycleManager) Start(_ context.Context) error {
ln, err := listener.New("tcp", m.ServiceAddress, m.TLSConf)
ln, err := listener.New("tcp", m.ServiceAddress, m.TLSConf, m.FileWatcher)
if err != nil {
m.Logger.Fatal().Err(err).Str("_service", m.ServiceName).Msg("Could not create listener")

Expand Down
18 changes: 9 additions & 9 deletions internal/handler/listener/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import (

"github.com/dadrus/heimdall/internal/config"
"github.com/dadrus/heimdall/internal/heimdall"
"github.com/dadrus/heimdall/internal/watcher"
"github.com/dadrus/heimdall/internal/x/errorchain"
"github.com/dadrus/heimdall/internal/x/tlsx"
)

type conn struct {
Expand Down Expand Up @@ -89,7 +91,7 @@ func (l *listener) Accept() (net.Conn, error) {
return &conn{Conn: con}, nil
}

func New(network, address string, tlsConf *config.TLS) (net.Listener, error) {
func New(network, address string, tlsConf *config.TLS, cw watcher.Watcher) (net.Listener, error) {
listnr, err := net.Listen(network, address)
if err != nil {
return nil, errorchain.NewWithMessage(heimdall.ErrInternal, "failed creating listener").
Expand All @@ -99,22 +101,20 @@ func New(network, address string, tlsConf *config.TLS) (net.Listener, error) {
wrapped := &listener{Listener: listnr}

if tlsConf != nil {
return newTLSListener(tlsConf, wrapped)
return newTLSListener(tlsConf, wrapped, cw)
}

return wrapped, nil
}

func newTLSListener(tlsConf *config.TLS, listener net.Listener) (net.Listener, error) {
cfg, err := tlsConf.TLSConfig()
func newTLSListener(tlsConf *config.TLS, listener net.Listener, cw watcher.Watcher) (net.Listener, error) {
cfg, err := tlsx.ToTLSConfig(tlsConf,
tlsx.WithServerAuthentication(true),
tlsx.WithSecretsWatcher(cw),
)
if err != nil {
return nil, err
}

if len(cfg.Certificates) == 0 {
return nil, errorchain.NewWithMessage(heimdall.ErrConfiguration,
"no tls server key and certificate available")
}

return tls.NewListener(listener, cfg), nil
}
Loading
Loading