Skip to content

Commit

Permalink
Add | Cache TokenCredential objects to take advantage of token caching (
Browse files Browse the repository at this point in the history
  • Loading branch information
dauinsight committed Aug 14, 2024
1 parent 6fe8e21 commit 9c0391b
Showing 1 changed file with 202 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap = new();
private static readonly ConcurrentDictionary<TokenCredentialKey, TokenCredentialData> s_tokenCredentialMap = new();
private static SemaphoreSlim s_pcaMapModifierSemaphore = new(1, 1);
private static SemaphoreSlim s_tokenCredentialMapModifierSemaphore = new(1, 1);
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
private readonly SqlClientLogger _logger = new SqlClientLogger();
private readonly SqlClientLogger _logger = new();
private Func<DeviceCodeResult, Task> _deviceCodeFlowCallback;
private ICustomWebUi _customWebUI = null;
private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId;
Expand Down Expand Up @@ -66,6 +68,11 @@ public static void ClearUserTokenCache()
{
s_pcaMap.Clear();
}

if (!s_tokenCredentialMap.IsEmpty)
{
s_tokenCredentialMap.Clear();
}
}

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetDeviceCodeFlowCallback/*'/>
Expand Down Expand Up @@ -145,50 +152,40 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
**/

int seperatorIndex = parameters.Authority.LastIndexOf('/');
string authority = parameters.Authority.Remove(seperatorIndex + 1);
string audience = parameters.Authority.Substring(seperatorIndex + 1);
int separatorIndex = parameters.Authority.LastIndexOf('/');
string authority = parameters.Authority.Remove(separatorIndex + 1);
string audience = parameters.Authority.Substring(separatorIndex + 1);
string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault)
{
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
{
AuthorityHost = new Uri(authority),
SharedTokenCacheTenantId = audience,
VisualStudioCodeTenantId = audience,
VisualStudioTenantId = audience,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};

// Optionally set clientId when available
if (clientId is not null)
{
defaultAzureCredentialOptions.ManagedIdentityClientId = clientId;
defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId;
}
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
// Cache DefaultAzureCredenial based on scope, authority, audience, and clientId
TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId);
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions() { AuthorityHost = new Uri(authority) };
TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) };

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
{
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
// Cache ManagedIdentityCredential based on scope, authority, and clientId
TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId);
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

AuthenticationResult result = null;
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
{
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
// Cache ClientSecretCredential based on scope, authority, audience, and clientId
TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId);
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

