Skip to content

Commit

Permalink
Merge pull request #6214 from zarvd/azclient/kv-credential
Browse files Browse the repository at this point in the history
[pkg/azclients] Add mutex guard for thread-safe token refresh
  • Loading branch information
k8s-ci-robot authored May 17, 2024
2 parents c01bb44 + 0556803 commit d7c516f
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 34 deletions.
123 changes: 93 additions & 30 deletions pkg/azclient/armauth/keyvault_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,34 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"

"sigs.k8s.io/cloud-provider-azure/pkg/azclient/utils"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/vaultclient"
)

type SecretResourceID struct {
SubscriptionID string
ResourceGroup string
VaultName string
SecretName string
}

func (s SecretResourceID) String() string {
return fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.KeyVault/vaults/%s/secrets/%s", s.SubscriptionID, s.ResourceGroup, s.VaultName, s.SecretName)
}

type KeyVaultCredential struct {
secretClient *azsecrets.Client
secretPath string
secretClient *azsecrets.Client
vaultURI string
secretResourceID SecretResourceID

mtx sync.RWMutex
token *azcore.AccessToken
}

Expand All @@ -41,62 +58,108 @@ type KeyVaultCredentialSecret struct {

func NewKeyVaultCredential(
msiCredential azcore.TokenCredential,
keyVaultURL string,
secretName string,
secretResourceID SecretResourceID,
) (*KeyVaultCredential, error) {
cli, err := azsecrets.NewClient(keyVaultURL, msiCredential, nil)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Get KeyVault URI
var vaultURI string
{
vaultCli, err := vaultclient.New(secretResourceID.SubscriptionID, msiCredential, utils.GetDefaultOption())
if err != nil {
return nil, fmt.Errorf("create KeyVault client: %w", err)
}

vault, err := vaultCli.Get(ctx, secretResourceID.ResourceGroup, secretResourceID.VaultName)
if err != nil {
return nil, fmt.Errorf("get vault %s: %w", secretResourceID.VaultName, err)
}

if vault.Properties == nil || vault.Properties.VaultURI == nil {
return nil, fmt.Errorf("vault uri is nil")
}
vaultURI = *vault.Properties.VaultURI
}

cli, err := azsecrets.NewClient(vaultURI, msiCredential, nil)
if err != nil {
return nil, fmt.Errorf("create KeyVault client: %w", err)
return nil, fmt.Errorf("create secret client: %w", err)
}

rv := &KeyVaultCredential{
secretClient: cli,
secretPath: secretName,
secretClient: cli,
mtx: sync.RWMutex{},
secretResourceID: secretResourceID,
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := rv.refreshToken(ctx); err != nil {
return nil, fmt.Errorf("refresh token: %w", err)
if _, err := rv.refreshToken(ctx); err != nil {
return nil, fmt.Errorf("refresh token from %s: %w", secretResourceID, err)
}

return rv, nil
}

func (c *KeyVaultCredential) refreshToken(ctx context.Context) error {
const LatestVersion = ""
func (c *KeyVaultCredential) refreshToken(ctx context.Context) (*azcore.AccessToken, error) {
const (
LatestVersion = ""
RefreshTokenOffset = 5 * time.Minute
)

cloneAccessToken := func(token *azcore.AccessToken) *azcore.AccessToken {
return &azcore.AccessToken{
Token: token.Token,
ExpiresOn: token.ExpiresOn,
}
}

resp, err := c.secretClient.GetSecret(ctx, c.secretPath, LatestVersion, nil)
if err != nil {
return err
{
c.mtx.RLock()
if c.token != nil && c.token.ExpiresOn.Add(RefreshTokenOffset).Before(time.Now()) {
c.mtx.RUnlock()
return cloneAccessToken(c.token), nil
}
c.mtx.RUnlock()
}
if resp.Value == nil {
return fmt.Errorf("secret value is nil")

c.mtx.Lock()
defer c.mtx.Unlock()

if c.token != nil && c.token.ExpiresOn.Add(RefreshTokenOffset).Before(time.Now()) {
return cloneAccessToken(c.token), nil
}

var secret KeyVaultCredentialSecret
if err := json.Unmarshal([]byte(*resp.Value), &secret); err != nil {
return fmt.Errorf("unmarshal secret value `%s`: %w", *resp.Value, err)
{
resp, err := c.secretClient.GetSecret(ctx, c.secretResourceID.SecretName, LatestVersion, nil)
if err != nil {
return nil, err
} else if resp.Value == nil {
return nil, fmt.Errorf("secret value is nil")
}

// Parse secret value
if err := json.Unmarshal([]byte(*resp.Value), &secret); err != nil {
return nil, fmt.Errorf("unmarshal secret value `%s`: %w", *resp.Value, err)
} else if secret.AccessToken == "" {
return nil, fmt.Errorf("access token is empty")
}
}

c.token = &azcore.AccessToken{
Token: secret.AccessToken,
ExpiresOn: secret.ExpiresOn,
}

return nil
// Return a copy of the token to avoid concurrent modification
return cloneAccessToken(c.token), nil
}

func (c *KeyVaultCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
const RefreshTokenOffset = 5 * time.Minute

if c.token != nil && c.token.ExpiresOn.Add(RefreshTokenOffset).Before(time.Now()) {
return *c.token, nil
}

if err := c.refreshToken(ctx); err != nil {
token, err := c.refreshToken(ctx)
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("refresh token: %w", err)
}

return *c.token, nil
return *token, nil
}
11 changes: 9 additions & 2 deletions pkg/azclient/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"

Expand All @@ -39,6 +40,8 @@ type AuthProvider struct {
NetworkClientSecretCredential azcore.TokenCredential

MultiTenantCredential azcore.TokenCredential

ClientOptions *policy.ClientOptions
}

func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, clientOptionsMutFn ...func(option *policy.ClientOptions)) (*AuthProvider, error) {
Expand Down Expand Up @@ -88,8 +91,7 @@ func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, client
if config.UseManagedIdentityExtension && config.AuxiliaryTokenProvider != nil && IsMultiTenant(armConfig) {
networkTokenCredential, err = armauth.NewKeyVaultCredential(
managedIdentityCredential,
config.AuxiliaryTokenProvider.KeyVaultURL,
config.AuxiliaryTokenProvider.SecretName,
config.AuxiliaryTokenProvider.SecretResourceID(),
)
if err != nil {
return nil, fmt.Errorf("create KeyVaultCredential for auxiliary token provider: %w", err)
Expand Down Expand Up @@ -210,3 +212,8 @@ func (factory *AuthProvider) GetMultiTenantIdentity() azcore.TokenCredential {
func (factory *AuthProvider) IsMultiTenantModeEnabled() bool {
return factory.MultiTenantCredential != nil
}

func (factory *AuthProvider) TokenScope() string {
audience := factory.ClientOptions.Cloud.Services[cloud.ResourceManager].Audience
return fmt.Sprintf("https://%s/.default", audience)
}
16 changes: 14 additions & 2 deletions pkg/azclient/auth_conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package azclient
import (
"os"

"sigs.k8s.io/cloud-provider-azure/pkg/azclient/armauth"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/utils"
)

Expand Down Expand Up @@ -48,8 +49,10 @@ type AzureAuthConfig struct {
}

type AzureAuthAuxiliaryTokenProvider struct {
KeyVaultURL string `json:"keyVaultURL,omitempty" yaml:"keyVaultURL,omitempty"`
SecretName string `json:"secretName" yaml:"secretName"`
SubscriptionID string `json:"subscriptionID,omitempty"`
ResourceGroup string `json:"resourceGroup,omitempty"`
VaultName string `json:"vaultName,omitempty"`
SecretName string `json:"secretName,omitempty"`
}

func (config *AzureAuthConfig) GetAADClientID() string {
Expand All @@ -75,3 +78,12 @@ func (config *AzureAuthConfig) GetAzureFederatedTokenFile() (string, bool) {
}
return config.AADFederatedTokenFile, config.UseFederatedWorkloadIdentityExtension
}

func (config *AzureAuthAuxiliaryTokenProvider) SecretResourceID() armauth.SecretResourceID {
return armauth.SecretResourceID{
SubscriptionID: config.SubscriptionID,
ResourceGroup: config.ResourceGroup,
VaultName: config.VaultName,
SecretName: config.SecretName,
}
}

0 comments on commit d7c516f

Please sign in to comment.