diff --git a/Directory.Build.props b/Directory.Build.props
index 152e03d2f..d5d5e5653 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -80,7 +80,7 @@
4.34.0
4.50.0-preview
3.1.3
- 4.0.0
+ 4.1.0
4.7.2
diff --git a/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs b/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs
index e94b14ec4..118beb256 100644
--- a/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs
+++ b/src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
+using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;
@@ -15,6 +16,7 @@ namespace Microsoft.Identity.Web
public class DefaultCredentialsLoader : ICredentialsLoader
{
ILogger? _logger;
+ private readonly ConcurrentDictionary _loadingSemaphores = new ConcurrentDictionary();
///
/// Constructor with a logger
@@ -56,9 +58,24 @@ public async Task LoadCredentialsIfNeededAsync(CredentialDescription credentialD
if (credentialDescription.CachedValue == null)
{
- if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader))
+ // Get or create a semaphore for this credentialDescription
+ var semaphore = _loadingSemaphores.GetOrAdd(credentialDescription.Id, (v) => new SemaphoreSlim(1));
+
+ // Wait to acquire the semaphore
+ await semaphore.WaitAsync();
+
+ try
+ {
+ if (credentialDescription.CachedValue == null)
+ {
+ if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader))
+ await loader.LoadIfNeededAsync(credentialDescription, parameters);
+ }
+ }
+ finally
{
- await loader.LoadIfNeededAsync(credentialDescription, parameters);
+ // Release the semaphore
+ semaphore.Release();
}
}
}
diff --git a/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs b/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs
index 0b9b29995..5ab4c4a2a 100644
--- a/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs
+++ b/tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs
@@ -14,6 +14,7 @@
using Microsoft.Identity.Web.TokenCacheProviders.InMemory;
using Microsoft.IdentityModel.Tokens;
using Xunit;
+using TaskStatus = System.Threading.Tasks.TaskStatus;
namespace TokenAcquirerTests
{
@@ -126,6 +127,46 @@ public async Task AcquireToken_WithFactoryAndAuthorityClientIdCert_ClientCredent
Assert.False(string.IsNullOrEmpty(result.AccessToken));
}
+ [IgnoreOnAzureDevopsFact]
+ //[Fact]
+ public async Task LoadCredentialsIfNeededAsync_MultipleThreads_WaitsForSemaphore()
+ {
+ TokenAcquirerFactory tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
+ IServiceCollection services = tokenAcquirerFactory.Services;
+
+ services.Configure(s_optionName, option =>
+ {
+ option.Instance = "https://login.microsoftonline.com/";
+ option.TenantId = "msidentitysamplestesting.onmicrosoft.com";
+ option.ClientId = "6af093f3-b445-4b7a-beae-046864468ad6";
+ option.ClientCredentials = s_clientCredentials;
+ });
+
+ services.AddInMemoryTokenCaches();
+ var serviceProvider = tokenAcquirerFactory.Build();
+ var options = serviceProvider.GetRequiredService>().Get(s_optionName);
+ var credentialsLoader = serviceProvider.GetRequiredService();
+
+ var task1 = Task.Run(async () =>
+ {
+ await credentialsLoader.LoadCredentialsIfNeededAsync(options.ClientCredentials!.First());
+ });
+
+ var task2 = Task.Run(async () =>
+ {
+ await credentialsLoader.LoadCredentialsIfNeededAsync(options.ClientCredentials!.First());
+ });
+
+ // Run task1 and task2 concurrently
+ await Task.WhenAll(task1, task2);
+
+ var cert = options.ClientCredentials!.First().Certificate;
+
+ Assert.NotNull(cert);
+ Assert.Equal(TaskStatus.RanToCompletion, task1.Status);
+ Assert.Equal(TaskStatus.RanToCompletion, task2.Status);
+ }
+
[IgnoreOnAzureDevopsFact]
//[Fact]
public async Task AcquireTokenWithPop_ClientCredentialsAsync()