Skip to content

Commit

Permalink
Probe IMDS /metadata/instance endpoint during discovery in MsiAccessT…
Browse files Browse the repository at this point in the history
…okenProvider (#14631)

* Call imds instance endpoint instead of token endpoint to probe for imds, extend timeout

* Update msi token tests to distinguish between which imds endpoint is being called

* Fix comment and visibility in test

* Change per stpetrov

* re-trigger checks
  • Loading branch information
isaacbanner authored Sep 12, 2020
1 parent 6ae13e6 commit eae6c13
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,31 @@ internal enum MsiTestType
MsiAppJsonParseFailure,
MsiMissingToken,
MsiAppServicesIncorrectRequest,
MsiAzureVmTimeout,
MsiAzureVmImdsTimeout,
MsiUnresponsive,
MsiThrottled,
MsiTransientServerError
}

private readonly MsiTestType _msiTestType;

private const string _azureVmImdsInstanceEndpoint = "http://169.254.169.254/metadata/instance";
private const string _azureVmImdsTokenEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token";

internal MockMsi(MsiTestType msiTestType)
{
_msiTestType = msiTestType;
}

/// <summary>
/// Returns a response based on the response type.
/// Returns a response based on the response type.
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
// HitCount is updated when this method gets called. This allows for testing of cache and retry logic.
// HitCount is updated when this method gets called. This allows for testing of cache and retry logic.
HitCount++;

HttpResponseMessage responseMessage = null;
Expand Down Expand Up @@ -138,16 +141,20 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
};
break;

case MsiTestType.MsiAzureVmTimeout:
var start = DateTime.Now;
while(DateTime.Now - start < TimeSpan.FromSeconds(MsiAccessTokenProvider.AzureVmImdsProbeTimeoutInSeconds + 10))
case MsiTestType.MsiAzureVmImdsTimeout:
if (request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsInstanceEndpoint))
{
if (cancellationToken.IsCancellationRequested)
var start = DateTime.Now;
while (DateTime.Now - start < TimeSpan.FromSeconds(MsiAccessTokenProvider.AzureVmImdsProbeTimeoutInSeconds + 10))
{
throw new TaskCanceledException();
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
}
throw new Exception("Test fail");
}
throw new Exception("Test fail");
break;

case MsiTestType.MsiUnresponsive:
case MsiTestType.MsiThrottled:
Expand All @@ -167,7 +174,20 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
// give error based on test type
if (_msiTestType == MsiTestType.MsiUnresponsive)
{
throw new HttpRequestException();
if (request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsInstanceEndpoint))
{
responseMessage = new HttpResponseMessage
{
Content = new StringContent(TokenHelper.GetInstanceMetadataResponse(),
Encoding.UTF8,
Constants.JsonContentType)
};
}
else if (Environment.GetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv) != null
|| request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsTokenEndpoint))
{
throw new HttpRequestException();
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace Microsoft.Azure.Services.AppAuthentication.Unit.Tests
{
/// <summary>
/// Test cases for MsiAccessTokenProvider class. MsiAccessTokenProvider is an internal class.
/// Test cases for MsiAccessTokenProvider class. MsiAccessTokenProvider is an internal class.
/// </summary>
public class MsiAccessTokenProviderTests : IDisposable
{
Expand Down Expand Up @@ -48,26 +48,26 @@ public async Task GetTokenUsingManagedIdentityAzureVm(bool specifyUserAssignedMa
expectedAppId = Constants.TestAppId;
}

// MockMsi is being asked to act like response from Azure VM MSI succeeded.
// MockMsi is being asked to act like response from Azure VM MSI succeeded.
MockMsi mockMsi = new MockMsi(msiTestType);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient, managedIdentityClientId: managedIdentityArgument);

// Get token.
var authResult = await msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);

// Check if the principalused and type are as expected.
// Check if the principalused and type are as expected.
Validator.ValidateToken(authResult.AccessToken, msiAccessTokenProvider.PrincipalUsed, Constants.AppType, Constants.TenantId, expectedAppId, expiresOn: authResult.ExpiresOn);
}

