diff --git a/integration/hsm/helpers.go b/integration/hsm/helpers.go index bf2589a98c8cc..d531c7e090bc2 100644 --- a/integration/hsm/helpers.go +++ b/integration/hsm/helpers.go @@ -26,6 +26,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -60,9 +61,6 @@ func newTeleportService(t *testing.T, config *servicecfg.Config, name string) *t serviceChannel: make(chan *service.TeleportProcess, 1), errorChannel: make(chan error, 1), } - t.Cleanup(func() { - require.NoError(t, s.close(), "error while closing %s during test cleanup", name) - }) return s } @@ -111,17 +109,43 @@ func (t *teleportService) waitForNewProcess(ctx context.Context) error { return nil } +func (t *teleportService) waitForEvent(ctx context.Context, event string) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + waitForEventErr := make(chan error) + go func() { + _, err := t.process.WaitForEvent(ctx, event) + select { + case waitForEventErr <- err: + case <-ctx.Done(): + } + }() + select { + case err := <-waitForEventErr: + return trace.Wrap(err) + case err := <-t.errorChannel: + if err != nil { + return trace.Wrap(err, "process unexpectedly exited while waiting for event %s", event) + } + return trace.Errorf("process unexpectedly exited while waiting for event %s", event) + case <-t.serviceChannel: + return trace.Errorf("process unexpectedly reloaded while waiting for event %s", event) + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + } +} + func (t *teleportService) waitForReady(ctx context.Context) error { t.log.Debugf("%s gen %d: waiting for TeleportReadyEvent", t.name, t.processGeneration) - if _, err := t.process.WaitForEvent(ctx, service.TeleportReadyEvent); err != nil { - return trace.Wrap(err, "timed out waiting for %s gen %d to be ready", t.name, t.processGeneration) + if err := t.waitForEvent(ctx, service.TeleportReadyEvent); err != nil { + return trace.Wrap(err, "waiting for %s gen %d to be ready", t.name, t.processGeneration) } t.log.Debugf("%s gen %d: got TeleportReadyEvent", t.name, t.processGeneration) // If this is an Auth server, also wait for AuthIdentityEvent so that we // can safely read the admin credentials and create a test client. if t.process.GetAuthServer() != nil { t.log.Debugf("%s gen %d: waiting for AuthIdentityEvent", t.name, t.processGeneration) - if _, err := t.process.WaitForEvent(ctx, service.AuthIdentityEvent); err != nil { + if err := t.waitForEvent(ctx, service.AuthIdentityEvent); err != nil { return trace.Wrap(err, "%s gen %d: timed out waiting AuthIdentityEvent", t.name, t.processGeneration) } t.log.Debugf("%s gen %d: got AuthIdentityEvent", t.name, t.processGeneration) @@ -170,7 +194,7 @@ func (t *teleportService) waitForLocalAdditionalKeys(ctx context.Context) error if err != nil { return trace.Wrap(err) } - if usableKeysResult.CAHasUsableKeys { + if usableKeysResult.CAHasPreferredKeyType { break } } @@ -180,7 +204,7 @@ func (t *teleportService) waitForLocalAdditionalKeys(ctx context.Context) error func (t *teleportService) waitForPhaseChange(ctx context.Context) error { t.log.Debugf("%s gen %d: waiting for phase change", t.name, t.processGeneration) - if _, err := t.process.WaitForEvent(ctx, service.TeleportPhaseChangeEvent); err != nil { + if err := t.waitForEvent(ctx, service.TeleportPhaseChangeEvent); err != nil { return trace.Wrap(err, "%s gen %d: timed out waiting for phase change", t.name, t.processGeneration) } t.log.Debugf("%s gen %d: changed phase", t.name, t.processGeneration) @@ -237,6 +261,7 @@ func newAuthConfig(t *testing.T, log utils.Logger) *servicecfg.Config { config.InstanceMetadataClient = cloud.NewDisabledIMDSClient() config.MaxRetryPeriod = 25 * time.Millisecond config.PollingPeriod = 2 * time.Second + config.Clock = fastClock(t) config.Auth.Enabled = true config.Auth.NoAudit = true @@ -268,6 +293,7 @@ func newAuthConfig(t *testing.T, log utils.Logger) *servicecfg.Config { func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *servicecfg.Config { config := servicecfg.MakeDefaultConfig() + config.Version = defaults.TeleportConfigVersionV3 config.DataDir = t.TempDir() config.CachePolicy.Enabled = true config.Auth.Enabled = false @@ -278,6 +304,7 @@ func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *ser config.InstanceMetadataClient = cloud.NewDisabledIMDSClient() config.MaxRetryPeriod = 25 * time.Millisecond config.PollingPeriod = 2 * time.Second + config.Clock = fastClock(t) config.Proxy.Enabled = true config.Proxy.DisableWebInterface = true @@ -288,3 +315,24 @@ func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *ser return config } + +// fastClock returns a clock that runs at ~20x realtime. +func fastClock(t *testing.T) clockwork.FakeClock { + // Start in the past to avoid cert not yet valid errors + clock := clockwork.NewFakeClockAt(time.Now().Add(-12 * time.Hour)) + done := make(chan struct{}) + t.Cleanup(func() { close(done) }) + go func() { + for { + select { + case <-done: + return + default: + } + clock.BlockUntil(1) + clock.Advance(time.Second) + time.Sleep(50 * time.Millisecond) + } + }() + return clock +} diff --git a/integration/hsm/hsm_test.go b/integration/hsm/hsm_test.go index c5ce0045272d0..c198a60a60e13 100644 --- a/integration/hsm/hsm_test.go +++ b/integration/hsm/hsm_test.go @@ -112,12 +112,6 @@ func liteBackendConfig(t *testing.T) *backend.Config { } } -func requireHSMAvailable(t *testing.T) { - if os.Getenv("SOFTHSM2_PATH") == "" && os.Getenv("TEST_GCP_KMS_KEYRING") == "" { - t.Skip("Skipping test because neither SOFTHSM2_PATH or TEST_GCP_KMS_KEYRING are set") - } -} - func requireETCDAvailable(t *testing.T) { if os.Getenv("TELEPORT_ETCD_TEST") == "" { t.Skip("Skipping test because TELEPORT_ETCD_TEST is not set") @@ -126,8 +120,6 @@ func requireETCDAvailable(t *testing.T) { // Tests a single CA rotation with a single HSM auth server func TestHSMRotation(t *testing.T) { - requireHSMAvailable(t) - ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) log := utils.NewLoggerForTests() @@ -229,11 +221,6 @@ func testAdminClient(t *testing.T, authDataDir string, authAddr string) { // Tests multiple CA rotations and rollbacks with 2 HSM auth servers in an HA configuration func TestHSMDualAuthRotation(t *testing.T) { - // TODO(nklaassen): fix this test and re-enable it. - // https://github.com/gravitational/teleport/issues/20217 - t.Skip("TestHSMDualAuthRotation is temporarily disabled due to flakiness") - - requireHSMAvailable(t) requireETCDAvailable(t) ctx, cancel := context.WithCancel(context.Background()) @@ -241,7 +228,7 @@ func TestHSMDualAuthRotation(t *testing.T) { log := utils.NewLoggerForTests() storageConfig := etcdBackendConfig(t) - // start a cluster with 1 auth server and a proxy + // start a cluster with 1 auth server log.Debug("TestHSMDualAuthRotation: Starting auth server 1") auth1Config := newHSMAuthConfig(t, storageConfig, log) auth1 := newTeleportService(t, auth1Config, "auth1") @@ -250,7 +237,6 @@ func TestHSMDualAuthRotation(t *testing.T) { "failed to delete hsm keys during test cleanup") }) authServices := teleportServices{auth1} - allServices := append(teleportServices{}, authServices...) require.NoError(t, authServices.start(ctx), "auth service failed initial startup") log.Debug("TestHSMDualAuthRotation: Starting load balancer") @@ -264,23 +250,16 @@ func TestHSMDualAuthRotation(t *testing.T) { go lb.Serve() t.Cleanup(func() { require.NoError(t, lb.Close()) }) - // start a proxy to make sure it can get creds at each stage of rotation - log.Debug("TestHSMDualAuthRotation: Starting proxy") - proxyConfig := newProxyConfig(t, utils.FromAddr(lb.Addr()), log) - proxy := newTeleportService(t, proxyConfig, "proxy") - require.NoError(t, proxy.start(ctx), "proxy failed initial startup") - allServices = append(allServices, proxy) - // add a new auth server log.Debug("TestHSMDualAuthRotation: Starting auth server 2") auth2Config := newHSMAuthConfig(t, storageConfig, log) auth2 := newTeleportService(t, auth2Config, "auth2") - require.NoError(t, auth2.start(ctx)) + err = auth2.start(ctx) + require.NoError(t, err, trace.DebugReport(err)) t.Cleanup(func() { require.NoError(t, auth2.process.GetAuthServer().GetKeyStore().DeleteUnusedKeys(ctx, nil)) }) authServices = append(authServices, auth2) - allServices = append(allServices, auth2) testAuth2Client := func(t *testing.T) { testAdminClient(t, auth2Config.DataDir, auth2.authAddrString(t)) @@ -294,7 +273,7 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForPhaseChange(ctx)) + require.NoError(t, authServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) testAuth2Client(t) }, @@ -302,21 +281,21 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testAuth2Client(t) }, }, { targetPhase: types.RotationPhaseUpdateServers, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testAuth2Client(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testAuth2Client(t) }, }, @@ -360,7 +339,7 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForPhaseChange(ctx)) + require.NoError(t, authServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) testLoadBalancedClient(t) }, @@ -368,21 +347,21 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseRollback, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForPhaseChange(ctx)) + require.NoError(t, authServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) testLoadBalancedClient(t) }, @@ -390,28 +369,28 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseRollback, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForPhaseChange(ctx)) + require.NoError(t, authServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) testLoadBalancedClient(t) }, @@ -419,28 +398,28 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseUpdateServers, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseRollback, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + require.NoError(t, authServices.waitForRestart(ctx)) testLoadBalancedClient(t) }, }, @@ -458,7 +437,6 @@ func TestHSMDualAuthRotation(t *testing.T) { // Tests a dual-auth server migration from raw keys to HSM keys func TestHSMMigrate(t *testing.T) { - requireHSMAvailable(t) requireETCDAvailable(t) ctx, cancel := context.WithCancel(context.Background()) @@ -494,12 +472,6 @@ func TestHSMMigrate(t *testing.T) { go lb.Serve() t.Cleanup(func() { require.NoError(t, lb.Close()) }) - // start a proxy to make sure it can get creds at each stage of migration - log.Debug("TestHSMMigrate: Starting proxy") - proxyConfig := newProxyConfig(t, utils.FromAddr(lb.Addr()), log) - proxy := newTeleportService(t, proxyConfig, "proxy") - require.NoError(t, proxy.start(ctx)) - testClient := func(t *testing.T) { testAdminClient(t, auth1Config.DataDir, lb.Addr().String()) } @@ -525,7 +497,6 @@ func TestHSMMigrate(t *testing.T) { assert.Contains(t, alert.Spec.Message, "host") authServices := teleportServices{auth1, auth2} - allServices := teleportServices{auth1, auth2, proxy} stages := []struct { targetPhase string @@ -534,7 +505,7 @@ func TestHSMMigrate(t *testing.T) { { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForPhaseChange(ctx)) + require.NoError(t, authServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) testClient(t) }, @@ -542,21 +513,24 @@ func TestHSMMigrate(t *testing.T) { { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + err := authServices.waitForRestart(ctx) + require.NoError(t, err, trace.DebugReport(err)) testClient(t) }, }, { targetPhase: types.RotationPhaseUpdateServers, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + err := authServices.waitForRestart(ctx) + require.NoError(t, err, trace.DebugReport(err)) testClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, allServices.waitForRestart(ctx)) + err := authServices.waitForRestart(ctx) + require.NoError(t, err, trace.DebugReport(err)) testClient(t) }, }, @@ -586,9 +560,7 @@ func TestHSMMigrate(t *testing.T) { auth2Config.Auth.KeyStore = keystore.HSMTestConfig(t) auth2 = newTeleportService(t, auth2Config, "auth2") require.NoError(t, auth2.start(ctx)) - authServices = teleportServices{auth1, auth2} - allServices = teleportServices{auth1, auth2, proxy} testClient(t) @@ -614,8 +586,6 @@ func TestHSMMigrate(t *testing.T) { // TestHSMRevert tests a single-auth server migration from HSM keys back to // software keys. func TestHSMRevert(t *testing.T) { - requireHSMAvailable(t) - clock := clockwork.NewFakeClock() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) diff --git a/lib/auth/init.go b/lib/auth/init.go index 5e74cc6e0c870..ae6b8b496b6c5 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -611,13 +611,22 @@ func initializeAuthority(ctx context.Context, asrv *Server, caID types.CertAuthI if err := asrv.ensureLocalAdditionalKeys(ctx, ca); err != nil { return nil, nil, trace.Wrap(err) } + ca, err = asrv.Services.GetCertAuthority(ctx, caID, true) + if err != nil { + return nil, nil, trace.Wrap(err) + } + usableKeysResult, err = asrv.keyStore.HasUsableActiveKeys(ctx, ca) + if err != nil { + return nil, nil, trace.Wrap(err) + } + } else { + log.Warnf("This Auth Service is configured to use %s but the %s CA contains only %s. "+ + "No new certificates can be signed with the existing keys. "+ + "You must perform a CA rotation to generate new keys, or adjust your configuration to use the existing keys.", + usableKeysResult.PreferredKeyType, + caID.Type, + strings.Join(usableKeysResult.CAKeyTypes, " and ")) } - log.Warnf("This Auth Service is configured to use %s but the %s CA contains only %s. "+ - "No new certificates can be signed with the existing keys. "+ - "You must perform a CA rotation to generate new keys, or adjust your configuration to use the existing keys.", - usableKeysResult.PreferredKeyType, - caID.Type, - strings.Join(usableKeysResult.CAKeyTypes, " and ")) } else if !usableKeysResult.CAHasPreferredKeyType { log.Warnf("This Auth Service is configured to use %s but the %s CA contains only %s. "+ "New certificates will continue to be signed with raw software keys but you must perform a CA rotation to begin using %s.", diff --git a/lib/auth/keystore/pkcs11.go b/lib/auth/keystore/pkcs11.go index d845a00265094..bfc1c32048f57 100644 --- a/lib/auth/keystore/pkcs11.go +++ b/lib/auth/keystore/pkcs11.go @@ -153,13 +153,13 @@ func (p *pkcs11KeyStore) generateRSA(ctx context.Context, options ...RSAKeyOptio <-p.semaphore }() - p.log.Debug("Creating new HSM keypair") - id, err := p.findUnusedID() if err != nil { return nil, nil, trace.Wrap(err) } + p.log.Debugf("Creating new HSM keypair %v", id) + ckaID, err := id.pkcs11Key(p.isYubiHSM) if err != nil { return nil, nil, trace.Wrap(err) @@ -201,7 +201,7 @@ func (p *pkcs11KeyStore) getSignerWithoutPublicKey(ctx context.Context, rawKey [ return nil, trace.Wrap(err) } if signer == nil { - return nil, trace.NotFound("failed to find keypair for given id") + return nil, trace.NotFound("failed to find keypair with id %v", keyID) } return signer, nil } @@ -308,6 +308,7 @@ func (p *pkcs11KeyStore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by if keyIsActive(signer) { continue } + p.log.Infof("Deleting unused key from HSM") if err := signer.Delete(); err != nil { // Key deletion is best-effort, log a warning on errors, and // continue trying to delete other keys. Errors have been observed diff --git a/lib/auth/keystore/testhelpers.go b/lib/auth/keystore/testhelpers.go index 4cbe6b63d6141..a4f2c850e87ec 100644 --- a/lib/auth/keystore/testhelpers.go +++ b/lib/auth/keystore/testhelpers.go @@ -51,7 +51,7 @@ func HSMTestConfig(t *testing.T) Config { t.Log("Running test with SoftHSM") return cfg } - t.Fatal("No HSM available for test") + t.Skip("No HSM available for test") return Config{} } diff --git a/lib/service/service.go b/lib/service/service.go index 0562ffb81dd97..4fde5af3145cd 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -743,16 +743,24 @@ func waitAndReload(ctx context.Context, cfg servicecfg.Config, srv Process, newT warnOnErr(srv.Close(), cfg.Log) return nil, trace.Wrap(err, "failed to start a new service") } + // Wait for the new server to report that it has started // before shutting down the old one. startTimeoutCtx, startCancel := context.WithTimeout(ctx, signalPipeTimeout) defer startCancel() + go func() { + // Avoid waiting for TeleportReadyEvent if it will never fire. + newSrv.WaitForEvent(startTimeoutCtx, ServiceExitedWithErrorEvent) + startCancel() + }() if _, err := newSrv.WaitForEvent(startTimeoutCtx, TeleportReadyEvent); err != nil { warnOnErr(newSrv.Close(), cfg.Log) warnOnErr(srv.Close(), cfg.Log) return nil, trace.BadParameter("the new service has failed to start") } cfg.Log.Infof("New service has started successfully.") + startCancel() + shutdownTimeout := cfg.Testing.ShutdownTimeout if shutdownTimeout == 0 { // The default shutdown timeout is very generous to avoid disrupting @@ -786,6 +794,7 @@ func waitAndReload(ctx context.Context, cfg servicecfg.Config, srv Process, newT } else { cfg.Log.Infof("The old service was successfully shut down gracefully.") } + return newSrv, nil } diff --git a/lib/utils/listener.go b/lib/utils/listener.go index e12d51bede342..42922c59fe9c1 100644 --- a/lib/utils/listener.go +++ b/lib/utils/listener.go @@ -29,9 +29,11 @@ import ( func GetListenerFile(listener net.Listener) (*os.File, error) { switch t := listener.(type) { case *net.TCPListener: - return t.File() + f, err := t.File() + return f, trace.Wrap(err) case *net.UnixListener: - return t.File() + f, err := t.File() + return f, trace.Wrap(err) } return nil, trace.BadParameter("unsupported listener: %T", listener) }