diff --git a/Directory.Build.props b/Directory.Build.props index 152e03d2f..d5d5e5653 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -80,7 +80,7 @@ 4.34.0 4.50.0-preview 3.1.3 - 4.0.0 + 4.1.0 4.7.2 diff --git a/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs b/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs index e94b14ec4..118beb256 100644 --- a/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs +++ b/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; +using System.Collections.Concurrent; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Identity.Abstractions; @@ -15,6 +16,7 @@ namespace Microsoft.Identity.Web public class DefaultCredentialsLoader : ICredentialsLoader { ILogger? _logger; + private readonly ConcurrentDictionary _loadingSemaphores = new ConcurrentDictionary(); /// /// Constructor with a logger @@ -56,9 +58,24 @@ public async Task LoadCredentialsIfNeededAsync(CredentialDescription credentialD if (credentialDescription.CachedValue == null) { - if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader)) + // Get or create a semaphore for this credentialDescription + var semaphore = _loadingSemaphores.GetOrAdd(credentialDescription.Id, (v) => new SemaphoreSlim(1)); + + // Wait to acquire the semaphore + await semaphore.WaitAsync(); + + try + { + if (credentialDescription.CachedValue == null) + { + if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader)) + await loader.LoadIfNeededAsync(credentialDescription, parameters); + } + } + finally { - await loader.LoadIfNeededAsync(credentialDescription, parameters); + // Release the semaphore + semaphore.Release(); } } } diff --git a/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs b/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs index 0b9b29995..5ab4c4a2a 100644 --- a/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs +++ b/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs @@ -14,6 +14,7 @@ using Microsoft.Identity.Web.TokenCacheProviders.InMemory; using Microsoft.IdentityModel.Tokens; using Xunit; +using TaskStatus = System.Threading.Tasks.TaskStatus; namespace TokenAcquirerTests { @@ -126,6 +127,46 @@ public async Task AcquireToken_WithFactoryAndAuthorityClientIdCert_ClientCredent Assert.False(string.IsNullOrEmpty(result.AccessToken)); } + [IgnoreOnAzureDevopsFact] + //[Fact] + public async Task LoadCredentialsIfNeededAsync_MultipleThreads_WaitsForSemaphore() + { + TokenAcquirerFactory tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance(); + IServiceCollection services = tokenAcquirerFactory.Services; + + services.Configure(s_optionName, option => + { + option.Instance = "https://login.microsoftonline.com/"; + option.TenantId = "msidentitysamplestesting.onmicrosoft.com"; + option.ClientId = "6af093f3-b445-4b7a-beae-046864468ad6"; + option.ClientCredentials = s_clientCredentials; + }); + + services.AddInMemoryTokenCaches(); + var serviceProvider = tokenAcquirerFactory.Build(); + var options = serviceProvider.GetRequiredService>().Get(s_optionName); + var credentialsLoader = serviceProvider.GetRequiredService(); + + var task1 = Task.Run(async () => + { + await credentialsLoader.LoadCredentialsIfNeededAsync(options.ClientCredentials!.First()); + }); + + var task2 = Task.Run(async () => + { + await credentialsLoader.LoadCredentialsIfNeededAsync(options.ClientCredentials!.First()); + }); + + // Run task1 and task2 concurrently + await Task.WhenAll(task1, task2); + + var cert = options.ClientCredentials!.First().Certificate; + + Assert.NotNull(cert); + Assert.Equal(TaskStatus.RanToCompletion, task1.Status); + Assert.Equal(TaskStatus.RanToCompletion, task2.Status); + } + [IgnoreOnAzureDevopsFact] //[Fact] public async Task AcquireTokenWithPop_ClientCredentialsAsync()