Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efcore issue 31178 #89181

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,9 @@ void AddCallSite(ServiceCallSite callSite, int index)
callSitesByIndex.Add(new(index, callSite));
}
}

ResultCache resultCache = ResultCache.None;
if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
{
resultCache = new ResultCache(cacheLocation, callSiteKey);
}

ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
? new ResultCache(cacheLocation, callSiteKey)
: new ResultCache(CallSiteResultCacheLocation.None, callSiteKey);
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites);
}
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
internal sealed class CallSiteValidator: CallSiteVisitor<CallSiteValidator.CallSiteValidatorState, Type?>
{
// Keys are services being resolved via GetService, values - first scoped service in their call site tree
private readonly ConcurrentDictionary<Type, Type> _scopedServices = new ConcurrentDictionary<Type, Type>();
private readonly ConcurrentDictionary<ServiceCacheKey, Type> _scopedServices = new ConcurrentDictionary<ServiceCacheKey, Type>();

public void ValidateCallSite(ServiceCallSite callSite)
{
Type? scoped = VisitCallSite(callSite, default);
if (scoped != null)
{
_scopedServices[callSite.ServiceType] = scoped;
_scopedServices[callSite.Cache.Key] = scoped;
}
}

public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope)
public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope)
{
if (ReferenceEquals(scope, rootScope)
&& _scopedServices.TryGetValue(serviceType, out Type? scopedService))
&& _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService))
{
Type serviceType = callSite.ServiceType;
if (serviceType == scopedService)
{
throw new InvalidOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite
private readonly Type _serviceType;
internal object? DefaultValue => Value;

public ConstantCallSite(Type serviceType, object? defaultValue): base(ResultCache.None)
public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType))
{
_serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType));
if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal struct ResultCache
{
public static ResultCache None { get; } = new ResultCache(CallSiteResultCacheLocation.None, ServiceCacheKey.Empty);
public static ResultCache None(Type serviceType)
{
var cacheKey = new ServiceCacheKey(serviceType, 0);
return new ResultCache(CallSiteResultCacheLocation.None, cacheKey);
}

internal ResultCache(CallSiteResultCacheLocation lifetime, ServiceCacheKey cacheKey)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal readonly struct ServiceCacheKey : IEquatable<ServiceCacheKey>
{
public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0);

/// <summary>
/// Type of service being cached
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal sealed class ServiceProviderCallSite : ServiceCallSite
{
public ServiceProviderCallSite() : base(ResultCache.None)
public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider)))
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ public sealed class ServiceProvider : IServiceProvider, IDisposable, IAsyncDispo
{
private readonly CallSiteValidator? _callSiteValidator;

private readonly Func<Type, Func<ServiceProviderEngineScope, object?>> _createServiceAccessor;
private readonly Func<Type, ServiceAccessor> _createServiceAccessor;

// Internal for testing
internal ServiceProviderEngine _engine;

private bool _disposed;

private readonly ConcurrentDictionary<Type, Func<ServiceProviderEngineScope, object?>> _realizedServices;
private readonly ConcurrentDictionary<Type, ServiceAccessor> _realizedServices;

internal CallSiteFactory CallSiteFactory { get; }

Expand All @@ -47,7 +47,7 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
Root = new ServiceProviderEngineScope(this, isRootScope: true);
_engine = GetEngine();
_createServiceAccessor = CreateServiceAccessor;
_realizedServices = new ConcurrentDictionary<Type, Func<ServiceProviderEngineScope, object?>>();
_realizedServices = new ConcurrentDictionary<Type, ServiceAccessor>();

CallSiteFactory = new CallSiteFactory(serviceDescriptors);
// The list of built in services that aren't part of the list of service descriptors
Expand Down Expand Up @@ -120,9 +120,12 @@ private void OnCreate(ServiceCallSite callSite)
_callSiteValidator?.ValidateCallSite(callSite);
}

private void OnResolve(Type serviceType, IServiceScope scope)
private void OnResolve(ServiceCallSite? callSite, IServiceScope scope)
{
_callSiteValidator?.ValidateResolution(serviceType, scope, Root);
if (callSite != null)
{
_callSiteValidator?.ValidateResolution(callSite, scope, Root);
}
}

internal object? GetService(Type serviceType, ServiceProviderEngineScope serviceProviderEngineScope)
Expand All @@ -132,10 +135,10 @@ private void OnResolve(Type serviceType, IServiceScope scope)
ThrowHelper.ThrowObjectDisposedException();
}

