diff --git a/src/Microsoft.Extensions.DependencyInjection/Properties/Resources.Designer.cs b/src/Microsoft.Extensions.DependencyInjection/Properties/Resources.Designer.cs index 7750f25b..2721925e 100644 --- a/src/Microsoft.Extensions.DependencyInjection/Properties/Resources.Designer.cs +++ b/src/Microsoft.Extensions.DependencyInjection/Properties/Resources.Designer.cs @@ -122,6 +122,54 @@ internal static string FormatNoConstructorMatch(object p0) return string.Format(CultureInfo.CurrentCulture, GetString("NoConstructorMatch"), p0); } + /// + /// Cannot consume {2} service '{0}' from {3} '{1}'. + /// + internal static string ScopedInSingletonException + { + get { return GetString("ScopedInSingletonException"); } + } + + /// + /// Cannot consume {2} service '{0}' from {3} '{1}'. + /// + internal static string FormatScopedInSingletonException(object p0, object p1, object p2, object p3) + { + return string.Format(CultureInfo.CurrentCulture, GetString("ScopedInSingletonException"), p0, p1, p2, p3); + } + + /// + /// Cannot resolve '{0}' from root provider because it requires {2} service '{1}'. + /// + internal static string ScopedResolvedFromRootException + { + get { return GetString("ScopedResolvedFromRootException"); } + } + + /// + /// Cannot resolve '{0}' from root provider because it requires {2} service '{1}'. + /// + internal static string FormatScopedResolvedFromRootException(object p0, object p1, object p2) + { + return string.Format(CultureInfo.CurrentCulture, GetString("ScopedResolvedFromRootException"), p0, p1, p2); + } + + /// + /// Cannot resolve {1} service '{0}' from root provider. + /// + internal static string DirectScopedResolvedFromRootException + { + get { return GetString("DirectScopedResolvedFromRootException"); } + } + + /// + /// Cannot resolve {1} service '{0}' from root provider. + /// + internal static string FormatDirectScopedResolvedFromRootException(object p0, object p1) + { + return string.Format(CultureInfo.CurrentCulture, GetString("DirectScopedResolvedFromRootException"), p0, p1); + } + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name); diff --git a/src/Microsoft.Extensions.DependencyInjection/Resources.resx b/src/Microsoft.Extensions.DependencyInjection/Resources.resx index c7618e0c..259f794b 100644 --- a/src/Microsoft.Extensions.DependencyInjection/Resources.resx +++ b/src/Microsoft.Extensions.DependencyInjection/Resources.resx @@ -141,4 +141,13 @@ A suitable constructor for type '{0}' could not be located. Ensure the type is concrete and services are registered for all parameters of a public constructor. {0} = service type + + Cannot consume {2} service '{0}' from {3} '{1}'. + + + Cannot resolve '{0}' from root provider because it requires {2} service '{1}'. + + + Cannot resolve {1} service '{0}' from root provider. + \ No newline at end of file diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceCollectionContainerBuilderExtensions.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceCollectionContainerBuilderExtensions.cs index 824b7ee4..f5945376 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceCollectionContainerBuilderExtensions.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceCollectionContainerBuilderExtensions.cs @@ -7,9 +7,29 @@ namespace Microsoft.Extensions.DependencyInjection { public static class ServiceCollectionContainerBuilderExtensions { + /// + /// Creates an containing services from the provided . + /// + /// The containing service descriptors. + /// The. + public static IServiceProvider BuildServiceProvider(this IServiceCollection services) { - return new ServiceProvider(services); + return BuildServiceProvider(services, validateScopes: false); + } + + /// + /// Creates an containing services from the provided + /// optionaly enabling scope validation. + /// + /// The containing service descriptors. + /// + /// true to perform check verifying that scoped services never gets resolved from root provider; otherwise false. + /// + /// The. + public static IServiceProvider BuildServiceProvider(this IServiceCollection services, bool validateScopes) + { + return new ServiceProvider(services, validateScopes); } } } diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteValidator.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteValidator.cs new file mode 100644 index 00000000..3f6c8b22 --- /dev/null +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/CallSiteValidator.cs @@ -0,0 +1,122 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.Extensions.DependencyInjection.ServiceLookup +{ + internal class CallSiteValidator: CallSiteVisitor + { + // Keys are services being resolved via GetService, values - first scoped service in their call site tree + private readonly Dictionary _scopedServices = new Dictionary(); + + public void ValidateCallSite(Type serviceType, IServiceCallSite callSite) + { + var scoped = VisitCallSite(callSite, default(CallSiteValidatorState)); + if (scoped != null) + { + _scopedServices.Add(serviceType, scoped); + } + } + + public void ValidateResolution(Type serviceType, ServiceProvider serviceProvider) + { + Type scopedService; + if (ReferenceEquals(serviceProvider, serviceProvider.Root) + && _scopedServices.TryGetValue(serviceType, out scopedService)) + { + if (serviceType == scopedService) + { + throw new InvalidOperationException( + Resources.FormatDirectScopedResolvedFromRootException(serviceType, + nameof(ServiceLifetime.Scoped).ToLowerInvariant())); + } + + throw new InvalidOperationException( + Resources.FormatScopedResolvedFromRootException( + serviceType, + scopedService, + nameof(ServiceLifetime.Scoped).ToLowerInvariant())); + } + } + + protected override Type VisitTransient(TransientCallSite transientCallSite, CallSiteValidatorState state) + { + return VisitCallSite(transientCallSite.Service, state); + } + + protected override Type VisitConstructor(ConstructorCallSite constructorCallSite, CallSiteValidatorState state) + { + Type result = null; + foreach (var parameterCallSite in constructorCallSite.ParameterCallSites) + { + var scoped = VisitCallSite(parameterCallSite, state); + if (result == null) + { + result = scoped; + } + } + return result; + } + + protected override Type VisitClosedIEnumerable(ClosedIEnumerableCallSite closedIEnumerableCallSite, + CallSiteValidatorState state) + { + Type result = null; + foreach (var serviceCallSite in closedIEnumerableCallSite.ServiceCallSites) + { + var scoped = VisitCallSite(serviceCallSite, state); + if (result == null) + { + result = scoped; + } + } + return result; + } + + protected override Type VisitSingleton(SingletonCallSite singletonCallSite, CallSiteValidatorState state) + { + state.Singleton = singletonCallSite; + return VisitCallSite(singletonCallSite.ServiceCallSite, state); + } + + protected override Type VisitScoped(ScopedCallSite scopedCallSite, CallSiteValidatorState state) + { + // We are fine with having ServiceScopeService requested by singletons + if (scopedCallSite.ServiceCallSite is ServiceScopeService) + { + return null; + } + if (state.Singleton != null) + { + throw new InvalidOperationException(Resources.FormatScopedInSingletonException( + scopedCallSite.Key.ServiceType, + state.Singleton.Key.ServiceType, + nameof(ServiceLifetime.Scoped).ToLowerInvariant(), + nameof(ServiceLifetime.Singleton).ToLowerInvariant() + )); + } + return scopedCallSite.Key.ServiceType; + } + + protected override Type VisitConstant(ConstantCallSite constantCallSite, CallSiteValidatorState state) => null; + + protected override Type VisitCreateInstance(CreateInstanceCallSite createInstanceCallSite, CallSiteValidatorState state) => null; + + protected override Type VisitInstanceService(InstanceService instanceCallSite, CallSiteValidatorState state) => null; + + protected override Type VisitServiceProviderService(ServiceProviderService serviceProviderService, CallSiteValidatorState state) => null; + + protected override Type VisitEmptyIEnumerable(EmptyIEnumerableCallSite emptyIEnumerableCallSite, CallSiteValidatorState state) => null; + + protected override Type VisitServiceScopeService(ServiceScopeService serviceScopeService, CallSiteValidatorState state) => null; + + protected override Type VisitFactoryService(FactoryService factoryService, CallSiteValidatorState state) => null; + + internal struct CallSiteValidatorState + { + public SingletonCallSite Singleton { get; set; } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ClosedIEnumerableService.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ClosedIEnumerableService.cs index 9ccc9422..d7215472 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ClosedIEnumerableService.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ClosedIEnumerableService.cs @@ -25,6 +25,8 @@ public ServiceLifetime Lifetime get { return ServiceLifetime.Transient; } } + public Type ServiceType => _itemType; + public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet callSiteChain) { var list = new List(); @@ -36,6 +38,5 @@ public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet call } return new ClosedIEnumerableCallSite(_itemType, list.ToArray()); } - } } diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/FactoryService.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/FactoryService.cs index 7e9e9e76..e65e4144 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/FactoryService.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/FactoryService.cs @@ -22,6 +22,8 @@ public ServiceLifetime Lifetime get { return Descriptor.Lifetime; } } + public Type ServiceType => Descriptor.ServiceType; + public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet callSiteChain) { return this; diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/IService.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/IService.cs index fc5bcadf..d76ee067 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/IService.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/IService.cs @@ -13,5 +13,7 @@ internal interface IService ServiceLifetime Lifetime { get; } IServiceCallSite CreateCallSite(ServiceProvider provider, ISet callSiteChain); + + Type ServiceType { get; } } } diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/InstanceService.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/InstanceService.cs index 46e63a02..1d1c76ea 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/InstanceService.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/InstanceService.cs @@ -25,6 +25,8 @@ public ServiceLifetime Lifetime get { return Descriptor.Lifetime; } } + public Type ServiceType => Descriptor.ServiceType; + public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet callSiteChain) { return this; diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/Service.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/Service.cs index 9c4a43e0..4cdccd06 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/Service.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/Service.cs @@ -117,6 +117,8 @@ public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet call } } + public Type ServiceType => _descriptor.ServiceType; + private bool IsSuperset(IEnumerable left, IEnumerable right) { return new HashSet(left).IsSupersetOf(right); diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceProviderService.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceProviderService.cs index 60ad8f97..e894b60c 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceProviderService.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceProviderService.cs @@ -12,9 +12,11 @@ internal class ServiceProviderService : IService, IServiceCallSite public ServiceLifetime Lifetime { - get { return ServiceLifetime.Scoped; } + get { return ServiceLifetime.Transient; } } + public Type ServiceType => typeof(IServiceProvider); + public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet callSiteChain) { return this; diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceScopeService.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceScopeService.cs index 012a90ad..bcb006cd 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceScopeService.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceLookup/ServiceScopeService.cs @@ -15,6 +15,8 @@ public ServiceLifetime Lifetime get { return ServiceLifetime.Scoped; } } + public Type ServiceType => typeof(IServiceScopeFactory); + public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet callSiteChain) { return this; diff --git a/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs b/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs index c7790d77..4ce8d813 100644 --- a/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs +++ b/src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs @@ -17,6 +17,7 @@ namespace Microsoft.Extensions.DependencyInjection /// internal class ServiceProvider : IServiceProvider, IDisposable { + private readonly CallSiteValidator _callSiteValidator; private readonly ServiceTable _table; private bool _disposeCalled; private List _transientDisposables; @@ -29,9 +30,15 @@ internal class ServiceProvider : IServiceProvider, IDisposable // CallSiteRuntimeResolver is stateless so can be shared between all instances private static readonly CallSiteRuntimeResolver _callSiteRuntimeResolver = new CallSiteRuntimeResolver(); - public ServiceProvider(IEnumerable serviceDescriptors) + public ServiceProvider(IEnumerable serviceDescriptors, bool validateScopes) { Root = this; + + if (validateScopes) + { + _callSiteValidator = new CallSiteValidator(); + } + _table = new ServiceTable(serviceDescriptors); _table.Add(typeof(IServiceProvider), new ServiceProviderService()); @@ -44,6 +51,7 @@ internal ServiceProvider(ServiceProvider parent) { Root = parent.Root; _table = parent._table; + _callSiteValidator = parent._callSiteValidator; } /// @@ -54,6 +62,9 @@ internal ServiceProvider(ServiceProvider parent) public object GetService(Type serviceType) { var realizedService = _table.RealizedServices.GetOrAdd(serviceType, _createServiceAccessor, this); + + _callSiteValidator?.ValidateResolution(serviceType, this); + return realizedService.Invoke(this); } @@ -62,6 +73,7 @@ private static Func CreateServiceAccessor(Type serviceT var callSite = serviceProvider.GetServiceCallSite(serviceType, new HashSet()); if (callSite != null) { + serviceProvider._callSiteValidator?.ValidateCallSite(serviceType, callSite); return RealizeService(serviceProvider._table, serviceType, callSite); } diff --git a/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs b/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs index 1f261f11..0abac19b 100644 --- a/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs +++ b/test/Microsoft.Extensions.DependencyInjection.Tests/CallSiteTests.cs @@ -82,7 +82,7 @@ public static IEnumerable TestServiceDescriptors(ServiceLifetime lifet public void BuiltExpressionWillReturnResolvedServiceWhenAppropriate( ServiceDescriptor[] desciptors, Type serviceType, Func compare) { - var provider = new ServiceProvider(desciptors); + var provider = new ServiceProvider(desciptors, validateScopes: true); var callSite = provider.GetServiceCallSite(serviceType, new HashSet()); var collectionCallSite = provider.GetServiceCallSite(typeof(IEnumerable<>).MakeGenericType(serviceType), new HashSet()); @@ -111,7 +111,7 @@ public void BuiltExpressionCanResolveNestedScopedService() descriptors.AddScoped(); descriptors.AddScoped(); - var provider = new ServiceProvider(descriptors); + var provider = new ServiceProvider(descriptors, validateScopes: true); var callSite = provider.GetServiceCallSite(typeof(ServiceC), new HashSet()); var compiledCallSite = CompileCallSite(callSite); @@ -129,7 +129,7 @@ public void BuiltExpressionRethrowsOriginalExceptionFromConstructor() descriptors.AddTransient(); descriptors.AddTransient(); - var provider = new ServiceProvider(descriptors); + var provider = new ServiceProvider(descriptors, validateScopes: true); var callSite1 = provider.GetServiceCallSite(typeof(ClassWithThrowingEmptyCtor), new HashSet()); var compiledCallSite1 = CompileCallSite(callSite1); diff --git a/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceLookup/ServiceTest.cs b/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceLookup/ServiceTest.cs index 23d41ea2..2036b0ac 100644 --- a/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceLookup/ServiceTest.cs +++ b/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceLookup/ServiceTest.cs @@ -22,7 +22,7 @@ public void CreateCallSite_Throws_IfTypeHasNoPublicConstructors() "Ensure the type is concrete and services are registered for all parameters of a public constructor."; var descriptor = new ServiceDescriptor(type, type, ServiceLifetime.Transient); var service = new Service(descriptor); - var serviceProvider = new ServiceProvider(new[] { descriptor }); + var serviceProvider = new ServiceProvider(new[] { descriptor }, validateScopes: true); // Act and Assert var ex = Assert.Throws(() => service.CreateCallSite(serviceProvider, new HashSet())); @@ -38,7 +38,7 @@ public void CreateCallSite_CreatesInstanceCallSite_IfTypeHasDefaultOrPublicParam // Arrange var descriptor = new ServiceDescriptor(type, type, ServiceLifetime.Transient); var service = new Service(descriptor); - var serviceProvider = new ServiceProvider(new[] { descriptor }); + var serviceProvider = new ServiceProvider(new[] { descriptor }, validateScopes: true); // Act var callSite = service.CreateCallSite(serviceProvider, new HashSet()); @@ -99,7 +99,7 @@ public void CreateCallSite_UsesNullaryConstructorIfServicesCannotBeInjectedIntoO var type = typeof(TypeWithParameterizedAndNullaryConstructor); var descriptor = new ServiceDescriptor(type, type, ServiceLifetime.Transient); var service = new Service(descriptor); - var serviceProvider = new ServiceProvider(new[] { descriptor }); + var serviceProvider = new ServiceProvider(new[] { descriptor }, validateScopes: true); // Act var callSite = service.CreateCallSite(serviceProvider, new HashSet()); diff --git a/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceProviderValidationTests.cs b/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceProviderValidationTests.cs new file mode 100644 index 00000000..552235cd --- /dev/null +++ b/test/Microsoft.Extensions.DependencyInjection.Tests/ServiceProviderValidationTests.cs @@ -0,0 +1,140 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Xunit; + +namespace Microsoft.Extensions.DependencyInjection.Tests +{ + public class ServiceProviderValidationTests + { + [Fact] + public void GetService_Throws_WhenScopedIsInjectedIntoSingleton() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IFoo))); + Assert.Equal($"Cannot consume scoped service '{typeof(IBar)}' from singleton '{typeof(IFoo)}'.", exception.Message); + } + + [Fact] + public void GetService_Throws_WhenScopedIsInjectedIntoSingletonThroughTransient() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddTransient(); + serviceCollection.AddScoped(); + var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IFoo))); + Assert.Equal($"Cannot consume scoped service '{typeof(IBaz)}' from singleton '{typeof(IFoo)}'.", exception.Message); + } + + [Fact] + public void GetService_Throws_WhenScopedIsInjectedIntoSingletonThroughSingleton() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IFoo))); + Assert.Equal($"Cannot consume scoped service '{typeof(IBaz)}' from singleton '{typeof(IBar)}'.", exception.Message); + } + + [Fact] + public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IBar))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message); + } + + [Fact] + public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTransient() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddTransient(); + serviceCollection.AddScoped(); + var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IFoo))); + Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message); + } + + [Fact] + public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + // Act + Assert + var result = serviceProvider.GetService(typeof(IBoo)); + Assert.NotNull(result); + } + + private interface IFoo + { + } + + private class Foo : IFoo + { + public Foo(IBar bar) + { + } + } + + private interface IBar + { + } + + private class Bar : IBar + { + } + + private class Bar2 : IBar + { + public Bar2(IBaz baz) + { + } + } + + private interface IBaz + { + } + + private class Baz : IBaz + { + } + + private interface IBoo + { + } + + private class Boo : IBoo + { + public Boo(IServiceScopeFactory scopeFactory) + { + } + } + } +}