Skip to content

Commit

Permalink
DefaultTokenAcquirerFactoryImplementation use ConcurrentDictionary (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sammurrayms authored Apr 16, 2024
1 parent 0c3b9d5 commit 3083a3c
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Microsoft.Identity.Abstractions;

namespace Microsoft.Identity.Web
Expand All @@ -17,7 +17,7 @@ public DefaultTokenAcquirerFactoryImplementation(IServiceProvider serviceProvide
}
private IServiceProvider ServiceProvider { get; set; }

readonly Dictionary<string, ITokenAcquirer> _authSchemes = new Dictionary<string, ITokenAcquirer>();
readonly ConcurrentDictionary<string, ITokenAcquirer> _authSchemes = new();

/// <inheritdoc/>
public ITokenAcquirer GetTokenAcquirer(
Expand All @@ -26,40 +26,38 @@ public ITokenAcquirer GetTokenAcquirer(
IEnumerable<CredentialDescription> clientCredentials,
string? region = null)
{
CheckServiceProviderNotNull();
string key = GetKey(authority, clientId, region);

ITokenAcquirer? tokenAcquirer;
// Compute the key
string key = GetKey(authority, clientId);
if (!_authSchemes.TryGetValue(key, out tokenAcquirer))
{
MicrosoftIdentityApplicationOptions MicrosoftIdentityApplicationOptions = new MicrosoftIdentityApplicationOptions
{
ClientId = clientId,
Authority = authority,
ClientCredentials = clientCredentials,
SendX5C = true
};
if (region != null)
// GetOrAdd ONLY synchronizes the outcome. So, the factory might still be invoked multiple times.
// Therefore, all side-effects within this block must remain idempotent.
return _authSchemes.GetOrAdd(key, (key) =>
{
MicrosoftIdentityApplicationOptions.AzureRegion = region;
}
MicrosoftIdentityApplicationOptions MicrosoftIdentityApplicationOptions = new()
{
ClientId = clientId,
Authority = authority,
ClientCredentials = clientCredentials,
SendX5C = true
};

var optionsMonitor = ServiceProvider.GetRequiredService<IMergedOptionsStore>();
var mergedOptions = optionsMonitor.Get(key);
MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions);
tokenAcquirer = GetTokenAcquirer(key);
}
return tokenAcquirer;
if (region != null)
{
MicrosoftIdentityApplicationOptions.AzureRegion = region;
}

IMergedOptionsStore optionsMonitor = ServiceProvider.GetRequiredService<IMergedOptionsStore>();
MergedOptions mergedOptions = optionsMonitor.Get(key);
MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions);

return MakeTokenAcquirer(key);
});
}

/// <inheritdoc/>
public ITokenAcquirer GetTokenAcquirer(IdentityApplicationOptions IdentityApplicationOptions)
{
_ = Throws.IfNull(IdentityApplicationOptions);

CheckServiceProviderNotNull();

// Compute the Azure region if the option is a MicrosoftIdentityApplicationOptions.
MicrosoftIdentityApplicationOptions? MicrosoftIdentityApplicationOptions = IdentityApplicationOptions as MicrosoftIdentityApplicationOptions;
if (MicrosoftIdentityApplicationOptions == null)
Expand All @@ -77,33 +75,36 @@ public ITokenAcquirer GetTokenAcquirer(IdentityApplicationOptions IdentityApplic
};
}

// Compute the key
ITokenAcquirer? tokenAcquirer;
string key = GetKey(IdentityApplicationOptions.Authority, IdentityApplicationOptions.ClientId);
if (!_authSchemes.TryGetValue(key, out tokenAcquirer))
string key = GetKey(IdentityApplicationOptions.Authority, IdentityApplicationOptions.ClientId, MicrosoftIdentityApplicationOptions.AzureRegion);