Func<ServiceProviderEngineScope, object?> realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor);
OnResolve(serviceType, serviceProviderEngineScope);
ServiceAccessor realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor);
OnResolve(realizedService.CallSite, serviceProviderEngineScope);
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
var result = realizedService.Invoke(serviceProviderEngineScope);
var result = realizedService.RealizedService.Invoke(serviceProviderEngineScope);
System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType));
return result;
}
Expand All @@ -161,7 +164,7 @@ private void ValidateService(ServiceDescriptor descriptor)
}
}

private Func<ServiceProviderEngineScope, object?> CreateServiceAccessor(Type serviceType)
private ServiceAccessor CreateServiceAccessor(Type serviceType)
{
ServiceCallSite? callSite = CallSiteFactory.GetCallSite(serviceType, new CallSiteChain());
if (callSite != null)
Expand All @@ -173,18 +176,17 @@ private void ValidateService(ServiceDescriptor descriptor)
if (callSite.Cache.Location == CallSiteResultCacheLocation.Root)
{
object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root);
return scope => value;
return new ServiceAccessor { CallSite = callSite, RealizedService = scope => value };
}

return _engine.RealizeService(callSite);
return new ServiceAccessor { CallSite = callSite, RealizedService = _engine.RealizeService(callSite) };
}

return _ => null;
return new ServiceAccessor { CallSite = callSite, RealizedService = _ => null };
}

internal void ReplaceServiceAccessor(ServiceCallSite callSite, Func<ServiceProviderEngineScope, object?> accessor)
{
_realizedServices[callSite.ServiceType] = accessor;
_realizedServices[callSite.ServiceType] = new ServiceAccessor { CallSite = callSite, RealizedService = accessor };
}

internal IServiceScope CreateScope()
Expand Down Expand Up @@ -220,5 +222,11 @@ private ServiceProviderEngine GetEngine()
Justification = "CreateDynamicEngine won't be called when using NativeAOT.")] // see also https://github.com/dotnet/linker/issues/2715
ServiceProviderEngine CreateDynamicEngine() => new DynamicServiceProviderEngine(this);
}

private struct ServiceAccessor
{
public ServiceCallSite? CallSite { get; set; }
public Func<ServiceProviderEngineScope, object?> RealizedService { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -786,17 +786,10 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de
var callSite = factory(typeof(IEnumerable<FakeService>));

var expectedLocation = (CallSiteResultCacheLocation)expectedCacheLocation;
Assert.Equal(expectedLocation, callSite.Cache.Location);

if (expectedLocation != CallSiteResultCacheLocation.None)
{
Assert.Equal(0, callSite.Cache.Key.Slot);
Assert.Equal(typeof(IEnumerable<FakeService>), callSite.Cache.Key.Type);
}
else
{
Assert.Equal(ResultCache.None, callSite.Cache);
}
Assert.Equal(expectedLocation, callSite.Cache.Location);
Assert.Equal(0, callSite.Cache.Key.Slot);
Assert.Equal(typeof(IEnumerable<FakeService>), callSite.Cache.Key.Type);
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
using Xunit;

Expand Down Expand Up @@ -97,6 +99,49 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTra
Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledOnRoot_AndTheLastOneIsNotScoped(bool validateOnBuild)
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddScoped<IBar, Bar>();
serviceCollection.AddTransient<IBar, Bar3>();
using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions
{
ValidateScopes = true,
ValidateOnBuild = validateOnBuild
});

// Act
var actual = serviceProvider.GetService<IBar>();

// Assert
Assert.IsType<Bar3>(actual);
}

[Fact]
public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot()
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddTransient<IBar, Bar>();
serviceCollection.AddScoped<IBar, Bar3>();

serviceCollection.AddTransient<IBaz, Baz>();
serviceCollection.AddTransient<IBaz, Baz2>();

// Act
using var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true);
Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService<IEnumerable<IBar>>());
var actual = serviceProvider.GetService<IEnumerable<IBaz>>();

// Assert
Assert.IsType<Baz>(actual.First());
Assert.IsType<Baz2>(actual.Last());
}

[Fact]
public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton()
{
Expand Down Expand Up @@ -206,13 +251,18 @@ private class Bar : IBar
{
}


private class Bar2 : IBar
{
public Bar2(IBaz baz)
{
}
}

private class Bar3 : IBar
{
}

private interface IBaz
{
}
Expand All @@ -221,6 +271,10 @@ private class Baz : IBaz
{
}

private class Baz2 : IBaz
{
}

private class BazRecursive : IBaz
{
public BazRecursive(IBaz baz)
Expand Down