33// See the LICENSE file in the project root for more information.
44
55using System ;
6- using System . IdentityModel . Tokens . Jwt ;
6+ using System . Collections . Concurrent ;
77using System . Linq ;
88using System . Net . Http ;
99using System . Threading ;
1010using System . Threading . Tasks ;
1111using Azure . Core ;
12- using Microsoft . IdentityModel . Clients . ActiveDirectory ;
13- using Newtonsoft . Json ;
14- using Newtonsoft . Json . Linq ;
12+ using Azure . Identity ;
1513
1614namespace Microsoft . Data . SqlClient . ManualTesting . Tests
1715{
1816 public class SqlClientCustomTokenCredential : TokenCredential
1917 {
18+ private const string DEFAULT_PREFIX = "/.default" ;
19+ private static readonly ConcurrentDictionary < string , ClientSecretCredential > s_clientSecretCredentials = new ( ) ;
20+
2021 string _authority = "" ;
2122 string _resource = "" ;
2223 string _akvUrl = "" ;
@@ -70,40 +71,8 @@ private async Task<AccessToken> AcquireTokenAsync()
7071 _akvUrl = DataTestUtility . AKVUrl ;
7172 }
7273
73- string strAccessToken = await AzureActiveDirectoryAuthenticationCallback ( _authority , _resource ) ;
74- DateTime expiryTime = InterceptAccessTokenForExpiry ( strAccessToken ) ;
75- return new AccessToken ( strAccessToken , new DateTimeOffset ( expiryTime ) ) ;
76- }
77-
78- private DateTime InterceptAccessTokenForExpiry ( string accessToken )
79- {
80- if ( null == accessToken )
81- {
82- throw new ArgumentNullException ( accessToken ) ;
83- }
84-
85- var jwtHandler = new JwtSecurityTokenHandler ( ) ;
86- var jwtOutput = string . Empty ;
87-
88- // Check Token Format
89- if ( ! jwtHandler . CanReadToken ( accessToken ) )
90- throw new FormatException ( accessToken ) ;
91-
92- JwtSecurityToken token = jwtHandler . ReadJwtToken ( accessToken ) ;
93-
94- // Re-serialize the Token Headers to just Key and Values
95- var jwtHeader = JsonConvert . SerializeObject ( token . Header . Select ( h => new { h . Key , h . Value } ) ) ;
96- jwtOutput = $ "{{\r \n \" Header\" :\r \n { JToken . Parse ( jwtHeader ) } ,";
97-
98- // Re-serialize the Token Claims to just Type and Values
99- var jwtPayload = JsonConvert . SerializeObject ( token . Claims . Select ( c => new { c . Type , c . Value } ) ) ;
100- jwtOutput += $ "\r \n \" Payload\" :\r \n { JToken . Parse ( jwtPayload ) } \r \n }}";
101-
102- // Output the whole thing to pretty JSON object formatted.
103- string jToken = JToken . Parse ( jwtOutput ) . ToString ( Formatting . Indented ) ;
104- JToken payload = JObject . Parse ( jToken ) . GetValue ( "Payload" ) ;
105-
106- return new DateTime ( 1970 , 1 , 1 ) . AddSeconds ( ( long ) payload [ 4 ] [ "Value" ] ) ;
74+ AccessToken accessToken = await AzureActiveDirectoryAuthenticationCallback ( _authority , _resource ) ;
75+ return accessToken ;
10776 }
10877
10978 private static string ValidateChallenge ( string challenge )
@@ -127,16 +96,20 @@ private static string ValidateChallenge(string challenge)
12796 /// <param name="authority">Authorization URL</param>
12897 /// <param name="resource">Resource</param>
12998 /// <returns></returns>
130- public static async Task < string > AzureActiveDirectoryAuthenticationCallback ( string authority , string resource )
99+ public static async Task < AccessToken > AzureActiveDirectoryAuthenticationCallback ( string authority , string resource )
131100 {
132- var authContext = new AuthenticationContext ( authority ) ;
133- ClientCredential clientCred = new ClientCredential ( DataTestUtility . AKVClientId , DataTestUtility . AKVClientSecret ) ;
134- AuthenticationResult result = await authContext . AcquireTokenAsync ( resource , clientCred ) ;
135- if ( result == null )
136- {
137- throw new InvalidOperationException ( $ "Failed to retrieve an access token for { resource } ") ;
138- }
139- return result . AccessToken ;
101+ using CancellationTokenSource cts = new ( ) ;
102+ cts . CancelAfter ( 30000 ) ; // Hard coded for tests
103+ string [ ] scopes = new string [ ] { resource + DEFAULT_PREFIX } ;
104+ TokenRequestContext tokenRequestContext = new ( scopes ) ;
105+ int separatorIndex = authority . LastIndexOf ( '/' ) ;
106+ string authorityHost = authority . Remove ( separatorIndex + 1 ) ;
107+ string audience = authority . Substring ( separatorIndex + 1 ) ;
108+ TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions ( ) { AuthorityHost = new Uri ( authorityHost ) } ;
109+ ClientSecretCredential clientSecretCredential = s_clientSecretCredentials . GetOrAdd ( authority + "|--|" + resource ,
110+ new ClientSecretCredential ( audience , DataTestUtility . AKVClientId , DataTestUtility . AKVClientSecret , tokenCredentialOptions ) ) ;
111+ AccessToken accessToken = await clientSecretCredential . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
112+ return accessToken ;
140113 }
141114 }
142115}
0 commit comments