Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Http.Retry;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.OAuth2;
using Microsoft.Identity.Client.OAuth2.Throttling;
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using Microsoft.Identity.Client.Utils;

Expand All @@ -37,7 +39,7 @@ public static async Task<CsrMetadata> GetCsrMetadataAsync(
var headers = new Dictionary<string, string>
{
{ "Metadata", "true" },
{ "x-ms-client-request-id", requestContext.CorrelationId.ToString() }
{ OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() }
};

IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory;
Expand Down Expand Up @@ -194,7 +196,7 @@ private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(st
var headers = new Dictionary<string, string>
{
{ "Metadata", "true" },
{ "x-ms-client-request-id", _requestContext.CorrelationId.ToString() }
{ OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString() }
};

var certificateRequestBody = new CertificateRequestBody()
Expand Down Expand Up @@ -264,11 +266,23 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
privateKey);

ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}"));
request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString());

var idParams = MsalIdHelper.GetMsalIdParameters(_requestContext.Logger);
foreach (var idParam in idParams)
{
request.Headers[idParam.Key] = idParam.Value;
}
request.Headers.Add(OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString());
request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue);
request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true");

request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId);
request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate);
request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials);
request.BodyParameters.Add("scope", "https://management.azure.com/.default");
request.RequestType = RequestType.Imds;
request.BodyParameters.Add("token_type", "bearer");

request.RequestType = RequestType.STS;

request.MtlsCertificate = mtlsCertificate;

return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ internal static class OAuth2RequestedTokenUse
internal static class OAuth2Header
{
public const string CorrelationId = "client-request-id";
public const string XMsCorrelationId = $"x-ms-{CorrelationId}";
public const string RequestCorrelationIdInResponse = "return-client-request-id";
public const string AppName = "x-app-name";
public const string AppVer = "x-app-ver";
Expand Down
71 changes: 63 additions & 8 deletions tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
using Castle.Core.Logging;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.Internal.Logger;
using Microsoft.Identity.Client.ManagedIdentity;
using Microsoft.Identity.Client.ManagedIdentity.V2;
using Microsoft.Identity.Client.OAuth2;
using Microsoft.Identity.Client.OAuth2.Throttling;
using Microsoft.Identity.Client.Utils;
using Microsoft.Identity.Test.Unit;
using Microsoft.VisualStudio.TestTools.UnitTesting.Logging;
Expand Down Expand Up @@ -592,14 +595,19 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce
public static MockHttpMessageHandler MockCsrResponse(
HttpStatusCode statusCode = HttpStatusCode.OK,
string responseServerHeader = "IMDS/150.870.65.1854",
UserAssignedIdentityId idType = UserAssignedIdentityId.None,
UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None,
string userAssignedId = null)
{
IDictionary<string, string> expectedQueryParams = new Dictionary<string, string>();
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>();
if (idType != UserAssignedIdentityId.None && userAssignedId != null)
IList<string> presentRequestHeaders = new List<string>
{
OAuth2Header.XMsCorrelationId
};

if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null)
{
var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null);
var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null);
expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value);
}
expectedQueryParams.Add("cred-api-version", "2.0");
Expand All @@ -619,6 +627,7 @@ public static MockHttpMessageHandler MockCsrResponse(
ExpectedMethod = HttpMethod.Get,
ExpectedQueryParams = expectedQueryParams,
ExpectedRequestHeaders = expectedRequestHeaders,
PresentRequestHeaders = presentRequestHeaders,
ResponseMessage = new HttpResponseMessage(statusCode)
{
Content = new StringContent(content),
Expand All @@ -639,14 +648,20 @@ public static MockHttpMessageHandler MockCsrResponseFailure()
}

public static MockHttpMessageHandler MockCertificateRequestResponse(
UserAssignedIdentityId idType = UserAssignedIdentityId.None,
string userAssignedId = null)
UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None,
string userAssignedId = null,
string certificate = TestConstants.ValidPemCertificate)
{
IDictionary<string, string> expectedQueryParams = new Dictionary<string, string>();
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>();
if (idType != UserAssignedIdentityId.None && userAssignedId != null)
IList<string> presentRequestHeaders = new List<string>
{
OAuth2Header.XMsCorrelationId
};

if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null)
{
var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null);
var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null);
expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value);
}
expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion);
Expand All @@ -656,7 +671,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse(
"{" +
"\"client_id\": \"" + TestConstants.ClientId + "\"," +
"\"tenant_id\": \"" + TestConstants.TenantId + "\"," +
"\"certificate\": \"" + TestConstants.ValidPemCertificate + "\"," +
"\"certificate\": \"" + certificate + "\"," +
"\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned", it doesn't matter for these tests
"\"mtls_authentication_endpoint\": \"" + TestConstants.MtlsAuthenticationEndpoint + "\"," +
"}";
Expand All @@ -667,6 +682,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse(
ExpectedMethod = HttpMethod.Post,
ExpectedQueryParams = expectedQueryParams,
ExpectedRequestHeaders = expectedRequestHeaders,
PresentRequestHeaders = presentRequestHeaders,
ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(content),
Expand All @@ -675,5 +691,44 @@ public static MockHttpMessageHandler MockCertificateRequestResponse(

return handler;
}

public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse(
IdentityLoggerAdapter identityLoggerAdapter,
bool mTLSPop = false)
{
IDictionary<string, string> expectedPostData = new Dictionary<string, string>();
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>
{
{ ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue }
};
IList<string> presentRequestHeaders = new List<string>
{
OAuth2Header.XMsCorrelationId
};

var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter);
foreach (var idParam in idParams)
{
expectedRequestHeaders[idParam.Key] = idParam.Value;
}

var tokenType = mTLSPop ? "mtls_pop" : "bearer";
expectedPostData.Add("token_type", tokenType);

var handler = new MockHttpMessageHandler()
{
ExpectedUrl = $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}",
ExpectedMethod = HttpMethod.Post,
ExpectedPostData = expectedPostData,
ExpectedRequestHeaders = expectedRequestHeaders,
PresentRequestHeaders = presentRequestHeaders,
ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(GetMsiSuccessfulResponse()),
}
};

