Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support For Generic Handlers With Multiple Generic Type Parameters #1048

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
935 changes: 480 additions & 455 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}

ServiceRegistrar.AddMediatRClasses(services, configuration);
ServiceRegistrar.SetGenericRequestHandlerRegistrationLimitations(configuration);

ServiceRegistrar.AddMediatRClassesWithTimeout(services, configuration);

ServiceRegistrar.AddRequiredServices(services, configuration);

Expand Down
132 changes: 111 additions & 21 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,50 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading;
using MediatR.Pipeline;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;

namespace MediatR.Registration;

public static class ServiceRegistrar
{
public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration)
{
{
private static int MaxGenericTypeParameters;
private static int MaxTypesClosing;
private static int MaxGenericTypeRegistrations;
private static int RegistrationTimeout;

public static void SetGenericRequestHandlerRegistrationLimitations(MediatRServiceConfiguration configuration)
{
MaxGenericTypeParameters = configuration.MaxGenericTypeParameters;
MaxTypesClosing = configuration.MaxTypesClosing;
MaxGenericTypeRegistrations = configuration.MaxGenericTypeRegistrations;
RegistrationTimeout = configuration.RegistrationTimeout;
}

public static void AddMediatRClassesWithTimeout(IServiceCollection services, MediatRServiceConfiguration configuration)
{
using(var cts = new CancellationTokenSource(RegistrationTimeout))
{
try
{
AddMediatRClasses(services, configuration, cts.Token);
}
catch (OperationCanceledException)
{
throw new TimeoutException("The generic handler registration process timed out.");
}
}
}

public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration, CancellationToken cancellationToken = default)
{

var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();

ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration, cancellationToken);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration, cancellationToken);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration);
Expand Down Expand Up @@ -63,7 +93,8 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
IServiceCollection services,
IEnumerable<Assembly> assembliesToScan,
bool addIfAlreadyExists,
MediatRServiceConfiguration configuration)
MediatRServiceConfiguration configuration,
CancellationToken cancellationToken = default)
{
var concretions = new List<Type>();
var interfaces = new List<Type>();
Expand All @@ -72,9 +103,10 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa

var types = assembliesToScan
.SelectMany(a => a.DefinedTypes)
.Where(t => !t.ContainsGenericParameters || configuration.RegisterGenericHandlers)
.Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any())
.Where(configuration.TypeEvaluator)
.ToList();
.ToList();

foreach (var type in types)
{
Expand Down Expand Up @@ -131,7 +163,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
foreach (var @interface in genericInterfaces)
{
var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList();
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan);
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan, cancellationToken);
}
}

Expand Down Expand Up @@ -174,7 +206,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>

private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation)
{
var closingType = concreteGenericTRequest.GetGenericArguments().First();
var closingTypes = concreteGenericTRequest.GetGenericArguments();

var concreteTResponse = concreteGenericTRequest.GetInterfaces()
.FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>))
Expand All @@ -187,33 +219,90 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(
typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) :
typeDefinition.MakeGenericType(concreteGenericTRequest);

return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingType));
return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingTypes));
}

private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan)
private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan, CancellationToken cancellationToken)
{
var constraints = openRequestHandlerImplementation.GetGenericArguments().First().GetGenericParameterConstraints();

var typesThatCanClose = assembliesToScan
.SelectMany(assembly => assembly.GetTypes())
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type)))
.ToList();
//request generic type constraints
var constraintsForEachParameter = openRequestHandlerImplementation
.GetGenericArguments()
.Select(x => x.GetGenericParameterConstraints())
.ToList();

if (constraintsForEachParameter.Count > 2 && constraintsForEachParameter.Any(constraints => !constraints.Where(x => x.IsInterface || x.IsClass).Any()))
throw new ArgumentException($"Error registering the generic handler type: {openRequestHandlerImplementation.FullName}. When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class.");

var typesThatCanCloseForEachParameter = constraintsForEachParameter
.Select(constraints => assembliesToScan
.SelectMany(assembly => assembly.GetTypes())
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))).ToList()
).ToList();

var requestType = openRequestHandlerInterface.GenericTypeArguments.First();

if (requestType.IsGenericParameter)
return null;

var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition();

var combinations = GenerateCombinations(requestType, typesThatCanCloseForEachParameter, 0, cancellationToken);

return combinations.Select(types => requestGenericTypeDefinition.MakeGenericType(types.ToArray())).ToList();
}

// Method to generate combinations recursively
public static List<List<Type>> GenerateCombinations(Type requestType, List<List<Type>> lists, int depth = 0, CancellationToken cancellationToken = default)
{
if (depth == 0)
{
// Initial checks
if (MaxGenericTypeParameters > 0 && lists.Count > MaxGenericTypeParameters)
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The number of generic type parameters exceeds the maximum allowed ({MaxGenericTypeParameters}).");

foreach (var list in lists)
{
if (MaxTypesClosing > 0 && list.Count > MaxTypesClosing)
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({MaxTypesClosing}).");
}

// Calculate the total number of combinations
long totalCombinations = 1;
foreach (var list in lists)
{
totalCombinations *= list.Count;
if (MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations)
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The total number of generic type registrations exceeds the maximum allowed ({MaxGenericTypeRegistrations}).");
}
}

if (depth >= lists.Count)
return new List<List<Type>> { new List<Type>() };

cancellationToken.ThrowIfCancellationRequested();

return typesThatCanClose.Select(type => requestGenericTypeDefinition.MakeGenericType(type)).ToList();
var currentList = lists[depth];
var childCombinations = GenerateCombinations(requestType, lists, depth + 1, cancellationToken);
var combinations = new List<List<Type>>();

foreach (var item in currentList)
{
foreach (var childCombination in childCombinations)
{
var currentCombination = new List<Type> { item };
currentCombination.AddRange(childCombination);
combinations.Add(currentCombination);
}
}

return combinations;
}

private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan)
private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan, CancellationToken cancellationToken)
{
foreach (var concretion in concretions)
{
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan);
{
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan, cancellationToken);

if (concreteRequests is null)
continue;
Expand All @@ -223,6 +312,7 @@ private static void AddAllConcretionsThatClose(Type openRequestInterface, List<T

foreach (var (Service, Implementation) in registrationTypes)
{
cancellationToken.ThrowIfCancellationRequested();
services.AddTransient(Service, Implementation);
}
}
Expand Down
Loading
Loading