Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API | AccessTokenCallback support #1260

Merged
merged 41 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b974ffb
POC for TokenCredential support
christothes Sep 10, 2021
8f76542
POC 2 - callback abstraction
christothes Oct 18, 2021
7ffc10c
merge
christothes Oct 18, 2021
9ed36d7
POC 3
christothes Oct 22, 2021
8321090
fix
christothes Oct 25, 2021
2aee2b2
Merge remote-tracking branch 'upstream/main' into chriss/ADCreds
christothes Oct 27, 2021
ea47be2
Merge remote-tracking branch 'upstream/main' into chriss/ADCreds
christothes Nov 8, 2021
10f5a95
merge
christothes Nov 8, 2021
d281983
cleanups
christothes Nov 12, 2021
fcfb66e
formatting
christothes Nov 12, 2021
4c718b5
netfx consistency
christothes Nov 12, 2021
4ccc412
fix
christothes Nov 12, 2021
68bf511
cleanup
christothes Nov 15, 2021
0671d58
merge
christothes Mar 8, 2023
d9570b3
nuget
christothes Mar 8, 2023
aa517b6
source ref
christothes Mar 8, 2023
c5def10
revert nuget.config change
christothes Mar 10, 2023
727406f
Update src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlCli…
christothes Mar 10, 2023
5c1b5ef
cast timeout to int
christothes Mar 13, 2023
d82f3bb
PR feedback
christothes Apr 14, 2023
9a9cd87
tests
christothes Apr 18, 2023
677f729
fix resources
christothes Apr 19, 2023
afc5fff
fix error messages for fx
christothes Apr 19, 2023
2673bd4
fixes
christothes Apr 19, 2023
b08a251
rename AzureADTokenRequestContext
christothes Apr 20, 2023
be1263e
use SqlAuthenticationParameters in callback
christothes May 2, 2023
52a84f9
docs and simple sample
christothes May 2, 2023
9be9391
tests for password and userId with callback
christothes May 3, 2023
a2233f5
add to SqlClient.cs API listing
christothes May 3, 2023
c1df1f4
Allow credential with callback and pass to the callback, if available
David-Engel May 12, 2023
a4aca8f
Merge pull request #1 from David-Engel/ADCreds
christothes Jun 5, 2023
fc57e98
Merge remote-tracking branch 'upstream/main' into chriss/ADCreds
DavoudEshtehari Jun 5, 2023
2ab55b8
fb
christothes Jun 6, 2023
ada78c2
Apply suggestions from code review
christothes Jun 6, 2023
1e531dc
Merge branch 'chriss/ADCreds' of https://github.com/christothes/SqlCl…
christothes Jun 6, 2023
caa6c3e
fb
christothes Jun 6, 2023
36a59e8
Apply suggestions from code review
christothes Jun 6, 2023
cd922db
Merge branch 'chriss/ADCreds' of https://github.com/christothes/SqlCl…
christothes Jun 6, 2023
16693f6
fb
christothes Jun 6, 2023
78b9e4a
fb
christothes Jun 13, 2023
67bfd01
Add tests
DavoudEshtehari Jun 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions doc/samples/SqlConnection_AccessTokenCallback.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Data;
// <Snippet1>
using Microsoft.Data.SqlClient;
using Azure.Identity;