/*
* Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows
* (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend
Expand All @@ -204,7 +201,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
redirectUri = "http://localhost";
}
#endif
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId
#if NETFRAMEWORK
, _iWin32WindowFunc
#endif
Expand All @@ -213,7 +210,8 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
#endif
);

IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
AuthenticationResult result = null;
IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false);

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
Expand Down Expand Up @@ -248,7 +246,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
if (null != previousPw &&
previousPw is byte[] previousPwBytes &&
// Only get the cached token if the current password hash matches the previously used password hash
currPwHash.SequenceEqual(previousPwBytes))
AreEqual(currPwHash, previousPwBytes))
{
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
}
Expand Down Expand Up @@ -353,7 +351,7 @@ private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlo
{
if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive)
{
CancellationTokenSource ctsInteractive = new CancellationTokenSource();
CancellationTokenSource ctsInteractive = new();
#if NETCOREAPP
/*
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
Expand Down Expand Up @@ -447,16 +445,69 @@ public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirec
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
}

private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
private async Task<IPublicClientApplication> GetPublicClientAppInstanceAsync(PublicClientAppKey publicClientAppKey, CancellationToken cancellationToken)
{
if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
await s_pcaMapModifierSemaphore.WaitAsync(cancellationToken);
try
{
// Double-check in case another thread added it while we waited for the semaphore
if (!s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance))
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
}
}
finally
{
s_pcaMapModifierSemaphore.Release();
}
}

return clientApplicationInstance;
}

private static async Task<AccessToken> GetTokenAsync(TokenCredentialKey tokenCredentialKey, string secret,
TokenRequestContext tokenRequestContext, CancellationToken cancellationToken)
{
if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out TokenCredentialData tokenCredentialInstance))
{
await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken);
try
{
// Double-check in case another thread added it while we waited for the semaphore
if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out tokenCredentialInstance))
{
tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret);
s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance);
}
}
finally
{
s_tokenCredentialMapModifierSemaphore.Release();
}
}

if (!AreEqual(tokenCredentialInstance._secretHash, GetHash(secret)))
{
// If the secret hash has changed, we need to remove the old token credential instance and create a new one.
await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken);
try
{
s_tokenCredentialMap.TryRemove(tokenCredentialKey, out _);
tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret);
s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance);
}
finally
{
s_tokenCredentialMapModifierSemaphore.Release();
}
}

return await tokenCredentialInstance._tokenCredential.GetTokenAsync(tokenRequestContext, cancellationToken);
}

private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
{
return parameters.Authority + "+" + parameters.UserId;
Expand All @@ -470,6 +521,24 @@ private static byte[] GetHash(string input)
return hashedBytes;
}

private static bool AreEqual(byte[] a1, byte[] a2)
{
if (ReferenceEquals(a1, a2))
{
return true;
}
else if (a1 is null || a2 is null)
{
return false;
}
else if (a1.Length != a2.Length)
{
return false;
}

return a1.AsSpan().SequenceEqual(a2.AsSpan());
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;
Expand Down Expand Up @@ -513,6 +582,59 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ
return publicClientApplication;
}

private static TokenCredentialData CreateTokenCredentialInstance(TokenCredentialKey tokenCredentialKey, string secret)
{
if (tokenCredentialKey._tokenCredentialType == typeof(DefaultAzureCredential))
{
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
{
AuthorityHost = new Uri(tokenCredentialKey._authority),
SharedTokenCacheTenantId = tokenCredentialKey._audience,
VisualStudioCodeTenantId = tokenCredentialKey._audience,
VisualStudioTenantId = tokenCredentialKey._audience,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};

// Optionally set clientId when available
if (tokenCredentialKey._clientId is not null)
{
defaultAzureCredentialOptions.ManagedIdentityClientId = tokenCredentialKey._clientId;
defaultAzureCredentialOptions.SharedTokenCacheUsername = tokenCredentialKey._clientId;
defaultAzureCredentialOptions.WorkloadIdentityClientId = tokenCredentialKey._clientId;
}

return new TokenCredentialData(new DefaultAzureCredential(defaultAzureCredentialOptions), GetHash(secret));
}

TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) };

if (tokenCredentialKey._tokenCredentialType == typeof(ManagedIdentityCredential))
{
return new TokenCredentialData(new ManagedIdentityCredential(tokenCredentialKey._clientId, tokenCredentialOptions), GetHash(secret));
}
else if (tokenCredentialKey._tokenCredentialType == typeof(ClientSecretCredential))
{
return new TokenCredentialData(new ClientSecretCredential(tokenCredentialKey._audience, tokenCredentialKey._clientId, secret, tokenCredentialOptions), GetHash(secret));
}
else if (tokenCredentialKey._tokenCredentialType == typeof(WorkloadIdentityCredential))
{
// The WorkloadIdentityCredentialOptions object initialization populates its instance members
// from the environment variables AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE,
// and AZURE_ADDITIONALLY_ALLOWED_TENANTS. AZURE_CLIENT_ID may be overridden by the User Id.
WorkloadIdentityCredentialOptions options = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) };

if (tokenCredentialKey._clientId is not null)
{
options.ClientId = tokenCredentialKey._clientId;
}

return new TokenCredentialData(new WorkloadIdentityCredential(options), GetHash(secret));
}

// This should never be reached, but if it is, throw an exception that will be noticed during development
throw new ArgumentException(nameof(ActiveDirectoryAuthenticationProvider));
}

internal class PublicClientAppKey
{
public readonly string _authority;
Expand Down Expand Up @@ -572,5 +694,52 @@ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _app
#endif
).GetHashCode();
}

internal class TokenCredentialData
{
public TokenCredential _tokenCredential;
public byte[] _secretHash;

public TokenCredentialData(TokenCredential tokenCredential, byte[] secretHash)
{
_tokenCredential = tokenCredential;
_secretHash = secretHash;
}
}

internal class TokenCredentialKey
{
public readonly Type _tokenCredentialType;
public readonly string _authority;
public readonly string _scope;
public readonly string _audience;
public readonly string _clientId;

public TokenCredentialKey(Type tokenCredentialType, string authority, string scope, string audience, string clientId)
{
_tokenCredentialType = tokenCredentialType;
_authority = authority;
_scope = scope;
_audience = audience;
_clientId = clientId;
}

public override bool Equals(object obj)
{
if (obj != null && obj is TokenCredentialKey tcKey)
{
return string.CompareOrdinal(nameof(_tokenCredentialType), nameof(tcKey._tokenCredentialType)) == 0
&& string.CompareOrdinal(_authority, tcKey._authority) == 0
&& string.CompareOrdinal(_scope, tcKey._scope) == 0
&& string.CompareOrdinal(_audience, tcKey._audience) == 0
&& string.CompareOrdinal(_clientId, tcKey._clientId) == 0
;
}
return false;
}

public override int GetHashCode() => Tuple.Create(_tokenCredentialType, _authority, _scope, _audience, _clientId).GetHashCode();
}

}
}

0 comments on commit 9c0391b

Please sign in to comment.