Skip to content

Commit deee462

Browse files
authored
Fix issues in GetKeyedService() and GetKeyedServices() with AnyKey (#113137)
1 parent f0f1457 commit deee462

File tree

5 files changed

+230
-30
lines changed

5 files changed

+230
-30
lines changed

src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs

+153-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,84 @@ public abstract partial class KeyedDependencyInjectionSpecificationTests
1313
{
1414
protected abstract IServiceProvider CreateServiceProvider(IServiceCollection collection);
1515

16+
[Fact]
17+
public void CombinationalRegistration()
18+
{
19+
Service service1 = new();
20+
Service service2 = new();
21+
Service keyedService1 = new();
22+
Service keyedService2 = new();
23+
Service anykeyService1 = new();
24+
Service anykeyService2 = new();
25+
Service nullkeyService1 = new();
26+
Service nullkeyService2 = new();
27+
28+
ServiceCollection serviceCollection = new();
29+
serviceCollection.AddSingleton<IService>(service1);
30+
serviceCollection.AddSingleton<IService>(service2);
31+
serviceCollection.AddKeyedSingleton<IService>(null, nullkeyService1);
32+
serviceCollection.AddKeyedSingleton<IService>(null, nullkeyService2);
33+
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, anykeyService1);
34+
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, anykeyService2);
35+
serviceCollection.AddKeyedSingleton<IService>("keyedService", keyedService1);
36+
serviceCollection.AddKeyedSingleton<IService>("keyedService", keyedService2);
37+
38+
IServiceProvider provider = CreateServiceProvider(serviceCollection);
39+
40+
/*
41+
* Table for what results are included:
42+
*
43+
* Query | Keyed? | Unkeyed? | AnyKey? | null key?
44+
* -------------------------------------------------------------------
45+
* GetServices(Type) | no | yes | no | yes
46+
* GetService(Type) | no | yes | no | yes
47+
*
48+
* GetKeyedServices(null) | no | yes | no | yes
49+
* GetKeyedService(null) | no | yes | no | yes
50+
*
51+
* GetKeyedServices(AnyKey) | yes | no | no | no
52+
* GetKeyedService(AnyKey) | throw | throw | throw | throw
53+
*
54+
* GetKeyedServices(key) | yes | no | no | no
55+
* GetKeyedService(key) | yes | no | yes | no
56+
*
57+
* Summary:
58+
* - A null key is the same as unkeyed. This allows the KeyServices APIs to support both keyed and unkeyed.
59+
* - AnyKey is a special case of Keyed.
60+
* - AnyKey registrations are not returned with GetKeyedServices(AnyKey) and GetKeyedService(AnyKey) always throws.
61+
* - For IEnumerable, the ordering of the results are in registration order.
62+
* - For a singleton resolve, the last match wins.
63+
*/
64+
65+
// Unkeyed (which is really keyed by Type).
66+
Assert.Equal(
67+
new[] { service1, service2, nullkeyService1, nullkeyService2 },
68+
provider.GetServices<IService>());
69+
70+
Assert.Equal(nullkeyService2, provider.GetService<IService>());
71+
72+
// Null key.
73+
Assert.Equal(
74+
new[] { service1, service2, nullkeyService1, nullkeyService2 },
75+
provider.GetKeyedServices<IService>(null));
76+
77+
Assert.Equal(nullkeyService2, provider.GetKeyedService<IService>(null));
78+
79+
// AnyKey.
80+
Assert.Equal(
81+
new[] { keyedService1, keyedService2 },
82+
provider.GetKeyedServices<IService>(KeyedService.AnyKey));
83+
84+
Assert.Throws<InvalidOperationException>(() => provider.GetKeyedService<IService>(KeyedService.AnyKey));
85+
86+
// Keyed.
87+
Assert.Equal(
88+
new[] { keyedService1, keyedService2 },
89+
provider.GetKeyedServices<IService>("keyedService"));
90+
91+
Assert.Equal(keyedService2, provider.GetKeyedService<IService>("keyedService"));
92+
}
93+
1694
[Fact]
1795
public void ResolveKeyedService()
1896
{
@@ -158,10 +236,75 @@ public void ResolveKeyedServicesAnyKeyWithAnyKeyRegistration()
158236
_ = provider.GetKeyedService<IService>("something-else");
159237
_ = provider.GetKeyedService<IService>("something-else-again");
160238

161-
// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey
239+
// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey,
240+
// nor the KeyedService.AnyKey registration
162241
var allServices = provider.GetKeyedServices<IService>(KeyedService.AnyKey).ToList();
163-
Assert.Equal(5, allServices.Count);
164-
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices.Skip(1));
242+
Assert.Equal(4, allServices.Count);
243+
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices);
244+
245+
var someKeyedServices = provider.GetKeyedServices<IService>("service").ToList();
246+
Assert.Equal(new[] { service2, service3, service4 }, someKeyedServices);
247+
248+
var unkeyedServices = provider.GetServices<IService>().ToList();
249+
Assert.Equal(new[] { service5, service6 }, unkeyedServices);
250+
}
251+
252+
[Fact]
253+
public void ResolveKeyedServicesAnyKeyConsistency()
254+
{
255+
var serviceCollection = new ServiceCollection();
256+
var service = new Service("first-service");
257+
serviceCollection.AddKeyedSingleton<IService>("first-service", service);
258+
259+
var provider1 = CreateServiceProvider(serviceCollection);
260+
Assert.Throws<InvalidOperationException>(() => provider1.GetKeyedService<IService>(KeyedService.AnyKey));
261+
// We don't return KeyedService.AnyKey registration when listing services
262+
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));
263+
264+
var provider2 = CreateServiceProvider(serviceCollection);
265+
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
266+
Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
267+
}
268+
269+
[Fact]
270+
public void ResolveKeyedServicesAnyKeyConsistencyWithAnyKeyRegistration()
271+
{
272+
var serviceCollection = new ServiceCollection();
273+
var service = new Service("first-service");
274+
var any = new Service("any");
275+
serviceCollection.AddKeyedSingleton<IService>("first-service", service);
276+
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, (sp, key) => any);
277+
278+
var provider1 = CreateServiceProvider(serviceCollection);
279+
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));
280+
281+
// Check twice in different order to check caching
282+
var provider2 = CreateServiceProvider(serviceCollection);
283+
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
284+
Assert.Same(any, provider2.GetKeyedService<IService>(new object()));
285+
286+
Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
287+
}
288+
289+
[Fact]
290+
public void ResolveKeyedServicesAnyKeyOrdering()
291+
{
292+
var serviceCollection = new ServiceCollection();
293+
var service1 = new Service();
294+
var service2 = new Service();
295+
var service3 = new Service();
296+
297+
serviceCollection.AddKeyedSingleton<IService>("A-service", service1);
298+
serviceCollection.AddKeyedSingleton<IService>("B-service", service2);
299+
serviceCollection.AddKeyedSingleton<IService>("A-service", service3);
300+
301+
var provider = CreateServiceProvider(serviceCollection);
302+
303+
// The order should be in registration order, and not grouped by key for example.
304+
// Although this isn't necessarily a requirement, it is the current behavior.
305+
Assert.Equal(
306+
new[] { service1, service2, service3 },
307+
provider.GetKeyedServices<IService>(KeyedService.AnyKey));
165308
}
166309

