Skip to content

Commit

Permalink
Lozensky/enable managed identity (#2650)
Browse files Browse the repository at this point in the history
* Add logic to default to using managed identity if provided.

* remove blank line

* Updated with caching and new design

* rearranging methods

* made GetOrBuildManagedIdentityApplication async

* added unit test for application caching

* finished unit test first draft

* minor changes

* changed according to PR feedback

* Add logic to default to using managed identity if provided.

* remove blank line

* Updated with caching and new design

* rearranging methods

* made GetOrBuildManagedIdentityApplication async

* added unit test for application caching

* finished unit test first draft

* minor changes

* changed according to PR feedback

* Rebase onto main

* added system-assigned managed identity e2e test

* Implemented PR feedback

* changing test to use user-assigned managed identity

* fixing tests

* Added configuration to e2e test

* moved build to after identity options config

* moving builder back

* fixed bug with TokenAcquisitionOptions/DefaultAuthorizationHeaderProvider

* simplified e2e test

* added concurrency test and removed reflection

* addressed PR comments and removed unnecessary code

* removed extra space

* addressed PR feedback

* making changes per PR comments

* removing test traces

---------

Co-authored-by: Jean-Marc Prieur <jmprieur@microsoft.com>
  • Loading branch information
JoshLozensky and jmprieur authored Jan 28, 2024
1 parent 59ce6ad commit 8909553
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 45 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,7 @@ MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
/tools/app-provisioning-tool/testwebapp

# Playwright e2e testing trace files
/tests/E2E Tests/PlaywrightTraces
/tests/IntegrationTests/PlaywrightTraces
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ public async Task<string> CreateAuthorizationHeaderForUserAsync(
}

/// <inheritdoc/>
public async Task<string> CreateAuthorizationHeaderForAppAsync(string scopes, AuthorizationHeaderProviderOptions? downstreamApiOptions = null, CancellationToken cancellationToken = default)
public async Task<string> CreateAuthorizationHeaderForAppAsync(
string scopes,
AuthorizationHeaderProviderOptions? downstreamApiOptions = null,
CancellationToken cancellationToken = default)
{
var result = await _tokenAcquisition.GetAuthenticationResultForAppAsync(
scopes,
Expand All @@ -47,7 +50,9 @@ public async Task<string> CreateAuthorizationHeaderForAppAsync(string scopes, Au
return result.CreateAuthorizationHeader();
}

private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptions(AuthorizationHeaderProviderOptions? downstreamApiOptions, CancellationToken cancellationToken)
private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptions(
AuthorizationHeaderProviderOptions? downstreamApiOptions,
CancellationToken cancellationToken)
{
return new TokenAcquisitionOptions()
{
Expand All @@ -58,6 +63,7 @@ private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptio
ExtraQueryParameters = downstreamApiOptions?.AcquireTokenOptions.ExtraQueryParameters,
ForceRefresh = downstreamApiOptions?.AcquireTokenOptions.ForceRefresh ?? false,
LongRunningWebApiSessionKey = downstreamApiOptions?.AcquireTokenOptions.LongRunningWebApiSessionKey,
ManagedIdentity = downstreamApiOptions?.AcquireTokenOptions.ManagedIdentity,
Tenant = downstreamApiOptions?.AcquireTokenOptions.Tenant,
UserFlow = downstreamApiOptions?.AcquireTokenOptions.UserFlow,
PopPublicKey = downstreamApiOptions?.AcquireTokenOptions.PopPublicKey,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Abstractions;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.IdentityModel.Tokens;

namespace Microsoft.Identity.Web
{
/// <summary>
/// Portion of the TokenAcquisition class that handles logic unique to managed identity.
/// </summary>
internal partial class TokenAcquisition
{
private readonly ConcurrentDictionary<string, IManagedIdentityApplication> _managedIdentityApplicationsByClientId = new();
private readonly SemaphoreSlim _managedIdSemaphore = new(1, 1);
private const string SystemAssignedManagedIdentityKey = "SYSTEM";

/// <summary>
/// Gets a cached ManagedIdentityApplication object or builds a new one if not found.
/// </summary>
/// <param name="mergedOptions">The configuration options for the app.</param>
/// <param name="managedIdentityOptions">The configuration specific to managed identity.</param>
/// <returns>The application object used to request a token with managed identity.</returns>
internal async Task<IManagedIdentityApplication> GetOrBuildManagedIdentityApplication(
MergedOptions mergedOptions,
ManagedIdentityOptions managedIdentityOptions)
{
string key = GetCacheKeyForManagedId(managedIdentityOptions);

// Check if the application is already built, if so return it without grabbing the lock
if (_managedIdentityApplicationsByClientId.TryGetValue(key, out IManagedIdentityApplication? application))
{
return application;
}

// Lock the potential write of the dictionary to prevent multiple threads from creating the same application.
await _managedIdSemaphore.WaitAsync();
try
{
// Check if the application is already built (could happen between previous check and obtaining the key)
if (_managedIdentityApplicationsByClientId.TryGetValue(key, out application))
{
return application;
}

// Set managedIdentityId to the correct value for either system or user assigned
ManagedIdentityId managedIdentityId;
if (key == SystemAssignedManagedIdentityKey)
{
managedIdentityId = ManagedIdentityId.SystemAssigned;
}
else
{
managedIdentityId = ManagedIdentityId.WithUserAssignedClientId(key);
}

// Build the application
application = BuildManagedIdentityApplication(
managedIdentityId,
mergedOptions.ConfidentialClientApplicationOptions.EnablePiiLogging
);

// Add the application to the cache
_managedIdentityApplicationsByClientId.TryAdd(key, application);
}
finally
{
// Now that the dictionary is updated, release the semaphore
_managedIdSemaphore.Release();
}
return application;
}

/// <summary>
/// Creates a managed identity client application.
/// </summary>
/// <param name="managedIdentityId">Indicates if system-assigned or user-assigned managed identity is used.</param>
/// <param name="enablePiiLogging">Indicates if logging that may contain personally identifiable information is enabled.</param>
/// <returns>A managed identity application.</returns>
private IManagedIdentityApplication BuildManagedIdentityApplication(ManagedIdentityId managedIdentityId, bool enablePiiLogging)
{
return ManagedIdentityApplicationBuilder
.Create(managedIdentityId)
.WithLogging(
Log,
ConvertMicrosoftExtensionsLogLevelToMsal(_logger),
enablePiiLogging: enablePiiLogging)
.Build();
}

/// <summary>
/// Gets the key value for the Managed Identity cache, the default key for system-assigned identity is used if there is
/// no clientId for a user-assigned identity specified. The method is internal rather than private for testing purposes.
/// </summary>
/// <param name="managedIdOptions">Holds the clientId for managed identity if none is present.</param>
/// <returns>A key value for the Managed Identity cache.</returns>
internal static string GetCacheKeyForManagedId(ManagedIdentityOptions managedIdOptions)
{
if (managedIdOptions.UserAssignedClientId.IsNullOrEmpty())
{
return SystemAssignedManagedIdentityKey;
}
else
{
return managedIdOptions.UserAssignedClientId!;
}
}
}
}
42 changes: 31 additions & 11 deletions src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Advanced;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Web.Experimental;
using Microsoft.Identity.Web.TokenCacheProviders;
using Microsoft.Identity.Web.TokenCacheProviders.InMemory;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Tokens;
using Microsoft.Identity.Web.Experimental;

namespace Microsoft.Identity.Web
{
Expand All @@ -47,9 +47,9 @@ class OAuthConstants
private readonly object _applicationSyncObj = new();

/// <summary>
/// Please call GetOrBuildConfidentialClientApplication instead of accessing this field directly.
/// Please call GetOrBuildConfidentialClientApplication instead of accessing _applicationsByAuthorityClientId directly.
/// </summary>
private readonly ConcurrentDictionary<string, IConfidentialClientApplication?> _applicationsByAuthorityClientId = new ConcurrentDictionary<string, IConfidentialClientApplication?>();
private readonly ConcurrentDictionary<string, IConfidentialClientApplication?> _applicationsByAuthorityClientId = new();
private bool _retryClientCertificate;
protected readonly IMsalHttpClientFactory _httpClientFactory;
protected readonly ILogger _logger;
Expand Down Expand Up @@ -115,7 +115,7 @@ public async Task<AcquireTokenResult> AddAccountToCacheFromAuthorizationCodeAsyn
_ = Throws.IfNull(authCodeRedemptionParameters.Scopes);
MergedOptions mergedOptions = _tokenAcquisitionHost.GetOptions(authCodeRedemptionParameters.AuthenticationScheme, out string effectiveAuthenticationScheme);

IConfidentialClientApplication? application=null;
IConfidentialClientApplication? application = null;
try
{
application = GetOrBuildConfidentialClientApplication(mergedOptions);
Expand Down Expand Up @@ -321,9 +321,10 @@ private void LogAuthResult(AuthenticationResult? authenticationResult)

/// <summary>
/// Acquires an authentication result from the authority configured in the app, for the confidential client itself (not on behalf of a user)
/// using the client credentials flow. See https://aka.ms/msal-net-client-credentials.
/// using either a client credentials or managed identity flow. See https://aka.ms/msal-net-client-credentials for client credentials or
/// https://aka.ms/Entra/ManagedIdentityOverview for managed identity.
/// </summary>
/// <param name="scope">The scope requested to access a protected API. For this flow (client credentials), the scope
/// <param name="scope">The scope requested to access a protected API. For these flows (client credentials or managed identity), the scope
/// should be of the form "{ResourceIdUri/.default}" for instance <c>https://management.azure.net/.default</c> or, for Microsoft
/// Graph, <c>https://graph.microsoft.com/.default</c> as the requested scopes are defined statically with the application registration
/// in the portal, and cannot be overridden in the application, as you can request a token for only one resource at a time (use
Expand Down Expand Up @@ -358,10 +359,28 @@ public async Task<AuthenticationResult> GetAuthenticationResultForAppAsync(
throw new ArgumentException(IDWebErrorMessage.ClientCredentialTenantShouldBeTenanted, nameof(tenant));
}

// If using managed identity
if (tokenAcquisitionOptions != null && tokenAcquisitionOptions.ManagedIdentity != null)
{
try
{
IManagedIdentityApplication managedIdApp = await GetOrBuildManagedIdentityApplication(
mergedOptions,
tokenAcquisitionOptions.ManagedIdentity
);
return await managedIdApp.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
Logger.TokenAcquisitionError(_logger, ex.Message, ex);
throw;
}
}

// Use MSAL to get the right token to call the API
var application = GetOrBuildConfidentialClientApplication(mergedOptions);

var builder = application
AcquireTokenForClientParameterBuilder builder = application
.AcquireTokenForClient(new[] { scope }.Except(_scopesRequestedByMsal))
.WithSendX5C(mergedOptions.SendX5C);

Expand Down Expand Up @@ -585,7 +604,6 @@ private bool IsInvalidClientCertificateOrSignedAssertionError(MsalServiceExcepti
_applicationsByAuthorityClientId.TryAdd(GetApplicationKey(mergedOptions), application);
}
}

return application;
}

Expand All @@ -599,7 +617,7 @@ private IConfidentialClientApplication BuildConfidentialClientApplication(Merged

try
{
var builder = ConfidentialClientApplicationBuilder
ConfidentialClientApplicationBuilder builder = ConfidentialClientApplicationBuilder
.CreateWithApplicationOptions(mergedOptions.ConfidentialClientApplicationOptions)
.WithHttpClientFactory(_httpClientFactory)
.WithLogging(
Expand Down Expand Up @@ -848,8 +866,10 @@ private static void CheckAssertionsForInjectionAttempt(string assertion, string
if (!assertion.IsNullOrEmpty() && assertion.Contains('&')) throw new ArgumentException(IDWebErrorMessage.InvalidAssertion, nameof(assertion));
if (!subAssertion.IsNullOrEmpty() && subAssertion.Contains('&')) throw new ArgumentException(IDWebErrorMessage.InvalidSubAssertion, nameof(subAssertion));
#else
if (!assertion.IsNullOrEmpty() && assertion.Contains('&', StringComparison.InvariantCultureIgnoreCase)) throw new ArgumentException(IDWebErrorMessage.InvalidAssertion, nameof(assertion));
if (!subAssertion.IsNullOrEmpty() && subAssertion.Contains('&', StringComparison.InvariantCultureIgnoreCase)) throw new ArgumentException(IDWebErrorMessage.InvalidSubAssertion, nameof(subAssertion));
if (!assertion.IsNullOrEmpty() && assertion.Contains('&', StringComparison.InvariantCultureIgnoreCase))
throw new ArgumentException(IDWebErrorMessage.InvalidAssertion, nameof(assertion));
if (!subAssertion.IsNullOrEmpty() && subAssertion.Contains('&', StringComparison.InvariantCultureIgnoreCase))
throw new ArgumentException(IDWebErrorMessage.InvalidSubAssertion, nameof(subAssertion));
#endif
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class TokenAcquisitionOptions : AcquireTokenOptions
PopClaim = PopClaim,
CancellationToken = CancellationToken,
LongRunningWebApiSessionKey = LongRunningWebApiSessionKey,
ManagedIdentity = ManagedIdentity,
};
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using Xunit;

namespace TokenAcquirerTests
{
public sealed class OnlyOnAzureDevopsFactAttribute : FactAttribute
{
public OnlyOnAzureDevopsFactAttribute()
{
if (IgnoreOnAzureDevopsFactAttribute.IsRunningOnAzureDevOps())
{
return;
}
Skip = "Ignored when not on Azure DevOps";
}
}
}
43 changes: 43 additions & 0 deletions tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

using System;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Microsoft.Graph;
Expand Down Expand Up @@ -299,5 +301,46 @@ private static async Task CreateGraphClientAndAssert(TokenAcquirerFactory tokenA
Assert.NotNull(result.AccessToken);
}
}

public class AcquireTokenManagedIdentity
{
[OnlyOnAzureDevopsFact]
//[Fact]
public async Task AcquireTokenWithManagedIdentity_UserAssigned()
{
// Arrange
const string scope = "https://vault.azure.net/.default";
const string baseUrl = "https://vault.azure.net";
const string clientId = "9c5896db-a74a-4b1a-a259-74c5080a3a6a";
TokenAcquirerFactory tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
_ = tokenAcquirerFactory.Services;
IServiceProvider serviceProvider = tokenAcquirerFactory.Build();

// Act: Get the authorization header provider and add the options to tell it to use Managed Identity
IAuthorizationHeaderProvider? api = serviceProvider.GetRequiredService<IAuthorizationHeaderProvider>();
Assert.NotNull(api);
string result = await api.CreateAuthorizationHeaderForAppAsync(scope, GetAuthHeaderOptions_ManagedId(baseUrl, clientId));

// Assert: Make sure we got a token
Assert.False(string.IsNullOrEmpty(result));
}

private static AuthorizationHeaderProviderOptions GetAuthHeaderOptions_ManagedId(string baseUrl, string? userAssignedClientId=null)
{
ManagedIdentityOptions managedIdentityOptions = new()
{
UserAssignedClientId = userAssignedClientId
};
AcquireTokenOptions aquireTokenOptions = new()
{
ManagedIdentity = managedIdentityOptions
};
return new AuthorizationHeaderProviderOptions()
{
BaseUrl = baseUrl,
AcquireTokenOptions = aquireTokenOptions
};
}
}
#endif //FROM_GITHUB_ACTION
}
9 changes: 8 additions & 1 deletion tests/Microsoft.Identity.Web.Test.Common/TestConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;

namespace Microsoft.Identity.Web.Test.Common
{
Expand Down Expand Up @@ -186,8 +187,14 @@ public static class TestConstants
public static readonly string s_todoListServicePath = Path.DirectorySeparatorChar.ToString() + "TodoListService";


// TokenAcqusitionOptions
// TokenAcqusitionOptions and ManagedIdentityOptions
public static Guid s_correlationId = new Guid("6347d33d-941a-4c35-9912-a9cf54fb1b3e");
public const string UserAssignedManagedIdentityClientId = "3b57c42c-3201-4295-ae27-d6baec5b7027";
public const string UserAssignedManagedIdentityResourceId = "/subscriptions/c1686c51-b717-4fe0-9af3-24a20a41fb0c/" +
"resourcegroups/MSAL_MSI/providers/Microsoft.ManagedIdentity/userAssignedIdentities/" + "MSAL_MSI_USERID";
public const BindingFlags StaticPrivateFieldFlags = BindingFlags.GetField | BindingFlags.Static | BindingFlags.NonPublic;
public const BindingFlags InstancePrivateFieldFlags = BindingFlags.GetField | BindingFlags.Instance | BindingFlags.NonPublic;
public const BindingFlags StaticPrivateMethodFlags = BindingFlags.InvokeMethod | BindingFlags.Static | BindingFlags.NonPublic;

// AadIssuerValidation
public const string AadAuthority = "aadAuthority";
Expand Down
Loading

0 comments on commit 8909553

Please sign in to comment.