return _authSchemes.GetOrAdd(key, (key) =>
{
var optionsMonitor = ServiceProvider!.GetRequiredService<IMergedOptionsStore>();
var mergedOptions = optionsMonitor.Get(key);
IMergedOptionsStore optionsMonitor = ServiceProvider!.GetRequiredService<IMergedOptionsStore>();
MergedOptions mergedOptions = optionsMonitor.Get(key);


MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions);
tokenAcquirer = GetTokenAcquirer(key);
}
return tokenAcquirer;
return MakeTokenAcquirer(key);
});
}

/// <inheritdoc/>
public ITokenAcquirer GetTokenAcquirer(string authenticationScheme = "")
{
return _authSchemes.GetOrAdd(authenticationScheme, (key) =>
{
return MakeTokenAcquirer(authenticationScheme);
});
}

private ITokenAcquirer MakeTokenAcquirer(string authenticationScheme = "")
{
CheckServiceProviderNotNull();

ITokenAcquirer? acquirer;
if (!_authSchemes.TryGetValue(authenticationScheme, out acquirer))
{
var tokenAcquisition = ServiceProvider!.GetRequiredService<ITokenAcquisition>();
acquirer = new TokenAcquirer(tokenAcquisition, authenticationScheme);
_authSchemes.Add(authenticationScheme, acquirer);
}
return acquirer;
ITokenAcquisition tokenAcquisition = ServiceProvider!.GetRequiredService<ITokenAcquisition>();
return new TokenAcquirer(tokenAcquisition, authenticationScheme);
}

private void CheckServiceProviderNotNull()
{
if (ServiceProvider == null)
Expand All @@ -112,10 +113,9 @@ private void CheckServiceProviderNotNull()
}
}


private static string GetKey(string? authority, string? clientId)
public static string GetKey(string? authority, string? clientId, string? region)
{
return $"{authority}{clientId}";
return $"{authority}{clientId}{region}";
}
}
}
77 changes: 77 additions & 0 deletions tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -46,6 +48,81 @@ public void TokenAcquirerFactoryDoesNotUseAspNetCoreHost()
Assert.Equal("Microsoft.Identity.Web.Hosts.DefaultTokenAcquisitionHost", service.GetType().FullName);
}

[Fact]
public void DefaultTokenAcquirer_GetKeyHandlesNulls()
{
var res = DefaultTokenAcquirerFactoryImplementation.GetKey("1", "2", "3");
Assert.Equal("123", res);

var no_region = DefaultTokenAcquirerFactoryImplementation.GetKey("1", "2", null);
Assert.Equal("12", no_region);
}

[Fact]
public void AcquireToken_WithMultipleRegions()
{
var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
_ = tokenAcquirerFactory.Build();

ITokenAcquirer tokenAcquirerA = tokenAcquirerFactory.GetTokenAcquirer(
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
clientCredentials: s_clientCredentials,
"US");

ITokenAcquirer tokenAcquirerB = tokenAcquirerFactory.GetTokenAcquirer(
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
clientCredentials: s_clientCredentials,
"US");

ITokenAcquirer tokenAcquirerC = tokenAcquirerFactory.GetTokenAcquirer(
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
clientCredentials: s_clientCredentials,
"EU");

Assert.Equal(tokenAcquirerA, tokenAcquirerB);
Assert.NotEqual(tokenAcquirerA, tokenAcquirerC);
}

[Fact]
public void AcquireToken_SafeFromMultipleThreads()
{
var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
_ = tokenAcquirerFactory.Build();

var count = new ConcurrentDictionary<ITokenAcquirer, bool>();

var action = () =>
{
for (int i = 0; i < 1000; i++)
{
ITokenAcquirer res = tokenAcquirerFactory.GetTokenAcquirer(
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
clientCredentials: s_clientCredentials,
"" + (i%11));

count.TryAdd(res, true);
}
};

Thread[] threads = new Thread[16];
for (int i = 0; i < 16; i++)
{
threads[i] = new Thread(() => action());
threads[i].Start();
}

foreach (Thread thread in threads)
{
thread.Join();
}

Assert.Equal(11, count.Count);
}

[IgnoreOnAzureDevopsFact]
//[Theory]
//[InlineData(false)]
Expand Down

0 comments on commit 3083a3c

Please sign in to comment.