diff --git a/src/OrchardCore.Modules/OrchardCore.OpenId/Startup.cs b/src/OrchardCore.Modules/OrchardCore.OpenId/Startup.cs index ab729b5115f..1c9139f22cf 100644 --- a/src/OrchardCore.Modules/OrchardCore.OpenId/Startup.cs +++ b/src/OrchardCore.Modules/OrchardCore.OpenId/Startup.cs @@ -20,6 +20,7 @@ using OrchardCore.BackgroundTasks; using OrchardCore.Deployment; using OrchardCore.DisplayManagement.Handlers; +using OrchardCore.Environment.Shell.Builders; using OrchardCore.Modules; using OrchardCore.Mvc.Core.Utilities; using OrchardCore.Navigation; @@ -355,7 +356,7 @@ public static IServiceCollection RemoveAll(this IServiceCollection services, Typ for (var index = services.Count - 1; index >= 0; index--) { var descriptor = services[index]; - if (descriptor.ServiceType == serviceType && descriptor.ImplementationType == implementationType) + if (descriptor.ServiceType == serviceType && descriptor.GetImplementationType() == implementationType) { services.RemoveAt(index); } diff --git a/src/OrchardCore.Modules/OrchardCore.Search.Elasticsearch/Startup.cs b/src/OrchardCore.Modules/OrchardCore.Search.Elasticsearch/Startup.cs index b11bb653db8..ffc8825f1e3 100644 --- a/src/OrchardCore.Modules/OrchardCore.Search.Elasticsearch/Startup.cs +++ b/src/OrchardCore.Modules/OrchardCore.Search.Elasticsearch/Startup.cs @@ -20,6 +20,7 @@ using OrchardCore.Deployment; using OrchardCore.DisplayManagement.Descriptors; using OrchardCore.DisplayManagement.Handlers; +using OrchardCore.Environment.Shell.Builders; using OrchardCore.Environment.Shell.Configuration; using OrchardCore.Modules; using OrchardCore.Mvc.Core.Utilities; @@ -268,7 +269,7 @@ public class DeploymentStartup : StartupBase { public override void ConfigureServices(IServiceCollection services) { - if (services.Any(d => d.ImplementationType == typeof(ElasticsearchService))) + if (services.Any(d => d.GetImplementationType() == typeof(ElasticsearchService))) { services.AddTransient(); services.AddSingleton(new DeploymentStepFactory()); @@ -294,7 +295,7 @@ public class ElasticWorkerStartup : StartupBase { public override void ConfigureServices(IServiceCollection services) { - if (services.Any(d => d.ImplementationType == typeof(ElasticsearchService))) + if (services.Any(d => d.GetImplementationType() == typeof(ElasticsearchService))) { services.AddSingleton(); } @@ -306,7 +307,7 @@ public class ElasticContentPickerStartup : StartupBase { public override void ConfigureServices(IServiceCollection services) { - if (services.Any(d => d.ImplementationType == typeof(ElasticsearchService))) + if (services.Any(d => d.GetImplementationType() == typeof(ElasticsearchService))) { services.AddScoped(); services.AddScoped(); diff --git a/src/OrchardCore/OrchardCore.Abstractions/Modules/Builder/OrchardCoreBuilder.cs b/src/OrchardCore/OrchardCore.Abstractions/Modules/Builder/OrchardCoreBuilder.cs index d97260ad266..ba41209dfc6 100644 --- a/src/OrchardCore/OrchardCore.Abstractions/Modules/Builder/OrchardCoreBuilder.cs +++ b/src/OrchardCore/OrchardCore.Abstractions/Modules/Builder/OrchardCoreBuilder.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Routing; +using OrchardCore.Environment.Shell.Builders; using OrchardCore.Environment.Shell.Descriptor.Models; using OrchardCore.Modules; @@ -103,7 +104,7 @@ public OrchardCoreBuilder EnableFeature(string id) for (var index = 0; index < services.Count; index++) { var service = services[index]; - if (service.ImplementationInstance is ShellFeature feature && + if (service.GetImplementationInstance() is ShellFeature feature && string.Equals(feature.Id, id, StringComparison.OrdinalIgnoreCase)) { return; diff --git a/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ClonedSingletonDescriptor.cs b/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ClonedSingletonDescriptor.cs index f074a06cc65..3df5058b2e3 100644 --- a/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ClonedSingletonDescriptor.cs +++ b/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ClonedSingletonDescriptor.cs @@ -1,6 +1,8 @@ using System; using Microsoft.Extensions.DependencyInjection; +#nullable enable + namespace OrchardCore.Environment.Shell.Builders { public class ClonedSingletonDescriptor : ServiceDescriptor @@ -11,12 +13,24 @@ public ClonedSingletonDescriptor(ServiceDescriptor parent, object implementation Parent = parent; } + public ClonedSingletonDescriptor(ServiceDescriptor parent, object? serviceKey, object implementationInstance) + : base(parent.ServiceType, serviceKey, implementationInstance) + { + Parent = parent; + } + public ClonedSingletonDescriptor(ServiceDescriptor parent, Func implementationFactory) : base(parent.ServiceType, implementationFactory, ServiceLifetime.Singleton) { Parent = parent; } + public ClonedSingletonDescriptor(ServiceDescriptor parent, object? serviceKey, Func implementationFactory) + : base(parent.ServiceType, serviceKey, implementationFactory, ServiceLifetime.Singleton) + { + Parent = parent; + } + public ServiceDescriptor Parent { get; } } } diff --git a/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceDescriptorExtensions.cs b/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceDescriptorExtensions.cs index 225f83f1e64..0ccaa55b924 100644 --- a/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceDescriptorExtensions.cs +++ b/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceDescriptorExtensions.cs @@ -1,11 +1,13 @@ using System; using Microsoft.Extensions.DependencyInjection; +#nullable enable + namespace OrchardCore.Environment.Shell.Builders { public static class ServiceDescriptorExtensions { - public static Type GetImplementationType(this ServiceDescriptor descriptor) + public static Type? GetImplementationType(this ServiceDescriptor descriptor) { if (descriptor is ClonedSingletonDescriptor cloned) { @@ -13,22 +15,62 @@ public static Type GetImplementationType(this ServiceDescriptor descriptor) return cloned.Parent.GetImplementationType(); } - if (descriptor.ImplementationType != null) + if (descriptor.TryGetImplementationTypeInternal(out var implementationType)) { - return descriptor.ImplementationType; + return implementationType; } - if (descriptor.ImplementationInstance != null) + if (descriptor.TryGetImplementationInstance(out var implementationInstance)) { - return descriptor.ImplementationInstance.GetType(); + return implementationInstance?.GetType(); } - if (descriptor.ImplementationFactory != null) + if (descriptor.TryGetImplementationFactory(out var implementationFactory)) { - return descriptor.ImplementationFactory.GetType().GenericTypeArguments[1]; + return implementationFactory?.GetType().GenericTypeArguments[1]; } return null; } + + public static object? GetImplementationInstance(this ServiceDescriptor serviceDescriptor) => serviceDescriptor.IsKeyedService + ? serviceDescriptor.KeyedImplementationInstance + : serviceDescriptor.ImplementationInstance; + + public static object? GetImplementationFactory(this ServiceDescriptor serviceDescriptor) => serviceDescriptor.IsKeyedService + ? serviceDescriptor.KeyedImplementationFactory + : serviceDescriptor.ImplementationFactory; + + public static bool TryGetImplementationType(this ServiceDescriptor serviceDescriptor, out Type? type) + { + type = serviceDescriptor.GetImplementationType(); + + return type is not null; + } + + public static bool TryGetImplementationInstance(this ServiceDescriptor serviceDescriptor, out object? instance) + { + instance = serviceDescriptor.GetImplementationInstance(); + + return instance is not null; + } + + public static bool TryGetImplementationFactory(this ServiceDescriptor serviceDescriptor, out object? factory) + { + factory = serviceDescriptor.GetImplementationFactory(); + + return factory is not null; + } + + internal static Type? GetImplementationTypeInternal(this ServiceDescriptor serviceDescriptor) => serviceDescriptor.IsKeyedService + ? serviceDescriptor.KeyedImplementationType + : serviceDescriptor.ImplementationType; + + internal static bool TryGetImplementationTypeInternal(this ServiceDescriptor serviceDescriptor, out Type? type) + { + type = serviceDescriptor.GetImplementationTypeInternal(); + + return type is not null; + } } } diff --git a/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceProviderExtensions.cs b/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceProviderExtensions.cs index c7b4c736461..e4e22aee29b 100644 --- a/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceProviderExtensions.cs +++ b/src/OrchardCore/OrchardCore.Abstractions/Shell/Builders/Extensions/ServiceProviderExtensions.cs @@ -1,6 +1,8 @@ using System; using Microsoft.Extensions.DependencyInjection; +#nullable enable + namespace OrchardCore.Environment.Shell.Builders { public static class ServiceProviderExtensions @@ -20,5 +22,19 @@ public static TResult CreateInstance(this IServiceProvider provider, Ty { return (TResult)ActivatorUtilities.CreateInstance(provider, type); } + + /// + /// Gets the service object of the specified type with the specified key. + /// + public static object? GetKeyedService(this IServiceProvider provider, Type serviceType, object? serviceKey) + { + ArgumentNullException.ThrowIfNull(provider); + if (provider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(serviceType, serviceKey); + } + + throw new InvalidOperationException("This service provider doesn't support keyed services."); + } } } diff --git a/src/OrchardCore/OrchardCore.Data.Abstractions/ServiceCollectionExtensions.cs b/src/OrchardCore/OrchardCore.Data.Abstractions/ServiceCollectionExtensions.cs index 3c9926f1224..400d3491199 100644 --- a/src/OrchardCore/OrchardCore.Data.Abstractions/ServiceCollectionExtensions.cs +++ b/src/OrchardCore/OrchardCore.Data.Abstractions/ServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ using System; using Microsoft.Extensions.DependencyInjection; +using OrchardCore.Environment.Shell.Builders; namespace OrchardCore.Data { @@ -24,10 +25,11 @@ public static IServiceCollection TryAddDataProvider(this IServiceCollection serv for (var i = services.Count - 1; i >= 0; i--) { var entry = services[i]; - if (entry.ImplementationInstance != null) + var implementationInstance = entry.GetImplementationInstance(); + if (implementationInstance is not null) { - var databaseProvider = entry.ImplementationInstance as DatabaseProvider; - if (databaseProvider != null && string.Equals(databaseProvider.Name, name, StringComparison.OrdinalIgnoreCase)) + var databaseProvider = implementationInstance as DatabaseProvider; + if (databaseProvider is not null && string.Equals(databaseProvider.Name, name, StringComparison.OrdinalIgnoreCase)) { services.RemoveAt(i); } diff --git a/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceCollectionExtensions.cs b/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceCollectionExtensions.cs index e9d39fac611..eb726a706d9 100644 --- a/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceCollectionExtensions.cs +++ b/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceCollectionExtensions.cs @@ -1,6 +1,8 @@ using System; using Microsoft.Extensions.DependencyInjection; +#nullable enable + namespace OrchardCore.Environment.Shell.Builders { internal static class ServiceCollectionExtensions @@ -10,8 +12,12 @@ public static IServiceCollection CloneSingleton( ServiceDescriptor parent, object implementationInstance) { - var cloned = new ClonedSingletonDescriptor(parent, implementationInstance); + var cloned = parent.ServiceKey is not null + ? new ClonedSingletonDescriptor(parent, parent.ServiceKey, implementationInstance) + : new ClonedSingletonDescriptor(parent, implementationInstance); + services.Add(cloned); + return services; } @@ -22,6 +28,18 @@ public static IServiceCollection CloneSingleton( { var cloned = new ClonedSingletonDescriptor(parent, implementationFactory); collection.Add(cloned); + + return collection; + } + + public static IServiceCollection CloneSingleton( + this IServiceCollection collection, + ServiceDescriptor parent, + Func implementationFactory) + { + var cloned = new ClonedSingletonDescriptor(parent, parent.ServiceKey, implementationFactory); + collection.Add(cloned); + return collection; } } diff --git a/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceProviderExtensions.cs b/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceProviderExtensions.cs index 1c00e97d94e..2f2c7c44442 100644 --- a/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceProviderExtensions.cs +++ b/src/OrchardCore/OrchardCore/Shell/Builders/Extensions/ServiceProviderExtensions.cs @@ -3,6 +3,8 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; +#nullable enable + namespace OrchardCore.Environment.Shell.Builders { public static class ServiceProviderExtensions @@ -15,17 +17,17 @@ public static class ServiceProviderExtensions public static IServiceCollection CreateChildContainer(this IServiceProvider serviceProvider, IServiceCollection serviceCollection) { IServiceCollection clonedCollection = new ServiceCollection(); - var servicesByType = serviceCollection.GroupBy(s => s.ServiceType); + var servicesByType = serviceCollection.GroupBy(s => (s.ServiceType, s.ServiceKey)); foreach (var services in servicesByType) { - // Prevent hosting 'IStartupFilter' to re-add middlewares to the tenant pipeline. - if (services.Key == typeof(IStartupFilter)) + // Prevent hosting 'IStartupFilter' to re-add middleware to the tenant pipeline. + if (services.Key.ServiceType == typeof(IStartupFilter)) { } // A generic type definition is rather used to create other constructed generic types. - else if (services.Key.IsGenericTypeDefinition) + else if (services.Key.ServiceType.IsGenericTypeDefinition) { // So, we just need to pass the descriptor. foreach (var service in services) @@ -38,25 +40,34 @@ public static IServiceCollection CreateChildContainer(this IServiceProvider serv else if (services.Count() == 1) { var service = services.First(); - if (service.Lifetime == ServiceLifetime.Singleton) { // An host singleton is shared across tenant containers but only registered instances are not disposed // by the DI, so we check if it is disposable or if it uses a factory which may return a different type. - - if (typeof(IDisposable).IsAssignableFrom(service.GetImplementationType()) || service.ImplementationFactory != null) + if (typeof(IDisposable).IsAssignableFrom(service.GetImplementationType()) || + service.GetImplementationFactory() is not null) { // If disposable, register an instance that we resolve immediately from the main container. - clonedCollection.CloneSingleton(service, serviceProvider.GetService(service.ServiceType)); + var instance = service.IsKeyedService + ? serviceProvider.GetRequiredKeyedService(services.Key.ServiceType, services.Key.ServiceKey) + : serviceProvider.GetRequiredService(services.Key.ServiceType); + + clonedCollection.CloneSingleton(service, instance); } - else + else if (!service.IsKeyedService) { // If not disposable, the singleton can be resolved through a factory when first requested. - clonedCollection.CloneSingleton(service, sp => serviceProvider.GetService(service.ServiceType)); + clonedCollection.CloneSingleton(service, sp => + serviceProvider.GetRequiredService(service.ServiceType)); // Note: Most of the time a singleton of a given type is unique and not disposable. So, // most of the time it will be resolved when first requested through a tenant container. } + else + { + clonedCollection.CloneSingleton(service, (sp, key) => + serviceProvider.GetRequiredKeyedService(service.ServiceType, key)); + } } else { @@ -78,11 +89,19 @@ public static IServiceCollection CreateChildContainer(this IServiceProvider serv else if (services.All(s => s.Lifetime == ServiceLifetime.Singleton)) { // We can resolve them from the main container. - var instances = serviceProvider.GetServices(services.Key); + var instances = services.Key.ServiceKey is not null + ? serviceProvider.GetKeyedServices(services.Key.ServiceType, services.Key.ServiceKey) + : serviceProvider.GetServices(services.Key.ServiceType); for (var i = 0; i < services.Count(); i++) { - clonedCollection.CloneSingleton(services.ElementAt(i), instances.ElementAt(i)); + var instance = instances.ElementAt(i); + if (instance is null) + { + continue; + } + + clonedCollection.CloneSingleton(services.ElementAt(i), instance); } } @@ -91,14 +110,23 @@ public static IServiceCollection CreateChildContainer(this IServiceProvider serv { // We need a service scope to resolve them. using var scope = serviceProvider.CreateScope(); - var instances = scope.ServiceProvider.GetServices(services.Key); + + var instances = services.Key.ServiceKey is not null + ? serviceProvider.GetKeyedServices(services.Key.ServiceType, services.Key.ServiceKey) + : serviceProvider.GetServices(services.Key.ServiceType); // Then we only keep singleton instances. for (var i = 0; i < services.Count(); i++) { if (services.ElementAt(i).Lifetime == ServiceLifetime.Singleton) { - clonedCollection.CloneSingleton(services.ElementAt(i), instances.ElementAt(i)); + var instance = instances.ElementAt(i); + if (instance is null) + { + continue; + } + + clonedCollection.CloneSingleton(services.ElementAt(i), instance); } else {