return handler;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -460,15 +460,6 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
expectedQueryParams.Add("resource", resource);
expectedRequestHeaders.Add("Metadata", "true");
break;
case ManagedIdentitySource.ImdsV2:
httpMessageHandler.ExpectedMethod = HttpMethod.Post;
expectedPostData = new Dictionary<string, string>
{
{ "client_id", TestConstants.ClientId },
{ "grant_type", TestConstants.ValidPemCertificate },
{ "scope", resource }
};
break;
case ManagedIdentitySource.CloudShell:
httpMessageHandler.ExpectedMethod = HttpMethod.Post;
expectedRequestHeaders.Add("Metadata", "true");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ internal class MockHttpMessageHandler : HttpClientHandler
public HttpRequestMessage ActualRequestMessage { get; private set; }
public Dictionary<string, string> ActualRequestPostData { get; private set; }
public HttpRequestHeaders ActualRequestHeaders { get; private set; }
public IList<string> PresentRequestHeaders { get; set; }

public X509Certificate2 ExpectedMtlsBindingCertificate { get; set; }

protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
Expand Down Expand Up @@ -174,6 +176,15 @@ private void ValidateNotExpectedPostData()

private void ValidateHeaders(HttpRequestMessage request)
{
if (PresentRequestHeaders != null)
{
foreach (var headerName in PresentRequestHeaders)
{
Assert.IsTrue(request.Headers.Contains(headerName),
$"Expected request header to be present: {headerName}.");
}
}

ActualRequestHeaders = request.Headers;
if (ExpectedRequestHeaders != null)
{
Expand Down
31 changes: 25 additions & 6 deletions tests/Microsoft.Identity.Test.Common/TestConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,34 @@ public static MsalTokenResponse CreateAadTestTokenResponseWithFoci()
internal const string RefreshToken = "mhJDJ8wtjA3KxpRtuPAreZnMcJ2yKC2JUbpOGbRTdOCImLyQ2B4EIhv8AiA2cCEylZZfZsOsZrNsMBZZAAU9TQYYEO72QcdfnIWpAOeKkud5W2L8nMq6i9dx1EVIl09zFXhOJ79BdFbU0Eb5aUHlcqPCQjec62UKBLkZJmtMnoAa8cjvgIuxTdVM8FNdghe5nlCNTEVooKleTTEHNl2BrdyitLaWTKSP0lRqnFxriG0xWcJoSMsdS7Vt6HZd1TkwHIXycNMlCcCdUh5tOgqx1M8y8uoXK4OJ1LQmtkZvcQWcycvOCPACYakKM1pUQqwTxI6Y4HrL38sqQaSNxpF9OcFxOQWpuGodRekCbxXVbWclttIpvSOLaBhZ2ZBpcCBEeEMSmhqqYgajNwwwe9w88u0UsYKe6PBbaI48ENr02u2qBeLsIQ2HUyKlN3iVmX7u7MhgDWA3NNavMtlLmWd63NfuDgXpLI0O4cLhjAx8uoBIK8LntXPHPTxJ28o0yrszvD4gf7RdhuTq5VE15zne6iAJgIGfy7latGFzxuDMcML9OoXURHnNEHBgS9ZQCfNzYZ2O9flF1UjGpcBLEi7hHVHnrQb4y7c98dz9p62cvEMhorGx9kCwSIkOae5LheXPQkFIbsGyomNEwz3HZvR131VGAwdfmUUodvPr6LAAtmjl4sZ72PRqAo8EdQ0IFsWoypXVv51IooR87tO3uiG2DkxhIAwumOQdaJNxw1a0WS9mpQOmwFlvfbZkaIoUKgagHc8fVa1aHZntLGwH0S1iYixJiIrMnPYAeRdSp9mlHllrMX8xUIznobcZ5i8MpUYCKlUXMZ82S3XUJ5dJxARNRPxXlLJ5LPYBhUNkBLQen9Qmq3VZEV1RDJyhbGp6GAo14KsMtVAVYNmYPIgo85pCZgOwVEOBUycszu4AD3p4PT2ella4LVoqmTTMSA5GEWoeWb5JvEo222Z0oKr7UK8dGwpWRSbg8TNeODihJaTUDfErvbgaZnjIRpqfgtM5i1HfQbD7Yyft5PqyygUra7GYy7pjRrEvq95XQD8sAZ32ku9AqCo5qOB584iX881WErOoheQZokt1txqwuIMUyhVuMKNEXy70CeNTsb30ghQMZpZcXIkrLYyQCZ0gNmARhMKagCSdrpUtxudLk44yfmuwSQzBN3ifWfLZiFpU53qdPLZoTw5";
internal const string IdToken = "6GwdM7f6hHXfivavPozhaRqrbxvEysfXSMQyEKBwVgivPZTtmowsmYygchhIuxjeFFeq1ZPHjhxKFnulrvoY6TDerZY5xyOlg45bToI9Bu95qFvUrrt5r17UJcXdw4YkvEt10CcDDcLcEYw704RpVefvbpjbF24pOgIuafcAkDnbDA0Qea4ePuSC45Lw7zpJhbo9Gh8IfMX597fayBvMs3fh7frrm9KpWMCeKY3h99YSaCYjZFKp1ppvXXPE9bc4sh4pRDOfnv0Yr9J8u4elZevEE4qGddfgd3hYb18XPGRjPEMlWsh7tnwxwUm6OSZlMTHYuvwBENNMx7SUQmMeg4rCfgnbcNDkWpXCiSDVt1lLLv8F2GjYnM6De3v1Ks5lhBWx3grLggcN9LnXz92eJ1l5lTB2v0y9MgmFZ4gY43oIOW5n8G5HOx3bGOyjTw0TKKbyVa3mDj0A3QqW8eLTUJz42BNiGOf5m9prMSlpAW59CHCMJLatsj3IvGeCITsGAr3sUZEytORWUdxCfuIPwecQgU6bO7pNqNvZc1tJHHNwJlfS23ZkiFuEXqEThHYfxBCFxAzMDlzO0TOdWhvrb8hlNeAOcNhoAKxu7HXsePajKs4fU1rcdSxzNKwtASEla3p6jfJnnDtKf38RJZPaRRYMviqqWEMhjmqIvBm7sMaf8RyNNuYl7otZwmwNVCR1hzzmaTAy4kQce67FJqFba7uizrgwp9zsvK8muCHKKPvNthy7fHsxKmrBIm0bLcoePKK3wAID4kFvNQcxXp6rAOr8bLFF3bLEoYdzmF2QJz1frVZZHHPy90Cmlhw48EQN8NE2OllpdaykKt5k4rPcZQyitayNNhism30qh7eCBhcA7mm5Ja0S8X4VPlkwvgwg0mQuul6gakmja8xpnTrwiOdtao320GDmJaJA6zf3UTpNZTq9tdfBtUrjAD8RS0tNUBT3Ko8N2Lfh9ry8y9ESmRVIhch3rKY7UeefFAnkiwH2WwC57ZEsHtMP0SwKYtYKHZW9HkERCCyqOT1Mw0IavsLGFvchzMAvTnz4RwRBk6IrWgANvqT3F3Vexc2K0poKb71XZ4aMXxjqAzydGQAKpKJEJcqEvX9RD8nL76TF2LZIepiaZ3dbQImkqSjbF7aaY2JFoN9ZWlcSQKe8zdO8TIG16bF8W9R4ldDyzV39L33KcweG";

#region Test Certificate and Private Key (ValidPemCertificate & XmlPrivateKey)
#region Test Certificate and Private Key (ExpiredPemCertificate, ValidPemCertificate & XmlPrivateKey)
/// <summary>
/// A test PEM-encoded X.509 certificate and its matching RSA private key.
/// Test (expired and valid) PEM-encoded X.509 certificate and their matching RSA private key.
/// These are used together in unit tests that require both a certificate and its private key.
/// The <see cref="ValidPemCertificate"/> and <see cref="XmlPrivateKey"/> are a matched pair:
/// - <see cref="ValidPemCertificate"/> is a PEM-encoded certificate.
/// - <see cref="XmlPrivateKey"/> is the corresponding RSA private key in XML format.
/// The certificate is valid for 100 years, ensuring it will not expire during the lifetime of the tests.
/// The <see cref="ExpiredPemCertificate"/>/<see cref="ValidPemCertificate"/> and <see cref="XmlPrivateKey"/> are a matched pair:
/// - <see cref="ExpiredPemCertificate"/> is an expired PEM-encoded certificate. The certificate is valid for 1 day and was created on September 8 2025, ensuring it will always be expired.
/// - <see cref="ValidPemCertificate"/> is a valid PEM-encoded certificate. The certificate is valid for 100 years and expires on August 4, 2125, ensuring it will not expire during the lifetime of the tests.
/// - <see cref="XmlPrivateKey"/> is their corresponding RSA private key in XML format.
/// </summary>
internal const string ExpiredPemCertificate = @"-----BEGIN CERTIFICATE-----
MIIC/zCCAeegAwIBAgIUGSVU23Wc0+QtCbUTjsyPOrc0XpEwDQYJKoZIhvcNAQEL
BQAwDzENMAsGA1UEAwwEVGVzdDAeFw0yNTA5MDgyMjAxMTdaFw0yNTA5MDkyMjAx
MTdaMA8xDTALBgNVBAMMBFRlc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
AoIBAQC5XNEuk3cIEChkZd2P/bljUaVqNVh4mbXdWHYAgbdK48U6rG0FLq1NAfSn
ZO0EPbK8Zo4psRh2lBcqW29/WsKiHUEHLkLyFI+frEIfc8wskd+WxkKfL8G52uRp
YQCG87FIv8uZBBlDG7kDdOV36CUkK1N+V2fHbkEgx+YfWg6+pLi3KQx6Pf/b2YqL
D36hj8WRrVYzL6yXVUBiyRd+cQ9y5V/MRtoiX1Sv8WEFYtzIG0TUGi9pR7WWhgHN
Qk6DFDzutMV62ZEBNPIQvdO2EwXGr1FUIOL6zmj6bArPhY+hCXGrAAwCXodZhgZ9
5BxTwsQWtjCha2hT6ed8zmoE72FdAgMBAAGjUzBRMB0GA1UdDgQWBBQPYq0Efzuv
1diVcgxBxTnVA4wLMjAfBgNVHSMEGDAWgBQPYq0Efzuv1diVcgxBxTnVA4wLMjAP
BgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCXAD7cjWmmTqP0NX4M
qwO0AHtO+KGVtfxF8aI21Ty/nHh2SAODzsemP3NBBvoEvllwtcVyutPqvUiAflML
Nbp0ucTu+aWE14s1V9Bnt6++5g7gtXItsNV3F/ymYKsyfhDvJbWCOv5qYeJMQ+jt
ODHN9qnATODT5voULTwEVSYQXtutwRxR8e70Cvok+F+4I6Ni49DJ8DmcYzvB94ut
hqpDsygY1vYzpRbB5hpW0/D7kgVVWyWoOWiE1mV7Fry7tUWQw7EqnX89kMLMy4g6
UfOv4gtam8RBa9dLyMW1rCHRxOulP47joI10g9JoJ9DssiQTUojJgQXOSBBXdD20
H+zl
-----END CERTIFICATE-----";
internal const string ValidPemCertificate = @"-----BEGIN CERTIFICATE-----
MIIDATCCAemgAwIBAgIUSfjghyQB4FIS41rWfNcZHTLE/R4wDQYJKoZIhvcNAQEL
BQAwDzENMAsGA1UEAwwEVGVzdDAgFw0yNTA4MjgyMDIxMDBaGA8yMTI1MDgwNDIw
Expand Down
Loading