diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 5bfcf9829d..ccb4602695 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -11,6 +11,8 @@ using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.OAuth2.Throttling; using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Client.Utils; @@ -37,7 +39,7 @@ public static async Task GetCsrMetadataAsync( var headers = new Dictionary { { "Metadata", "true" }, - { "x-ms-client-request-id", requestContext.CorrelationId.ToString() } + { OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() } }; IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; @@ -194,7 +196,7 @@ private async Task ExecuteCertificateRequestAsync(st var headers = new Dictionary { { "Metadata", "true" }, - { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } + { OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString() } }; var certificateRequestBody = new CertificateRequestBody() @@ -264,11 +266,23 @@ protected override async Task CreateRequestAsync(string privateKey); ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); - request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); + + var idParams = MsalIdHelper.GetMsalIdParameters(_requestContext.Logger); + foreach (var idParam in idParams) + { + request.Headers[idParam.Key] = idParam.Value; + } + request.Headers.Add(OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString()); + request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue); + request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true"); + request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); - request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate); + request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); - request.RequestType = RequestType.Imds; + request.BodyParameters.Add("token_type", "bearer"); + + request.RequestType = RequestType.STS; + request.MtlsCertificate = mtlsCertificate; return request; diff --git a/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs b/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs index 7aa76e3cd1..a6003809d9 100644 --- a/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs +++ b/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs @@ -77,6 +77,7 @@ internal static class OAuth2RequestedTokenUse internal static class OAuth2Header { public const string CorrelationId = "client-request-id"; + public const string XMsCorrelationId = $"x-ms-{CorrelationId}"; public const string RequestCorrelationIdInResponse = "return-client-request-id"; public const string AppName = "x-app-name"; public const string AppVer = "x-app-ver"; diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 04665fc0dd..1afc958e91 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -11,9 +11,12 @@ using Castle.Core.Logging; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.OAuth2.Throttling; using Microsoft.Identity.Client.Utils; using Microsoft.Identity.Test.Unit; using Microsoft.VisualStudio.TestTools.UnitTesting.Logging; @@ -592,14 +595,19 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce public static MockHttpMessageHandler MockCsrResponse( HttpStatusCode statusCode = HttpStatusCode.OK, string responseServerHeader = "IMDS/150.870.65.1854", - UserAssignedIdentityId idType = UserAssignedIdentityId.None, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, string userAssignedId = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - if (idType != UserAssignedIdentityId.None && userAssignedId != null) + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null); + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); } expectedQueryParams.Add("cred-api-version", "2.0"); @@ -619,6 +627,7 @@ public static MockHttpMessageHandler MockCsrResponse( ExpectedMethod = HttpMethod.Get, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, ResponseMessage = new HttpResponseMessage(statusCode) { Content = new StringContent(content), @@ -639,14 +648,20 @@ public static MockHttpMessageHandler MockCsrResponseFailure() } public static MockHttpMessageHandler MockCertificateRequestResponse( - UserAssignedIdentityId idType = UserAssignedIdentityId.None, - string userAssignedId = null) + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificate = TestConstants.ValidPemCertificate) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - if (idType != UserAssignedIdentityId.None && userAssignedId != null) + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null); + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); } expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); @@ -656,7 +671,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( "{" + "\"client_id\": \"" + TestConstants.ClientId + "\"," + "\"tenant_id\": \"" + TestConstants.TenantId + "\"," + - "\"certificate\": \"" + TestConstants.ValidPemCertificate + "\"," + + "\"certificate\": \"" + certificate + "\"," + "\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned", it doesn't matter for these tests "\"mtls_authentication_endpoint\": \"" + TestConstants.MtlsAuthenticationEndpoint + "\"," + "}"; @@ -667,6 +682,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( ExpectedMethod = HttpMethod.Post, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(content), @@ -675,5 +691,44 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( return handler; } + + public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( + IdentityLoggerAdapter identityLoggerAdapter, + bool mTLSPop = false) + { + IDictionary expectedPostData = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary + { + { ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue } + }; + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter); + foreach (var idParam in idParams) + { + expectedRequestHeaders[idParam.Key] = idParam.Value; + } + + var tokenType = mTLSPop ? "mtls_pop" : "bearer"; + expectedPostData.Add("token_type", tokenType); + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", + ExpectedMethod = HttpMethod.Post, + ExpectedPostData = expectedPostData, + ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(GetMsiSuccessfulResponse()), + } + }; + + return handler; + } } } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 017213d275..7f8667d93f 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -460,15 +460,6 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource( expectedQueryParams.Add("resource", resource); expectedRequestHeaders.Add("Metadata", "true"); break; - case ManagedIdentitySource.ImdsV2: - httpMessageHandler.ExpectedMethod = HttpMethod.Post; - expectedPostData = new Dictionary - { - { "client_id", TestConstants.ClientId }, - { "grant_type", TestConstants.ValidPemCertificate }, - { "scope", resource } - }; - break; case ManagedIdentitySource.CloudShell: httpMessageHandler.ExpectedMethod = HttpMethod.Post; expectedRequestHeaders.Add("Metadata", "true"); diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs index e067d640b8..cdfbb5432b 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs @@ -37,6 +37,8 @@ internal class MockHttpMessageHandler : HttpClientHandler public HttpRequestMessage ActualRequestMessage { get; private set; } public Dictionary ActualRequestPostData { get; private set; } public HttpRequestHeaders ActualRequestHeaders { get; private set; } + public IList PresentRequestHeaders { get; set; } + public X509Certificate2 ExpectedMtlsBindingCertificate { get; set; } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) @@ -174,6 +176,15 @@ private void ValidateNotExpectedPostData() private void ValidateHeaders(HttpRequestMessage request) { + if (PresentRequestHeaders != null) + { + foreach (var headerName in PresentRequestHeaders) + { + Assert.IsTrue(request.Headers.Contains(headerName), + $"Expected request header to be present: {headerName}."); + } + } + ActualRequestHeaders = request.Headers; if (ExpectedRequestHeaders != null) { diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 35107976c7..7ab49c10e2 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -585,15 +585,34 @@ public static MsalTokenResponse CreateAadTestTokenResponseWithFoci() internal const string RefreshToken = "mhJDJ8wtjA3KxpRtuPAreZnMcJ2yKC2JUbpOGbRTdOCImLyQ2B4EIhv8AiA2cCEylZZfZsOsZrNsMBZZAAU9TQYYEO72QcdfnIWpAOeKkud5W2L8nMq6i9dx1EVIl09zFXhOJ79BdFbU0Eb5aUHlcqPCQjec62UKBLkZJmtMnoAa8cjvgIuxTdVM8FNdghe5nlCNTEVooKleTTEHNl2BrdyitLaWTKSP0lRqnFxriG0xWcJoSMsdS7Vt6HZd1TkwHIXycNMlCcCdUh5tOgqx1M8y8uoXK4OJ1LQmtkZvcQWcycvOCPACYakKM1pUQqwTxI6Y4HrL38sqQaSNxpF9OcFxOQWpuGodRekCbxXVbWclttIpvSOLaBhZ2ZBpcCBEeEMSmhqqYgajNwwwe9w88u0UsYKe6PBbaI48ENr02u2qBeLsIQ2HUyKlN3iVmX7u7MhgDWA3NNavMtlLmWd63NfuDgXpLI0O4cLhjAx8uoBIK8LntXPHPTxJ28o0yrszvD4gf7RdhuTq5VE15zne6iAJgIGfy7latGFzxuDMcML9OoXURHnNEHBgS9ZQCfNzYZ2O9flF1UjGpcBLEi7hHVHnrQb4y7c98dz9p62cvEMhorGx9kCwSIkOae5LheXPQkFIbsGyomNEwz3HZvR131VGAwdfmUUodvPr6LAAtmjl4sZ72PRqAo8EdQ0IFsWoypXVv51IooR87tO3uiG2DkxhIAwumOQdaJNxw1a0WS9mpQOmwFlvfbZkaIoUKgagHc8fVa1aHZntLGwH0S1iYixJiIrMnPYAeRdSp9mlHllrMX8xUIznobcZ5i8MpUYCKlUXMZ82S3XUJ5dJxARNRPxXlLJ5LPYBhUNkBLQen9Qmq3VZEV1RDJyhbGp6GAo14KsMtVAVYNmYPIgo85pCZgOwVEOBUycszu4AD3p4PT2ella4LVoqmTTMSA5GEWoeWb5JvEo222Z0oKr7UK8dGwpWRSbg8TNeODihJaTUDfErvbgaZnjIRpqfgtM5i1HfQbD7Yyft5PqyygUra7GYy7pjRrEvq95XQD8sAZ32ku9AqCo5qOB584iX881WErOoheQZokt1txqwuIMUyhVuMKNEXy70CeNTsb30ghQMZpZcXIkrLYyQCZ0gNmARhMKagCSdrpUtxudLk44yfmuwSQzBN3ifWfLZiFpU53qdPLZoTw5"; internal const string IdToken = "6GwdM7f6hHXfivavPozhaRqrbxvEysfXSMQyEKBwVgivPZTtmowsmYygchhIuxjeFFeq1ZPHjhxKFnulrvoY6TDerZY5xyOlg45bToI9Bu95qFvUrrt5r17UJcXdw4YkvEt10CcDDcLcEYw704RpVefvbpjbF24pOgIuafcAkDnbDA0Qea4ePuSC45Lw7zpJhbo9Gh8IfMX597fayBvMs3fh7frrm9KpWMCeKY3h99YSaCYjZFKp1ppvXXPE9bc4sh4pRDOfnv0Yr9J8u4elZevEE4qGddfgd3hYb18XPGRjPEMlWsh7tnwxwUm6OSZlMTHYuvwBENNMx7SUQmMeg4rCfgnbcNDkWpXCiSDVt1lLLv8F2GjYnM6De3v1Ks5lhBWx3grLggcN9LnXz92eJ1l5lTB2v0y9MgmFZ4gY43oIOW5n8G5HOx3bGOyjTw0TKKbyVa3mDj0A3QqW8eLTUJz42BNiGOf5m9prMSlpAW59CHCMJLatsj3IvGeCITsGAr3sUZEytORWUdxCfuIPwecQgU6bO7pNqNvZc1tJHHNwJlfS23ZkiFuEXqEThHYfxBCFxAzMDlzO0TOdWhvrb8hlNeAOcNhoAKxu7HXsePajKs4fU1rcdSxzNKwtASEla3p6jfJnnDtKf38RJZPaRRYMviqqWEMhjmqIvBm7sMaf8RyNNuYl7otZwmwNVCR1hzzmaTAy4kQce67FJqFba7uizrgwp9zsvK8muCHKKPvNthy7fHsxKmrBIm0bLcoePKK3wAID4kFvNQcxXp6rAOr8bLFF3bLEoYdzmF2QJz1frVZZHHPy90Cmlhw48EQN8NE2OllpdaykKt5k4rPcZQyitayNNhism30qh7eCBhcA7mm5Ja0S8X4VPlkwvgwg0mQuul6gakmja8xpnTrwiOdtao320GDmJaJA6zf3UTpNZTq9tdfBtUrjAD8RS0tNUBT3Ko8N2Lfh9ry8y9ESmRVIhch3rKY7UeefFAnkiwH2WwC57ZEsHtMP0SwKYtYKHZW9HkERCCyqOT1Mw0IavsLGFvchzMAvTnz4RwRBk6IrWgANvqT3F3Vexc2K0poKb71XZ4aMXxjqAzydGQAKpKJEJcqEvX9RD8nL76TF2LZIepiaZ3dbQImkqSjbF7aaY2JFoN9ZWlcSQKe8zdO8TIG16bF8W9R4ldDyzV39L33KcweG"; - #region Test Certificate and Private Key (ValidPemCertificate & XmlPrivateKey) + #region Test Certificate and Private Key (ExpiredPemCertificate, ValidPemCertificate & XmlPrivateKey) /// - /// A test PEM-encoded X.509 certificate and its matching RSA private key. + /// Test (expired and valid) PEM-encoded X.509 certificate and their matching RSA private key. /// These are used together in unit tests that require both a certificate and its private key. - /// The and are a matched pair: - /// - is a PEM-encoded certificate. - /// - is the corresponding RSA private key in XML format. - /// The certificate is valid for 100 years, ensuring it will not expire during the lifetime of the tests. + /// The / and are a matched pair: + /// - is an expired PEM-encoded certificate. The certificate is valid for 1 day and was created on September 8 2025, ensuring it will always be expired. + /// - is a valid PEM-encoded certificate. The certificate is valid for 100 years and expires on August 4, 2125, ensuring it will not expire during the lifetime of the tests. + /// - is their corresponding RSA private key in XML format. /// + internal const string ExpiredPemCertificate = @"-----BEGIN CERTIFICATE----- +MIIC/zCCAeegAwIBAgIUGSVU23Wc0+QtCbUTjsyPOrc0XpEwDQYJKoZIhvcNAQEL +BQAwDzENMAsGA1UEAwwEVGVzdDAeFw0yNTA5MDgyMjAxMTdaFw0yNTA5MDkyMjAx +MTdaMA8xDTALBgNVBAMMBFRlc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQC5XNEuk3cIEChkZd2P/bljUaVqNVh4mbXdWHYAgbdK48U6rG0FLq1NAfSn +ZO0EPbK8Zo4psRh2lBcqW29/WsKiHUEHLkLyFI+frEIfc8wskd+WxkKfL8G52uRp +YQCG87FIv8uZBBlDG7kDdOV36CUkK1N+V2fHbkEgx+YfWg6+pLi3KQx6Pf/b2YqL +D36hj8WRrVYzL6yXVUBiyRd+cQ9y5V/MRtoiX1Sv8WEFYtzIG0TUGi9pR7WWhgHN +Qk6DFDzutMV62ZEBNPIQvdO2EwXGr1FUIOL6zmj6bArPhY+hCXGrAAwCXodZhgZ9 +5BxTwsQWtjCha2hT6ed8zmoE72FdAgMBAAGjUzBRMB0GA1UdDgQWBBQPYq0Efzuv +1diVcgxBxTnVA4wLMjAfBgNVHSMEGDAWgBQPYq0Efzuv1diVcgxBxTnVA4wLMjAP +BgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCXAD7cjWmmTqP0NX4M +qwO0AHtO+KGVtfxF8aI21Ty/nHh2SAODzsemP3NBBvoEvllwtcVyutPqvUiAflML +Nbp0ucTu+aWE14s1V9Bnt6++5g7gtXItsNV3F/ymYKsyfhDvJbWCOv5qYeJMQ+jt +ODHN9qnATODT5voULTwEVSYQXtutwRxR8e70Cvok+F+4I6Ni49DJ8DmcYzvB94ut +hqpDsygY1vYzpRbB5hpW0/D7kgVVWyWoOWiE1mV7Fry7tUWQw7EqnX89kMLMy4g6 +UfOv4gtam8RBa9dLyMW1rCHRxOulP47joI10g9JoJ9DssiQTUojJgQXOSBBXdD20 +H+zl +-----END CERTIFICATE-----"; internal const string ValidPemCertificate = @"-----BEGIN CERTIFICATE----- MIIDATCCAemgAwIBAgIUSfjghyQB4FIS41rWfNcZHTLE/R4wDQYJKoZIhvcNAQEL BQAwDzENMAsGA1UEAwwEVGVzdDAgFw0yNTA4MjgyMDIxMDBaGA8yMTI1MDgwNDIw diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 62b8a8a7c9..abfd88ddce 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -7,12 +7,13 @@ using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; -using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; +using Microsoft.Identity.Test.Unit.PublicApiTests; using Microsoft.VisualStudio.TestTools.UnitTesting; using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; @@ -23,92 +24,371 @@ public class ImdsV2Tests : TestBase { private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); private readonly TestCsrFactory _testCsrFactory = new TestCsrFactory(); + private readonly IdentityLoggerAdapter _identityLoggerAdapter = new IdentityLoggerAdapter( + new TestIdentityLogger(), + Guid.Empty, + "TestClient", + "1.0.0", + enablePiiLogging: false + ); + public const string Bearer = "Bearer"; + public const string MTLSPoP = "MTLSPoP"; + + private void AddMocksToGetEntraToken( + MockHttpManager httpManager, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificateRequestCertificate = TestConstants.ValidPemCertificate, + bool mTLSPop = false) + { + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId, certificateRequestCertificate)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(certificate: certificateRequestCertificate)); + } + + httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter, mTLSPop)); + } - [TestMethod] - public async Task ImdsV2SAMIHappyPathAsync() + private async Task CreateManagedIdentityAsync( + MockHttpManager httpManager, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + bool addProbeMock = true, + bool addSourceCheck = true) + { + ManagedIdentityApplicationBuilder miBuilder = null; + + var uami = userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null; + if (uami) + { + miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); + } + else + { + miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + } + + miBuilder + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory) + .WithCsrFactory(_testCsrFactory); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var managedIdentityApp = miBuilder.Build(); + + if (addProbeMock) + { + if (uami) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + } + } + + if (addSourceCheck) + { + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + } + + return managedIdentityApp; + } + + #region Acceptance Tests + #region Bearer Token Tests + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenHappyPath( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenTokenIsPerIdentity( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + #region Identity 1 + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + #endregion Identity 1 + + #region Identity 2 + var managedIdentityApp2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); // source is already cached + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + + result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource); + #endregion Identity 2 + + // TODO: Assert.AreEqual(CertificateCache.Count, 2); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenIsReAcquiredWhenCertificatIsExpired( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredPemCertificate); // cert will be expired on second request + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: Add functionality to check cert expiration in the cache + /** + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache + */ + } + } + #endregion Bearer Token Tests + + #region mTLS PoP Token Tests + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenHappyPath( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) { using (var httpManager = new MockHttpManager()) { - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .WithCsrFactory(_testCsrFactory); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - - var mi = miBuilder.Build(); - - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // initial probe - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // do it again, since CsrMetadata from initial probe is not cached - httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); - httpManager.AddManagedIdentityMockHandler( - $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.ImdsV2); - - var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); - result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); } } [DataTestMethod] - [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] - [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] - [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] - public async Task ImdsV2UAMIHappyPathAsync( + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenTokenIsPerIdentity( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { using (var httpManager = new MockHttpManager()) { - var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); - miBuilder - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .WithCsrFactory(_testCsrFactory); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - - var mi = miBuilder.Build(); - - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // initial probe - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // do it again, since CsrMetadata from initial probe is not cached - httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId)); - httpManager.AddManagedIdentityMockHandler( - $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.ImdsV2); - - var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + #region Identity 1 + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); - result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + #endregion Identity 1 + + #region Identity 2 + var managedIdentityApp2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); // source is already cached + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + + result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource); + #endregion Identity 2 + + // TODO: Assert.AreEqual(CertificateCache.Count, 2); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredPemCertificate/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: Add functionality to check cert expiration in the cache + /** + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, // mTLSPop: true); // TODO: implement mTLS Pop + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache + */ } } + #endregion mTLS Pop Token Tests + #endregion Acceptance Tests [TestMethod] public async Task GetCsrMetadataAsyncSucceeds() @@ -117,17 +397,7 @@ public async Task GetCsrMetadataAsyncSucceeds() { var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); - - Assert.IsTrue(handler.ActualRequestHeaders.Contains("Metadata")); - Assert.IsTrue(handler.ActualRequestHeaders.Contains("x-ms-client-request-id")); - Assert.IsTrue(handler.ActualRequestMessage.RequestUri.Query.Contains("api-version")); + await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); } } @@ -139,16 +409,8 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() // First attempt fails with INTERNAL_SERVER_ERROR (500) httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); - // Second attempt succeeds - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + // Second attempt succeeds (defined inside of CreateSAMIAsync) + await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); } } @@ -159,10 +421,7 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null)); - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); @@ -176,10 +435,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidFormat() { httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "I_MDS/150.870.65.1854")); - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); @@ -191,17 +447,14 @@ public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() { using (var httpManager = new MockHttpManager()) { - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - const int Num500Errors = 1 + TestCsrMetadataProbeRetryPolicy.ExponentialStrategyNumRetries; for (int i = 0; i < Num500Errors; i++) { httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); } + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } @@ -212,13 +465,10 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs { using (var httpManager = new MockHttpManager()) { - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.NotFound)); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } @@ -268,19 +518,9 @@ public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem } #region AttachPrivateKeyToCert Tests - [TestMethod] public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { // For this test, we just want to verify that the method doesn't crash @@ -297,15 +537,6 @@ public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() [TestMethod] public void AttachPrivateKeyToCert_NullCertificatePem_ThrowsArgumentNullException() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -316,15 +547,6 @@ public void AttachPrivateKeyToCert_NullCertificatePem_ThrowsArgumentNullExceptio [TestMethod] public void AttachPrivateKeyToCert_EmptyCertificatePem_ThrowsArgumentNullException() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -335,15 +557,6 @@ public void AttachPrivateKeyToCert_EmptyCertificatePem_ThrowsArgumentNullExcepti [TestMethod] public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - Assert.ThrowsException(() => CommonCryptographyManager.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, null)); } @@ -353,15 +566,6 @@ public void AttachPrivateKeyToCert_InvalidPemFormat_ThrowsArgumentException() { const string InvalidPemNoCertMarker = @"This is not a valid PEM certificate"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -375,15 +579,6 @@ public void AttachPrivateKeyToCert_MissingBeginMarker_ThrowsArgumentException() const string InvalidPemMissingBeginMarker = @"MIICXTCCAUWgAwIBAgIJAKPiQh26MIuPMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV -----END CERTIFICATE-----"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -396,15 +591,7 @@ public void AttachPrivateKeyToCert_MissingEndMarker_ThrowsArgumentException() { const string InvalidPemMissingEndMarker = @"-----BEGIN CERTIFICATE----- MIICXTCCAUWgAwIBAgIJAKPiQh26MIuPMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - + using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -419,22 +606,12 @@ public void AttachPrivateKeyToCert_BadBase64Content_ThrowsFormatException() Invalid@#$%Base64Content! -----END CERTIFICATE-----"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => CommonCryptographyManager.AttachPrivateKeyToCert(InvalidPemBadBase64, rsa)); } } - #endregion } }