167310
[Fact]
@@ -250,7 +393,7 @@ public void ResolveKeyedServicesSingletonInstanceWithAnyKey()
250393
var provider = CreateServiceProvider(serviceCollection);
251394

252395
var services = provider.GetKeyedServices<IFakeOpenGenericService<PocoClass>>("some-key").ToList();
253-
Assert.Equal(new[] { service1, service2 }, services);
396+
Assert.Equal(new[] { service2 }, services);
254397
}
255398

256399
[Fact]
@@ -504,6 +647,9 @@ public void ResolveKeyedSingletonFromScopeServiceProvider()
504647
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
505648
Assert.Null(scopeB.ServiceProvider.GetService<IService>());
506649

650+
Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
651+
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
652+
507653
var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
508654
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
509655

@@ -528,6 +674,9 @@ public void ResolveKeyedScopedFromScopeServiceProvider()
528674
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
529675
Assert.Null(scopeB.ServiceProvider.GetService<IService>());
530676

677+
Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
678+
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
679+
531680
var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
532681
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
533682

src/libraries/Microsoft.Extensions.DependencyInjection/src/Resources/Strings.resx

+3
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,7 @@
192192
<data name="InvalidServiceKeyType" xml:space="preserve">
193193
<value>The type of the key used for lookup doesn't match the type in the constructor parameter with the ServiceKey attribute.</value>
194194
</data>
195+
<data name="KeyedServiceAnyKeyUsedToResolveService" xml:space="preserve">
196+
<value>KeyedService.AnyKey cannot be used to resolve a single service.</value>
197+
</data>
195198
</root>

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs

+39-23
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,13 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
282282
CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root;
283283
ServiceCallSite[] callSites;
284284