/// <summary>
/// If json parse error when aquiring token, an exception should be thrown.
/// If json parse error when aquiring token, an exception should be thrown.
/// </summary>
/// <returns></returns>
[Fact]
public async Task ParseErrorMsiGetTokenTest()
{
// MockMsi is being asked to act like response from Azure VM MSI suceeded.
// MockMsi is being asked to act like response from Azure VM MSI suceeded.
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppJsonParseFailure);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand All @@ -80,13 +80,13 @@ public async Task ParseErrorMsiGetTokenTest()
}

/// <summary>
/// If MSI response if missing the token, an exception should be thrown.
/// If MSI response if missing the token, an exception should be thrown.
/// </summary>
/// <returns></returns>
[Fact]
public async Task MsiResponseMissingTokenTest()
{
// MockMsi is being asked to act like response from Azure VM MSI failed.
// MockMsi is being asked to act like response from Azure VM MSI failed.
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiMissingToken);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand All @@ -103,7 +103,7 @@ public async Task MsiResponseMissingTokenTest()
[InlineData(false)]
public async Task GetTokenUsingManagedIdentityAppServices(bool specifyUserAssignedManagedIdentity)
{
// Setup the environment variables that App Service MSI would setup.
// Setup the environment variables that App Service MSI would setup.
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

Expand All @@ -125,19 +125,19 @@ public async Task GetTokenUsingManagedIdentityAppServices(bool specifyUserAssign
expectedAppId = Constants.TestAppId;
}

// MockMsi is being asked to act like response from App Service MSI suceeded.
// MockMsi is being asked to act like response from App Service MSI suceeded.
MockMsi mockMsi = new MockMsi(msiTestType);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient, managedIdentityClientId: managedIdentityArgument);

// Get token. This confirms that the environment variables are being read.
// Get token. This confirms that the environment variables are being read.
var authResult = await msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);

Validator.ValidateToken(authResult.AccessToken, msiAccessTokenProvider.PrincipalUsed, Constants.AppType, Constants.TenantId, expectedAppId, expiresOn: authResult.ExpiresOn);
}

/// <summary>
/// Test response when IDENTITY_HEADER in AppServices MSI is invalid.
/// Test response when IDENTITY_HEADER in AppServices MSI is invalid.
/// </summary>
/// <returns></returns>
[Fact]
Expand All @@ -147,7 +147,7 @@ public async Task UnauthorizedTest()
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

// MockMsi is being asked to act like response from App Service MSI failed (unauthorized).
// MockMsi is being asked to act like response from App Service MSI failed (unauthorized).
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppServicesUnauthorized);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand All @@ -159,7 +159,7 @@ public async Task UnauthorizedTest()
}

/// <summary>
/// Test that response when MSI request is not valid is as expected.
/// Test that response when MSI request is not valid is as expected.
/// </summary>
/// <returns></returns>
[Fact]
Expand All @@ -180,7 +180,7 @@ public async Task IncorrectFormatTest()
}

/// <summary>
/// If an unexpected http response has been received, ensure exception is thrown.
/// If an unexpected http response has been received, ensure exception is thrown.
/// </summary>
/// <returns></returns>
[Fact]
Expand Down Expand Up @@ -208,13 +208,13 @@ public async Task AzureVmImdsTimeoutTest()
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, null);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, null);

MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAzureVmTimeout);
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAzureVmImdsTimeout);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);

var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId)));

Assert.Contains(AzureServiceTokenProviderException.MsiEndpointNotListening, exception.Message);
Assert.Contains(AzureServiceTokenProviderException.MetadataEndpointNotListening, exception.Message);
Assert.DoesNotContain(AzureServiceTokenProviderException.RetryFailure, exception.Message);
}

Expand All @@ -224,7 +224,7 @@ public async Task AzureVmImdsTimeoutTest()
[InlineData(MockMsi.MsiTestType.MsiTransientServerError)]
internal async Task TransientErrorRetryTest(MockMsi.MsiTestType testType)
{
// To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request by
// To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request by
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

Expand Down Expand Up @@ -254,12 +254,17 @@ internal async Task TransientErrorRetryTest(MockMsi.MsiTestType testType)
}
}

