diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index d660eb104007d9..908d46787f13ad 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -316,13 +316,9 @@ void AddCallSite(ServiceCallSite callSite, int index) callSitesByIndex.Add(new(index, callSite)); } } - - ResultCache resultCache = ResultCache.None; - if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) - { - resultCache = new ResultCache(cacheLocation, callSiteKey); - } - + ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) + ? new ResultCache(cacheLocation, callSiteKey) + : new ResultCache(CallSiteResultCacheLocation.None, callSiteKey); return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites); } finally diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index 17c5d34068c64c..f228507a6b8ab5 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -10,22 +10,23 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup internal sealed class CallSiteValidator: CallSiteVisitor { // Keys are services being resolved via GetService, values - first scoped service in their call site tree - private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); public void ValidateCallSite(ServiceCallSite callSite) { Type? scoped = VisitCallSite(callSite, default); if (scoped != null) { - _scopedServices[callSite.ServiceType] = scoped; + _scopedServices[callSite.Cache.Key] = scoped; } } - public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope) + public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope) { if (ReferenceEquals(scope, rootScope) - && _scopedServices.TryGetValue(serviceType, out Type? scopedService)) + && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService)) { + Type serviceType = callSite.ServiceType; if (serviceType == scopedService) { throw new InvalidOperationException( diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs index 8071c67013352d..66984648134553 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs @@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite private readonly Type _serviceType; internal object? DefaultValue => Value; - public ConstantCallSite(Type serviceType, object? defaultValue): base(ResultCache.None) + public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType)) { _serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType)); if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue)) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs index 65b1c799b6f3a2..5b4da78aaec11f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs @@ -8,7 +8,11 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal struct ResultCache { - public static ResultCache None { get; } = new ResultCache(CallSiteResultCacheLocation.None, ServiceCacheKey.Empty); + public static ResultCache None(Type serviceType) + { + var cacheKey = new ServiceCacheKey(serviceType, 0); + return new ResultCache(CallSiteResultCacheLocation.None, cacheKey); + } internal ResultCache(CallSiteResultCacheLocation lifetime, ServiceCacheKey cacheKey) { diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs index 737c23d7f4445a..569fbef9de9bbb 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs @@ -8,8 +8,6 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal readonly struct ServiceCacheKey : IEquatable { - public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0); - /// /// Type of service being cached /// diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs index 3db2f7f0723e00..6271473505c29f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs @@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal sealed class ServiceProviderCallSite : ServiceCallSite { - public ServiceProviderCallSite() : base(ResultCache.None) + public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider))) { } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index b776b78b835ff7..fcc44c8f4f8f4c 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -18,14 +18,14 @@ public sealed class ServiceProvider : IServiceProvider, IDisposable, IAsyncDispo { private readonly CallSiteValidator? _callSiteValidator; - private readonly Func> _createServiceAccessor; + private readonly Func _createServiceAccessor; // Internal for testing internal ServiceProviderEngine _engine; private bool _disposed; - private readonly ConcurrentDictionary> _realizedServices; + private readonly ConcurrentDictionary _realizedServices; internal CallSiteFactory CallSiteFactory { get; } @@ -47,7 +47,7 @@ internal ServiceProvider(ICollection serviceDescriptors, Serv Root = new ServiceProviderEngineScope(this, isRootScope: true); _engine = GetEngine(); _createServiceAccessor = CreateServiceAccessor; - _realizedServices = new ConcurrentDictionary>(); + _realizedServices = new ConcurrentDictionary(); CallSiteFactory = new CallSiteFactory(serviceDescriptors); // The list of built in services that aren't part of the list of service descriptors @@ -120,9 +120,12 @@ private void OnCreate(ServiceCallSite callSite) _callSiteValidator?.ValidateCallSite(callSite); } - private void OnResolve(Type serviceType, IServiceScope scope) + private void OnResolve(ServiceCallSite? callSite, IServiceScope scope) { - _callSiteValidator?.ValidateResolution(serviceType, scope, Root); + if (callSite != null) + { + _callSiteValidator?.ValidateResolution(callSite, scope, Root); + } } internal object? GetService(Type serviceType, ServiceProviderEngineScope serviceProviderEngineScope) @@ -132,10 +135,10 @@ private void OnResolve(Type serviceType, IServiceScope scope) ThrowHelper.ThrowObjectDisposedException(); } - Func realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor); - OnResolve(serviceType, serviceProviderEngineScope); + ServiceAccessor realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor); + OnResolve(realizedService.CallSite, serviceProviderEngineScope); DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); - var result = realizedService.Invoke(serviceProviderEngineScope); + var result = realizedService.RealizedService.Invoke(serviceProviderEngineScope); System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType)); return result; } @@ -161,7 +164,7 @@ private void ValidateService(ServiceDescriptor descriptor) } } - private Func CreateServiceAccessor(Type serviceType) + private ServiceAccessor CreateServiceAccessor(Type serviceType) { ServiceCallSite? callSite = CallSiteFactory.GetCallSite(serviceType, new CallSiteChain()); if (callSite != null) @@ -173,18 +176,17 @@ private void ValidateService(ServiceDescriptor descriptor) if (callSite.Cache.Location == CallSiteResultCacheLocation.Root) { object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root); - return scope => value; + return new ServiceAccessor { CallSite = callSite, RealizedService = scope => value }; } - return _engine.RealizeService(callSite); + return new ServiceAccessor { CallSite = callSite, RealizedService = _engine.RealizeService(callSite) }; } - - return _ => null; + return new ServiceAccessor { CallSite = callSite, RealizedService = _ => null }; } internal void ReplaceServiceAccessor(ServiceCallSite callSite, Func accessor) { - _realizedServices[callSite.ServiceType] = accessor; + _realizedServices[callSite.ServiceType] = new ServiceAccessor { CallSite = callSite, RealizedService = accessor }; } internal IServiceScope CreateScope() @@ -220,5 +222,11 @@ private ServiceProviderEngine GetEngine() Justification = "CreateDynamicEngine won't be called when using NativeAOT.")] // see also https://github.com/dotnet/linker/issues/2715 ServiceProviderEngine CreateDynamicEngine() => new DynamicServiceProviderEngine(this); } + + private struct ServiceAccessor + { + public ServiceCallSite? CallSite { get; set; } + public Func RealizedService { get; set; } + } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs index f1abdcb46b0c21..dbf4e7b5619189 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs @@ -786,17 +786,10 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de var callSite = factory(typeof(IEnumerable)); var expectedLocation = (CallSiteResultCacheLocation)expectedCacheLocation; - Assert.Equal(expectedLocation, callSite.Cache.Location); - if (expectedLocation != CallSiteResultCacheLocation.None) - { - Assert.Equal(0, callSite.Cache.Key.Slot); - Assert.Equal(typeof(IEnumerable), callSite.Cache.Key.Type); - } - else - { - Assert.Equal(ResultCache.None, callSite.Cache); - } + Assert.Equal(expectedLocation, callSite.Cache.Location); + Assert.Equal(0, callSite.Cache.Key.Slot); + Assert.Equal(typeof(IEnumerable), callSite.Cache.Key.Type); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index c26e8d65fbfce8..c6f1834888f0fa 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -97,6 +99,49 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTra Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledOnRoot_AndTheLastOneIsNotScoped(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + serviceCollection.AddTransient(); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actual = serviceProvider.GetService(); + + // Assert + Assert.IsType(actual); + } + + [Fact] + public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddTransient(); + serviceCollection.AddScoped(); + + serviceCollection.AddTransient(); + serviceCollection.AddTransient(); + + // Act + using var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + Assert.Throws(() => serviceProvider.GetService>()); + var actual = serviceProvider.GetService>(); + + // Assert + Assert.IsType(actual.First()); + Assert.IsType(actual.Last()); + } + [Fact] public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() { @@ -206,6 +251,7 @@ private class Bar : IBar { } + private class Bar2 : IBar { public Bar2(IBaz baz) @@ -213,6 +259,10 @@ public Bar2(IBaz baz) } } + private class Bar3 : IBar + { + } + private interface IBaz { } @@ -221,6 +271,10 @@ private class Baz : IBaz { } + private class Baz2 : IBaz + { + } + private class BazRecursive : IBaz { public BazRecursive(IBaz baz)