From bc9670b8776bea0cfa97747236b96ff6e91aa49c Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 5 Sep 2025 11:32:09 -0400 Subject: [PATCH 1/5] Added headers --- .../V2/ImdsV2ManagedIdentitySource.cs | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index d329bfbfa8..2c90eab268 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -10,6 +10,7 @@ using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.OAuth2.Throttling; using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Client.Utils; @@ -143,7 +144,7 @@ private static bool ValidateCsrMetadataResponse( $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because response doesn't have server header. Status code: {response.StatusCode} Body: {response.Body}", null, (int)response.StatusCode); - } + } } var match = System.Text.RegularExpressions.Regex.Match( @@ -206,7 +207,7 @@ private async Task ExecuteCertificateRequestAsync(st { "Metadata", "true" }, { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - + var body = $"{{\"csr\":\"{csr}\"}}"; IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; @@ -261,18 +262,29 @@ protected override async Task CreateRequestAsync(string var (csr, privateKey) = _requestContext.ServiceBundle.Config.CsrFactory.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); - + // transform certificateRequestResponse.Certificate to x509 with private key var mtlsCertificate = CommonCryptographyManager.AttachPrivateKeyToCert( certificateRequestResponse.Certificate, privateKey); ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); + + var idParams = MsalIdHelper.GetMsalIdParameters(_requestContext.Logger); + foreach (var idParam in idParams) + { + request.Headers[idParam.Key] = idParam.Value; + } request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); + request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue); + request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); - request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate); + request.BodyParameters.Add("grant_type", "client_credentials"); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); - request.RequestType = RequestType.Imds; + request.BodyParameters.Add("token_type", "bearer"); + + request.RequestType = RequestType.STS; + request.MtlsCertificate = mtlsCertificate; return request; From 11a846379434941f7450ba656eed211ef8d03bd2 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 5 Sep 2025 14:09:26 -0400 Subject: [PATCH 2/5] Checking headers via unit tests --- .../V2/ImdsV2ManagedIdentitySource.cs | 3 +- .../Core/Mocks/MockHelpers.cs | 30 +++++++++++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 23 +++++++------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 2c90eab268..907c680e5a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -194,7 +194,8 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) } internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : - base(requestContext, ManagedIdentitySource.ImdsV2) { } + base(requestContext, ManagedIdentitySource.ImdsV2) + { } private async Task ExecuteCertificateRequestAsync(string csr) { diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 04665fc0dd..f3dbdd3f52 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -11,9 +11,12 @@ using Castle.Core.Logging; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.OAuth2.Throttling; using Microsoft.Identity.Client.Utils; using Microsoft.Identity.Test.Unit; using Microsoft.VisualStudio.TestTools.UnitTesting.Logging; @@ -675,5 +678,32 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( return handler; } + + public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( + IdentityLoggerAdapter identityLoggerAdapter) + { + IDictionary expectedRequestHeaders = new Dictionary + { + { ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue } + }; + var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter); + foreach (var idParam in idParams) + { + expectedRequestHeaders[idParam.Key] = idParam.Value; + } + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", + ExpectedMethod = HttpMethod.Post, + ExpectedRequestHeaders = expectedRequestHeaders, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(GetMsiSuccessfulResponse()), + } + }; + + return handler; + } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 25322ec08b..a569b7b043 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -8,11 +8,13 @@ using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; +using Microsoft.Identity.Test.Unit.PublicApiTests; using Microsoft.VisualStudio.TestTools.UnitTesting; using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; @@ -23,6 +25,13 @@ public class ImdsV2Tests : TestBase { private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); private readonly TestCsrFactory _testCsrFactory = new TestCsrFactory(); + private readonly IdentityLoggerAdapter _identityLoggerAdapter = new IdentityLoggerAdapter( + new TestIdentityLogger(), + Guid.Empty, + "TestClient", + "1.0.0", + enablePiiLogging: false + ); [TestMethod] public async Task ImdsV2SAMIHappyPathAsync() @@ -42,12 +51,8 @@ public async Task ImdsV2SAMIHappyPathAsync() httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // initial probe httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // do it again, since CsrMetadata from initial probe is not cached httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); - httpManager.AddManagedIdentityMockHandler( - $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.ImdsV2); - + httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter)); + var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); @@ -88,11 +93,7 @@ public async Task ImdsV2UAMIHappyPathAsync( httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // initial probe httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // do it again, since CsrMetadata from initial probe is not cached httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId)); - httpManager.AddManagedIdentityMockHandler( - $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.ImdsV2); + httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter)); var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); From 7aa4c9db5a8d5eb14069d6a3e037e9e8d34c3ab5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 8 Sep 2025 18:16:46 -0400 Subject: [PATCH 3/5] Deleted dead code --- .../Core/Mocks/MockHttpManagerExtensions.cs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 017213d275..7f8667d93f 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -460,15 +460,6 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource( expectedQueryParams.Add("resource", resource); expectedRequestHeaders.Add("Metadata", "true"); break; - case ManagedIdentitySource.ImdsV2: - httpMessageHandler.ExpectedMethod = HttpMethod.Post; - expectedPostData = new Dictionary - { - { "client_id", TestConstants.ClientId }, - { "grant_type", TestConstants.ValidPemCertificate }, - { "scope", resource } - }; - break; case ManagedIdentitySource.CloudShell: httpMessageHandler.ExpectedMethod = HttpMethod.Post; expectedRequestHeaders.Add("Metadata", "true"); From 818276239ec36a4b15a14756b93f3baf0ea94d6f Mon Sep 17 00:00:00 2001 From: Robbie-Microsoft <87724641+Robbie-Microsoft@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:35:39 -0400 Subject: [PATCH 4/5] ImdsV2: Additional Acceptance Tests (#5465) --- .../Core/Mocks/MockHelpers.cs | 28 +- .../TestConstants.cs | 31 +- .../ManagedIdentityTests/ImdsV2Tests.cs | 474 ++++++++++++------ 3 files changed, 371 insertions(+), 162 deletions(-) diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index f3dbdd3f52..df4f8b3a56 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -595,14 +595,15 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce public static MockHttpMessageHandler MockCsrResponse( HttpStatusCode statusCode = HttpStatusCode.OK, string responseServerHeader = "IMDS/150.870.65.1854", - UserAssignedIdentityId idType = UserAssignedIdentityId.None, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, string userAssignedId = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - if (idType != UserAssignedIdentityId.None && userAssignedId != null) + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null); + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); } expectedQueryParams.Add("cred-api-version", "2.0"); @@ -642,14 +643,16 @@ public static MockHttpMessageHandler MockCsrResponseFailure() } public static MockHttpMessageHandler MockCertificateRequestResponse( - UserAssignedIdentityId idType = UserAssignedIdentityId.None, - string userAssignedId = null) + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificate = TestConstants.ValidPemCertificate) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - if (idType != UserAssignedIdentityId.None && userAssignedId != null) + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null); + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); } expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); @@ -659,7 +662,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( "{" + "\"client_id\": \"" + TestConstants.ClientId + "\"," + "\"tenant_id\": \"" + TestConstants.TenantId + "\"," + - "\"certificate\": \"" + TestConstants.ValidPemCertificate + "\"," + + "\"certificate\": \"" + certificate + "\"," + "\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned", it doesn't matter for these tests "\"mtls_authentication_endpoint\": \"" + TestConstants.MtlsAuthenticationEndpoint + "\"," + "}"; @@ -680,22 +683,29 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( } public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( - IdentityLoggerAdapter identityLoggerAdapter) + IdentityLoggerAdapter identityLoggerAdapter, + bool mTLSPop = false) { + IDictionary expectedPostData = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary { { ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue } }; + var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter); foreach (var idParam in idParams) { expectedRequestHeaders[idParam.Key] = idParam.Value; } + var tokenType = mTLSPop ? "mtls_pop" : "bearer"; + expectedPostData.Add("token_type", tokenType); + var handler = new MockHttpMessageHandler() { ExpectedUrl = $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", ExpectedMethod = HttpMethod.Post, + ExpectedPostData = expectedPostData, ExpectedRequestHeaders = expectedRequestHeaders, ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) { diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 35107976c7..7ab49c10e2 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -585,15 +585,34 @@ public static MsalTokenResponse CreateAadTestTokenResponseWithFoci() internal const string RefreshToken = "mhJDJ8wtjA3KxpRtuPAreZnMcJ2yKC2JUbpOGbRTdOCImLyQ2B4EIhv8AiA2cCEylZZfZsOsZrNsMBZZAAU9TQYYEO72QcdfnIWpAOeKkud5W2L8nMq6i9dx1EVIl09zFXhOJ79BdFbU0Eb5aUHlcqPCQjec62UKBLkZJmtMnoAa8cjvgIuxTdVM8FNdghe5nlCNTEVooKleTTEHNl2BrdyitLaWTKSP0lRqnFxriG0xWcJoSMsdS7Vt6HZd1TkwHIXycNMlCcCdUh5tOgqx1M8y8uoXK4OJ1LQmtkZvcQWcycvOCPACYakKM1pUQqwTxI6Y4HrL38sqQaSNxpF9OcFxOQWpuGodRekCbxXVbWclttIpvSOLaBhZ2ZBpcCBEeEMSmhqqYgajNwwwe9w88u0UsYKe6PBbaI48ENr02u2qBeLsIQ2HUyKlN3iVmX7u7MhgDWA3NNavMtlLmWd63NfuDgXpLI0O4cLhjAx8uoBIK8LntXPHPTxJ28o0yrszvD4gf7RdhuTq5VE15zne6iAJgIGfy7latGFzxuDMcML9OoXURHnNEHBgS9ZQCfNzYZ2O9flF1UjGpcBLEi7hHVHnrQb4y7c98dz9p62cvEMhorGx9kCwSIkOae5LheXPQkFIbsGyomNEwz3HZvR131VGAwdfmUUodvPr6LAAtmjl4sZ72PRqAo8EdQ0IFsWoypXVv51IooR87tO3uiG2DkxhIAwumOQdaJNxw1a0WS9mpQOmwFlvfbZkaIoUKgagHc8fVa1aHZntLGwH0S1iYixJiIrMnPYAeRdSp9mlHllrMX8xUIznobcZ5i8MpUYCKlUXMZ82S3XUJ5dJxARNRPxXlLJ5LPYBhUNkBLQen9Qmq3VZEV1RDJyhbGp6GAo14KsMtVAVYNmYPIgo85pCZgOwVEOBUycszu4AD3p4PT2ella4LVoqmTTMSA5GEWoeWb5JvEo222Z0oKr7UK8dGwpWRSbg8TNeODihJaTUDfErvbgaZnjIRpqfgtM5i1HfQbD7Yyft5PqyygUra7GYy7pjRrEvq95XQD8sAZ32ku9AqCo5qOB584iX881WErOoheQZokt1txqwuIMUyhVuMKNEXy70CeNTsb30ghQMZpZcXIkrLYyQCZ0gNmARhMKagCSdrpUtxudLk44yfmuwSQzBN3ifWfLZiFpU53qdPLZoTw5"; internal const string IdToken = "6GwdM7f6hHXfivavPozhaRqrbxvEysfXSMQyEKBwVgivPZTtmowsmYygchhIuxjeFFeq1ZPHjhxKFnulrvoY6TDerZY5xyOlg45bToI9Bu95qFvUrrt5r17UJcXdw4YkvEt10CcDDcLcEYw704RpVefvbpjbF24pOgIuafcAkDnbDA0Qea4ePuSC45Lw7zpJhbo9Gh8IfMX597fayBvMs3fh7frrm9KpWMCeKY3h99YSaCYjZFKp1ppvXXPE9bc4sh4pRDOfnv0Yr9J8u4elZevEE4qGddfgd3hYb18XPGRjPEMlWsh7tnwxwUm6OSZlMTHYuvwBENNMx7SUQmMeg4rCfgnbcNDkWpXCiSDVt1lLLv8F2GjYnM6De3v1Ks5lhBWx3grLggcN9LnXz92eJ1l5lTB2v0y9MgmFZ4gY43oIOW5n8G5HOx3bGOyjTw0TKKbyVa3mDj0A3QqW8eLTUJz42BNiGOf5m9prMSlpAW59CHCMJLatsj3IvGeCITsGAr3sUZEytORWUdxCfuIPwecQgU6bO7pNqNvZc1tJHHNwJlfS23ZkiFuEXqEThHYfxBCFxAzMDlzO0TOdWhvrb8hlNeAOcNhoAKxu7HXsePajKs4fU1rcdSxzNKwtASEla3p6jfJnnDtKf38RJZPaRRYMviqqWEMhjmqIvBm7sMaf8RyNNuYl7otZwmwNVCR1hzzmaTAy4kQce67FJqFba7uizrgwp9zsvK8muCHKKPvNthy7fHsxKmrBIm0bLcoePKK3wAID4kFvNQcxXp6rAOr8bLFF3bLEoYdzmF2QJz1frVZZHHPy90Cmlhw48EQN8NE2OllpdaykKt5k4rPcZQyitayNNhism30qh7eCBhcA7mm5Ja0S8X4VPlkwvgwg0mQuul6gakmja8xpnTrwiOdtao320GDmJaJA6zf3UTpNZTq9tdfBtUrjAD8RS0tNUBT3Ko8N2Lfh9ry8y9ESmRVIhch3rKY7UeefFAnkiwH2WwC57ZEsHtMP0SwKYtYKHZW9HkERCCyqOT1Mw0IavsLGFvchzMAvTnz4RwRBk6IrWgANvqT3F3Vexc2K0poKb71XZ4aMXxjqAzydGQAKpKJEJcqEvX9RD8nL76TF2LZIepiaZ3dbQImkqSjbF7aaY2JFoN9ZWlcSQKe8zdO8TIG16bF8W9R4ldDyzV39L33KcweG"; - #region Test Certificate and Private Key (ValidPemCertificate & XmlPrivateKey) + #region Test Certificate and Private Key (ExpiredPemCertificate, ValidPemCertificate & XmlPrivateKey) /// - /// A test PEM-encoded X.509 certificate and its matching RSA private key. + /// Test (expired and valid) PEM-encoded X.509 certificate and their matching RSA private key. /// These are used together in unit tests that require both a certificate and its private key. - /// The and are a matched pair: - /// - is a PEM-encoded certificate. - /// - is the corresponding RSA private key in XML format. - /// The certificate is valid for 100 years, ensuring it will not expire during the lifetime of the tests. + /// The / and are a matched pair: + /// - is an expired PEM-encoded certificate. The certificate is valid for 1 day and was created on September 8 2025, ensuring it will always be expired. + /// - is a valid PEM-encoded certificate. The certificate is valid for 100 years and expires on August 4, 2125, ensuring it will not expire during the lifetime of the tests. + /// - is their corresponding RSA private key in XML format. /// + internal const string ExpiredPemCertificate = @"-----BEGIN CERTIFICATE----- +MIIC/zCCAeegAwIBAgIUGSVU23Wc0+QtCbUTjsyPOrc0XpEwDQYJKoZIhvcNAQEL +BQAwDzENMAsGA1UEAwwEVGVzdDAeFw0yNTA5MDgyMjAxMTdaFw0yNTA5MDkyMjAx +MTdaMA8xDTALBgNVBAMMBFRlc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQC5XNEuk3cIEChkZd2P/bljUaVqNVh4mbXdWHYAgbdK48U6rG0FLq1NAfSn +ZO0EPbK8Zo4psRh2lBcqW29/WsKiHUEHLkLyFI+frEIfc8wskd+WxkKfL8G52uRp +YQCG87FIv8uZBBlDG7kDdOV36CUkK1N+V2fHbkEgx+YfWg6+pLi3KQx6Pf/b2YqL +D36hj8WRrVYzL6yXVUBiyRd+cQ9y5V/MRtoiX1Sv8WEFYtzIG0TUGi9pR7WWhgHN +Qk6DFDzutMV62ZEBNPIQvdO2EwXGr1FUIOL6zmj6bArPhY+hCXGrAAwCXodZhgZ9 +5BxTwsQWtjCha2hT6ed8zmoE72FdAgMBAAGjUzBRMB0GA1UdDgQWBBQPYq0Efzuv +1diVcgxBxTnVA4wLMjAfBgNVHSMEGDAWgBQPYq0Efzuv1diVcgxBxTnVA4wLMjAP +BgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCXAD7cjWmmTqP0NX4M +qwO0AHtO+KGVtfxF8aI21Ty/nHh2SAODzsemP3NBBvoEvllwtcVyutPqvUiAflML +Nbp0ucTu+aWE14s1V9Bnt6++5g7gtXItsNV3F/ymYKsyfhDvJbWCOv5qYeJMQ+jt +ODHN9qnATODT5voULTwEVSYQXtutwRxR8e70Cvok+F+4I6Ni49DJ8DmcYzvB94ut +hqpDsygY1vYzpRbB5hpW0/D7kgVVWyWoOWiE1mV7Fry7tUWQw7EqnX89kMLMy4g6 +UfOv4gtam8RBa9dLyMW1rCHRxOulP47joI10g9JoJ9DssiQTUojJgQXOSBBXdD20 +H+zl +-----END CERTIFICATE-----"; internal const string ValidPemCertificate = @"-----BEGIN CERTIFICATE----- MIIDATCCAemgAwIBAgIUSfjghyQB4FIS41rWfNcZHTLE/R4wDQYJKoZIhvcNAQEL BQAwDzENMAsGA1UEAwwEVGVzdDAgFw0yNTA4MjgyMDIxMDBaGA8yMTI1MDgwNDIw diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index a569b7b043..9dc5ced4c2 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -7,7 +7,6 @@ using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; -using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; @@ -32,85 +31,365 @@ public class ImdsV2Tests : TestBase "1.0.0", enablePiiLogging: false ); + public const string Bearer = "Bearer"; + public const string MTLSPoP = "MTLSPoP"; + + private void AddMocksToGetEntraToken( + MockHttpManager httpManager, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificateRequestCertificate = TestConstants.ValidPemCertificate, + bool mTLSPop = false) + { + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId, certificateRequestCertificate)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(certificate: certificateRequestCertificate)); + } + + httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter, mTLSPop)); + } - [TestMethod] - public async Task ImdsV2SAMIHappyPathAsync() + private async Task CreateManagedIdentityAsync( + MockHttpManager httpManager, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + bool addProbeMock = true, + bool addSourceCheck = true) + { + ManagedIdentityApplicationBuilder miBuilder = null; + + var uami = userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null; + if (uami) + { + miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); + } + else + { + miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + } + + miBuilder + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory) + .WithCsrFactory(_testCsrFactory); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var managedIdentityApp = miBuilder.Build(); + + if (addProbeMock) + { + if (uami) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + } + } + + if (addSourceCheck) + { + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + } + + return managedIdentityApp; + } + + #region Acceptance Tests + #region Bearer Token Tests + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenHappyPath( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenTokenIsPerIdentity( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) { using (var httpManager = new MockHttpManager()) { - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .WithCsrFactory(_testCsrFactory); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - - var mi = miBuilder.Build(); - - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // initial probe - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // do it again, since CsrMetadata from initial probe is not cached - httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); - httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter)); - - var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + #region Identity 1 + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); - result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + #endregion Identity 1 + + #region Identity 2 + var managedIdentityApp2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); // source is already cached + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + + result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource); + #endregion Identity 2 + + // TODO: Assert.AreEqual(CertificateCache.Count, 2); } } [DataTestMethod] - [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] - [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] - [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] - public async Task ImdsV2UAMIHappyPathAsync( + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenIsReAcquiredWhenCertificatIsExpired( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { using (var httpManager = new MockHttpManager()) { - var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); - miBuilder - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .WithCsrFactory(_testCsrFactory); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredPemCertificate); // cert will be expired on second request - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); - var mi = miBuilder.Build(); + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // initial probe - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // do it again, since CsrMetadata from initial probe is not cached - httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId)); - httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter)); + // TODO: Add functionality to check cert expiration in the cache + /** + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); - var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); - result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache + */ + } + } + #endregion Bearer Token Tests + + #region mTLS PoP Token Tests + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenHappyPath( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); } } + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenTokenIsPerIdentity( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + #region Identity 1 + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + #endregion Identity 1 + + #region Identity 2 + var managedIdentityApp2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); // source is already cached + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + + result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource); + #endregion Identity 2 + + // TODO: Assert.AreEqual(CertificateCache.Count, 2); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredPemCertificate/*, mTLSPop: true*/); // TODO: implement mTLS Pop + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: Add functionality to check cert expiration in the cache + /** + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, // mTLSPop: true); // TODO: implement mTLS Pop + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + // .WithMtlsProofOfPossession() // TODO: implement mTLS Pop + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + // Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache + */ + } + } + #endregion mTLS Pop Token Tests + #endregion Acceptance Tests + [TestMethod] public async Task GetCsrMetadataAsyncSucceeds() { @@ -118,13 +397,7 @@ public async Task GetCsrMetadataAsyncSucceeds() { var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); Assert.IsTrue(handler.ActualRequestHeaders.Contains("Metadata")); Assert.IsTrue(handler.ActualRequestHeaders.Contains("x-ms-client-request-id")); @@ -140,16 +413,8 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() // First attempt fails with INTERNAL_SERVER_ERROR (500) httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); - // Second attempt succeeds - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + // Second attempt succeeds (defined inside of CreateSAMIAsync) + await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); } } @@ -160,10 +425,7 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null)); - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); @@ -177,10 +439,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() { httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "IMDS/150.870.65.1853")); // min version is 1854 - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); @@ -192,17 +451,14 @@ public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() { using (var httpManager = new MockHttpManager()) { - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - const int Num500Errors = 1 + TestCsrMetadataProbeRetryPolicy.ExponentialStrategyNumRetries; for (int i = 0; i < Num500Errors; i++) { httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); } + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } @@ -213,13 +469,10 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs { using (var httpManager = new MockHttpManager()) { - var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory) - .Build(); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.NotFound)); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } @@ -269,19 +522,9 @@ public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem } #region AttachPrivateKeyToCert Tests - [TestMethod] public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { // For this test, we just want to verify that the method doesn't crash @@ -298,15 +541,6 @@ public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() [TestMethod] public void AttachPrivateKeyToCert_NullCertificatePem_ThrowsArgumentNullException() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -317,15 +551,6 @@ public void AttachPrivateKeyToCert_NullCertificatePem_ThrowsArgumentNullExceptio [TestMethod] public void AttachPrivateKeyToCert_EmptyCertificatePem_ThrowsArgumentNullException() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -336,15 +561,6 @@ public void AttachPrivateKeyToCert_EmptyCertificatePem_ThrowsArgumentNullExcepti [TestMethod] public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() { - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - Assert.ThrowsException(() => CommonCryptographyManager.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, null)); } @@ -354,15 +570,6 @@ public void AttachPrivateKeyToCert_InvalidPemFormat_ThrowsArgumentException() { const string InvalidPemNoCertMarker = @"This is not a valid PEM certificate"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -376,15 +583,6 @@ public void AttachPrivateKeyToCert_MissingBeginMarker_ThrowsArgumentException() const string InvalidPemMissingBeginMarker = @"MIICXTCCAUWgAwIBAgIJAKPiQh26MIuPMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV -----END CERTIFICATE-----"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -397,15 +595,7 @@ public void AttachPrivateKeyToCert_MissingEndMarker_ThrowsArgumentException() { const string InvalidPemMissingEndMarker = @"-----BEGIN CERTIFICATE----- MIICXTCCAUWgAwIBAgIJAKPiQh26MIuPMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - + using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => @@ -420,22 +610,12 @@ public void AttachPrivateKeyToCert_BadBase64Content_ThrowsFormatException() Invalid@#$%Base64Content! -----END CERTIFICATE-----"; - using var httpManager = new MockHttpManager(); - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); - var managedIdentityApp = miBuilder.BuildConcrete(); - - var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); - var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - using (RSA rsa = RSA.Create()) { Assert.ThrowsException(() => CommonCryptographyManager.AttachPrivateKeyToCert(InvalidPemBadBase64, rsa)); } } - #endregion } } From 22b244215a55566572e78ed16ee5951c334a97d2 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 15 Sep 2025 18:07:27 -0400 Subject: [PATCH 5/5] Implemented feedback --- .../V2/ImdsV2ManagedIdentitySource.cs | 10 ++++++---- .../OAuth2/OAuthConstants.cs | 1 + .../Core/Mocks/MockHelpers.cs | 15 +++++++++++++++ .../Core/Mocks/MockHttpMessageHandler.cs | 11 +++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 4 ---- 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 6b7f2d8272..ccb4602695 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -11,6 +11,7 @@ using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Client.OAuth2.Throttling; using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Client.Utils; @@ -38,7 +39,7 @@ public static async Task GetCsrMetadataAsync( var headers = new Dictionary { { "Metadata", "true" }, - { "x-ms-client-request-id", requestContext.CorrelationId.ToString() } + { OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() } }; IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; @@ -195,7 +196,7 @@ private async Task ExecuteCertificateRequestAsync(st var headers = new Dictionary { { "Metadata", "true" }, - { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } + { OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString() } }; var certificateRequestBody = new CertificateRequestBody() @@ -271,11 +272,12 @@ protected override async Task CreateRequestAsync(string { request.Headers[idParam.Key] = idParam.Value; } - request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); + request.Headers.Add(OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString()); request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue); + request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true"); request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); - request.BodyParameters.Add("grant_type", "client_credentials"); + request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); request.BodyParameters.Add("token_type", "bearer"); diff --git a/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs b/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs index 7aa76e3cd1..a6003809d9 100644 --- a/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs +++ b/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs @@ -77,6 +77,7 @@ internal static class OAuth2RequestedTokenUse internal static class OAuth2Header { public const string CorrelationId = "client-request-id"; + public const string XMsCorrelationId = $"x-ms-{CorrelationId}"; public const string RequestCorrelationIdInResponse = "return-client-request-id"; public const string AppName = "x-app-name"; public const string AppVer = "x-app-ver"; diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index df4f8b3a56..1afc958e91 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -600,6 +600,10 @@ public static MockHttpMessageHandler MockCsrResponse( { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { @@ -623,6 +627,7 @@ public static MockHttpMessageHandler MockCsrResponse( ExpectedMethod = HttpMethod.Get, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, ResponseMessage = new HttpResponseMessage(statusCode) { Content = new StringContent(content), @@ -649,6 +654,10 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { @@ -673,6 +682,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( ExpectedMethod = HttpMethod.Post, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(content), @@ -691,6 +701,10 @@ public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( { { ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue } }; + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter); foreach (var idParam in idParams) @@ -707,6 +721,7 @@ public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( ExpectedMethod = HttpMethod.Post, ExpectedPostData = expectedPostData, ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(GetMsiSuccessfulResponse()), diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs index e067d640b8..cdfbb5432b 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs @@ -37,6 +37,8 @@ internal class MockHttpMessageHandler : HttpClientHandler public HttpRequestMessage ActualRequestMessage { get; private set; } public Dictionary ActualRequestPostData { get; private set; } public HttpRequestHeaders ActualRequestHeaders { get; private set; } + public IList PresentRequestHeaders { get; set; } + public X509Certificate2 ExpectedMtlsBindingCertificate { get; set; } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) @@ -174,6 +176,15 @@ private void ValidateNotExpectedPostData() private void ValidateHeaders(HttpRequestMessage request) { + if (PresentRequestHeaders != null) + { + foreach (var headerName in PresentRequestHeaders) + { + Assert.IsTrue(request.Headers.Contains(headerName), + $"Expected request header to be present: {headerName}."); + } + } + ActualRequestHeaders = request.Headers; if (ExpectedRequestHeaders != null) { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index d0c4c8b5a1..abfd88ddce 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -398,10 +398,6 @@ public async Task GetCsrMetadataAsyncSucceeds() var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); - - Assert.IsTrue(handler.ActualRequestHeaders.Contains("Metadata")); - Assert.IsTrue(handler.ActualRequestHeaders.Contains("x-ms-client-request-id")); - Assert.IsTrue(handler.ActualRequestMessage.RequestUri.Query.Contains("api-version")); } }