diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs index eb7931489b557..6a42373ccc77d 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Concurrent; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq.Expressions; @@ -10,6 +11,10 @@ using System.Runtime.ExceptionServices; using Microsoft.Extensions.Internal; +#if NETCOREAPP +[assembly: System.Reflection.Metadata.MetadataUpdateHandler(typeof(Microsoft.Extensions.DependencyInjection.ActivatorUtilities.ActivatorUtilitiesUpdateHandler))] +#endif + namespace Microsoft.Extensions.DependencyInjection { /// @@ -17,6 +22,14 @@ namespace Microsoft.Extensions.DependencyInjection /// public static class ActivatorUtilities { +#if NETCOREAPP + // Support caching of constructor metadata for the common case of types in non-collectible assemblies. + private static readonly ConcurrentDictionary s_constructorInfos = new(); + + // Support caching of constructor metadata for types in collectible assemblies. + private static readonly Lazy> s_collectibleConstructorInfos = new(); +#endif + #if NET8_0_OR_GREATER // Maximum number of fixed arguments for ConstructorInvoker.Invoke(arg1, etc). private const int FixedArgumentThreshold = 4; @@ -47,6 +60,17 @@ public static object CreateInstance( throw new InvalidOperationException(SR.CannotCreateAbstractClasses); } + ConstructorInfoEx[]? constructors; +#if NETCOREAPP + if (!s_constructorInfos.TryGetValue(instanceType, out constructors)) + { + constructors = GetOrAddConstructors(instanceType); + } +#else + constructors = CreateConstructorInfoExs(instanceType); +#endif + + ConstructorInfoEx? constructor; IServiceProviderIsService? serviceProviderIsService = provider.GetService(); // if container supports using IServiceProviderIsService, we try to find the longest ctor that // (a) matches all parameters given to CreateInstance @@ -61,10 +85,11 @@ public static object CreateInstance( ConstructorMatcher bestMatcher = default; bool multipleBestLengthFound = false; - foreach (ConstructorInfo? constructor in instanceType.GetConstructors()) + for (int i = 0; i < constructors.Length; i++) { - var matcher = new ConstructorMatcher(constructor); - bool isPreferred = constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), false); + constructor = constructors[i]; + ConstructorMatcher matcher = new(constructor); + bool isPreferred = constructor.IsPreferred; int length = matcher.Match(parameters, serviceProviderIsService); if (isPreferred) @@ -105,18 +130,79 @@ public static object CreateInstance( } } - Type?[] argumentTypes = new Type[parameters.Length]; - for (int i = 0; i < argumentTypes.Length; i++) + Type?[] argumentTypes; + if (parameters.Length == 0) { - argumentTypes[i] = parameters[i]?.GetType(); + argumentTypes = Type.EmptyTypes; + } + else + { + argumentTypes = new Type[parameters.Length]; + for (int i = 0; i < argumentTypes.Length; i++) + { + argumentTypes[i] = parameters[i]?.GetType(); + } } FindApplicableConstructor(instanceType, argumentTypes, out ConstructorInfo constructorInfo, out int?[] parameterMap); - var constructorMatcher = new ConstructorMatcher(constructorInfo); + + // Find the ConstructorInfoEx from the given constructorInfo. + constructor = null; + foreach (ConstructorInfoEx ctor in constructors) + { + if (ReferenceEquals(ctor.Info, constructorInfo)) + { + constructor = ctor; + break; + } + } + + Debug.Assert(constructor != null); + + var constructorMatcher = new ConstructorMatcher(constructor); constructorMatcher.MapParameters(parameterMap, parameters); return constructorMatcher.CreateInstance(provider); } +#if NETCOREAPP + private static ConstructorInfoEx[] GetOrAddConstructors( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) + { + // Not found. Do the slower work of checking for the value in the correct cache. + // Null and non-collectible load contexts use the default cache. + if (!type.Assembly.IsCollectible) + { + return s_constructorInfos.GetOrAdd(type, CreateConstructorInfoExs(type)); + } + + // Collectible load contexts should use the ConditionalWeakTable so they can be unloaded. + if (s_collectibleConstructorInfos.Value.TryGetValue(type, out ConstructorInfoEx[]? value)) + { + return value; + } + + value = CreateConstructorInfoExs(type); + + // ConditionalWeakTable doesn't support GetOrAdd() so use AddOrUpdate(). This means threads + // can have different instances for the same type, but that is OK since they are equivalent. + s_collectibleConstructorInfos.Value.AddOrUpdate(type, value); + return value; + } +#endif // NETCOREAPP + + private static ConstructorInfoEx[] CreateConstructorInfoExs( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) + { + ConstructorInfo[] constructors = type.GetConstructors(); + ConstructorInfoEx[]? value = new ConstructorInfoEx[constructors.Length]; + for (int i = 0; i < constructors.Length; i++) + { + value[i] = new ConstructorInfoEx(constructors[i]); + } + + return value; + } + /// /// Create a delegate that will instantiate a type with constructor arguments provided directly /// and/or from an . @@ -551,58 +637,82 @@ private static bool TryCreateParameterMap(ParameterInfo[] constructorParameters, return true; } - private static object? GetService(IServiceProvider serviceProvider, ParameterInfo parameterInfo) + private sealed class ConstructorInfoEx { - // Handle keyed service - if (TryGetServiceKey(parameterInfo, out object? key)) + public readonly ConstructorInfo Info; + public readonly ParameterInfo[] Parameters; + public readonly bool IsPreferred; + private readonly object?[]? _parameterKeys; + + public ConstructorInfoEx(ConstructorInfo constructor) { - if (serviceProvider is IKeyedServiceProvider keyedServiceProvider) + Info = constructor; + Parameters = constructor.GetParameters(); + IsPreferred = constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), inherit: false); + + for (int i = 0; i < Parameters.Length; i++) { - return keyedServiceProvider.GetKeyedService(parameterInfo.ParameterType, key); + FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?) + Attribute.GetCustomAttribute(Parameters[i], typeof(FromKeyedServicesAttribute), inherit: false); + + if (attr is not null) + { + _parameterKeys ??= new object?[Parameters.Length]; + _parameterKeys[i] = attr.Key; + } } - throw new InvalidOperationException(SR.KeyedServicesNotSupported); } - // Try non keyed service - return serviceProvider.GetService(parameterInfo.ParameterType); - } - private static bool IsService(IServiceProviderIsService serviceProviderIsService, ParameterInfo parameterInfo) - { - // Handle keyed service - if (TryGetServiceKey(parameterInfo, out object? key)) + public bool IsService(IServiceProviderIsService serviceProviderIsService, int parameterIndex) { - if (serviceProviderIsService is IServiceProviderIsKeyedService serviceProviderIsKeyedService) + ParameterInfo parameterInfo = Parameters[parameterIndex]; + + // Handle keyed service + object? key = _parameterKeys?[parameterIndex]; + if (key is not null) { - return serviceProviderIsKeyedService.IsKeyedService(parameterInfo.ParameterType, key); + if (serviceProviderIsService is IServiceProviderIsKeyedService serviceProviderIsKeyedService) + { + return serviceProviderIsKeyedService.IsKeyedService(parameterInfo.ParameterType, key); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); } - throw new InvalidOperationException(SR.KeyedServicesNotSupported); + + // Use non-keyed service + return serviceProviderIsService.IsService(parameterInfo.ParameterType); } - // Try non keyed service - return serviceProviderIsService.IsService(parameterInfo.ParameterType); - } - private static bool TryGetServiceKey(ParameterInfo parameterInfo, out object? key) - { - foreach (var attribute in parameterInfo.GetCustomAttributes(false)) + public object? GetService(IServiceProvider serviceProvider, int parameterIndex) { - key = attribute.Key; - return true; + ParameterInfo parameterInfo = Parameters[parameterIndex]; + + // Handle keyed service + object? key = _parameterKeys?[parameterIndex]; + if (key is not null) + { + if (serviceProvider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(parameterInfo.ParameterType, key); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); + } + + // Use non-keyed service + return serviceProvider.GetService(parameterInfo.ParameterType); } - key = null; - return false; } private readonly struct ConstructorMatcher { - private readonly ConstructorInfo _constructor; - private readonly ParameterInfo[] _parameters; + private readonly ConstructorInfoEx _constructor; private readonly object?[] _parameterValues; - public ConstructorMatcher(ConstructorInfo constructor) + public ConstructorMatcher(ConstructorInfoEx constructor) { _constructor = constructor; - _parameters = _constructor.GetParameters(); - _parameterValues = new object?[_parameters.Length]; + _parameterValues = new object[constructor.Parameters.Length]; } public int Match(object[] givenParameters, IServiceProviderIsService serviceProviderIsService) @@ -612,10 +722,10 @@ public int Match(object[] givenParameters, IServiceProviderIsService serviceProv Type? givenType = givenParameters[givenIndex]?.GetType(); bool givenMatched = false; - for (int applyIndex = 0; applyIndex < _parameters.Length; applyIndex++) + for (int applyIndex = 0; applyIndex < _constructor.Parameters.Length; applyIndex++) { if (_parameterValues[applyIndex] == null && - _parameters[applyIndex].ParameterType.IsAssignableFrom(givenType)) + _constructor.Parameters[applyIndex].ParameterType.IsAssignableFrom(givenType)) { givenMatched = true; _parameterValues[applyIndex] = givenParameters[givenIndex]; @@ -630,12 +740,12 @@ public int Match(object[] givenParameters, IServiceProviderIsService serviceProv } // confirms the rest of ctor arguments match either as a parameter with a default value or as a service registered - for (int i = 0; i < _parameters.Length; i++) + for (int i = 0; i < _constructor.Parameters.Length; i++) { if (_parameterValues[i] == null && - !IsService(serviceProviderIsService, _parameters[i])) + !_constructor.IsService(serviceProviderIsService, i)) { - if (ParameterDefaultValue.TryGetDefaultValue(_parameters[i], out object? defaultValue)) + if (ParameterDefaultValue.TryGetDefaultValue(_constructor.Parameters[i], out object? defaultValue)) { _parameterValues[i] = defaultValue; } @@ -646,21 +756,21 @@ public int Match(object[] givenParameters, IServiceProviderIsService serviceProv } } - return _parameters.Length; + return _constructor.Parameters.Length; } public object CreateInstance(IServiceProvider provider) { - for (int index = 0; index < _parameters.Length; index++) + for (int index = 0; index < _constructor.Parameters.Length; index++) { if (_parameterValues[index] == null) { - object? value = GetService(provider, _parameters[index]); + object? value = _constructor.GetService(provider, index); if (value == null) { - if (!ParameterDefaultValue.TryGetDefaultValue(_parameters[index], out object? defaultValue)) + if (!ParameterDefaultValue.TryGetDefaultValue(_constructor.Parameters[index], out object? defaultValue)) { - throw new InvalidOperationException(SR.Format(SR.UnableToResolveService, _parameters[index].ParameterType, _constructor.DeclaringType)); + throw new InvalidOperationException(SR.Format(SR.UnableToResolveService, _constructor.Parameters[index].ParameterType, _constructor.Info.DeclaringType)); } else { @@ -677,7 +787,7 @@ public object CreateInstance(IServiceProvider provider) #if NETFRAMEWORK || NETSTANDARD2_0 try { - return _constructor.Invoke(_parameterValues); + return _constructor.Info.Invoke(_parameterValues); } catch (TargetInvocationException ex) when (ex.InnerException != null) { @@ -686,13 +796,13 @@ public object CreateInstance(IServiceProvider provider) throw; } #else - return _constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: _parameterValues, culture: null); + return _constructor.Info.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: _parameterValues, culture: null); #endif } public void MapParameters(int?[] parameterMap, object[] givenParameters) { - for (int i = 0; i < _parameters.Length; i++) + for (int i = 0; i < _constructor.Parameters.Length; i++) { if (parameterMap[i] != null) { @@ -974,5 +1084,20 @@ private static object ReflectionFactoryCanonical( return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null); } #endif // NET8_0_OR_GREATER + +#if NETCOREAPP + internal static class ActivatorUtilitiesUpdateHandler + { + public static void ClearCache(Type[]? _) + { + // Ignore the Type[] argument; just clear the caches. + s_constructorInfos.Clear(); + if (s_collectibleConstructorInfos.IsValueCreated) + { + s_collectibleConstructorInfos.Value.Clear(); + } + } + } +#endif } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectableClasses.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectableClasses.cs new file mode 100644 index 0000000000000..cc9e925ae0ad8 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectableClasses.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; + +namespace CollectibleAssembly +{ + public class ClassToCreate + { + public object ClassAsCtorArgument { get; set; } + + public ClassToCreate(ClassAsCtorArgument obj) { ClassAsCtorArgument = obj; } + + public static object Create(ServiceProvider provider) + { + // Both the type to create (ClassToCreate) and the ctor's arg type (ClassAsCtorArgument) are + // located in this assembly, so both types need to be GC'd for this assembly to be collected. + return ActivatorUtilities.CreateInstance(provider, new ClassAsCtorArgument()); + } + } + + public class ClassAsCtorArgument + { + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectibleAssembly.csproj b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectibleAssembly.csproj new file mode 100644 index 0000000000000..82159cece2822 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectibleAssembly.csproj @@ -0,0 +1,11 @@ + + + $(NetCoreAppCurrent);$(NetFrameworkMinimum) + true + + + + + + + diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs index 7572e6977a4c4..dda3cafa442eb 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs @@ -2,8 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.IO; +using System.Reflection; using Microsoft.DotNet.RemoteExecutor; using Xunit; +using System.Runtime.CompilerServices; + +#if NETCOREAPP +using System.Runtime.Loader; +#endif namespace Microsoft.Extensions.DependencyInjection.Tests { @@ -386,6 +393,125 @@ public void CreateFactory_RemoteExecutor_NoParameters_Success(bool useDynamicCod }, options); } +#if NETCOREAPP + [ActiveIssue("https://github.com/dotnet/runtime/issues/34072", TestRuntimes.Mono)] + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] + [InlineData(false)] + public void CreateInstance_CollectibleAssembly(bool useDynamicCode) + { + if (PlatformDetection.IsNonBundledAssemblyLoadingSupported) + { + RemoteInvokeOptions options = new(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + Assert.False(Collectible_IsAssemblyLoaded()); + Collectible_LoadAndCreate(useCollectibleAssembly : true, out WeakReference asmWeakRef, out WeakReference typeWeakRef); + + for (int i = 0; (typeWeakRef.IsAlive || asmWeakRef.IsAlive) && (i < 10); i++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + // These should be GC'd. + Assert.False(asmWeakRef.IsAlive, "asmWeakRef.IsAlive"); + Assert.False(typeWeakRef.IsAlive, "typeWeakRef.IsAlive"); + Assert.False(Collectible_IsAssemblyLoaded()); + }, options); + } + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] + [InlineData(false)] + public void CreateInstance_NormalAssembly(bool useDynamicCode) + { + RemoteInvokeOptions options = new(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + Assert.False(Collectible_IsAssemblyLoaded()); + Collectible_LoadAndCreate(useCollectibleAssembly: false, out WeakReference asmWeakRef, out WeakReference typeWeakRef); + + for (int i = 0; (typeWeakRef.IsAlive || asmWeakRef.IsAlive) && (i < 10); i++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + // These will not be GC'd. + Assert.True(asmWeakRef.IsAlive, "alcWeakRef.IsAlive"); + Assert.True(typeWeakRef.IsAlive, "typeWeakRef.IsAlive"); + Assert.True(Collectible_IsAssemblyLoaded()); + }, options); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Collectible_LoadAndCreate(bool useCollectibleAssembly, out WeakReference asmWeakRef, out WeakReference typeWeakRef) + { + Assembly asm; + object obj; + + if (useCollectibleAssembly) + { + asm = MyLoadContext.LoadAsCollectable(); + obj = CreateWithActivator(asm); + Assert.True(obj.GetType().Assembly.IsCollectible); + } + else + { + asm = MyLoadContext.LoadNormal(); + obj = CreateWithActivator(asm); + Assert.False(obj.GetType().Assembly.IsCollectible); + } + + Assert.True(Collectible_IsAssemblyLoaded()); + asmWeakRef = new WeakReference(asm); + typeWeakRef = new WeakReference(obj.GetType()); + + static object CreateWithActivator(Assembly asm) + { + Type t = asm.GetType("CollectibleAssembly.ClassToCreate"); + MethodInfo mi = t.GetMethod("Create", BindingFlags.Static | BindingFlags.Public, new Type[] { typeof(ServiceProvider) }); + + object instance; + ServiceCollection services = new(); + using (ServiceProvider provider = services.BuildServiceProvider()) + { + instance = mi.Invoke(null, new object[] { provider }); + } + + return instance; + } + } + + static bool Collectible_IsAssemblyLoaded() + { + Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies(); + for (int i = 0; i < assemblies.Length; i++) + { + Assembly asm = assemblies[i]; + string asmName = Path.GetFileName(asm.Location); + if (asmName == "CollectibleAssembly.dll") + { + return true; + } + } + + return false; + } +#endif + private static void DisableDynamicCode(RemoteInvokeOptions options) { // We probably only need to set 'IsDynamicCodeCompiled' since only that is checked, @@ -581,5 +707,36 @@ public ClassWithStringDefaultValue(string text = "DEFAULT") Text = text; } } -} +#if NETCOREAPP + internal class MyLoadContext : AssemblyLoadContext + { + private MyLoadContext() : base(isCollectible: true) + { + } + + public Assembly LoadAssembly() + { + Assembly asm = LoadFromAssemblyPath(GetPath()); + Assert.Equal(GetLoadContext(asm), this); + return asm; + } + + public static Assembly LoadAsCollectable() + { + MyLoadContext alc = new MyLoadContext(); + return alc.LoadAssembly(); + } + + public static Assembly LoadNormal() + { + return Assembly.LoadFrom(GetPath()); + } + + private static string GetPath() + { + return Path.Combine(Directory.GetCurrentDirectory(), "CollectibleAssembly.dll"); + } + } +#endif +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj index 4ac3c02d7157a..067508506a82f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent);$(NetFrameworkMinimum) @@ -24,6 +24,7 @@ +