Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter

public string Resource { get; set; }

public string Claims { get; set; }

public string RevokedTokenHash { get; set; }

public void LogParameters(ILoggerAdapter logger)
{
if (logger.IsLoggingEnabled(LogLevel.Info))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.ManagedIdentity;
using Microsoft.Identity.Client.OAuth2;
using Microsoft.Identity.Client.PlatformsCommon.Interfaces;
using Microsoft.Identity.Client.Utils;

namespace Microsoft.Identity.Client.Internal.Requests
Expand All @@ -18,6 +19,7 @@ internal class ManagedIdentityAuthRequest : RequestBase
{
private readonly AcquireTokenForManagedIdentityParameters _managedIdentityParameters;
private static readonly SemaphoreSlim s_semaphoreSlim = new SemaphoreSlim(1, 1);
private readonly ICryptographyManager _cryptoManager;

public ManagedIdentityAuthRequest(
IServiceBundle serviceBundle,
Expand All @@ -26,71 +28,105 @@ public ManagedIdentityAuthRequest(
: base(serviceBundle, authenticationRequestParameters, managedIdentityParameters)
{
_managedIdentityParameters = managedIdentityParameters;
_cryptoManager = serviceBundle.PlatformProxy.CryptographyManager;
}

protected override async Task<AuthenticationResult> ExecuteAsync(CancellationToken cancellationToken)
{
AuthenticationResult authResult = null;
ILoggerAdapter logger = AuthenticationRequestParameters.RequestContext.Logger;

// Skip checking cache when force refresh or claims is specified
if (_managedIdentityParameters.ForceRefresh || !string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
// 1. FIRST, handle ForceRefresh
if (_managedIdentityParameters.ForceRefresh)
{
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;

logger.Info("[ManagedIdentityRequest] Skipped looking for a cached access token because ForceRefresh or Claims were set. " +
"This means either a force refresh was requested or claims were present.");
logger.Info("[ManagedIdentityRequest] Skipped using the cache because ForceRefresh was set.");

// We still respect claims if present
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;

// Straight to the MI endpoint
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
return authResult;
}

// 2. Otherwise, look for a cached token
MsalAccessTokenCacheItem cachedAccessTokenItem = await GetCachedAccessTokenAsync().ConfigureAwait(false);

// No access token or cached access token needs to be refreshed
// If we have claims, we do NOT use the cached token (but we still need it to compute the hash).
if (!string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
{
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;

// If there is a cached token, compute its hash for the “bad token” scenario
if (cachedAccessTokenItem != null)
{
string cachedTokenHash = _cryptoManager.CreateSha256HashHex(cachedAccessTokenItem.Secret);
_managedIdentityParameters.RevokedTokenHash = cachedTokenHash;

logger.Info("[ManagedIdentityRequest] Claims are present. Computed hash of the cached (bad) token. " +
"Will now request a fresh token from the MI endpoint.");
}
else
{
logger.Info("[ManagedIdentityRequest] Claims are present, but no cached token was found. " +
"Requesting a fresh token from the MI endpoint without a bad-token hash.");
}

// In both cases, we skip using the cached token and get a new one
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
return authResult;
}

// 3. If we have no ForceRefresh and no claims, we can use the cache
if (cachedAccessTokenItem != null)
{
// Found a valid token in cache
authResult = CreateAuthenticationResultFromCache(cachedAccessTokenItem);

logger.Info("[ManagedIdentityRequest] Access token retrieved from cache.");

try
{
var proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);

// If needed, refreshes token in the background
{
// If token is close to expiry, proactively refresh it in the background
bool proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);
if (proactivelyRefresh)
{
logger.Info("[ManagedIdentityRequest] Initiating a proactive refresh.");

AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ProactivelyRefreshed;

SilentRequestHelper.ProcessFetchInBackground(
cachedAccessTokenItem,
() =>
{
// Use a linked token source, in case the original cancellation token source is disposed before this background task completes.
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
return GetAccessTokenAsync(tokenSource.Token, logger);
}, logger, ServiceBundle, AuthenticationRequestParameters.RequestContext.ApiEvent,
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId,
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion);
cachedAccessTokenItem,
() =>
{
// Use a linked token source, in case the original cts is disposed
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
return GetAccessTokenAsync(tokenSource.Token, logger);
},
logger,
ServiceBundle,
AuthenticationRequestParameters.RequestContext.ApiEvent,
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId,
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion);
}
}
catch (MsalServiceException e)
{
// If background refresh fails, we handle the exception
return await HandleTokenRefreshErrorAsync(e, cachedAccessTokenItem).ConfigureAwait(false);
}
}
else
{
// No AT in the cache
// No cached token
if (AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo != CacheRefreshReason.Expired)
{
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.NoCachedAccessToken;
}

logger.Info("[ManagedIdentityRequest] No cached access token. Getting a token from the managed identity endpoint.");
logger.Info("[ManagedIdentityRequest] No cached access token found. " +
"Getting a token from the managed identity endpoint.");
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
using System.Net;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using System.Text;
using System.Collections.Generic;
using System.Linq;
#if SUPPORTS_SYSTEM_TEXT_JSON
using System.Text.Json;
#else
Expand Down Expand Up @@ -48,7 +50,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
// Convert the scopes to a resource string.
string resource = parameters.Resource;

ManagedIdentityRequest request = CreateRequest(resource);
ManagedIdentityRequest request = CreateRequest(resource, parameters);

_requestContext.Logger.Info("[Managed Identity] Sending request to managed identity endpoints.");

Expand Down Expand Up @@ -130,7 +132,7 @@ protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(
throw exception;
}

protected abstract ManagedIdentityRequest CreateRequest(string resource);
protected abstract ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters);

protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response)
{
Expand Down Expand Up @@ -298,5 +300,52 @@ private static void CreateAndThrowException(string errorCode,

throw exception;
}

/// <summary>
/// Sets the claims and capabilities in the request.
/// </summary>
/// <param name="request"></param>
/// <param name="parameters"></param>
protected virtual void ApplyClaimsAndCapabilities(
ManagedIdentityRequest request,
AcquireTokenForManagedIdentityParameters parameters)
{
IEnumerable<string> clientCapabilities = _requestContext.ServiceBundle.Config.ClientCapabilities;

// Set xms_cc only if clientCapabilities exist
if (clientCapabilities != null && clientCapabilities.Any())
{
SetRequestParameter(request, "xms_cc", string.Join(",", clientCapabilities));
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
}

// Only include 'token_sha256_to_refresh' if we have both Claims and the old token's hash
if (!string.IsNullOrEmpty(parameters.Claims) &&
!string.IsNullOrEmpty(parameters.RevokedTokenHash))
{
SetRequestParameter(request, "token_sha256_to_refresh", parameters.RevokedTokenHash);
_requestContext.Logger.Info(
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
);
}
}

/// <summary>
/// Sets the request parameter in either the query or body based on the request method.
/// </summary>
/// <param name="request"></param>
/// <param name="key"></param>
/// <param name="value"></param>
protected void SetRequestParameter(ManagedIdentityRequest request, string key, string value)
{
if (request.Method == HttpMethod.Post)
{
request.BodyParameters[key] = value;
}
else
{
request.QueryParameters[key] = value;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.Utils;
Expand All @@ -13,7 +15,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity
internal class AppServiceManagedIdentitySource : AbstractManagedIdentity
{
// MSI Constants. Docs for MSI are available here https://docs.microsoft.com/azure/app-service/overview-managed-identity
private const string AppServiceMsiApiVersion = "2019-08-01";
private const string AppServiceMsiApiVersion = "2025-03-30";
private const string SecretHeaderName = "X-IDENTITY-HEADER";

private readonly Uri _endpoint;
Expand Down Expand Up @@ -65,14 +67,16 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger
return true;
}

protected override ManagedIdentityRequest CreateRequest(string resource)
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
{
ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint);

request.Headers.Add(SecretHeaderName, _secret);
request.QueryParameters["api-version"] = AppServiceMsiApiVersion;
request.QueryParameters["resource"] = resource;

ApplyClaimsAndCapabilities(request, parameters);

switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
{
case AppConfig.ManagedIdentityIdType.ClientId:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private AzureArcManagedIdentitySource(Uri endpoint, RequestContext requestContex
}
}

protected override ManagedIdentityRequest CreateRequest(string resource)
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
{
ManagedIdentityRequest request = new ManagedIdentityRequest(System.Net.Http.HttpMethod.Get, _endpoint);

Expand Down Expand Up @@ -118,7 +118,7 @@ protected override async Task<ManagedIdentityResponse> HandleResponseAsync(

var authHeaderValue = "Basic " + File.ReadAllText(splitChallenge[1]);

ManagedIdentityRequest request = CreateRequest(parameters.Resource);
ManagedIdentityRequest request = CreateRequest(parameters.Resource, parameters);

_requestContext.Logger.Verbose(() => "[Managed Identity] Adding authorization header to the request.");
request.Headers.Add("Authorization", authHeaderValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Globalization;
using System.Net.Http;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

Expand Down Expand Up @@ -73,15 +74,15 @@ private CloudShellManagedIdentitySource(Uri endpoint, RequestContext requestCont
}
}

protected override ManagedIdentityRequest CreateRequest(string resource)
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
{
ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Post, _endpoint);

request.Headers.Add("ContentType", "application/x-www-form-urlencoded");
request.Headers.Add("Metadata", "true");

request.BodyParameters.Add("resource", resource);

return request;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ internal ImdsManagedIdentitySource(RequestContext requestContext) :
requestContext.Logger.Verbose(() => "[Managed Identity] Creating IMDS managed identity source. Endpoint URI: " + _imdsEndpoint);
}

protected override ManagedIdentityRequest CreateRequest(string resource)
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
{
ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Globalization;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

Expand Down Expand Up @@ -62,7 +63,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger
return true;
}

protected override ManagedIdentityRequest CreateRequest(string resource)
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
{
ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Globalization;
using System.Net.Http;
using System.Net.Security;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

Expand Down Expand Up @@ -32,9 +33,9 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
var exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.InvalidManagedIdentityEndpoint,
errorMessage,
null,
null,
ManagedIdentitySource.ServiceFabric,
null);
null);

throw exception;
}
Expand All @@ -54,7 +55,7 @@ internal override bool ValidateServerCertificate(HttpRequestMessage message, Sys
return string.Equals(certificate.GetCertHashString(), EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase);
}

private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
base(requestContext, ManagedIdentitySource.ServiceFabric)
{
_endpoint = endpoint;
Expand All @@ -66,7 +67,7 @@ private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri en
}
}

protected override ManagedIdentityRequest CreateRequest(string resource)
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
{
ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Get, _endpoint);

Expand All @@ -75,6 +76,8 @@ protected override ManagedIdentityRequest CreateRequest(string resource)
request.QueryParameters["api-version"] = ServiceFabricMsiApiVersion;
request.QueryParameters["resource"] = resource;

ApplyClaimsAndCapabilities(request, parameters);

switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
{
case AppConfig.ManagedIdentityIdType.ClientId:
Expand Down
Loading