[Fact]
private async Task MsiRetryTimeoutTest()
[Theory]
[InlineData(false)]
[InlineData(true)]
internal async Task MsiRetryTimeoutTest(bool isAppServices)
{
// To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);
if (isAppServices)
{
// Mock as MSI App Services to skip Azure VM IDMS probe request
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);
}

int timeoutInSeconds = (new Random()).Next(1, 4);

Expand Down Expand Up @@ -291,11 +296,11 @@ private async Task AppServicesDifferentCultureTest()
// ensure thread culture is NOT using en-US culture (App Services MSI endpoint always uses en-US DateTime format)
Thread.CurrentThread.CurrentCulture = new CultureInfo("en-GB");

// Setup the environment variables that App Service MSI would setup.
// Setup the environment variables that App Service MSI would setup.
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

// MockMsi is being asked to act like response from App Service MSI suceeded.
// MockMsi is being asked to act like response from App Service MSI suceeded.
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppServicesSuccess);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Azure.Services.AppAuthentication.Unit.Tests
public class TokenHelper
{
/// <summary>
/// The hardcoded user token has expiry replaced by [exp], so we can replace it with some value to test functionality.
/// The hardcoded user token has expiry replaced by [exp], so we can replace it with some value to test functionality.
/// </summary>
/// <param name="accessToken"></param>
/// <param name="secondsFromCurrent"></param>
Expand All @@ -30,7 +30,7 @@ private static string UpdateTokenTime(string accessToken, long secondsFromCurren

internal static string GetUserToken()
{
// Gets a user token that will expire in 10 seconds from now.
// Gets a user token that will expire in 10 seconds from now.
return GetUserToken(10);
}

Expand Down Expand Up @@ -63,6 +63,16 @@ internal static string GetUserTokenResponse(long secondsFromCurrent, bool format
return tokenResult;
}

/// <summary>
/// Sample IMDS /instance response
/// </summary>
/// <returns></returns>
internal static string GetInstanceMetadataResponse()
{
return
"{\"compute\":{\"location\":\"westus\",\"name\":\"TestBedVm\",\"resourceGroupName\":\"testbed\",\"subscriptionId\":\"bdd789f3-d9d1-4bea-ac14-30a39ed66d33\"}}";
}

/// <summary>
/// The response has claims as expected from App Service MSI response
/// </summary>
Expand Down Expand Up @@ -128,7 +138,7 @@ internal static string GetInvalidMsiTokenResponse()
}

/// <summary>
/// The response has claims as expected from Client Credentials flow response.
/// The response has claims as expected from Client Credentials flow response.
/// </summary>
/// <returns></returns>
internal static string GetAppToken()
Expand All @@ -139,7 +149,7 @@ internal static string GetAppToken()
}

/// <summary>
/// Invalid AppToken.
/// Invalid AppToken.
/// </summary>
/// <returns></returns>
internal static string GetInvalidAppToken()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
namespace Microsoft.Azure.Services.AppAuthentication
{
/// <summary>
/// Instance of this exception is thrown if access token cannot be acquired.
/// Instance of this exception is thrown if access token cannot be acquired.
/// </summary>
#if FullNetFx || NETSTANDARD2_0
[Serializable]
#endif

public class AzureServiceTokenProviderException : Exception
{
internal const string MetadataEndpointNotListening = "Unable to connect to the Instance Metadata Service (IMDS). Skipping request to the Managed Service Identity (MSI) token endpoint.";

internal const string MsiEndpointNotListening = "Unable to connect to the Managed Service Identity (MSI) endpoint. Please check that you are running on an Azure resource that has MSI setup.";

internal const string UnableToParseMsiTokenResponse = "A successful response was received from Managed Service Identity, but it could not be parsed.";
Expand All @@ -42,7 +44,7 @@ public class AzureServiceTokenProviderException : Exception
internal const string NonRetryableError = "Received a non-retryable error.";

/// <summary>
/// Creates an instance of AzureServiceTokenProviderException.
/// Creates an instance of AzureServiceTokenProviderException.
/// </summary>
/// <param name="connectionString">Connection string used.</param>
/// <param name="resource">Resource for which token was expected.</param>
Expand Down
Loading

0 comments on commit eae6c13

Please sign in to comment.