From 5d8dc5a0c65dbdde74917ee9e2ed5e8a95e121a0 Mon Sep 17 00:00:00 2001 From: mapogolions Date: Sat, 10 Jun 2023 03:42:52 +0500 Subject: [PATCH 1/6] reproduce inconsistent behaviour --- .../ServiceProviderValidationTests.cs | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index c26e8d65fbfce8..c6f1834888f0fa 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -97,6 +99,49 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTra Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledOnRoot_AndTheLastOneIsNotScoped(bool validateOnBuild) + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + serviceCollection.AddTransient(); + using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true, + ValidateOnBuild = validateOnBuild + }); + + // Act + var actual = serviceProvider.GetService(); + + // Assert + Assert.IsType(actual); + } + + [Fact] + public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddTransient(); + serviceCollection.AddScoped(); + + serviceCollection.AddTransient(); + serviceCollection.AddTransient(); + + // Act + using var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + Assert.Throws(() => serviceProvider.GetService>()); + var actual = serviceProvider.GetService>(); + + // Assert + Assert.IsType(actual.First()); + Assert.IsType(actual.Last()); + } + [Fact] public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() { @@ -206,6 +251,7 @@ private class Bar : IBar { } + private class Bar2 : IBar { public Bar2(IBaz baz) @@ -213,6 +259,10 @@ public Bar2(IBaz baz) } } + private class Bar3 : IBar + { + } + private interface IBaz { } @@ -221,6 +271,10 @@ private class Baz : IBaz { } + private class Baz2 : IBaz + { + } + private class BazRecursive : IBaz { public BazRecursive(IBaz baz) From a0c95a0c05a233670013c4d77ae7bb1eff1c084d Mon Sep 17 00:00:00 2001 From: mapogolions Date: Sat, 10 Jun 2023 03:43:55 +0500 Subject: [PATCH 2/6] use ServiceCacheKey instead of Type to keep track scoped services --- .../src/ServiceLookup/CallSiteValidator.cs | 15 +++++++++++---- .../src/ServiceProvider.cs | 12 ++++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index 17c5d34068c64c..f129b96a18d1e1 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -10,21 +10,22 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup internal sealed class CallSiteValidator: CallSiteVisitor { // Keys are services being resolved via GetService, values - first scoped service in their call site tree - private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); public void ValidateCallSite(ServiceCallSite callSite) { Type? scoped = VisitCallSite(callSite, default); if (scoped != null) { - _scopedServices[callSite.ServiceType] = scoped; + _scopedServices[GetCacheKey(callSite)] = scoped; } } - public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope) + public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope) { + Type serviceType = callSite.ServiceType; if (ReferenceEquals(scope, rootScope) - && _scopedServices.TryGetValue(serviceType, out Type? scopedService)) + && _scopedServices.TryGetValue(GetCacheKey(callSite), out Type? scopedService)) { if (serviceType == scopedService) { @@ -97,6 +98,12 @@ public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceSc protected override Type? VisitFactory(FactoryCallSite factoryCallSite, CallSiteValidatorState state) => null; + private static ServiceCacheKey GetCacheKey(ServiceCallSite callSite) + { + return callSite.Cache.Key.Equals(ServiceCacheKey.Empty) + ? new ServiceCacheKey(callSite.ServiceType, 0) : callSite.Cache.Key; + } + internal struct CallSiteValidatorState { [DisallowNull] diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index b776b78b835ff7..badd51f62d0f7b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -120,9 +120,9 @@ private void OnCreate(ServiceCallSite callSite) _callSiteValidator?.ValidateCallSite(callSite); } - private void OnResolve(Type serviceType, IServiceScope scope) + private void OnResolve(ServiceCallSite callSite, IServiceScope scope) { - _callSiteValidator?.ValidateResolution(serviceType, scope, Root); + _callSiteValidator?.ValidateResolution(callSite, scope, Root); } internal object? GetService(Type serviceType, ServiceProviderEngineScope serviceProviderEngineScope) @@ -133,7 +133,6 @@ private void OnResolve(Type serviceType, IServiceScope scope) } Func realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor); - OnResolve(serviceType, serviceProviderEngineScope); DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); var result = realizedService.Invoke(serviceProviderEngineScope); System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType)); @@ -176,7 +175,12 @@ private void ValidateService(ServiceDescriptor descriptor) return scope => value; } - return _engine.RealizeService(callSite); + Func realizedService = _engine.RealizeService(callSite); + return scope => + { + OnResolve(callSite, scope); + return realizedService.Invoke(scope); + }; } return _ => null; From dc51a2da04650abfbc990161a6f47de1bf97bf27 Mon Sep 17 00:00:00 2001 From: mapogolions Date: Sun, 11 Jun 2023 17:53:44 +0500 Subject: [PATCH 3/6] log after OnResolve --- .../src/ServiceProvider.cs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index badd51f62d0f7b..4816f32d1c6b26 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -133,7 +133,6 @@ private void OnResolve(ServiceCallSite callSite, IServiceScope scope) } Func realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor); - DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); var result = realizedService.Invoke(serviceProviderEngineScope); System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType)); return result; @@ -172,13 +171,18 @@ private void ValidateService(ServiceDescriptor descriptor) if (callSite.Cache.Location == CallSiteResultCacheLocation.Root) { object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root); - return scope => value; + return scope => + { + DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); + return value; + }; } Func realizedService = _engine.RealizeService(callSite); return scope => { OnResolve(callSite, scope); + DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); return realizedService.Invoke(scope); }; } From 09249bf0e019f7a2f00055c7a60826f7e4e5bce8 Mon Sep 17 00:00:00 2001 From: mapogolions Date: Tue, 13 Jun 2023 01:40:25 +0500 Subject: [PATCH 4/6] eval ServiceCacheKey only once --- .../src/ServiceLookup/CallSiteFactory.cs | 10 +++------- .../src/ServiceLookup/CallSiteValidator.cs | 10 ++-------- .../src/ServiceLookup/ConstantCallSite.cs | 2 +- .../src/ServiceLookup/ResultCache.cs | 6 +++++- .../src/ServiceLookup/ServiceCacheKey.cs | 2 -- .../src/ServiceLookup/ServiceProviderCallSite.cs | 2 +- .../DI.Tests/ServiceLookup/CallSiteFactoryTest.cs | 13 +++---------- 7 files changed, 15 insertions(+), 30 deletions(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index d660eb104007d9..908d46787f13ad 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -316,13 +316,9 @@ void AddCallSite(ServiceCallSite callSite, int index) callSitesByIndex.Add(new(index, callSite)); } } - - ResultCache resultCache = ResultCache.None; - if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) - { - resultCache = new ResultCache(cacheLocation, callSiteKey); - } - + ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) + ? new ResultCache(cacheLocation, callSiteKey) + : new ResultCache(CallSiteResultCacheLocation.None, callSiteKey); return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites); } finally diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index f129b96a18d1e1..ff4c3a8807cc2e 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -17,7 +17,7 @@ public void ValidateCallSite(ServiceCallSite callSite) Type? scoped = VisitCallSite(callSite, default); if (scoped != null) { - _scopedServices[GetCacheKey(callSite)] = scoped; + _scopedServices[callSite.Cache.Key] = scoped; } } @@ -25,7 +25,7 @@ public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IS { Type serviceType = callSite.ServiceType; if (ReferenceEquals(scope, rootScope) - && _scopedServices.TryGetValue(GetCacheKey(callSite), out Type? scopedService)) + && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService)) { if (serviceType == scopedService) { @@ -98,12 +98,6 @@ public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IS protected override Type? VisitFactory(FactoryCallSite factoryCallSite, CallSiteValidatorState state) => null; - private static ServiceCacheKey GetCacheKey(ServiceCallSite callSite) - { - return callSite.Cache.Key.Equals(ServiceCacheKey.Empty) - ? new ServiceCacheKey(callSite.ServiceType, 0) : callSite.Cache.Key; - } - internal struct CallSiteValidatorState { [DisallowNull] diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs index 8071c67013352d..66984648134553 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs @@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite private readonly Type _serviceType; internal object? DefaultValue => Value; - public ConstantCallSite(Type serviceType, object? defaultValue): base(ResultCache.None) + public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType)) { _serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType)); if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue)) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs index 65b1c799b6f3a2..5b4da78aaec11f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs @@ -8,7 +8,11 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal struct ResultCache { - public static ResultCache None { get; } = new ResultCache(CallSiteResultCacheLocation.None, ServiceCacheKey.Empty); + public static ResultCache None(Type serviceType) + { + var cacheKey = new ServiceCacheKey(serviceType, 0); + return new ResultCache(CallSiteResultCacheLocation.None, cacheKey); + } internal ResultCache(CallSiteResultCacheLocation lifetime, ServiceCacheKey cacheKey) { diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs index 737c23d7f4445a..569fbef9de9bbb 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs @@ -8,8 +8,6 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal readonly struct ServiceCacheKey : IEquatable { - public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0); - /// /// Type of service being cached /// diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs index 3db2f7f0723e00..6271473505c29f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs @@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { internal sealed class ServiceProviderCallSite : ServiceCallSite { - public ServiceProviderCallSite() : base(ResultCache.None) + public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider))) { } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs index f1abdcb46b0c21..dbf4e7b5619189 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs @@ -786,17 +786,10 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de var callSite = factory(typeof(IEnumerable)); var expectedLocation = (CallSiteResultCacheLocation)expectedCacheLocation; - Assert.Equal(expectedLocation, callSite.Cache.Location); - if (expectedLocation != CallSiteResultCacheLocation.None) - { - Assert.Equal(0, callSite.Cache.Key.Slot); - Assert.Equal(typeof(IEnumerable), callSite.Cache.Key.Type); - } - else - { - Assert.Equal(ResultCache.None, callSite.Cache); - } + Assert.Equal(expectedLocation, callSite.Cache.Location); + Assert.Equal(0, callSite.Cache.Key.Slot); + Assert.Equal(typeof(IEnumerable), callSite.Cache.Key.Type); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] From 7cfab3211f8968b66652b02937b0c678b5636dda Mon Sep 17 00:00:00 2001 From: mapogolions Date: Tue, 13 Jun 2023 01:41:36 +0500 Subject: [PATCH 5/6] refactoring --- .../src/ServiceProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index 4816f32d1c6b26..0610d2096f273e 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -183,7 +183,7 @@ private void ValidateService(ServiceDescriptor descriptor) { OnResolve(callSite, scope); DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType); - return realizedService.Invoke(scope); + return realizedService(scope); }; } From a020ab841f26fc2c1afdc65b80f9465420661a82 Mon Sep 17 00:00:00 2001 From: mapogolions Date: Tue, 13 Jun 2023 03:01:15 +0500 Subject: [PATCH 6/6] read service type only when needed --- .../src/ServiceLookup/CallSiteValidator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index ff4c3a8807cc2e..f228507a6b8ab5 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -23,10 +23,10 @@ public void ValidateCallSite(ServiceCallSite callSite) public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope) { - Type serviceType = callSite.ServiceType; if (ReferenceEquals(scope, rootScope) && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService)) { + Type serviceType = callSite.ServiceType; if (serviceType == scopedService) { throw new InvalidOperationException(