285+
var isAnyKeyLookup = serviceIdentifier.ServiceKey == KeyedService.AnyKey;
286+
285287
// If item type is not generic we can safely use descriptor cache
286288
// Special case for KeyedService.AnyKey, we don't want to check the cache because a KeyedService.AnyKey registration
287289
// will "hide" all the other service registration
288290
if (!itemType.IsConstructedGenericType &&
289-
!KeyedService.AnyKey.Equals(cacheKey.ServiceKey) &&
291+
!isAnyKeyLookup &&
290292
_descriptorLookup.TryGetValue(cacheKey, out ServiceDescriptorCacheItem descriptors))
291293
{
292294
callSites = new ServiceCallSite[descriptors.Count];
@@ -317,19 +319,25 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
317319
int slot = 0;
318320
for (int i = _descriptors.Length - 1; i >= 0; i--)
319321
{
320-
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
322+
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
321323
{
322-
if (TryCreateExact(_descriptors[i], cacheKey, callSiteChain, slot) is { } callSite)
324+
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
325+
// so we need to ask creation with the original identity of the descriptor
326+
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
327+
if (TryCreateExact(_descriptors[i], registrationKey, callSiteChain, slot) is { } callSite)
323328
{
324329
AddCallSite(callSite, i);
325330
}
326331
}
327332
}
328333
for (int i = _descriptors.Length - 1; i >= 0; i--)
329334
{
330-
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
335+
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
331336
{
332-
if (TryCreateOpenGeneric(_descriptors[i], cacheKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
337+
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
338+
// so we need to ask creation with the original identity of the descriptor
339+
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
340+
if (TryCreateOpenGeneric(_descriptors[i], registrationKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
333341
{
334342
AddCallSite(callSite, i);
335343
}
@@ -360,6 +368,32 @@ void AddCallSite(ServiceCallSite callSite, int index)
360368
{
361369
callSiteChain.Remove(serviceIdentifier);
362370
}
371+
372+
static bool KeysMatch(object? lookupKey, object? descriptorKey)
373+
{
374+
if (lookupKey == null && descriptorKey == null)
375+
{
376+
// Both are non keyed services
377+
return true;
378+
}
379+
380+
if (lookupKey != null && descriptorKey != null)
381+
{
382+
// Both are keyed services
383+
384+
// We don't want to return AnyKey registration, so ignore it
385+
if (descriptorKey.Equals(KeyedService.AnyKey))
386+
return false;
387+
388+
// Check if both keys are equal, or if the lookup key
389+
// should matches all keys (except AnyKey)
390+
return lookupKey.Equals(descriptorKey)
391+
|| lookupKey.Equals(KeyedService.AnyKey);
392+
}
393+
394+
// One is a keyed service, one is not
395+
return false;
396+
}
363397
}
364398

365399
private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB)
@@ -693,24 +727,6 @@ internal bool IsService(ServiceIdentifier serviceIdentifier)
693727
serviceType == typeof(IServiceProviderIsKeyedService);
694728
}
695729

696-
/// <summary>
697-
/// Returns true if both keys are null or equals, or if key1 is KeyedService.AnyKey and key2 is not null
698-
/// </summary>
699-
private static bool KeysMatch(object? key1, object? key2)
700-
{
701-
if (key1 == null && key2 == null)
702-
return true;
703-
704-
if (key1 != null && key2 != null)
705-
{
706-
return key1.Equals(key2)
707-
|| key1.Equals(KeyedService.AnyKey)
708-
|| key2.Equals(KeyedService.AnyKey);
709-
}
710-
711-
return false;
712-
}
713-
714730
private struct ServiceDescriptorCacheItem
715731
{
716732
[DisallowNull]

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ThrowHelper.cs

+7
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,12 @@ internal static void ThrowObjectDisposedException()
1515
{
1616
throw new ObjectDisposedException(nameof(IServiceProvider));
1717
}
18+
19+
[DoesNotReturn]
20+
[MethodImpl(MethodImplOptions.NoInlining)]
21+
internal static void ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService()
22+
{
23+
throw new InvalidOperationException(SR.Format(SR.KeyedServiceAnyKeyUsedToResolveService, nameof(IServiceProvider), nameof(IServiceScopeFactory)));
24+
}
1825
}
1926
}

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs

+28-3
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,46 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
108108
/// <param name="serviceType">The type of the service to get.</param>
109109
/// <param name="serviceKey">The key of the service to get.</param>
110110
/// <returns>The keyed service.</returns>
111+
/// <exception cref="InvalidOperationException">The <see cref="KeyedService.AnyKey"/> value is used for <paramref name="serviceKey"/>
112+
/// when <paramref name="serviceType"/> is not an enumerable based on <see cref="IEnumerable{T}"/>.
113+
/// </exception>
111114
public object? GetKeyedService(Type serviceType, object? serviceKey)
112115
=> GetKeyedService(serviceType, serviceKey, Root);
113116

114117
internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
115-
=> GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
118+
{
119+
if (serviceKey == KeyedService.AnyKey)
120+
{
121+
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
122+
{
123+
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
124+
}
125+
}
126+
127+
return GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
128+
}
116129

117130
/// <summary>
118131
/// Gets the service object of the specified type.
119132
/// </summary>
120133
/// <param name="serviceType">The type of the service to get.</param>
121134
/// <param name="serviceKey">The key of the service to get.</param>
122135
/// <returns>The keyed service.</returns>
123-
/// <exception cref="InvalidOperationException">The service wasn't found.</exception>
136+
/// <exception cref="InvalidOperationException">The service wasn't found or the <see cref="KeyedService.AnyKey"/> value is used
137+
/// for <paramref name="serviceKey"/> when <paramref name="serviceType"/> is not an enumerable based on <see cref="IEnumerable{T}"/>.
138+
/// </exception>
124139
public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
125-
=> GetRequiredKeyedService(serviceType, serviceKey, Root);
140+
{
141+
if (serviceKey == KeyedService.AnyKey)
142+
{
143+
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
144+
{
145+
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
146+
}
147+
}
148+
149+
return GetRequiredKeyedService(serviceType, serviceKey, Root);
150+
}
126151

127152
internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
128153
{

0 commit comments

Comments
 (0)