Skip to content

Commit 569470a

Browse files
Tests for Managed Identity - All pipelines updated
1 parent 2e815e4 commit 569470a

File tree

6 files changed

+336
-3
lines changed

6 files changed

+336
-3
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureManagedIdentityAuthenticationProvider.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ public override bool IsSupported(SqlAuthenticationMethod authentication)
168168
}
169169
}
170170

171+
#region IMDS Retry Helper
171172
internal static class SqlManagedIdentityRetryHelper
172173
{
173174
internal const int DeltaBackOffInSeconds = 2;
@@ -246,4 +247,5 @@ internal static async Task<HttpResponseMessage> SendAsyncWithRetry(this HttpClie
246247
}
247248
}
248249
}
250+
#endregion
249251
}

src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/AADUtility.cs

Lines changed: 181 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Diagnostics;
7+
using System.Net.Http;
8+
using System.Text.RegularExpressions;
9+
using System.Threading;
610
using System.Threading.Tasks;
711
using Microsoft.IdentityModel.Clients.ActiveDirectory;
812

@@ -17,10 +21,186 @@ public static async Task<string> AzureActiveDirectoryAuthenticationCallback(stri
1721
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCred);
1822
if (result == null)
1923
{
20-
throw new InvalidOperationException($"Failed to retrieve an access token for {resource}");
24+
throw new Exception($"Failed to retrieve an access token for {resource}");
2125
}
2226

