Skip to content

Commit

Permalink
Active Directory Default authentication improvements (#1360)
Browse files Browse the repository at this point in the history
* Active Directory Default improvements

* Disable test on Linux

* Feedback applied

* Update src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Co-authored-by: Javad <v-jarahn@microsoft.com>

* Add configureawaits

Co-authored-by: Javad <v-jarahn@microsoft.com>
  • Loading branch information
cheenamalhotra and Javad authored Oct 20, 2021
1 parent aa95223 commit 3c8ef33
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,28 +118,46 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
string[] scopes = new string[] { scope };
TokenRequestContext tokenRequestContext = new(scopes);

/* We split audience from Authority URL here. Audience can be one of the following:
* The Azure AD authority audience enumeration
* The tenant ID, which can be:
* - A GUID (the ID of your Azure AD instance), for single-tenant applications
* - A domain name associated with your Azure AD instance (also for single-tenant applications)
* One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
* - `organizations` for a multitenant application
* - `consumers` to sign in users only with their personal accounts
* - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
*
* MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
* If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
**/

int seperatorIndex = parameters.Authority.LastIndexOf('/');
string tenantId = parameters.Authority.Substring(seperatorIndex + 1);
string authority = parameters.Authority.Remove(seperatorIndex + 1);

TokenRequestContext tokenRequestContext = new TokenRequestContext(scopes);
string audience = parameters.Authority.Substring(seperatorIndex + 1);
string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault)
{
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new DefaultAzureCredentialOptions()
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
{
AuthorityHost = new Uri(authority),
ManagedIdentityClientId = clientId,
InteractiveBrowserTenantId = tenantId,
SharedTokenCacheTenantId = tenantId,
SharedTokenCacheUsername = clientId,
VisualStudioCodeTenantId = tenantId,
VisualStudioTenantId = tenantId,
SharedTokenCacheTenantId = audience,
VisualStudioCodeTenantId = audience,
VisualStudioTenantId = audience,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);

// Optionally set clientId when available
if (clientId is not null)
{
defaultAzureCredentialOptions.ManagedIdentityClientId = clientId;
defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId;
}
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}
Expand All @@ -148,15 +166,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
{
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

AuthenticationResult result;
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
{
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}
Expand Down Expand Up @@ -194,13 +212,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
else
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
Expand All @@ -213,14 +233,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync()).GetEnumerator();
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();

IAccount account = default;
if (accounts.MoveNext())
Expand Down Expand Up @@ -250,22 +271,22 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token);
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
}
else
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
}
Expand Down Expand Up @@ -304,7 +325,8 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
.WithCorrelationId(connectionId)
.WithCustomWebUi(_customWebUI)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token);
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
}
else
{
Expand All @@ -328,15 +350,17 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
return await app.AcquireTokenInteractive(scopes)
.WithCorrelationId(connectionId)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token);
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
}
}
else
{
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
.WithCorrelationId(connectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
return result;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ private static void ConnectAndDisconnect(string connectionString, SqlCredential
private static bool IsAADConnStringsSetup() => DataTestUtility.IsAADPasswordConnStrSetup();
private static bool IsManagedIdentitySetup() => DataTestUtility.ManagedIdentitySupported;

[PlatformSpecific(TestPlatforms.Windows)]
[ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))]
public static void KustoDatabaseTest()
{
// This is a sample Kusto database that can be connected by any AD account.
using SqlConnection connection = new SqlConnection("Data Source=help.kusto.windows.net; Authentication=Active Directory Default;Trust Server Certificate=True;");
connection.Open();
Assert.True(connection.State == System.Data.ConnectionState.Open);
}

[ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))]
public static void AccessTokenTest()
{
Expand Down

0 comments on commit 3c8ef33

Please sign in to comment.