Skip to content
This repository was archived by the owner on Nov 2, 2018. It is now read-only.

Validate scope into singleton sevice injection #430

Merged
merged 4 commits into from
Jul 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/Microsoft.Extensions.DependencyInjection/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,13 @@
<value>A suitable constructor for type '{0}' could not be located. Ensure the type is concrete and services are registered for all parameters of a public constructor.</value>
<comment>{0} = service type</comment>
</data>
<data name="ScopedInSingletonException" xml:space="preserve">
<value>Cannot consume {2} service '{0}' from {3} '{1}'.</value>
</data>
<data name="ScopedResolvedFromRootException" xml:space="preserve">
<value>Cannot resolve '{0}' from root provider because it requires {2} service '{1}'.</value>
</data>
<data name="DirectScopedResolvedFromRootException" xml:space="preserve">
<value>Cannot resolve {1} service '{0}' from root provider.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,29 @@ namespace Microsoft.Extensions.DependencyInjection
{
public static class ServiceCollectionContainerBuilderExtensions
{
/// <summary>
/// Creates an <see cref="IServiceProvider"/> containing services from the provided <see cref="IServiceCollection"/>.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> containing service descriptors.</param>
/// <returns>The<see cref="IServiceProvider"/>.</returns>

public static IServiceProvider BuildServiceProvider(this IServiceCollection services)
{
return new ServiceProvider(services);
return BuildServiceProvider(services, validateScopes: false);
}

/// <summary>
/// Creates an <see cref="IServiceProvider"/> containing services from the provided <see cref="IServiceCollection"/>
/// optionaly enabling scope validation.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> containing service descriptors.</param>
/// <param name="validateScopes">
/// <c>true</c> to perform check verifying that scoped services never gets resolved from root provider; otherwise <c>false</c>.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well point out false is the default behavior.

/// </param>
/// <returns>The<see cref="IServiceProvider"/>.</returns>
public static IServiceProvider BuildServiceProvider(this IServiceCollection services, bool validateScopes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time for doc comments? On one hand, I kind of hope no one ever uses this (except for may us). On the other hand, without context, I would have no idea what validateScopes is even about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the other method + type too since they're all public.

{
return new ServiceProvider(services, validateScopes);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;

namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal class CallSiteValidator: CallSiteVisitor<CallSiteValidator.CallSiteValidatorState, Type>
{
// Keys are services being resolved via GetService, values - first scoped service in their call site tree
private readonly Dictionary<Type, Type> _scopedServices = new Dictionary<Type, Type>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment to indicate the key is the service being resolved via GetService and the value is a single scoped service that it depends on (but there may be more.


public void ValidateCallSite(Type serviceType, IServiceCallSite callSite)
{
var scoped = VisitCallSite(callSite, default(CallSiteValidatorState));
if (scoped != null)
{
_scopedServices.Add(serviceType, scoped);
}
}

public void ValidateResolution(Type serviceType, ServiceProvider serviceProvider)
{
Type scopedService;
if (ReferenceEquals(serviceProvider, serviceProvider.Root)
&& _scopedServices.TryGetValue(serviceType, out scopedService))
{
if (serviceType == scopedService)
{
throw new InvalidOperationException(
Resources.FormatDirectScopedResolvedFromRootException(serviceType,
nameof(ServiceLifetime.Scoped).ToLowerInvariant()));
}

throw new InvalidOperationException(
Resources.FormatScopedResolvedFromRootException(
serviceType,
scopedService,
nameof(ServiceLifetime.Scoped).ToLowerInvariant()));
}
}

protected override Type VisitTransient(TransientCallSite transientCallSite, CallSiteValidatorState state)
{
return VisitCallSite(transientCallSite.Service, state);
}

protected override Type VisitConstructor(ConstructorCallSite constructorCallSite, CallSiteValidatorState state)
{
Type result = null;
foreach (var parameterCallSite in constructorCallSite.ParameterCallSites)
{
var scoped = VisitCallSite(parameterCallSite, state);
if (result == null)
{
result = scoped;
}
}
return result;
}

protected override Type VisitClosedIEnumerable(ClosedIEnumerableCallSite closedIEnumerableCallSite,
CallSiteValidatorState state)
{
Type result = null;
foreach (var serviceCallSite in closedIEnumerableCallSite.ServiceCallSites)
{
var scoped = VisitCallSite(serviceCallSite, state);
if (result == null)
{
result = scoped;
}
}
return result;
}

protected override Type VisitSingleton(SingletonCallSite singletonCallSite, CallSiteValidatorState state)
{
state.Singleton = singletonCallSite;
return VisitCallSite(singletonCallSite.ServiceCallSite, state);
}

protected override Type VisitScoped(ScopedCallSite scopedCallSite, CallSiteValidatorState state)
{
// We are fine with having ServiceScopeService requested by singletons
if (scopedCallSite.ServiceCallSite is ServiceScopeService)
{
return null;
}
if (state.Singleton != null)
{
throw new InvalidOperationException(Resources.FormatScopedInSingletonException(
scopedCallSite.Key.ServiceType,
state.Singleton.Key.ServiceType,
nameof(ServiceLifetime.Scoped).ToLowerInvariant(),
nameof(ServiceLifetime.Singleton).ToLowerInvariant()
));
}
return scopedCallSite.Key.ServiceType;
}

protected override Type VisitConstant(ConstantCallSite constantCallSite, CallSiteValidatorState state) => null;

protected override Type VisitCreateInstance(CreateInstanceCallSite createInstanceCallSite, CallSiteValidatorState state) => null;

protected override Type VisitInstanceService(InstanceService instanceCallSite, CallSiteValidatorState state) => null;

protected override Type VisitServiceProviderService(ServiceProviderService serviceProviderService, CallSiteValidatorState state) => null;

protected override Type VisitEmptyIEnumerable(EmptyIEnumerableCallSite emptyIEnumerableCallSite, CallSiteValidatorState state) => null;

protected override Type VisitServiceScopeService(ServiceScopeService serviceScopeService, CallSiteValidatorState state) => null;

protected override Type VisitFactoryService(FactoryService factoryService, CallSiteValidatorState state) => null;

internal struct CallSiteValidatorState
{
public SingletonCallSite Singleton { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public ServiceLifetime Lifetime
get { return ServiceLifetime.Transient; }
}

public Type ServiceType => _itemType;

public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> callSiteChain)
{
var list = new List<IServiceCallSite>();
Expand All @@ -36,6 +38,5 @@ public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> call
}
return new ClosedIEnumerableCallSite(_itemType, list.ToArray());
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public ServiceLifetime Lifetime
get { return Descriptor.Lifetime; }
}

public Type ServiceType => Descriptor.ServiceType;

public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> callSiteChain)
{
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ internal interface IService
ServiceLifetime Lifetime { get; }

IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> callSiteChain);

Type ServiceType { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public ServiceLifetime Lifetime
get { return Descriptor.Lifetime; }
}

public Type ServiceType => Descriptor.ServiceType;

public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> callSiteChain)
{
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> call
}
}

public Type ServiceType => _descriptor.ServiceType;

private bool IsSuperset(IEnumerable<Type> left, IEnumerable<Type> right)
{
return new HashSet<Type>(left).IsSupersetOf(right);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ internal class ServiceProviderService : IService, IServiceCallSite

public ServiceLifetime Lifetime
{
get { return ServiceLifetime.Scoped; }
get { return ServiceLifetime.Transient; }
}

public Type ServiceType => typeof(IServiceProvider);

public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> callSiteChain)
{
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public ServiceLifetime Lifetime
get { return ServiceLifetime.Scoped; }
}

public Type ServiceType => typeof(IServiceScopeFactory);

public IServiceCallSite CreateCallSite(ServiceProvider provider, ISet<Type> callSiteChain)
{
return this;
Expand Down
14 changes: 13 additions & 1 deletion src/Microsoft.Extensions.DependencyInjection/ServiceProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace Microsoft.Extensions.DependencyInjection
/// </summary>
internal class ServiceProvider : IServiceProvider, IDisposable
{
private readonly CallSiteValidator _callSiteValidator;
private readonly ServiceTable _table;
private bool _disposeCalled;
private List<IDisposable> _transientDisposables;
Expand All @@ -29,9 +30,15 @@ internal class ServiceProvider : IServiceProvider, IDisposable
// CallSiteRuntimeResolver is stateless so can be shared between all instances
private static readonly CallSiteRuntimeResolver _callSiteRuntimeResolver = new CallSiteRuntimeResolver();

public ServiceProvider(IEnumerable<ServiceDescriptor> serviceDescriptors)
public ServiceProvider(IEnumerable<ServiceDescriptor> serviceDescriptors, bool validateScopes)
{
Root = this;

if (validateScopes)
{
_callSiteValidator = new CallSiteValidator();
}

_table = new ServiceTable(serviceDescriptors);

_table.Add(typeof(IServiceProvider), new ServiceProviderService());
Expand All @@ -44,6 +51,7 @@ internal ServiceProvider(ServiceProvider parent)
{
Root = parent.Root;
_table = parent._table;
_callSiteValidator = parent._callSiteValidator;
}

/// <summary>
Expand All @@ -54,6 +62,9 @@ internal ServiceProvider(ServiceProvider parent)
public object GetService(Type serviceType)
{
var realizedService = _table.RealizedServices.GetOrAdd(serviceType, _createServiceAccessor, this);

_callSiteValidator?.ValidateResolution(serviceType, this);

return realizedService.Invoke(this);
}

Expand All @@ -62,6 +73,7 @@ private static Func<ServiceProvider, object> CreateServiceAccessor(Type serviceT
var callSite = serviceProvider.GetServiceCallSite(serviceType, new HashSet<Type>());
if (callSite != null)
{
serviceProvider._callSiteValidator?.ValidateCallSite(serviceType, callSite);
return RealizeService(serviceProvider._table, serviceType, callSite);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public static IEnumerable<object[]> TestServiceDescriptors(ServiceLifetime lifet
public void BuiltExpressionWillReturnResolvedServiceWhenAppropriate(
ServiceDescriptor[] desciptors, Type serviceType, Func<object, object, bool> compare)
{
var provider = new ServiceProvider(desciptors);
var provider = new ServiceProvider(desciptors, validateScopes: true);

var callSite = provider.GetServiceCallSite(serviceType, new HashSet<Type>());
var collectionCallSite = provider.GetServiceCallSite(typeof(IEnumerable<>).MakeGenericType(serviceType), new HashSet<Type>());
Expand Down Expand Up @@ -111,7 +111,7 @@ public void BuiltExpressionCanResolveNestedScopedService()
descriptors.AddScoped<ServiceB>();
descriptors.AddScoped<ServiceC>();

var provider = new ServiceProvider(descriptors);
var provider = new ServiceProvider(descriptors, validateScopes: true);
var callSite = provider.GetServiceCallSite(typeof(ServiceC), new HashSet<Type>());
var compiledCallSite = CompileCallSite(callSite);

Expand All @@ -129,7 +129,7 @@ public void BuiltExpressionRethrowsOriginalExceptionFromConstructor()
descriptors.AddTransient<ClassWithThrowingCtor>();
descriptors.AddTransient<IFakeService, FakeService>();

var provider = new ServiceProvider(descriptors);
var provider = new ServiceProvider(descriptors, validateScopes: true);

var callSite1 = provider.GetServiceCallSite(typeof(ClassWithThrowingEmptyCtor), new HashSet<Type>());
var compiledCallSite1 = CompileCallSite(callSite1);
Expand Down
Loading