2327
return result.AccessToken;
2428
}
29+
30+
public static async Task<string> GetManagedIdentityToken(string objectId) =>
31+
await new MockManagedIdentityTokenProvider().AcquireTokenAsync(objectId).ConfigureAwait(false);
32+
33+
}
34+
35+
#region Mock Managed Identity Token Provider
36+
internal class MockManagedIdentityTokenProvider
37+
{
38+
// HttpClient is intended to be instantiated once and re-used throughout the life of an application.
39+
#if NETFRAMEWORK
40+
private static readonly HttpClient s_defaultHttpClient = new HttpClient();
41+
#else
42+
private static readonly HttpClient s_defaultHttpClient = new HttpClient(new HttpClientHandler() { CheckCertificateRevocationList = true });
43+
#endif
44+
45+
private const string AzureVmImdsApiVersion = "&api-version=2018-02-01";
46+
private const string AccessToken = "access_token";
47+
private const string Resource = "https://database.windows.net";
48+
49+
50+
private const int DefaultRetryTimeout = 0;
51+
private const int DefaultMaxRetryCount = 5;
52+
53+
// Azure Instance Metadata Service (IMDS) endpoint
54+
private const string AzureVmImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token";
55+
56+
// Timeout for Azure IMDS probe request
57+
internal const int AzureVmImdsProbeTimeoutInSeconds = 2;
58+
internal readonly TimeSpan _azureVmImdsProbeTimeout = TimeSpan.FromSeconds(AzureVmImdsProbeTimeoutInSeconds);
59+
60+
// Configurable timeout for MSI retry logic
61+
internal readonly int _retryTimeoutInSeconds = DefaultRetryTimeout;
62+
internal readonly int _maxRetryCount = DefaultMaxRetryCount;
63+
64+
public async Task<string> AcquireTokenAsync(string objectId = null)
65+
{
66+
// Use the httpClient specified in the constructor. If it was not specified in the constructor, use the default httpClient.
67+
HttpClient httpClient = s_defaultHttpClient;
68+
69+
try
70+
{
71+
// If user assigned managed identity is specified, include object ID parameter in request
72+
string objectIdParameter = objectId != default
73+
? $"&object_id={objectId}"
74+
: string.Empty;
75+
76+
// Craft request as per the MSI protocol
77+
var requestUrl = $"{AzureVmImdsEndpoint}?resource={Resource}{objectIdParameter}{AzureVmImdsApiVersion}";
78+
79+
HttpResponseMessage response = null;
80+
81+
try
82+
{
83+
response = await httpClient.SendAsyncWithRetry(getRequestMessage, _retryTimeoutInSeconds, _maxRetryCount, default).ConfigureAwait(false);
84+
HttpRequestMessage getRequestMessage()
85+
{
86+
HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, requestUrl);
87+
request.Headers.Add("Metadata", "true");
88+
return request;
89+
}
90+
}
91+
catch (HttpRequestException)
92+
{
93+
// Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
94+
return null;
95+
}
96+
97+
// If the response is successful, it should have JSON response with an access_token field
98+
if (response.IsSuccessStatusCode)
99+
{
100+
string jsonResponse = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
101+
int accessTokenStartIndex = jsonResponse.IndexOf(AccessToken) + AccessToken.Length + 3;
102+
return jsonResponse.Substring(accessTokenStartIndex, jsonResponse.IndexOf('"', accessTokenStartIndex) - accessTokenStartIndex);
103+
}
104+
105+
// RetryFailure : Failed after 5 retries.
106+
// NonRetryableError : Received a non-retryable error.
107+
string errorStatusDetail = response.IsRetryableStatusCode()
108+
? "Failed after 5 retries"
109+
: "Received a non-retryable error.";
110+
111+
string errorText = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
112+
113+
// Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
114+
return null;
115+
}
116+
catch (Exception)
117+
{
118+
// Not throwing exception if Access Token cannot be fetched. Tests will be disabled.
119+
return null;
120+
}
121+
}
25122
}
123+
124+
#region IMDS Retry Helper
125+
internal static class SqlManagedIdentityRetryHelper
126+
{
127+
internal const int DeltaBackOffInSeconds = 2;
128+
internal const string RetryTimeoutError = "Reached retry timeout limit set by MsiRetryTimeout parameter in connection string.";
129+
130+
// for unit test purposes
131+
internal static bool s_waitBeforeRetry = true;
132+
133+
internal static bool IsRetryableStatusCode(this HttpResponseMessage response)
134+
{
135+
// 404 NotFound, 429 TooManyRequests, and 5XX server error status codes are retryable
136+
return Regex.IsMatch(((int)response.StatusCode).ToString(), @"404|429|5\d{2}");
137+
}
138+
139+
/// <summary>
140+
/// Implements recommended retry guidance here: https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
141+
/// </summary>
142+
internal static async Task<HttpResponseMessage> SendAsyncWithRetry(this HttpClient httpClient, Func<HttpRequestMessage> getRequest, int retryTimeoutInSeconds, int maxRetryCount, CancellationToken cancellationToken)
143+
{
144+
using (var timeoutTokenSource = new CancellationTokenSource())
145+
using (var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, cancellationToken))
146+
{
147+
try
148+
{
149+
// if retry timeout is configured, configure cancellation after timeout period elapses
150+
if (retryTimeoutInSeconds > 0)
151+
{
152+
timeoutTokenSource.CancelAfter(TimeSpan.FromSeconds(retryTimeoutInSeconds));
153+
}
154+
155+
var attempts = 0;
156+
var backoffTimeInSecs = 0;
157+
HttpResponseMessage response;
158+
159+
while (true)
160+
{
161+
attempts++;
162+
163+
try
164+
{
165+
response = await httpClient.SendAsync(getRequest(), linkedTokenSource.Token).ConfigureAwait(false);
166+
167+
if (response.IsSuccessStatusCode || !response.IsRetryableStatusCode() || attempts == maxRetryCount)
168+
{
169+
break;
170+
}
171+
}
172+
catch (HttpRequestException)
173+
{
174+
if (attempts == maxRetryCount)
175+
{
176+
throw;
177+
}
178+
}
179+
180+
if (s_waitBeforeRetry)
181+
{
182+
// use recommended exponential backoff strategy, and use linked token wait handle so caller or retry timeout is still able to cancel
183+
backoffTimeInSecs += (int)Math.Pow(DeltaBackOffInSeconds, attempts);
184+
linkedTokenSource.Token.WaitHandle.WaitOne(TimeSpan.FromSeconds(backoffTimeInSecs));
185+
linkedTokenSource.Token.ThrowIfCancellationRequested();
186+
}
187+
}
188+
189+
return response;
190+
}
191+
catch (OperationCanceledException)
192+
{
193+
if (timeoutTokenSource.IsCancellationRequested)
194+
{
195+
throw new TimeoutException(RetryTimeoutError);
196+
}
197+
198+
throw;
199+
}
200+
}
201+
}
202+
}
203+
#endregion
204+
#endregion
26205
}
206+

