diff --git a/.gitignore b/.gitignore index f478d67ba..25da21fa9 100644 --- a/.gitignore +++ b/.gitignore @@ -350,3 +350,7 @@ MigrationBackup/ # Ionide (cross platform F# VS Code tools) working folder .ionide/ /tools/app-provisioning-tool/testwebapp + +# Playwright e2e testing trace files +/tests/E2E Tests/PlaywrightTraces +/tests/IntegrationTests/PlaywrightTraces \ No newline at end of file diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs b/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs index 65295491d..a630a859f 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs @@ -37,7 +37,10 @@ public async Task CreateAuthorizationHeaderForUserAsync( } /// - public async Task CreateAuthorizationHeaderForAppAsync(string scopes, AuthorizationHeaderProviderOptions? downstreamApiOptions = null, CancellationToken cancellationToken = default) + public async Task CreateAuthorizationHeaderForAppAsync( + string scopes, + AuthorizationHeaderProviderOptions? downstreamApiOptions = null, + CancellationToken cancellationToken = default) { var result = await _tokenAcquisition.GetAuthenticationResultForAppAsync( scopes, @@ -47,7 +50,9 @@ public async Task CreateAuthorizationHeaderForAppAsync(string scopes, Au return result.CreateAuthorizationHeader(); } - private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptions(AuthorizationHeaderProviderOptions? downstreamApiOptions, CancellationToken cancellationToken) + private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptions( + AuthorizationHeaderProviderOptions? downstreamApiOptions, + CancellationToken cancellationToken) { return new TokenAcquisitionOptions() { @@ -58,6 +63,7 @@ private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptio ExtraQueryParameters = downstreamApiOptions?.AcquireTokenOptions.ExtraQueryParameters, ForceRefresh = downstreamApiOptions?.AcquireTokenOptions.ForceRefresh ?? false, LongRunningWebApiSessionKey = downstreamApiOptions?.AcquireTokenOptions.LongRunningWebApiSessionKey, + ManagedIdentity = downstreamApiOptions?.AcquireTokenOptions.ManagedIdentity, Tenant = downstreamApiOptions?.AcquireTokenOptions.Tenant, UserFlow = downstreamApiOptions?.AcquireTokenOptions.UserFlow, PopPublicKey = downstreamApiOptions?.AcquireTokenOptions.PopPublicKey, diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.ManagedIdentity.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.ManagedIdentity.cs new file mode 100644 index 000000000..a6a323213 --- /dev/null +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.ManagedIdentity.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Abstractions; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.AppConfig; +using Microsoft.IdentityModel.Tokens; + +namespace Microsoft.Identity.Web +{ + /// + /// Portion of the TokenAcquisition class that handles logic unique to managed identity. + /// + internal partial class TokenAcquisition + { + private readonly ConcurrentDictionary _managedIdentityApplicationsByClientId = new(); + private readonly SemaphoreSlim _managedIdSemaphore = new(1, 1); + private const string SystemAssignedManagedIdentityKey = "SYSTEM"; + + /// + /// Gets a cached ManagedIdentityApplication object or builds a new one if not found. + /// + /// The configuration options for the app. + /// The configuration specific to managed identity. + /// The application object used to request a token with managed identity. + internal async Task GetOrBuildManagedIdentityApplication( + MergedOptions mergedOptions, + ManagedIdentityOptions managedIdentityOptions) + { + string key = GetCacheKeyForManagedId(managedIdentityOptions); + + // Check if the application is already built, if so return it without grabbing the lock + if (_managedIdentityApplicationsByClientId.TryGetValue(key, out IManagedIdentityApplication? application)) + { + return application; + } + + // Lock the potential write of the dictionary to prevent multiple threads from creating the same application. + await _managedIdSemaphore.WaitAsync(); + try + { + // Check if the application is already built (could happen between previous check and obtaining the key) + if (_managedIdentityApplicationsByClientId.TryGetValue(key, out application)) + { + return application; + } + + // Set managedIdentityId to the correct value for either system or user assigned + ManagedIdentityId managedIdentityId; + if (key == SystemAssignedManagedIdentityKey) + { + managedIdentityId = ManagedIdentityId.SystemAssigned; + } + else + { + managedIdentityId = ManagedIdentityId.WithUserAssignedClientId(key); + } + + // Build the application + application = BuildManagedIdentityApplication( + managedIdentityId, + mergedOptions.ConfidentialClientApplicationOptions.EnablePiiLogging + ); + + // Add the application to the cache + _managedIdentityApplicationsByClientId.TryAdd(key, application); + } + finally + { + // Now that the dictionary is updated, release the semaphore + _managedIdSemaphore.Release(); + } + return application; + } + + /// + /// Creates a managed identity client application. + /// + /// Indicates if system-assigned or user-assigned managed identity is used. + /// Indicates if logging that may contain personally identifiable information is enabled. + /// A managed identity application. + private IManagedIdentityApplication BuildManagedIdentityApplication(ManagedIdentityId managedIdentityId, bool enablePiiLogging) + { + return ManagedIdentityApplicationBuilder + .Create(managedIdentityId) + .WithLogging( + Log, + ConvertMicrosoftExtensionsLogLevelToMsal(_logger), + enablePiiLogging: enablePiiLogging) + .Build(); + } + + /// + /// Gets the key value for the Managed Identity cache, the default key for system-assigned identity is used if there is + /// no clientId for a user-assigned identity specified. The method is internal rather than private for testing purposes. + /// + /// Holds the clientId for managed identity if none is present. + /// A key value for the Managed Identity cache. + internal static string GetCacheKeyForManagedId(ManagedIdentityOptions managedIdOptions) + { + if (managedIdOptions.UserAssignedClientId.IsNullOrEmpty()) + { + return SystemAssignedManagedIdentityKey; + } + else + { + return managedIdOptions.UserAssignedClientId!; + } + } + } +} diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs index 493f44bbf..1ad176692 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs @@ -19,11 +19,11 @@ using Microsoft.Identity.Client; using Microsoft.Identity.Client.Advanced; using Microsoft.Identity.Client.Extensibility; +using Microsoft.Identity.Web.Experimental; using Microsoft.Identity.Web.TokenCacheProviders; using Microsoft.Identity.Web.TokenCacheProviders.InMemory; using Microsoft.IdentityModel.JsonWebTokens; using Microsoft.IdentityModel.Tokens; -using Microsoft.Identity.Web.Experimental; namespace Microsoft.Identity.Web { @@ -47,9 +47,9 @@ class OAuthConstants private readonly object _applicationSyncObj = new(); /// - /// Please call GetOrBuildConfidentialClientApplication instead of accessing this field directly. + /// Please call GetOrBuildConfidentialClientApplication instead of accessing _applicationsByAuthorityClientId directly. /// - private readonly ConcurrentDictionary _applicationsByAuthorityClientId = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _applicationsByAuthorityClientId = new(); private bool _retryClientCertificate; protected readonly IMsalHttpClientFactory _httpClientFactory; protected readonly ILogger _logger; @@ -115,7 +115,7 @@ public async Task AddAccountToCacheFromAuthorizationCodeAsyn _ = Throws.IfNull(authCodeRedemptionParameters.Scopes); MergedOptions mergedOptions = _tokenAcquisitionHost.GetOptions(authCodeRedemptionParameters.AuthenticationScheme, out string effectiveAuthenticationScheme); - IConfidentialClientApplication? application=null; + IConfidentialClientApplication? application = null; try { application = GetOrBuildConfidentialClientApplication(mergedOptions); @@ -321,9 +321,10 @@ private void LogAuthResult(AuthenticationResult? authenticationResult) /// /// Acquires an authentication result from the authority configured in the app, for the confidential client itself (not on behalf of a user) - /// using the client credentials flow. See https://aka.ms/msal-net-client-credentials. + /// using either a client credentials or managed identity flow. See https://aka.ms/msal-net-client-credentials for client credentials or + /// https://aka.ms/Entra/ManagedIdentityOverview for managed identity. /// - /// The scope requested to access a protected API. For this flow (client credentials), the scope + /// The scope requested to access a protected API. For these flows (client credentials or managed identity), the scope /// should be of the form "{ResourceIdUri/.default}" for instance https://management.azure.net/.default or, for Microsoft /// Graph, https://graph.microsoft.com/.default as the requested scopes are defined statically with the application registration /// in the portal, and cannot be overridden in the application, as you can request a token for only one resource at a time (use @@ -358,10 +359,28 @@ public async Task GetAuthenticationResultForAppAsync( throw new ArgumentException(IDWebErrorMessage.ClientCredentialTenantShouldBeTenanted, nameof(tenant)); } + // If using managed identity + if (tokenAcquisitionOptions != null && tokenAcquisitionOptions.ManagedIdentity != null) + { + try + { + IManagedIdentityApplication managedIdApp = await GetOrBuildManagedIdentityApplication( + mergedOptions, + tokenAcquisitionOptions.ManagedIdentity + ); + return await managedIdApp.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + Logger.TokenAcquisitionError(_logger, ex.Message, ex); + throw; + } + } + // Use MSAL to get the right token to call the API var application = GetOrBuildConfidentialClientApplication(mergedOptions); - var builder = application + AcquireTokenForClientParameterBuilder builder = application .AcquireTokenForClient(new[] { scope }.Except(_scopesRequestedByMsal)) .WithSendX5C(mergedOptions.SendX5C); @@ -585,7 +604,6 @@ private bool IsInvalidClientCertificateOrSignedAssertionError(MsalServiceExcepti _applicationsByAuthorityClientId.TryAdd(GetApplicationKey(mergedOptions), application); } } - return application; } @@ -599,7 +617,7 @@ private IConfidentialClientApplication BuildConfidentialClientApplication(Merged try { - var builder = ConfidentialClientApplicationBuilder + ConfidentialClientApplicationBuilder builder = ConfidentialClientApplicationBuilder .CreateWithApplicationOptions(mergedOptions.ConfidentialClientApplicationOptions) .WithHttpClientFactory(_httpClientFactory) .WithLogging( @@ -848,8 +866,10 @@ private static void CheckAssertionsForInjectionAttempt(string assertion, string if (!assertion.IsNullOrEmpty() && assertion.Contains('&')) throw new ArgumentException(IDWebErrorMessage.InvalidAssertion, nameof(assertion)); if (!subAssertion.IsNullOrEmpty() && subAssertion.Contains('&')) throw new ArgumentException(IDWebErrorMessage.InvalidSubAssertion, nameof(subAssertion)); #else - if (!assertion.IsNullOrEmpty() && assertion.Contains('&', StringComparison.InvariantCultureIgnoreCase)) throw new ArgumentException(IDWebErrorMessage.InvalidAssertion, nameof(assertion)); - if (!subAssertion.IsNullOrEmpty() && subAssertion.Contains('&', StringComparison.InvariantCultureIgnoreCase)) throw new ArgumentException(IDWebErrorMessage.InvalidSubAssertion, nameof(subAssertion)); + if (!assertion.IsNullOrEmpty() && assertion.Contains('&', StringComparison.InvariantCultureIgnoreCase)) + throw new ArgumentException(IDWebErrorMessage.InvalidAssertion, nameof(assertion)); + if (!subAssertion.IsNullOrEmpty() && subAssertion.Contains('&', StringComparison.InvariantCultureIgnoreCase)) + throw new ArgumentException(IDWebErrorMessage.InvalidSubAssertion, nameof(subAssertion)); #endif } } diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisitionOptions.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisitionOptions.cs index 60663a1f2..6c158af82 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisitionOptions.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisitionOptions.cs @@ -45,6 +45,7 @@ public class TokenAcquisitionOptions : AcquireTokenOptions PopClaim = PopClaim, CancellationToken = CancellationToken, LongRunningWebApiSessionKey = LongRunningWebApiSessionKey, + ManagedIdentity = ManagedIdentity, }; } } diff --git a/tests/E2E Tests/TokenAcquirerTests/OnlyOnAzureDevopsFactAttribute.cs b/tests/E2E Tests/TokenAcquirerTests/OnlyOnAzureDevopsFactAttribute.cs new file mode 100644 index 000000000..80f98ec0f --- /dev/null +++ b/tests/E2E Tests/TokenAcquirerTests/OnlyOnAzureDevopsFactAttribute.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Xunit; + +namespace TokenAcquirerTests +{ + public sealed class OnlyOnAzureDevopsFactAttribute : FactAttribute + { + public OnlyOnAzureDevopsFactAttribute() + { + if (IgnoreOnAzureDevopsFactAttribute.IsRunningOnAzureDevOps()) + { + return; + } + Skip = "Ignored when not on Azure DevOps"; + } + } +} diff --git a/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs b/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs index 2d7545cfc..f91813011 100644 --- a/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs +++ b/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs @@ -3,9 +3,11 @@ using System; using System.Linq; +using System.Net.Http; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; using Microsoft.Graph; @@ -299,5 +301,46 @@ private static async Task CreateGraphClientAndAssert(TokenAcquirerFactory tokenA Assert.NotNull(result.AccessToken); } } + + public class AcquireTokenManagedIdentity + { + [OnlyOnAzureDevopsFact] + //[Fact] + public async Task AcquireTokenWithManagedIdentity_UserAssigned() + { + // Arrange + const string scope = "https://vault.azure.net/.default"; + const string baseUrl = "https://vault.azure.net"; + const string clientId = "9c5896db-a74a-4b1a-a259-74c5080a3a6a"; + TokenAcquirerFactory tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance(); + _ = tokenAcquirerFactory.Services; + IServiceProvider serviceProvider = tokenAcquirerFactory.Build(); + + // Act: Get the authorization header provider and add the options to tell it to use Managed Identity + IAuthorizationHeaderProvider? api = serviceProvider.GetRequiredService(); + Assert.NotNull(api); + string result = await api.CreateAuthorizationHeaderForAppAsync(scope, GetAuthHeaderOptions_ManagedId(baseUrl, clientId)); + + // Assert: Make sure we got a token + Assert.False(string.IsNullOrEmpty(result)); + } + + private static AuthorizationHeaderProviderOptions GetAuthHeaderOptions_ManagedId(string baseUrl, string? userAssignedClientId=null) + { + ManagedIdentityOptions managedIdentityOptions = new() + { + UserAssignedClientId = userAssignedClientId + }; + AcquireTokenOptions aquireTokenOptions = new() + { + ManagedIdentity = managedIdentityOptions + }; + return new AuthorizationHeaderProviderOptions() + { + BaseUrl = baseUrl, + AcquireTokenOptions = aquireTokenOptions + }; + } + } #endif //FROM_GITHUB_ACTION } diff --git a/tests/Microsoft.Identity.Web.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Web.Test.Common/TestConstants.cs index cd27d0565..e9d11a6ad 100644 --- a/tests/Microsoft.Identity.Web.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Web.Test.Common/TestConstants.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Reflection; namespace Microsoft.Identity.Web.Test.Common { @@ -186,8 +187,14 @@ public static class TestConstants public static readonly string s_todoListServicePath = Path.DirectorySeparatorChar.ToString() + "TodoListService"; - // TokenAcqusitionOptions + // TokenAcqusitionOptions and ManagedIdentityOptions public static Guid s_correlationId = new Guid("6347d33d-941a-4c35-9912-a9cf54fb1b3e"); + public const string UserAssignedManagedIdentityClientId = "3b57c42c-3201-4295-ae27-d6baec5b7027"; + public const string UserAssignedManagedIdentityResourceId = "/subscriptions/c1686c51-b717-4fe0-9af3-24a20a41fb0c/" + + "resourcegroups/MSAL_MSI/providers/Microsoft.ManagedIdentity/userAssignedIdentities/" + "MSAL_MSI_USERID"; + public const BindingFlags StaticPrivateFieldFlags = BindingFlags.GetField | BindingFlags.Static | BindingFlags.NonPublic; + public const BindingFlags InstancePrivateFieldFlags = BindingFlags.GetField | BindingFlags.Instance | BindingFlags.NonPublic; + public const BindingFlags StaticPrivateMethodFlags = BindingFlags.InvokeMethod | BindingFlags.Static | BindingFlags.NonPublic; // AadIssuerValidation public const string AadAuthority = "aadAuthority"; diff --git a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs index c41c95a29..8190f65c4 100644 --- a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs +++ b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs @@ -1,10 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.Collections.Concurrent; using System.Collections.Generic; using System.Globalization; using System.Net.Http; -using Microsoft.AspNetCore.Authentication; +using System.Threading; using Microsoft.AspNetCore.Authentication.JwtBearer; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.Extensions.Caching.Memory; @@ -13,11 +14,11 @@ using Microsoft.Extensions.Options; using Microsoft.Identity.Abstractions; using Microsoft.Identity.Client; -using Microsoft.Identity.Web.Test.Common; using Microsoft.Identity.Web.Test.Common.Mocks; using Microsoft.Identity.Web.Test.Common.TestHelpers; using Microsoft.Identity.Web.TokenCacheProviders.InMemory; using Xunit; +using TC = Microsoft.Identity.Web.Test.Common.TestConstants; namespace Microsoft.Identity.Web.Test { @@ -84,18 +85,18 @@ public void VerifyCorrectSchemeTests(string scheme) } [Theory] - [InlineData(TestConstants.B2CInstance)] - [InlineData(TestConstants.B2CLoginMicrosoft)] - [InlineData(TestConstants.B2CInstance, true)] - [InlineData(TestConstants.B2CLoginMicrosoft, true)] + [InlineData(TC.B2CInstance)] + [InlineData(TC.B2CLoginMicrosoft)] + [InlineData(TC.B2CInstance, true)] + [InlineData(TC.B2CLoginMicrosoft, true)] public void VerifyCorrectAuthorityUsedInTokenAcquisition_B2CAuthorityTests( string authorityInstance, bool withTfp = false) { _microsoftIdentityOptionsMonitor = new TestOptionsMonitor(new MicrosoftIdentityOptions { - SignUpSignInPolicyId = TestConstants.B2CSignUpSignInUserFlow, - Domain = TestConstants.B2CTenant, + SignUpSignInPolicyId = TC.B2CSignUpSignInUserFlow, + Domain = TC.B2CTenant, }); if (withTfp) @@ -103,23 +104,21 @@ public void VerifyCorrectAuthorityUsedInTokenAcquisition_B2CAuthorityTests( _applicationOptionsMonitor = new TestOptionsMonitor(new ConfidentialClientApplicationOptions { Instance = authorityInstance + "/tfp/", - ClientId = TestConstants.ConfidentialClientId, - ClientSecret = TestConstants.ClientSecret, + ClientId = TC.ConfidentialClientId, + ClientSecret = TC.ClientSecret, }); - BuildTheRequiredServices(); } else { _applicationOptionsMonitor = new TestOptionsMonitor(new ConfidentialClientApplicationOptions { Instance = authorityInstance, - ClientId = TestConstants.ConfidentialClientId, - ClientSecret = TestConstants.ClientSecret, + ClientId = TC.ConfidentialClientId, + ClientSecret = TC.ClientSecret, }); - - BuildTheRequiredServices(); } + BuildTheRequiredServices(); MergedOptions mergedOptions = _provider.GetRequiredService().Get(OpenIdConnectDefaults.AuthenticationScheme); MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityOptions(_microsoftIdentityOptionsMonitor.Get(OpenIdConnectDefaults.AuthenticationScheme), mergedOptions); MergedOptions.UpdateMergedOptionsFromConfidentialClientApplicationOptions(_applicationOptionsMonitor.Get(OpenIdConnectDefaults.AuthenticationScheme), mergedOptions); @@ -132,8 +131,8 @@ public void VerifyCorrectAuthorityUsedInTokenAcquisition_B2CAuthorityTests( CultureInfo.InvariantCulture, "{0}/tfp/{1}/{2}/", authorityInstance, - TestConstants.B2CTenant, - TestConstants.B2CSignUpSignInUserFlow); + TC.B2CTenant, + TC.B2CSignUpSignInUserFlow); Assert.Equal(expectedAuthority, app.Authority); } @@ -146,16 +145,16 @@ public void VerifyCorrectRedirectUriAsync( { _microsoftIdentityOptionsMonitor = new TestOptionsMonitor(new MicrosoftIdentityOptions { - Authority = TestConstants.AuthorityCommonTenant, - ClientId = TestConstants.ConfidentialClientId, + Authority = TC.AuthorityCommonTenant, + ClientId = TC.ConfidentialClientId, CallbackPath = string.Empty, }); _applicationOptionsMonitor = new TestOptionsMonitor(new ConfidentialClientApplicationOptions { - Instance = TestConstants.AadInstance, + Instance = TC.AadInstance, RedirectUri = redirectUri, - ClientSecret = TestConstants.ClientSecret, + ClientSecret = TC.ClientSecret, }); BuildTheRequiredServices(); @@ -185,15 +184,15 @@ public void VerifyCorrectBooleansAsync( { _microsoftIdentityOptionsMonitor = new TestOptionsMonitor(new MicrosoftIdentityOptions { - Authority = TestConstants.AuthorityCommonTenant, - ClientId = TestConstants.ConfidentialClientId, + Authority = TC.AuthorityCommonTenant, + ClientId = TC.ConfidentialClientId, SendX5C = sendx5c, }); _applicationOptionsMonitor = new TestOptionsMonitor(new ConfidentialClientApplicationOptions { - Instance = TestConstants.AadInstance, - ClientSecret = TestConstants.ClientSecret, + Instance = TC.AadInstance, + ClientSecret = TC.ClientSecret, }); BuildTheRequiredServices(); @@ -219,18 +218,18 @@ public void TestParseAuthorityIfNecessary() // Arrange MergedOptions mergedOptions = new() { - Authority = TestConstants.AuthorityWithTenantSpecified, - TenantId = TestConstants.TenantIdAsGuid, - Instance = TestConstants.AadInstance + Authority = TC.AuthorityWithTenantSpecified, + TenantId = TC.TenantIdAsGuid, + Instance = TC.AadInstance }; // Act MergedOptions.ParseAuthorityIfNecessary(mergedOptions); // Assert - Assert.Equal(TestConstants.AuthorityWithTenantSpecified, mergedOptions.Authority); - Assert.Equal(TestConstants.AadInstance, mergedOptions.Instance); - Assert.Equal(TestConstants.TenantIdAsGuid, mergedOptions.TenantId); + Assert.Equal(TC.AuthorityWithTenantSpecified, mergedOptions.Authority); + Assert.Equal(TC.AadInstance, mergedOptions.Instance); + Assert.Equal(TC.TenantIdAsGuid, mergedOptions.TenantId); } [Fact] @@ -309,5 +308,109 @@ public void MergeExtraQueryParameters_MergedOptionsNull_Test() // Assert Assert.Null(mergedDict); } + + [Theory] + [InlineData("https://localhost:1234")] + [InlineData("")] + [InlineData(null)] + public void ManagedIdCacheKey_Test(string? clientId) + { + // Arrange + string defaultKey = "SYSTEM"; + ManagedIdentityOptions managedIdentityOptions = new() + { + UserAssignedClientId = clientId + }; + + // Act + string key = TokenAcquisition.GetCacheKeyForManagedId(managedIdentityOptions); + + // Assert + if (string.IsNullOrEmpty(clientId)) + { + Assert.Equal(defaultKey, key); + } + else + { + Assert.Equal(clientId, key); + } + } + + [Theory] + [InlineData("https://localhost:1234")] + [InlineData("")] + [InlineData(null)] + public async void GetOrBuildManagedIdentity_TestAsync(string? clientId) + { + // Arrange + ManagedIdentityOptions managedIdentityOptions = new() + { + UserAssignedClientId = clientId + }; + MergedOptions mergedOptions = new(); + BuildTheRequiredServices(); + InitializeTokenAcquisitionObjects(); + + // Act + var app1 = + await _tokenAcquisition.GetOrBuildManagedIdentityApplication(mergedOptions, managedIdentityOptions); + var app2 = + await _tokenAcquisition.GetOrBuildManagedIdentityApplication(mergedOptions, managedIdentityOptions); + + // Assert + Assert.Same(app1, app2); + } + + [Theory] + [InlineData("https://localhost:1234")] + [InlineData(null)] + public async void GetOrBuildManagedIdentity_TestConcurrencyAsync(string? clientId) + { + // Arrange + ThreadPool.GetMaxThreads(out int maxThreads, out int _); + ConcurrentBag appsBag = []; + CountdownEvent taskStartGate = new(maxThreads); + CountdownEvent threadsDone = new(maxThreads); + ManagedIdentityOptions managedIdentityOptions = new() + { + UserAssignedClientId = clientId + }; + MergedOptions mergedOptions = new(); + BuildTheRequiredServices(); + InitializeTokenAcquisitionObjects(); + + // Act + for (int i = 0; i < maxThreads; i++) + { + Thread thread = new(async () => + { + try + { + // Signal that the thread is ready to start and wait for the other threads to be ready. + taskStartGate.Signal(); + taskStartGate.Wait(); + + // Add the application to the bag + appsBag.Add(await _tokenAcquisition.GetOrBuildManagedIdentityApplication(mergedOptions, managedIdentityOptions)); + } + finally + { + // No matter what happens, signal that the thread is done so the test doesn't get stuck. + threadsDone.Signal(); + } + } + ); + thread.Start(); + } + threadsDone.Wait(); + var testApp = await _tokenAcquisition.GetOrBuildManagedIdentityApplication(mergedOptions, managedIdentityOptions); + + // Assert + Assert.True(appsBag.Count == maxThreads, "Not all threads put objects in the concurrent bag"); + foreach (IManagedIdentityApplication app in appsBag) + { + Assert.Same(testApp, app); + } + } } }