Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.1.x] Fix | Throttling of token requests by calling AcquireTokenSilent (#1925) #2021

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
using System.Collections.Concurrent;
using System.Linq;
using System.Security;
using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
Expand All @@ -23,6 +26,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
Expand Down Expand Up @@ -101,7 +106,9 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/AcquireTokenAsync/*'/>
public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters) => Task.Run(async () =>
{
AuthenticationResult result;
CancellationTokenSource cts = new();

AuthenticationResult result = null;
string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
string[] scopes = new string[] { scope };

Expand Down Expand Up @@ -147,69 +154,84 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync().Result;
}
else
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);

if (result == null)
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync().Result;
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token).Result;
}
else
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token).Result;
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();
result = app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync().Result;
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn);
string pwCacheKey = GetAccountPwCacheKey(parameters);
object previousPw = s_accountPwCache.Get(pwCacheKey);
byte[] currPwHash = GetHash(parameters.Password);

if (null != previousPw &&
previousPw is byte[] previousPwBytes &&
// Only get the cached token if the current password hash matches the previously used password hash
currPwHash.SequenceEqual(previousPwBytes))
{
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
}

if (result == null)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync()
.ConfigureAwait(false);

// We cache the password hash to ensure future connection requests include a validated password
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
// against the same tenant could succeed with an invalid password when we re-use the cached token.
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
{
s_accountPwCache.Remove(pwCacheKey);
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
}

SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn);
}
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
IAccount account;
if (!string.IsNullOrEmpty(parameters.UserId))
try
{
account = accounts.FirstOrDefault(a => parameters.UserId.Equals(a.Username, System.StringComparison.InvariantCultureIgnoreCase));
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
else
catch (MsalUiRequiredException)
{
account = accounts.FirstOrDefault();
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}

if (null != account)
{
try
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
else
if (result == null)
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
Expand All @@ -222,11 +244,58 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
});

private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app,
SqlAuthenticationParameters parameters,
string[] scopes,
CancellationTokenSource cts)
{
AuthenticationResult result = null;

// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();

IAccount account = default;
if (accounts.MoveNext())
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
do
{
IAccount currentVal = accounts.Current;
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
{
account = currentVal;
break;
}
}
while (accounts.MoveNext());
}
else
{
account = accounts.Current;
}
}

if (null != account)
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}

return result;
}

private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
SqlAuthenticationMethod authenticationMethod)
private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app,
string[] scopes,
Guid connectionId,
string userId,
SqlAuthenticationMethod authenticationMethod,
CancellationTokenSource cts,
ICustomWebUi customWebUI,
Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
{
CancellationTokenSource cts = new CancellationTokenSource();
#if NETCOREAPP
/*
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
Expand All @@ -243,11 +312,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
{
if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive)
{
if (_customWebUI != null)
if (customWebUI != null)
{
return await app.AcquireTokenInteractive(scopes)
.WithCorrelationId(connectionId)
.WithCustomWebUi(_customWebUI)
.WithCustomWebUi(customWebUI)
.WithLoginHint(userId)
.ExecuteAsync(cts.Token);
}
Expand Down Expand Up @@ -279,7 +348,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
else
{
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync();
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync();
return result;
}
}
Expand Down Expand Up @@ -329,6 +398,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
return clientApplicationInstance;
}

private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
{
return parameters.Authority + "+" + parameters.UserId;
}

private static byte[] GetHash(string input)
{
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
SHA256 sha256 = SHA256.Create();
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
return hashedBytes;
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;
Expand Down