diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/Mocks/MockMsi.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/Mocks/MockMsi.cs index 30c60f7b57d7e..342450e042244 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/Mocks/MockMsi.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/Mocks/MockMsi.cs @@ -33,7 +33,7 @@ internal enum MsiTestType MsiAppJsonParseFailure, MsiMissingToken, MsiAppServicesIncorrectRequest, - MsiAzureVmTimeout, + MsiAzureVmImdsTimeout, MsiUnresponsive, MsiThrottled, MsiTransientServerError @@ -41,20 +41,23 @@ internal enum MsiTestType private readonly MsiTestType _msiTestType; + private const string _azureVmImdsInstanceEndpoint = "http://169.254.169.254/metadata/instance"; + private const string _azureVmImdsTokenEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"; + internal MockMsi(MsiTestType msiTestType) { _msiTestType = msiTestType; } /// - /// Returns a response based on the response type. + /// Returns a response based on the response type. /// /// /// /// protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - // HitCount is updated when this method gets called. This allows for testing of cache and retry logic. + // HitCount is updated when this method gets called. This allows for testing of cache and retry logic. HitCount++; HttpResponseMessage responseMessage = null; @@ -138,16 +141,20 @@ protected override Task SendAsync(HttpRequestMessage reques }; break; - case MsiTestType.MsiAzureVmTimeout: - var start = DateTime.Now; - while(DateTime.Now - start < TimeSpan.FromSeconds(MsiAccessTokenProvider.AzureVmImdsProbeTimeoutInSeconds + 10)) + case MsiTestType.MsiAzureVmImdsTimeout: + if (request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsInstanceEndpoint)) { - if (cancellationToken.IsCancellationRequested) + var start = DateTime.Now; + while (DateTime.Now - start < TimeSpan.FromSeconds(MsiAccessTokenProvider.AzureVmImdsProbeTimeoutInSeconds + 10)) { - throw new TaskCanceledException(); + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); + } } + throw new Exception("Test fail"); } - throw new Exception("Test fail"); + break; case MsiTestType.MsiUnresponsive: case MsiTestType.MsiThrottled: @@ -167,7 +174,20 @@ protected override Task SendAsync(HttpRequestMessage reques // give error based on test type if (_msiTestType == MsiTestType.MsiUnresponsive) { - throw new HttpRequestException(); + if (request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsInstanceEndpoint)) + { + responseMessage = new HttpResponseMessage + { + Content = new StringContent(TokenHelper.GetInstanceMetadataResponse(), + Encoding.UTF8, + Constants.JsonContentType) + }; + } + else if (Environment.GetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv) != null + || request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsTokenEndpoint)) + { + throw new HttpRequestException(); + } } else { diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/MsiAccessTokenProviderTests.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/MsiAccessTokenProviderTests.cs index 060e2bc243d11..b49df6af416d7 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/MsiAccessTokenProviderTests.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/MsiAccessTokenProviderTests.cs @@ -14,7 +14,7 @@ namespace Microsoft.Azure.Services.AppAuthentication.Unit.Tests { /// - /// Test cases for MsiAccessTokenProvider class. MsiAccessTokenProvider is an internal class. + /// Test cases for MsiAccessTokenProvider class. MsiAccessTokenProvider is an internal class. /// public class MsiAccessTokenProviderTests : IDisposable { @@ -48,7 +48,7 @@ public async Task GetTokenUsingManagedIdentityAzureVm(bool specifyUserAssignedMa expectedAppId = Constants.TestAppId; } - // MockMsi is being asked to act like response from Azure VM MSI succeeded. + // MockMsi is being asked to act like response from Azure VM MSI succeeded. MockMsi mockMsi = new MockMsi(msiTestType); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient, managedIdentityClientId: managedIdentityArgument); @@ -56,18 +56,18 @@ public async Task GetTokenUsingManagedIdentityAzureVm(bool specifyUserAssignedMa // Get token. var authResult = await msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false); - // Check if the principalused and type are as expected. + // Check if the principalused and type are as expected. Validator.ValidateToken(authResult.AccessToken, msiAccessTokenProvider.PrincipalUsed, Constants.AppType, Constants.TenantId, expectedAppId, expiresOn: authResult.ExpiresOn); } /// - /// If json parse error when aquiring token, an exception should be thrown. + /// If json parse error when aquiring token, an exception should be thrown. /// /// [Fact] public async Task ParseErrorMsiGetTokenTest() { - // MockMsi is being asked to act like response from Azure VM MSI suceeded. + // MockMsi is being asked to act like response from Azure VM MSI suceeded. MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppJsonParseFailure); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient); @@ -80,13 +80,13 @@ public async Task ParseErrorMsiGetTokenTest() } /// - /// If MSI response if missing the token, an exception should be thrown. + /// If MSI response if missing the token, an exception should be thrown. /// /// [Fact] public async Task MsiResponseMissingTokenTest() { - // MockMsi is being asked to act like response from Azure VM MSI failed. + // MockMsi is being asked to act like response from Azure VM MSI failed. MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiMissingToken); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient); @@ -103,7 +103,7 @@ public async Task MsiResponseMissingTokenTest() [InlineData(false)] public async Task GetTokenUsingManagedIdentityAppServices(bool specifyUserAssignedManagedIdentity) { - // Setup the environment variables that App Service MSI would setup. + // Setup the environment variables that App Service MSI would setup. Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint); Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret); @@ -125,19 +125,19 @@ public async Task GetTokenUsingManagedIdentityAppServices(bool specifyUserAssign expectedAppId = Constants.TestAppId; } - // MockMsi is being asked to act like response from App Service MSI suceeded. + // MockMsi is being asked to act like response from App Service MSI suceeded. MockMsi mockMsi = new MockMsi(msiTestType); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient, managedIdentityClientId: managedIdentityArgument); - // Get token. This confirms that the environment variables are being read. + // Get token. This confirms that the environment variables are being read. var authResult = await msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false); Validator.ValidateToken(authResult.AccessToken, msiAccessTokenProvider.PrincipalUsed, Constants.AppType, Constants.TenantId, expectedAppId, expiresOn: authResult.ExpiresOn); } /// - /// Test response when IDENTITY_HEADER in AppServices MSI is invalid. + /// Test response when IDENTITY_HEADER in AppServices MSI is invalid. /// /// [Fact] @@ -147,7 +147,7 @@ public async Task UnauthorizedTest() Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint); Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret); - // MockMsi is being asked to act like response from App Service MSI failed (unauthorized). + // MockMsi is being asked to act like response from App Service MSI failed (unauthorized). MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppServicesUnauthorized); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient); @@ -159,7 +159,7 @@ public async Task UnauthorizedTest() } /// - /// Test that response when MSI request is not valid is as expected. + /// Test that response when MSI request is not valid is as expected. /// /// [Fact] @@ -180,7 +180,7 @@ public async Task IncorrectFormatTest() } /// - /// If an unexpected http response has been received, ensure exception is thrown. + /// If an unexpected http response has been received, ensure exception is thrown. /// /// [Fact] @@ -208,13 +208,13 @@ public async Task AzureVmImdsTimeoutTest() Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, null); Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, null); - MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAzureVmTimeout); + MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAzureVmImdsTimeout); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient); var exception = await Assert.ThrowsAsync(() => Task.Run(() => msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId))); - Assert.Contains(AzureServiceTokenProviderException.MsiEndpointNotListening, exception.Message); + Assert.Contains(AzureServiceTokenProviderException.MetadataEndpointNotListening, exception.Message); Assert.DoesNotContain(AzureServiceTokenProviderException.RetryFailure, exception.Message); } @@ -224,7 +224,7 @@ public async Task AzureVmImdsTimeoutTest() [InlineData(MockMsi.MsiTestType.MsiTransientServerError)] internal async Task TransientErrorRetryTest(MockMsi.MsiTestType testType) { - // To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request by + // To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request by Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint); Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret); @@ -254,12 +254,17 @@ internal async Task TransientErrorRetryTest(MockMsi.MsiTestType testType) } } - [Fact] - private async Task MsiRetryTimeoutTest() + [Theory] + [InlineData(false)] + [InlineData(true)] + internal async Task MsiRetryTimeoutTest(bool isAppServices) { - // To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request - Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint); - Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret); + if (isAppServices) + { + // Mock as MSI App Services to skip Azure VM IDMS probe request + Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint); + Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret); + } int timeoutInSeconds = (new Random()).Next(1, 4); @@ -291,11 +296,11 @@ private async Task AppServicesDifferentCultureTest() // ensure thread culture is NOT using en-US culture (App Services MSI endpoint always uses en-US DateTime format) Thread.CurrentThread.CurrentCulture = new CultureInfo("en-GB"); - // Setup the environment variables that App Service MSI would setup. + // Setup the environment variables that App Service MSI would setup. Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint); Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret); - // MockMsi is being asked to act like response from App Service MSI suceeded. + // MockMsi is being asked to act like response from App Service MSI suceeded. MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppServicesSuccess); HttpClient httpClient = new HttpClient(mockMsi); MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient); diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/TokenHelper.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/TokenHelper.cs index 7e3f5ec2b7082..49f4add3e2d9a 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/TokenHelper.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/TokenHelper.cs @@ -12,7 +12,7 @@ namespace Microsoft.Azure.Services.AppAuthentication.Unit.Tests public class TokenHelper { /// - /// The hardcoded user token has expiry replaced by [exp], so we can replace it with some value to test functionality. + /// The hardcoded user token has expiry replaced by [exp], so we can replace it with some value to test functionality. /// /// /// @@ -30,7 +30,7 @@ private static string UpdateTokenTime(string accessToken, long secondsFromCurren internal static string GetUserToken() { - // Gets a user token that will expire in 10 seconds from now. + // Gets a user token that will expire in 10 seconds from now. return GetUserToken(10); } @@ -63,6 +63,16 @@ internal static string GetUserTokenResponse(long secondsFromCurrent, bool format return tokenResult; } + /// + /// Sample IMDS /instance response + /// + /// + internal static string GetInstanceMetadataResponse() + { + return + "{\"compute\":{\"location\":\"westus\",\"name\":\"TestBedVm\",\"resourceGroupName\":\"testbed\",\"subscriptionId\":\"bdd789f3-d9d1-4bea-ac14-30a39ed66d33\"}}"; + } + /// /// The response has claims as expected from App Service MSI response /// @@ -128,7 +138,7 @@ internal static string GetInvalidMsiTokenResponse() } /// - /// The response has claims as expected from Client Credentials flow response. + /// The response has claims as expected from Client Credentials flow response. /// /// internal static string GetAppToken() @@ -139,7 +149,7 @@ internal static string GetAppToken() } /// - /// Invalid AppToken. + /// Invalid AppToken. /// /// internal static string GetInvalidAppToken() diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderException.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderException.cs index 5f563a24a21c7..9b1ade57a66c2 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderException.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderException.cs @@ -9,7 +9,7 @@ namespace Microsoft.Azure.Services.AppAuthentication { /// - /// Instance of this exception is thrown if access token cannot be acquired. + /// Instance of this exception is thrown if access token cannot be acquired. /// #if FullNetFx || NETSTANDARD2_0 [Serializable] @@ -17,6 +17,8 @@ namespace Microsoft.Azure.Services.AppAuthentication public class AzureServiceTokenProviderException : Exception { + internal const string MetadataEndpointNotListening = "Unable to connect to the Instance Metadata Service (IMDS). Skipping request to the Managed Service Identity (MSI) token endpoint."; + internal const string MsiEndpointNotListening = "Unable to connect to the Managed Service Identity (MSI) endpoint. Please check that you are running on an Azure resource that has MSI setup."; internal const string UnableToParseMsiTokenResponse = "A successful response was received from Managed Service Identity, but it could not be parsed."; @@ -42,7 +44,7 @@ public class AzureServiceTokenProviderException : Exception internal const string NonRetryableError = "Received a non-retryable error."; /// - /// Creates an instance of AzureServiceTokenProviderException. + /// Creates an instance of AzureServiceTokenProviderException. /// /// Connection string used. /// Resource for which token was expected. diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/MsiAccessTokenProvider.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/MsiAccessTokenProvider.cs index 6a3fd99a18865..adb65c9546047 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/MsiAccessTokenProvider.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/MsiAccessTokenProvider.cs @@ -10,7 +10,7 @@ namespace Microsoft.Azure.Services.AppAuthentication { /// - /// Gets a token using Azure VM or App Services MSI. + /// Gets a token using Azure VM or App Services MSI. /// https://docs.microsoft.com/en-us/azure/active-directory/msi-overview /// internal class MsiAccessTokenProvider : NonInteractiveAzureServiceTokenProviderBase @@ -21,7 +21,7 @@ internal class MsiAccessTokenProvider : NonInteractiveAzureServiceTokenProviderB // This client ID can be specified in the constructor to specify a specific managed identity to use (e.g. user-assigned identity) private readonly string _managedIdentityClientId; - // HttpClient is intended to be instantiated once and re-used throughout the life of an application. + // HttpClient is intended to be instantiated once and re-used throughout the life of an application. #if NETSTANDARD1_4 || net452 || net461 private static readonly HttpClient DefaultHttpClient = new HttpClient(); #else @@ -29,10 +29,14 @@ internal class MsiAccessTokenProvider : NonInteractiveAzureServiceTokenProviderB #endif // Azure Instance Metadata Service (IMDS) endpoint - private const string AzureVmImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"; + private const string AzureVmImdsEndpoint = "http://169.254.169.254"; + private const string ImdsInstanceRoute = "/metadata/instance"; + private const string ImdsTokenRoute = "/metadata/identity/oauth2/token"; + private const string ImdsInstanceApiVersion = "2020-06-01"; + private const string ImdsTokenApiVersion = "2019-11-01"; // Timeout for Azure IMDS probe request - internal const int AzureVmImdsProbeTimeoutInSeconds = 2; + internal const int AzureVmImdsProbeTimeoutInSeconds = 3; internal readonly TimeSpan AzureVmImdsProbeTimeout = TimeSpan.FromSeconds(AzureVmImdsProbeTimeoutInSeconds); // Configurable timeout for MSI retry logic @@ -61,12 +65,12 @@ internal MsiAccessTokenProvider(HttpClient httpClient, int retryTimeoutInSeconds public override async Task GetAuthResultAsync(string resource, string authority, CancellationToken cancellationToken = default) { - // Use the httpClient specified in the constructor. If it was not specified in the constructor, use the default httpClient. + // Use the httpClient specified in the constructor. If it was not specified in the constructor, use the default httpClient. HttpClient httpClient = _httpClient ?? DefaultHttpClient; try { - // Check if App Services MSI is available. If both these environment variables are set, then it is. + // Check if App Services MSI is available. If both these environment variables are set, then it is. string msiEndpoint = Environment.GetEnvironmentVariable("IDENTITY_ENDPOINT"); string msiHeader = Environment.GetEnvironmentVariable("IDENTITY_HEADER"); var isAppServicesMsiAvailable = !string.IsNullOrWhiteSpace(msiEndpoint) && !string.IsNullOrWhiteSpace(msiHeader); @@ -77,7 +81,8 @@ public override async Task GetAuthResultAsync(string re using (var internalTokenSource = new CancellationTokenSource()) using (var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(internalTokenSource.Token, cancellationToken)) { - HttpRequestMessage imdsProbeRequest = new HttpRequestMessage(HttpMethod.Get, AzureVmImdsEndpoint); + string probeRequestUrl = $"{AzureVmImdsEndpoint}{ImdsInstanceRoute}?api-version={ImdsInstanceApiVersion}"; + HttpRequestMessage imdsProbeRequest = new HttpRequestMessage(HttpMethod.Get, probeRequestUrl); try { @@ -90,7 +95,7 @@ public override async Task GetAuthResultAsync(string re if (internalTokenSource.Token.IsCancellationRequested) { throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, - $"{AzureServiceTokenProviderException.ManagedServiceIdentityUsed} {AzureServiceTokenProviderException.MsiEndpointNotListening}"); + $"{AzureServiceTokenProviderException.ManagedServiceIdentityUsed} {AzureServiceTokenProviderException.MetadataEndpointNotListening}"); } throw; @@ -106,7 +111,7 @@ public override async Task GetAuthResultAsync(string re // Craft request as per the MSI protocol var requestUrl = isAppServicesMsiAvailable ? $"{msiEndpoint}?resource={resource}{clientIdParameter}&api-version=2019-08-01" - : $"{AzureVmImdsEndpoint}?resource={resource}{clientIdParameter}&api-version=2018-02-01"; + : $"{AzureVmImdsEndpoint}{ImdsTokenRoute}?resource={resource}{clientIdParameter}&api-version={ImdsTokenApiVersion}"; Func getRequestMessage = () => { @@ -153,7 +158,7 @@ public override async Task GetAuthResultAsync(string re PrincipalUsed.AppId = token.AppId; PrincipalUsed.TenantId = token.TenantId; } - + return AppAuthenticationResult.Create(tokenResponse); }