class Program
{
static void Main()
{
OpenSqlConnection();
Console.ReadLine();
}

private static void OpenSqlConnection()
{
string connectionString = GetConnectionString();
using (SqlConnection connection = new SqlConnection("Data Source=contoso.database.windows.net;Initial Catalog=AdventureWorks;")
{
AccessTokenCallback = async (authParams, cancellationToken) =>
{
var cred = new DefaultAzureCredential();
string scope = authParams.Resource.EndsWith(s_defaultScopeSuffix) ? authParams.Resource : authParams.Resource + s_defaultScopeSuffix;
var token = await cred.GetTokenAsync(new TokenRequestContext(new[] { scope }), cancellationToken);
return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
}
})
{
connection.Open();
Console.WriteLine("ServerVersion: {0}", connection.ServerVersion);
Console.WriteLine("State: {0}", connection.State);
}
}
}
// </Snippet1>
16 changes: 16 additions & 0 deletions doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ using (SqlConnection connection = new SqlConnection(connectionString))
<value>The access token for the connection.</value>
<remarks>To be added.</remarks>
</AccessToken>
<AccessTokenCallback>
christothes marked this conversation as resolved.
Show resolved Hide resolved
<summary>Gets or sets the access token callback for the connection.</summary>
<value>
The Func that takes a <see cref="SqlAuthenticationParameters" /> and <see cref="System.Threading.CancellationToken" /> and returns a <see cref="SqlAuthenticationToken" />.</value>
<remarks>
<format type="text/markdown"><![CDATA[

## Examples
The following example demonstrates how to define and set an <xref:Microsoft.Data.SqlClient.AccessTokenCallback>.

[!code-csharp[SqlConnection_AccessTokenCallback Example#1](~/../sqlclient/doc/samples/SqlConnection_AccessTokenCallback.cs#1)]

]]></format>
</remarks>
<exception cref="T:System.InvalidOperationException">The AccessTokenCallback is combined with other conflicting authentication configurations.</exception>
</AccessTokenCallback>
<BeginDbTransaction>
<param name="isolationLevel">To be added.</param>
<summary>To be added.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available.
// New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future.

[assembly: System.CLSCompliant(true)]
namespace Microsoft.Data
{
Expand Down Expand Up @@ -839,6 +840,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ClientConnectionId/*'/>
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public System.Guid ClientConnectionId { get { throw null; } }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
public System.Func<SqlAuthenticationParameters, System.Threading.CancellationToken, System.Threading.Tasks.Task<SqlAuthenticationToken>> AccessTokenCallback { get { throw null; } set { } }

///
/// for internal test only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;

Expand Down Expand Up @@ -272,7 +274,7 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProvidersNames()
}

/// <summary>
/// This function returns a list of the names of the custom providers currently registered. If the
/// This function returns a list of the names of the custom providers currently registered. If the
/// instance-level cache is not empty, that cache is used, else the global cache is used.
/// </summary>
/// <returns>Combined list of provider names</returns>
Expand Down Expand Up @@ -344,7 +346,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<st
new(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}
Expand Down Expand Up @@ -584,7 +586,7 @@ public override string ConnectionString
}
set
{
if (_credential != null || _accessToken != null)
if (_credential != null || _accessToken != null || _accessTokenCallback != null)
{
SqlConnectionString connectionOptions = new SqlConnectionString(value);
if (_credential != null)
Expand Down Expand Up @@ -620,12 +622,18 @@ public override string ConnectionString

CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
}
else if (_accessToken != null)

if (_accessToken != null)
{
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(connectionOptions);
}

if (_accessTokenCallback != null)
{
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
}
}
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken));
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
Expand Down Expand Up @@ -685,11 +693,34 @@ public string AccessToken
}

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
_accessToken = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> AccessTokenCallback
{
get { return _accessTokenCallback; }
set
{
// If a connection is connecting or is ever opened, AccessToken callback cannot be set
if (!InnerConnection.AllowSetConnectionString)
{
throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State);
}

if (value != null)
{
// Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
}

ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
_accessTokenCallback = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
Expand Down Expand Up @@ -970,6 +1001,7 @@ public SqlCredential Credential
}

CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);

if (_accessToken != null)
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
Expand All @@ -979,7 +1011,7 @@ public SqlCredential Credential
_credential = value;

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
}
}

Expand Down Expand Up @@ -1026,6 +1058,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
}

if(_accessTokenCallback != null)
{
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
}
}

// CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict
// with the keys used in connection string and credential
// If there is any conflict, it throws InvalidOperationException
// This is to be used setter of ConnectionString and AccessTokenCallback properties
private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions)
christothes marked this conversation as resolved.
Show resolved Hide resolved
{
if (UsesIntegratedSecurity(connectionOptions))
{
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity();
}

if (UsesAuthentication(connectionOptions))
{
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
}

if(_accessToken != null)
{
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/DbProviderFactory/*' />
Expand Down Expand Up @@ -2128,7 +2187,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity)
Expand Down Expand Up @@ -2177,7 +2236,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);

Expand Down Expand Up @@ -2216,7 +2275,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
if (con != null)
con.Dispose();
}
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);

SqlConnectionFactory.SingletonInstance.ClearPool(key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down
Loading