diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/DefaultTokenAcquirerFactoryImplementation.cs b/src/Microsoft.Identity.Web.TokenAcquisition/DefaultTokenAcquirerFactoryImplementation.cs index 35c870d09..c4c76af60 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/DefaultTokenAcquirerFactoryImplementation.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/DefaultTokenAcquirerFactoryImplementation.cs @@ -2,9 +2,9 @@ // Licensed under the MIT License. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; using Microsoft.Identity.Abstractions; namespace Microsoft.Identity.Web @@ -17,7 +17,7 @@ public DefaultTokenAcquirerFactoryImplementation(IServiceProvider serviceProvide } private IServiceProvider ServiceProvider { get; set; } - readonly Dictionary _authSchemes = new Dictionary(); + readonly ConcurrentDictionary _authSchemes = new(); /// public ITokenAcquirer GetTokenAcquirer( @@ -26,31 +26,31 @@ public ITokenAcquirer GetTokenAcquirer( IEnumerable clientCredentials, string? region = null) { - CheckServiceProviderNotNull(); + string key = GetKey(authority, clientId, region); - ITokenAcquirer? tokenAcquirer; - // Compute the key - string key = GetKey(authority, clientId); - if (!_authSchemes.TryGetValue(key, out tokenAcquirer)) - { - MicrosoftIdentityApplicationOptions MicrosoftIdentityApplicationOptions = new MicrosoftIdentityApplicationOptions - { - ClientId = clientId, - Authority = authority, - ClientCredentials = clientCredentials, - SendX5C = true - }; - if (region != null) + // GetOrAdd ONLY synchronizes the outcome. So, the factory might still be invoked multiple times. + // Therefore, all side-effects within this block must remain idempotent. + return _authSchemes.GetOrAdd(key, (key) => { - MicrosoftIdentityApplicationOptions.AzureRegion = region; - } + MicrosoftIdentityApplicationOptions MicrosoftIdentityApplicationOptions = new() + { + ClientId = clientId, + Authority = authority, + ClientCredentials = clientCredentials, + SendX5C = true + }; - var optionsMonitor = ServiceProvider.GetRequiredService(); - var mergedOptions = optionsMonitor.Get(key); - MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions); - tokenAcquirer = GetTokenAcquirer(key); - } - return tokenAcquirer; + if (region != null) + { + MicrosoftIdentityApplicationOptions.AzureRegion = region; + } + + IMergedOptionsStore optionsMonitor = ServiceProvider.GetRequiredService(); + MergedOptions mergedOptions = optionsMonitor.Get(key); + MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions); + + return MakeTokenAcquirer(key); + }); } /// @@ -58,8 +58,6 @@ public ITokenAcquirer GetTokenAcquirer(IdentityApplicationOptions IdentityApplic { _ = Throws.IfNull(IdentityApplicationOptions); - CheckServiceProviderNotNull(); - // Compute the Azure region if the option is a MicrosoftIdentityApplicationOptions. MicrosoftIdentityApplicationOptions? MicrosoftIdentityApplicationOptions = IdentityApplicationOptions as MicrosoftIdentityApplicationOptions; if (MicrosoftIdentityApplicationOptions == null) @@ -77,33 +75,36 @@ public ITokenAcquirer GetTokenAcquirer(IdentityApplicationOptions IdentityApplic }; } - // Compute the key - ITokenAcquirer? tokenAcquirer; - string key = GetKey(IdentityApplicationOptions.Authority, IdentityApplicationOptions.ClientId); - if (!_authSchemes.TryGetValue(key, out tokenAcquirer)) + string key = GetKey(IdentityApplicationOptions.Authority, IdentityApplicationOptions.ClientId, MicrosoftIdentityApplicationOptions.AzureRegion); + + return _authSchemes.GetOrAdd(key, (key) => { - var optionsMonitor = ServiceProvider!.GetRequiredService(); - var mergedOptions = optionsMonitor.Get(key); + IMergedOptionsStore optionsMonitor = ServiceProvider!.GetRequiredService(); + MergedOptions mergedOptions = optionsMonitor.Get(key); + + MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions); - tokenAcquirer = GetTokenAcquirer(key); - } - return tokenAcquirer; + return MakeTokenAcquirer(key); + }); } /// public ITokenAcquirer GetTokenAcquirer(string authenticationScheme = "") + { + return _authSchemes.GetOrAdd(authenticationScheme, (key) => + { + return MakeTokenAcquirer(authenticationScheme); + }); + } + + private ITokenAcquirer MakeTokenAcquirer(string authenticationScheme = "") { CheckServiceProviderNotNull(); - ITokenAcquirer? acquirer; - if (!_authSchemes.TryGetValue(authenticationScheme, out acquirer)) - { - var tokenAcquisition = ServiceProvider!.GetRequiredService(); - acquirer = new TokenAcquirer(tokenAcquisition, authenticationScheme); - _authSchemes.Add(authenticationScheme, acquirer); - } - return acquirer; + ITokenAcquisition tokenAcquisition = ServiceProvider!.GetRequiredService(); + return new TokenAcquirer(tokenAcquisition, authenticationScheme); } + private void CheckServiceProviderNotNull() { if (ServiceProvider == null) @@ -112,10 +113,9 @@ private void CheckServiceProviderNotNull() } } - - private static string GetKey(string? authority, string? clientId) + public static string GetKey(string? authority, string? clientId, string? region) { - return $"{authority}{clientId}"; + return $"{authority}{clientId}{region}"; } } } diff --git a/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs b/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs index 15f66e00e..e10c2d805 100644 --- a/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs +++ b/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs @@ -2,10 +2,12 @@ // Licensed under the MIT License. using System; +using System.Collections.Concurrent; using System.Linq; using System.Net.Http; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; @@ -46,6 +48,81 @@ public void TokenAcquirerFactoryDoesNotUseAspNetCoreHost() Assert.Equal("Microsoft.Identity.Web.Hosts.DefaultTokenAcquisitionHost", service.GetType().FullName); } + [Fact] + public void DefaultTokenAcquirer_GetKeyHandlesNulls() + { + var res = DefaultTokenAcquirerFactoryImplementation.GetKey("1", "2", "3"); + Assert.Equal("123", res); + + var no_region = DefaultTokenAcquirerFactoryImplementation.GetKey("1", "2", null); + Assert.Equal("12", no_region); + } + + [Fact] + public void AcquireToken_WithMultipleRegions() + { + var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance(); + _ = tokenAcquirerFactory.Build(); + + ITokenAcquirer tokenAcquirerA = tokenAcquirerFactory.GetTokenAcquirer( + authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com", + clientId: "6af093f3-b445-4b7a-beae-046864468ad6", + clientCredentials: s_clientCredentials, + "US"); + + ITokenAcquirer tokenAcquirerB = tokenAcquirerFactory.GetTokenAcquirer( + authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com", + clientId: "6af093f3-b445-4b7a-beae-046864468ad6", + clientCredentials: s_clientCredentials, + "US"); + + ITokenAcquirer tokenAcquirerC = tokenAcquirerFactory.GetTokenAcquirer( + authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com", + clientId: "6af093f3-b445-4b7a-beae-046864468ad6", + clientCredentials: s_clientCredentials, + "EU"); + + Assert.Equal(tokenAcquirerA, tokenAcquirerB); + Assert.NotEqual(tokenAcquirerA, tokenAcquirerC); + } + + [Fact] + public void AcquireToken_SafeFromMultipleThreads() + { + var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance(); + _ = tokenAcquirerFactory.Build(); + + var count = new ConcurrentDictionary(); + + var action = () => + { + for (int i = 0; i < 1000; i++) + { + ITokenAcquirer res = tokenAcquirerFactory.GetTokenAcquirer( + authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com", + clientId: "6af093f3-b445-4b7a-beae-046864468ad6", + clientCredentials: s_clientCredentials, + "" + (i%11)); + + count.TryAdd(res, true); + } + }; + + Thread[] threads = new Thread[16]; + for (int i = 0; i < 16; i++) + { + threads[i] = new Thread(() => action()); + threads[i].Start(); + } + + foreach (Thread thread in threads) + { + thread.Join(); + } + + Assert.Equal(11, count.Count); + } + [IgnoreOnAzureDevopsFact] //[Theory] //[InlineData(false)]