33// See the LICENSE file in the project root for more information.
44
55using System ;
6+ using System . Diagnostics ;
7+ using System . Net . Http ;
8+ using System . Text . RegularExpressions ;
9+ using System . Threading ;
610using System . Threading . Tasks ;
711using Microsoft . IdentityModel . Clients . ActiveDirectory ;
812
@@ -17,10 +21,186 @@ public static async Task<string> AzureActiveDirectoryAuthenticationCallback(stri
1721 AuthenticationResult result = await authContext . AcquireTokenAsync ( resource , clientCred ) ;
1822 if ( result == null )
1923 {
20- throw new InvalidOperationException ( $ "Failed to retrieve an access token for { resource } ") ;
24+ throw new Exception ( $ "Failed to retrieve an access token for { resource } ") ;
2125 }
2226
2327 return result . AccessToken ;
2428 }
29+
30+ public static async Task < string > GetManagedIdentityToken ( string objectId ) =>
31+ await new MockManagedIdentityTokenProvider ( ) . AcquireTokenAsync ( objectId ) . ConfigureAwait ( false ) ;
32+
33+ }
34+
35+ #region Mock Managed Identity Token Provider
36+ internal class MockManagedIdentityTokenProvider
37+ {
38+ // HttpClient is intended to be instantiated once and re-used throughout the life of an application.
39+ #if NETFRAMEWORK
40+ private static readonly HttpClient s_defaultHttpClient = new HttpClient ( ) ;
41+ #else
42+ private static readonly HttpClient s_defaultHttpClient = new HttpClient ( new HttpClientHandler ( ) { CheckCertificateRevocationList = true } ) ;
43+ #endif
44+
45+ private const string AzureVmImdsApiVersion = "&api-version=2018-02-01" ;
46+ private const string AccessToken = "access_token" ;
47+ private const string Resource = "https://database.windows.net" ;
48+
49+
50+ private const int DefaultRetryTimeout = 0 ;
51+ private const int DefaultMaxRetryCount = 5 ;
52+
53+ // Azure Instance Metadata Service (IMDS) endpoint
54+ private const string AzureVmImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" ;
55+
56+ // Timeout for Azure IMDS probe request
57+ internal const int AzureVmImdsProbeTimeoutInSeconds = 2 ;
58+ internal readonly TimeSpan _azureVmImdsProbeTimeout = TimeSpan . FromSeconds ( AzureVmImdsProbeTimeoutInSeconds ) ;
59+
60+ // Configurable timeout for MSI retry logic
61+ internal readonly int _retryTimeoutInSeconds = DefaultRetryTimeout ;
62+ internal readonly int _maxRetryCount = DefaultMaxRetryCount ;
63+
64+ public async Task < string > AcquireTokenAsync ( string objectId = null )
65+ {
66+ // Use the httpClient specified in the constructor. If it was not specified in the constructor, use the default httpClient.
67+ HttpClient httpClient = s_defaultHttpClient ;
68+
69+ try
70+ {
71+ // If user assigned managed identity is specified, include object ID parameter in request
72+ string objectIdParameter = objectId != default
73+ ? $ "&object_id={ objectId } "
74+ : string . Empty ;
75+
76+ // Craft request as per the MSI protocol
77+ var requestUrl = $ "{ AzureVmImdsEndpoint } ?resource={ Resource } { objectIdParameter } { AzureVmImdsApiVersion } ";
78+
79+ HttpResponseMessage response = null ;
80+
81+ try
82+ {
83+ response = await httpClient . SendAsyncWithRetry ( getRequestMessage , _retryTimeoutInSeconds , _maxRetryCount , default ) . ConfigureAwait ( false ) ;
84+ HttpRequestMessage getRequestMessage ( )
85+ {
86+ HttpRequestMessage request = new HttpRequestMessage ( HttpMethod . Get , requestUrl ) ;
87+ request . Headers . Add ( "Metadata" , "true" ) ;
88+ return request ;
89+ }
90+ }
91+ catch ( HttpRequestException )
92+ {
93+ // Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
94+ return null ;
95+ }
96+
97+ // If the response is successful, it should have JSON response with an access_token field
98+ if ( response . IsSuccessStatusCode )
99+ {
100+ string jsonResponse = await response . Content . ReadAsStringAsync ( ) . ConfigureAwait ( false ) ;
101+ int accessTokenStartIndex = jsonResponse . IndexOf ( AccessToken ) + AccessToken . Length + 3 ;
102+ return jsonResponse . Substring ( accessTokenStartIndex , jsonResponse . IndexOf ( '"' , accessTokenStartIndex ) - accessTokenStartIndex ) ;
103+ }
104+
105+ // RetryFailure : Failed after 5 retries.
106+ // NonRetryableError : Received a non-retryable error.
107+ string errorStatusDetail = response . IsRetryableStatusCode ( )
108+ ? "Failed after 5 retries"
109+ : "Received a non-retryable error." ;
110+
111+ string errorText = await response . Content . ReadAsStringAsync ( ) . ConfigureAwait ( false ) ;
112+
113+ // Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
114+ return null ;
115+ }
116+ catch ( Exception )
117+ {
118+ // Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
119+ return null ;
120+ }
121+ }
25122 }
123+
124+ #region IMDS Retry Helper
125+ internal static class SqlManagedIdentityRetryHelper
126+ {
127+ internal const int DeltaBackOffInSeconds = 2 ;
128+ internal const string RetryTimeoutError = "Reached retry timeout limit set by MsiRetryTimeout parameter in connection string." ;
129+
130+ // for unit test purposes
131+ internal static bool s_waitBeforeRetry = true ;
132+
133+ internal static bool IsRetryableStatusCode ( this HttpResponseMessage response )
134+ {
135+ // 404 NotFound, 429 TooManyRequests, and 5XX server error status codes are retryable
136+ return Regex . IsMatch ( ( ( int ) response . StatusCode ) . ToString ( ) , @"404|429|5\d{2}" ) ;
137+ }
138+
139+ /// <summary>
140+ /// Implements recommended retry guidance here: https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
141+ /// </summary>
142+ internal static async Task < HttpResponseMessage > SendAsyncWithRetry ( this HttpClient httpClient , Func < HttpRequestMessage > getRequest , int retryTimeoutInSeconds , int maxRetryCount , CancellationToken cancellationToken )
143+ {
144+ using ( var timeoutTokenSource = new CancellationTokenSource ( ) )
145+ using ( var linkedTokenSource = CancellationTokenSource . CreateLinkedTokenSource ( timeoutTokenSource . Token , cancellationToken ) )
146+ {
147+ try
148+ {
149+ // if retry timeout is configured, configure cancellation after timeout period elapses
150+ if ( retryTimeoutInSeconds > 0 )
151+ {
152+ timeoutTokenSource . CancelAfter ( TimeSpan . FromSeconds ( retryTimeoutInSeconds ) ) ;
153+ }
154+
155+ var attempts = 0 ;
156+ var backoffTimeInSecs = 0 ;
157+ HttpResponseMessage response ;
158+
159+ while ( true )
160+ {
161+ attempts ++ ;
162+
163+ try
164+ {
165+ response = await httpClient . SendAsync ( getRequest ( ) , linkedTokenSource . Token ) . ConfigureAwait ( false ) ;
166+
167+ if ( response . IsSuccessStatusCode || ! response . IsRetryableStatusCode ( ) || attempts == maxRetryCount )
168+ {
169+ break ;
170+ }
171+ }
172+ catch ( HttpRequestException )
173+ {
174+ if ( attempts == maxRetryCount )
175+ {
176+ throw ;
177+ }
178+ }
179+
180+ if ( s_waitBeforeRetry )
181+ {
182+ // use recommended exponential backoff strategy, and use linked token wait handle so caller or retry timeout is still able to cancel
183+ backoffTimeInSecs += ( int ) Math . Pow ( DeltaBackOffInSeconds , attempts ) ;
184+ linkedTokenSource . Token . WaitHandle . WaitOne ( TimeSpan . FromSeconds ( backoffTimeInSecs ) ) ;
185+ linkedTokenSource . Token . ThrowIfCancellationRequested ( ) ;
186+ }
187+ }
188+
189+ return response ;
190+ }
191+ catch ( OperationCanceledException )
192+ {
193+ if ( timeoutTokenSource . IsCancellationRequested )
194+ {
195+ throw new TimeoutException ( RetryTimeoutError ) ;
196+ }
197+
198+ throw ;
199+ }
200+ }
201+ }
202+ }
203+ #endregion
204+ #endregion
26205}
206+
0 commit comments