diff --git a/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs b/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs index 7c471fea59..3ef28d1c19 100644 --- a/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs +++ b/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs @@ -16,6 +16,10 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter public string Resource { get; set; } + public string Claims { get; set; } + + public string RevokedTokenHash { get; set; } + public void LogParameters(ILoggerAdapter logger) { if (logger.IsLoggingEnabled(LogLevel.Info)) diff --git a/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs b/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs index cb441baa80..5e0b3daaf4 100644 --- a/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs +++ b/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs @@ -10,6 +10,7 @@ using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.PlatformsCommon.Interfaces; using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.Internal.Requests @@ -18,6 +19,7 @@ internal class ManagedIdentityAuthRequest : RequestBase { private readonly AcquireTokenForManagedIdentityParameters _managedIdentityParameters; private static readonly SemaphoreSlim s_semaphoreSlim = new SemaphoreSlim(1, 1); + private readonly ICryptographyManager _cryptoManager; public ManagedIdentityAuthRequest( IServiceBundle serviceBundle, @@ -26,6 +28,7 @@ public ManagedIdentityAuthRequest( : base(serviceBundle, authenticationRequestParameters, managedIdentityParameters) { _managedIdentityParameters = managedIdentityParameters; + _cryptoManager = serviceBundle.PlatformProxy.CryptographyManager; } protected override async Task ExecuteAsync(CancellationToken cancellationToken) @@ -33,32 +36,60 @@ protected override async Task ExecuteAsync(CancellationTok AuthenticationResult authResult = null; ILoggerAdapter logger = AuthenticationRequestParameters.RequestContext.Logger; - // Skip checking cache when force refresh or claims is specified - if (_managedIdentityParameters.ForceRefresh || !string.IsNullOrEmpty(AuthenticationRequestParameters.Claims)) + // 1. FIRST, handle ForceRefresh + if (_managedIdentityParameters.ForceRefresh) { AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims; - - logger.Info("[ManagedIdentityRequest] Skipped looking for a cached access token because ForceRefresh or Claims were set. " + - "This means either a force refresh was requested or claims were present."); + logger.Info("[ManagedIdentityRequest] Skipped using the cache because ForceRefresh was set."); + // We still respect claims if present + _managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims; + + // Straight to the MI endpoint authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false); return authResult; } + // 2. Otherwise, look for a cached token MsalAccessTokenCacheItem cachedAccessTokenItem = await GetCachedAccessTokenAsync().ConfigureAwait(false); - // No access token or cached access token needs to be refreshed + // If we have claims, we do NOT use the cached token (but we still need it to compute the hash). + if (!string.IsNullOrEmpty(AuthenticationRequestParameters.Claims)) + { + _managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims; + AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims; + + // If there is a cached token, compute its hash for the “bad token” scenario + if (cachedAccessTokenItem != null) + { + string cachedTokenHash = _cryptoManager.CreateSha256HashHex(cachedAccessTokenItem.Secret); + _managedIdentityParameters.RevokedTokenHash = cachedTokenHash; + + logger.Info("[ManagedIdentityRequest] Claims are present. Computed hash of the cached (bad) token. " + + "Will now request a fresh token from the MI endpoint."); + } + else + { + logger.Info("[ManagedIdentityRequest] Claims are present, but no cached token was found. " + + "Requesting a fresh token from the MI endpoint without a bad-token hash."); + } + + // In both cases, we skip using the cached token and get a new one + authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false); + return authResult; + } + + // 3. If we have no ForceRefresh and no claims, we can use the cache if (cachedAccessTokenItem != null) { + // Found a valid token in cache authResult = CreateAuthenticationResultFromCache(cachedAccessTokenItem); - logger.Info("[ManagedIdentityRequest] Access token retrieved from cache."); try - { - var proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem); - - // If needed, refreshes token in the background + { + // If token is close to expiry, proactively refresh it in the background + bool proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem); if (proactivelyRefresh) { logger.Info("[ManagedIdentityRequest] Initiating a proactive refresh."); @@ -66,31 +97,36 @@ protected override async Task ExecuteAsync(CancellationTok AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ProactivelyRefreshed; SilentRequestHelper.ProcessFetchInBackground( - cachedAccessTokenItem, - () => - { - // Use a linked token source, in case the original cancellation token source is disposed before this background task completes. - using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - return GetAccessTokenAsync(tokenSource.Token, logger); - }, logger, ServiceBundle, AuthenticationRequestParameters.RequestContext.ApiEvent, - AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId, - AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion); + cachedAccessTokenItem, + () => + { + // Use a linked token source, in case the original cts is disposed + using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + return GetAccessTokenAsync(tokenSource.Token, logger); + }, + logger, + ServiceBundle, + AuthenticationRequestParameters.RequestContext.ApiEvent, + AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId, + AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion); } } catch (MsalServiceException e) { + // If background refresh fails, we handle the exception return await HandleTokenRefreshErrorAsync(e, cachedAccessTokenItem).ConfigureAwait(false); } } else { - // No AT in the cache + // No cached token if (AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo != CacheRefreshReason.Expired) { AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.NoCachedAccessToken; } - logger.Info("[ManagedIdentityRequest] No cached access token. Getting a token from the managed identity endpoint."); + logger.Info("[ManagedIdentityRequest] No cached access token found. " + + "Getting a token from the managed identity endpoint."); authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false); } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs index 65a4f56597..9f4bd50fc0 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs @@ -12,6 +12,8 @@ using System.Net; using Microsoft.Identity.Client.ApiConfig.Parameters; using System.Text; +using System.Collections.Generic; +using System.Linq; #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; #else @@ -48,7 +50,7 @@ public virtual async Task AuthenticateAsync( // Convert the scopes to a resource string. string resource = parameters.Resource; - ManagedIdentityRequest request = CreateRequest(resource); + ManagedIdentityRequest request = CreateRequest(resource, parameters); _requestContext.Logger.Info("[Managed Identity] Sending request to managed identity endpoints."); @@ -130,7 +132,7 @@ protected virtual Task HandleResponseAsync( throw exception; } - protected abstract ManagedIdentityRequest CreateRequest(string resource); + protected abstract ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters); protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response) { @@ -298,5 +300,52 @@ private static void CreateAndThrowException(string errorCode, throw exception; } + + /// + /// Sets the claims and capabilities in the request. + /// + /// + /// + protected virtual void ApplyClaimsAndCapabilities( + ManagedIdentityRequest request, + AcquireTokenForManagedIdentityParameters parameters) + { + IEnumerable clientCapabilities = _requestContext.ServiceBundle.Config.ClientCapabilities; + + // Set xms_cc only if clientCapabilities exist + if (clientCapabilities != null && clientCapabilities.Any()) + { + SetRequestParameter(request, "xms_cc", string.Join(",", clientCapabilities)); + _requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request."); + } + + // Only include 'token_sha256_to_refresh' if we have both Claims and the old token's hash + if (!string.IsNullOrEmpty(parameters.Claims) && + !string.IsNullOrEmpty(parameters.RevokedTokenHash)) + { + SetRequestParameter(request, "token_sha256_to_refresh", parameters.RevokedTokenHash); + _requestContext.Logger.Info( + "[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint." + ); + } + } + + /// + /// Sets the request parameter in either the query or body based on the request method. + /// + /// + /// + /// + protected void SetRequestParameter(ManagedIdentityRequest request, string key, string value) + { + if (request.Method == HttpMethod.Post) + { + request.BodyParameters[key] = value; + } + else + { + request.QueryParameters[key] = value; + } + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs index e0fad1eac1..20eb5f9e92 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs @@ -1,9 +1,11 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using System; using System.Collections.Generic; using System.Globalization; +using System.Linq; +using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Utils; @@ -13,7 +15,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class AppServiceManagedIdentitySource : AbstractManagedIdentity { // MSI Constants. Docs for MSI are available here https://docs.microsoft.com/azure/app-service/overview-managed-identity - private const string AppServiceMsiApiVersion = "2019-08-01"; + private const string AppServiceMsiApiVersion = "2025-03-30"; private const string SecretHeaderName = "X-IDENTITY-HEADER"; private readonly Uri _endpoint; @@ -65,7 +67,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger return true; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters) { ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); @@ -73,6 +75,8 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.QueryParameters["api-version"] = AppServiceMsiApiVersion; request.QueryParameters["resource"] = resource; + ApplyClaimsAndCapabilities(request, parameters); + switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType) { case AppConfig.ManagedIdentityIdType.ClientId: diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs index a643ce7880..1e85f25445 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs @@ -78,7 +78,7 @@ private AzureArcManagedIdentitySource(Uri endpoint, RequestContext requestContex } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters) { ManagedIdentityRequest request = new ManagedIdentityRequest(System.Net.Http.HttpMethod.Get, _endpoint); @@ -118,7 +118,7 @@ protected override async Task HandleResponseAsync( var authHeaderValue = "Basic " + File.ReadAllText(splitChallenge[1]); - ManagedIdentityRequest request = CreateRequest(parameters.Resource); + ManagedIdentityRequest request = CreateRequest(parameters.Resource, parameters); _requestContext.Logger.Verbose(() => "[Managed Identity] Adding authorization header to the request."); request.Headers.Add("Authorization", authHeaderValue); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs index 47006d6727..8e64030b07 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs @@ -4,6 +4,7 @@ using System; using System.Globalization; using System.Net.Http; +using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -73,7 +74,7 @@ private CloudShellManagedIdentitySource(Uri endpoint, RequestContext requestCont } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters) { ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Post, _endpoint); @@ -81,7 +82,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.Headers.Add("Metadata", "true"); request.BodyParameters.Add("resource", resource); - + return request; } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index 6cfb8854e6..e0875c75aa 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -56,7 +56,7 @@ internal ImdsManagedIdentitySource(RequestContext requestContext) : requestContext.Logger.Verbose(() => "[Managed Identity] Creating IMDS managed identity source. Endpoint URI: " + _imdsEndpoint); } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters) { ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs index f69f34de7a..46919a61e0 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs @@ -3,6 +3,7 @@ using System; using System.Globalization; +using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -62,7 +63,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger return true; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters) { ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs index a8fb2379fd..04b555f145 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs @@ -5,6 +5,7 @@ using System.Globalization; using System.Net.Http; using System.Net.Security; +using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -32,9 +33,9 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) var exception = MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.InvalidManagedIdentityEndpoint, errorMessage, - null, + null, ManagedIdentitySource.ServiceFabric, - null); + null); throw exception; } @@ -54,7 +55,7 @@ internal override bool ValidateServerCertificate(HttpRequestMessage message, Sys return string.Equals(certificate.GetCertHashString(), EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase); } - private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) : + private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) : base(requestContext, ManagedIdentitySource.ServiceFabric) { _endpoint = endpoint; @@ -66,7 +67,7 @@ private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri en } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters) { ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Get, _endpoint); @@ -75,6 +76,8 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.QueryParameters["api-version"] = ServiceFabricMsiApiVersion; request.QueryParameters["resource"] = resource; + ApplyClaimsAndCapabilities(request, parameters); + switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType) { case AppConfig.ManagedIdentityIdType.ClientId: diff --git a/src/client/Microsoft.Identity.Client/Utils/CoreHelpers.cs b/src/client/Microsoft.Identity.Client/Utils/CoreHelpers.cs index 9adb92a15f..2b5ca56db8 100644 --- a/src/client/Microsoft.Identity.Client/Utils/CoreHelpers.cs +++ b/src/client/Microsoft.Identity.Client/Utils/CoreHelpers.cs @@ -75,40 +75,54 @@ public static string ToQueryParameter(this IDictionary input) return builder.ToString(); } - public static Dictionary ParseKeyValueList(string input, char delimiter, bool urlDecode, - bool lowercaseKeys, + public static Dictionary ParseKeyValueList( + string input, + char delimiter, + bool urlDecode, + bool lowercaseKeys, RequestContext requestContext) { var response = new Dictionary(); + // Split the full query string on & (or any provided delimiter) to get individual k=v pairs. var queryPairs = SplitWithQuotes(input, delimiter); foreach (string queryPair in queryPairs) { - var pair = SplitWithQuotes(queryPair, '='); + // Instead of splitting on *all* '=' characters, find only the first one. + // This ensures that if the value itself contains '=', such as a trailing '=' in Base64, + // we do not accidentally split the base64 value into extra parts and lose the padding. + int idx = queryPair.IndexOf('='); - if (pair.Count == 2 && !string.IsNullOrWhiteSpace(pair[0]) && !string.IsNullOrWhiteSpace(pair[1])) + // idx > 0 means we found an '=' and have a valid key substring before it + if (idx > 0) { - string key = pair[0]; - string value = pair[1]; + // The key is everything before the first '=' + string key = queryPair.Substring(0, idx); + + // The value is everything after the first '=' (including any trailing '=') + string value = queryPair.Substring(idx + 1); - // Url decoding is needed for parsing OAuth response, but not for parsing WWW-Authenticate header in 401 challenge + // If urlDecode == true, decode both key and value if (urlDecode) { key = UrlDecode(key); value = UrlDecode(value); } + // Optionally convert key to lowercase if (lowercaseKeys) { key = key.Trim().ToLowerInvariant(); } + // Trim quotes and whitespace around the value value = value.Trim().Trim('\"').Trim(); if (response.ContainsKey(key)) { - requestContext?.Logger.Warning(string.Format(CultureInfo.InvariantCulture, + requestContext?.Logger.Warning( + string.Format(CultureInfo.InvariantCulture, "Key/value pair list contains redundant key '{0}'.", key)); } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 91e5c3d268..4db94cbb36 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -370,7 +370,7 @@ public static HttpResponseMessage CreateSuccessTokenResponseMessage( string[] scope, bool foci = false, string utid = TestConstants.Utid, - string accessToken = "some-access-token", + string accessToken = TestConstants.ATSecret, string refreshToken = "OAAsomethingencrypedQwgAA") { HttpResponseMessage responseMessage = new HttpResponseMessage(HttpStatusCode.OK); @@ -385,7 +385,7 @@ public static string CreateSuccessTokenResponseString(string uniqueId, string[] scope, bool foci = false, string utid = TestConstants.Utid, - string accessToken = "some-access-token", + string accessToken = TestConstants.ATSecret, string refreshToken = "OAAsomethingencrypedQwgAA") { string idToken = CreateIdToken(uniqueId, displayableId, TestConstants.Utid); diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 5a04bedf5b..4455d4636b 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -7,12 +7,15 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Security.Cryptography; +using System.Text; using System.Threading.Tasks; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.Instance; using Microsoft.Identity.Client.Instance.Discovery; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Client.Utils; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Unit; @@ -367,14 +370,21 @@ public static void AddManagedIdentityMockHandler( ManagedIdentitySource managedIdentitySourceType, string userAssignedId = null, UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, - HttpStatusCode statusCode = HttpStatusCode.OK + HttpStatusCode statusCode = HttpStatusCode.OK, + bool capabilityEnabled = false, + bool claimsEnabled = false ) { HttpResponseMessage responseMessage = new HttpResponseMessage(statusCode); HttpContent content = new StringContent(response); responseMessage.Content = content; - MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(managedIdentitySourceType, resource); + MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource( + managedIdentitySourceType, + resource, + capabilityEnabled, + claimsEnabled + ); if (userAssignedIdentityId == UserAssignedIdentityId.ClientId) { @@ -408,18 +418,24 @@ public static void AddManagedIdentityMockHandler( httpManager.AddMockHandler(httpMessageHandler); } - - private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource) + + private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource( + ManagedIdentitySource managedIdentitySourceType, + string resource, + bool capabilityEnabled = false, + bool claimsEnabled = false) { MockHttpMessageHandler httpMessageHandler = new MockHttpMessageHandler(); IDictionary expectedQueryParams = new Dictionary(); + IDictionary notExpectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); + IDictionary expectedPostData = null; // Only used for Cloud Shell switch (managedIdentitySourceType) { case ManagedIdentitySource.AppService: httpMessageHandler.ExpectedMethod = HttpMethod.Get; - expectedQueryParams.Add("api-version", "2019-08-01"); + expectedQueryParams.Add("api-version", "2025-03-30"); expectedQueryParams.Add("resource", resource); expectedRequestHeaders.Add("X-IDENTITY-HEADER", "secret"); break; @@ -439,7 +455,10 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(M httpMessageHandler.ExpectedMethod = HttpMethod.Post; expectedRequestHeaders.Add("Metadata", "true"); expectedRequestHeaders.Add("ContentType", "application/x-www-form-urlencoded"); - httpMessageHandler.ExpectedPostData = new Dictionary { { "resource", resource } }; + expectedPostData = new Dictionary + { + { "resource", resource } + }; break; case ManagedIdentitySource.ServiceFabric: httpMessageHandler.ExpectedMethod = HttpMethod.Get; @@ -456,11 +475,45 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(M break; } + var manager = new CommonCryptographyManager(); + var value = manager.CreateSha256HashHex(TestConstants.ATSecret); + + // If capabilityEnabled, add "xms_cc": "cp1" + if (capabilityEnabled) + { + if (managedIdentitySourceType == ManagedIdentitySource.AppService + || managedIdentitySourceType == ManagedIdentitySource.ServiceFabric) + { + expectedQueryParams.Add("xms_cc", "cp1,cp2"); + } + } + else + { + notExpectedQueryParams.Add("xms_cc", "cp1,cp2"); + } + + if (claimsEnabled) + { + if (managedIdentitySourceType == ManagedIdentitySource.AppService + || managedIdentitySourceType == ManagedIdentitySource.ServiceFabric) + { + expectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256HashHex(TestConstants.ATSecret)); + } + } + else + { + notExpectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256HashHex(TestConstants.ATSecret)); + } + if (managedIdentitySourceType != ManagedIdentitySource.CloudShell) { httpMessageHandler.ExpectedQueryParams = expectedQueryParams; } - + else + { + httpMessageHandler.ExpectedPostData = expectedPostData; + } + httpMessageHandler.ExpectedRequestHeaders = expectedRequestHeaders; return httpMessageHandler; diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs index 5b1fb2293b..e067d640b8 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs @@ -25,6 +25,7 @@ internal class MockHttpMessageHandler : HttpClientHandler public IDictionary ExpectedRequestHeaders { get; set; } public IList UnexpectedRequestHeaders { get; set; } public IDictionary UnExpectedPostData { get; set; } + public IDictionary NotExpectedQueryParams { get; set; } public HttpMethod ExpectedMethod { get; set; } public Exception ExceptionToThrow { get; set; } @@ -65,7 +66,9 @@ protected override async Task SendAsync(HttpRequestMessage Assert.AreEqual(ExpectedMethod, request.Method); - ValidateQueryParams(uri); + ValidateExpectedQueryParams(uri); + + ValidateNotExpectedQueryParams(uri); await ValidatePostDataAsync(request).ConfigureAwait(false); @@ -80,12 +83,12 @@ protected override async Task SendAsync(HttpRequestMessage return ResponseMessage; } - private void ValidateQueryParams(Uri uri) + private void ValidateExpectedQueryParams(Uri uri) { if (ExpectedQueryParams != null && ExpectedQueryParams.Any()) { Assert.IsFalse(string.IsNullOrEmpty(uri.Query), $"Provided url ({uri.AbsoluteUri}) does not contain query parameters as expected."); - var inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null); + Dictionary inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null); Assert.AreEqual(ExpectedQueryParams.Count, inputQp.Count, "Different number of query params."); foreach (var key in ExpectedQueryParams.Keys) { @@ -95,6 +98,35 @@ private void ValidateQueryParams(Uri uri) } } + private void ValidateNotExpectedQueryParams(Uri uri) + { + if (NotExpectedQueryParams != null && NotExpectedQueryParams.Any()) + { + // Parse actual query params again (or reuse inputQp if you like) + Dictionary actualQueryParams = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null); + List unexpectedKeysFound = new List(); + + foreach (KeyValuePair kvp in NotExpectedQueryParams) + { + // Check if the request's query has this key + if (actualQueryParams.TryGetValue(kvp.Key, out string value)) + { + // Optionally, also check if we care about matching the *value*: + if (string.Equals(value, kvp.Value, StringComparison.OrdinalIgnoreCase)) + { + unexpectedKeysFound.Add(kvp.Key); + } + } + } + + // Fail if any "not expected" key/value pairs were found + Assert.IsTrue( + unexpectedKeysFound.Count == 0, + $"Did not expect to find these query parameter keys/values: {string.Join(", ", unexpectedKeysFound)}" + ); + } + } + private async Task ValidatePostDataAsync(HttpRequestMessage request) { if (request.Method != HttpMethod.Get && request.Content != null) diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs index 70b7caf3bf..0b5972429d 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs @@ -128,6 +128,68 @@ public async Task AcquireMSITokenAsync(MsiAzureResource azureResource, string us } } + [DataTestMethod] + [DataRow(MsiAzureResource.WebApp, "", DisplayName = "System_Identity_Web_App")] + [DataRow(MsiAzureResource.WebApp, UserAssignedClientID, UserAssignedIdentityId.ClientId, DisplayName = "ClientId_Web_App")] + [DataRow(MsiAzureResource.WebApp, UamiResourceId, UserAssignedIdentityId.ResourceId, DisplayName = "ResourceID_Web_App")] + [DataRow(MsiAzureResource.WebApp, UserAssignedObjectID, UserAssignedIdentityId.ObjectId, DisplayName = "ObjectID_Web_App")] + public async Task AcquireMSITokenWithClaimsAsync( + MsiAzureResource azureResource, + string userIdentity, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None) + { + using (new EnvVariableContext()) + { + // ---------- Arrange ---------- + var envVariables = await GetEnvironmentVariablesAsync(azureResource).ConfigureAwait(false); + SetEnvironmentVariables(envVariables); + + string uri = s_baseURL + $"MSIToken?azureresource={azureResource}&uri="; + + IManagedIdentityApplication mia = + CreateMIAWithProxy(uri, userIdentity, userAssignedIdentityId); + + // ---------- Act & Assert 1 ---------- + AuthenticationResult result1 = await mia + .AcquireTokenForManagedIdentity(s_msi_scopes) + .ExecuteAsync() + .ConfigureAwait(false); + + Assert.AreEqual("Bearer", result1.TokenType); + Assert.AreEqual(TokenSource.IdentityProvider, + result1.AuthenticationResultMetadata.TokenSource); + CoreAssert.IsWithinRange( + DateTimeOffset.UtcNow, + result1.ExpiresOn, + TimeSpan.FromHours(24)); + + // ---------- Act & Assert 2 (cache hit) ---------- + AuthenticationResult result2 = await mia + .AcquireTokenForManagedIdentity(s_msi_scopes) + .ExecuteAsync() + .ConfigureAwait(false); + + Assert.IsTrue(result2.Scopes.All(s_msi_scopes.Contains)); + Assert.AreEqual(TokenSource.Cache, + result2.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(result1.AccessToken, result2.AccessToken, // sanity + "Second call should come from cache"); + + // ---------- Act & Assert 3 (claims → bypass_cache) ---------- + const string claimsJson = TestConstants.Claims; + + AuthenticationResult result3 = await mia + .AcquireTokenForManagedIdentity(s_msi_scopes) + .WithClaims(claimsJson) + .ExecuteAsync() + .ConfigureAwait(false); + + // Token source should now be IdentityProvider again + Assert.AreEqual(TokenSource.IdentityProvider, + result3.AuthenticationResultMetadata.TokenSource); + } + } + [TestMethod] public async Task AcquireMsiToken_ForTokenExchangeResource_Successfully() { @@ -449,7 +511,7 @@ private IManagedIdentityApplication CreateMIAWithProxy(string url, string userAs // Disabling shared cache options to avoid cross test pollution. builder.Config.AccessorOptions = null; - IManagedIdentityApplication mia = builder + IManagedIdentityApplication mia = builder.WithClientCapabilities(new[] { "cp1" }) .WithHttpManager(proxyHttpManager).Build(); return mia; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index 84d399a51d..0a71102a32 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -18,6 +18,8 @@ using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.VisualStudio.TestTools.UnitTesting; +using OpenTelemetry.Resources; +using Microsoft.Identity.Client.PlatformsCommon.Shared; using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests @@ -318,7 +320,9 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( endpoint, Resource, MockHelpers.GetMsiSuccessfulResponse(), - managedIdentitySource); + managedIdentitySource, + claimsEnabled: false, + capabilityEnabled: true); var result = await mi.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false); @@ -338,7 +342,9 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( endpoint, scope, MockHelpers.GetMsiSuccessfulResponse(), - managedIdentitySource); + managedIdentitySource, + claimsEnabled: true, + capabilityEnabled: true); // Acquire token with force refresh result = await mi.AcquireTokenForManagedIdentity(scope).WithClaims(TestConstants.Claims) @@ -400,7 +406,8 @@ public async Task ManagedIdentityWithClaimsTestAsync( endpoint, scope, MockHelpers.GetMsiSuccessfulResponse(), - managedIdentitySource); + managedIdentitySource, + claimsEnabled: true); // Acquire token with force refresh result = await mi.AcquireTokenForManagedIdentity(scope).WithClaims(TestConstants.Claims) @@ -412,6 +419,56 @@ public async Task ManagedIdentityWithClaimsTestAsync( } } + [DataTestMethod] + [DataRow(AppServiceEndpoint, Resource, ManagedIdentitySource.AppService)] + [DataRow(ImdsEndpoint, Resource, ManagedIdentitySource.Imds)] + [DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)] + [DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)] + [DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, Resource, ManagedIdentitySource.MachineLearning)] + public async Task ManagedIdentityWithCapabilitiesTestAsync( + string endpoint, + string scope, + ManagedIdentitySource managedIdentitySource) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithClientCapabilities(TestConstants.ClientCapabilities) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + httpManager.AddManagedIdentityMockHandler( + endpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + managedIdentitySource, + claimsEnabled: false, + capabilityEnabled: true); + + var result = await mi.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // Acquire token from cache + result = await mi.AcquireTokenForManagedIdentity(scope) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + [DataTestMethod] [DataRow("user.read", ManagedIdentitySource.AppService, AppServiceEndpoint)] [DataRow("https://management.core.windows.net//user_impersonation", ManagedIdentitySource.AppService, AppServiceEndpoint)]