diff --git a/tls.go b/tls.go index 0c19303..9426925 100644 --- a/tls.go +++ b/tls.go @@ -344,7 +344,8 @@ func NewKeypairReloader(logger FieldLogger, certPath, keyPath string, clientCert } // maybeReload reloads TLS cert and updates client certificates. -// Client certificates are used to conenct to gubernator peers. +// Client certificates are used to conenct to gubernator peers. Note that +// maybeReload is triggered upon SIGHUP func (kpr *keypairReloader) maybeReload() error { newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath) if err != nil { @@ -357,6 +358,11 @@ func (kpr *keypairReloader) maybeReload() error { return nil } +func (kpr *keypairReloader) UpdatePath(certPath, keyPath string) { + kpr.certPath = certPath + kpr.keyPath = keyPath +} + func (kpr *keypairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { kpr.certMu.RLock() diff --git a/tls_test.go b/tls_test.go index f6b56ba..aac48fc 100644 --- a/tls_test.go +++ b/tls_test.go @@ -22,11 +22,19 @@ import ( "fmt" "io" "net/http" + "os" + "os/signal" "strings" + "syscall" "testing" + "time" + + // "time" "github.com/gubernator-io/gubernator/v2" "github.com/mailgun/holster/v4/clock" + "github.com/mailgun/holster/v4/retry" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/http2" @@ -340,4 +348,42 @@ func TestHTTPSClientAuth(t *testing.T) { b, err = io.ReadAll(resp2.Body) require.NoError(t, err) assert.Equal(t, `{"status":"healthy","message":"","peer_count":1}`, strings.ReplaceAll(string(b), " ", "")) + +} + +func TestReloadTLS(t *testing.T) { + reloader, err := gubernator.NewKeypairReloader(logrus.WithField("category", "gubernator"), + "contrib/certs/gubernator.pem", + "contrib/certs/gubernator.key", + []tls.Certificate{}) + require.NoError(t, err) + + // Get the cert for the first time + certFn := reloader.GetCertificateFunc() + cert1, err := certFn(&tls.ClientHelloInfo{}) + require.NoError(t, err) + + // update the cert file and first a SIGHUP signal + reloader.UpdatePath("contrib/certs/client-auth.pem", "contrib/certs/client-auth.key") + c := make(chan os.Signal, 10) + signal.Notify(c, syscall.SIGHUP) + syscall.Kill(syscall.Getpid(), syscall.SIGHUP) + <-c + + // Wait until cert is reloaded + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = retry.Until(ctx, retry.Interval(clock.Millisecond*300), + func(ctx context.Context, i int) error { + cert2, err := certFn(&tls.ClientHelloInfo{}) + if err != nil { + return err + } + if cert1 == cert2 { + return fmt.Errorf("cert not updated") + } + t.Logf("tls reloaded successfully after retry %d times", i) + return nil + }) + require.NoError(t, err) }