diff --git a/lib/client/api.go b/lib/client/api.go index bf331ff1b8774..9ca921cb28d5d 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -469,15 +469,6 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error) log.Warningf("Failed to save profile: %v", err) return trace.Wrap(err) } - // Override client's auth methods, current cluster and user name - authMethod, err := key.AsAuthMethod() - if err != nil { - return trace.Wrap(err) - } - // After successful login we have local agent updated with latest - // and greatest auth information, setup client to try only this new - // method fetched from key, to isolate the retry - tc.Config.AuthMethods = []ssh.AuthMethod{authMethod} return fn() } @@ -1066,7 +1057,11 @@ func (tc *TeleportClient) LoadKeyForClusterWithReissue(ctx context.Context, clus return trace.Wrap(err) } // Reissuing also loads the new key. - return tc.ReissueUserCerts(ctx, ReissueParams{RouteToCluster: clusterName}) + err = tc.ReissueUserCerts(ctx, ReissueParams{RouteToCluster: clusterName}) + if err != nil { + return trace.Wrap(err) + } + return nil } // accessPoint returns access point based on the cache policy @@ -2208,7 +2203,18 @@ func (tc *TeleportClient) ActivateKey(ctx context.Context, key *Key) error { // Connect to the Auth Server of the root cluster and fetch the known hosts. rootClusterName := key.TrustedCA[0].ClusterName if err := tc.UpdateTrustedCA(ctx, rootClusterName); err != nil { - return trace.Wrap(err) + if len(tc.JumpHosts) == 0 { + return trace.Wrap(err) + } + errViaJumphost := err + // If JumpHosts was pointing at the leaf cluster (e.g. during 'tsh ssh + // -J leaf.example.com'), this could've caused the above error. Try to + // fetch CAs without JumpHosts to force it to use the root cluster. + if err := tc.WithoutJumpHosts(func(tc *TeleportClient) error { + return tc.UpdateTrustedCA(ctx, rootClusterName) + }); err != nil { + return trace.NewAggregate(errViaJumphost, err) + } } return nil diff --git a/lib/client/interfaces.go b/lib/client/interfaces.go index 84911a812c996..552d34335fa2a 100644 --- a/lib/client/interfaces.go +++ b/lib/client/interfaces.go @@ -370,6 +370,16 @@ func (k *Key) CheckCert() error { return trace.Wrap(err) } + // Check that the certificate was for the current public key. If not, the + // public/private key pair may have been rotated. + pub, _, _, _, err := ssh.ParseAuthorizedKey(k.Pub) + if err != nil { + return trace.Wrap(err) + } + if !sshutils.KeysEqual(cert.Key, pub) { + return trace.CompareFailed("public key in profile does not match the public key in SSH certificate") + } + // A valid principal is always passed in because the principals are not being // checked here, but rather the validity period, signature, and algorithms. certChecker := utils.CertChecker{