src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,16 @@ public static class DataTestUtility
4747
public static readonly string DNSCachingServerTR = null; // this is for the tenant ring
4848
public static readonly bool IsDNSCachingSupportedCR = false; // this is for the control ring
4949
public static readonly bool IsDNSCachingSupportedTR = false; // this is for the tenant ring
50+
public static readonly string UserManagedIdentityObjectId = null;
5051

5152
public static readonly string EnclaveAzureDatabaseConnString = null;
52-
53+
public static bool ManagedIdentity = true;
5354
public static string AADAccessToken = null;
55+
public static string AADSystemIdentityAccessToken = null;
56+
public static string AADUserIdentityAccessToken = null;
5457
public const string UdtTestDbName = "UdtTestDb";
5558
public const string AKVKeyName = "TestSqlClientAzureKeyVaultProvider";
59+
5660
private const string ManagedNetworkingAppContextSwitch = "Switch.Microsoft.Data.SqlClient.UseManagedNetworkingOnWindows";
5761

5862
private static Dictionary<string, bool> AvailableDatabases;
@@ -83,6 +87,7 @@ static DataTestUtility()
8387
IsDNSCachingSupportedCR = c.IsDNSCachingSupportedCR;
8488
IsDNSCachingSupportedTR = c.IsDNSCachingSupportedTR;
8589
EnclaveAzureDatabaseConnString = c.EnclaveAzureDatabaseConnString;
90+
UserManagedIdentityObjectId = c.UserManagedIdentityObjectId;
8691

8792
System.Net.ServicePointManager.SecurityProtocol |= System.Net.SecurityProtocolType.Tls12;
8893

@@ -403,8 +408,39 @@ public static string GetAccessToken()
403408
return (null != AADAccessToken) ? new string(AADAccessToken.ToCharArray()) : null;
404409
}
405410

411+
public static string GetSystemIdentityAccessToken()
412+
{
413+
if (true == ManagedIdentity && null == AADSystemIdentityAccessToken && IsAADPasswordConnStrSetup())
414+
{
415+
AADSystemIdentityAccessToken = AADUtility.GetManagedIdentityToken(null).GetAwaiter().GetResult();
416+
if (AADSystemIdentityAccessToken == null)
417+
{
418+
ManagedIdentity = false;
419+
}
420+
}
421+
return (null != AADSystemIdentityAccessToken) ? new string(AADSystemIdentityAccessToken.ToCharArray()) : null;
422+
}
423+
424+
public static string GetUserIdentityAccessToken()
425+
{
426+
if (true == ManagedIdentity && null == AADUserIdentityAccessToken && IsAADPasswordConnStrSetup())
427+
{
428+
// Pass User Assigned Managed Identity Object Id here.
429+
AADUserIdentityAccessToken = AADUtility.GetManagedIdentityToken(UserManagedIdentityObjectId).GetAwaiter().GetResult();
430+
if (AADSystemIdentityAccessToken == null)
431+
{
432+
ManagedIdentity = false;
433+
}
434+
}
435+
return (null != AADUserIdentityAccessToken) ? new string(AADUserIdentityAccessToken.ToCharArray()) : null;
436+
}
437+
406438
public static bool IsAccessTokenSetup() => !string.IsNullOrEmpty(GetAccessToken());
407439

440+
public static bool IsSystemIdentityTokenSetup() => !string.IsNullOrEmpty(GetSystemIdentityAccessToken());
441+
442+
public static bool IsUserIdentityTokenSetup() => !string.IsNullOrEmpty(GetUserIdentityAccessToken());
443+
408444
public static bool IsFileStreamSetup() => SupportsFileStream;
409445

410446
private static bool CheckException<TException>(Exception ex, string exceptionMessage, bool innerExceptionMustBeNull) where TException : Exception

0 commit comments

Comments
 (0)