Skip to content

Commit

Permalink
Merge pull request #828 from jbogard/consolidate-registration
Browse files Browse the repository at this point in the history
Consolidate registration to single configuration object and optimize registration
  • Loading branch information
jbogard authored Feb 6, 2023
2 parents eb6d1f0 + 667d601 commit db8f2f5
Show file tree
Hide file tree
Showing 15 changed files with 148 additions and 116 deletions.
5 changes: 4 additions & 1 deletion samples/MediatR.Examples.AspNetCore/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ private static IMediator BuildMediator(WrappingWriter writer)

services.AddSingleton<TextWriter>(writer);

services.AddMediatR(typeof(Ping), typeof(Sing));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssemblies(typeof(Ping).Assembly, typeof(Sing).Assembly);
});

services.AddScoped(typeof(IStreamRequestHandler<Sing, Song>), typeof(SingHandler));

Expand Down
57 changes: 36 additions & 21 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,48 +1,63 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using MediatR;

namespace Microsoft.Extensions.DependencyInjection;

public class MediatRServiceConfiguration
{
public Func<Type, bool> TypeEvaluator { get; private set; } = t => true;
public Type MediatorImplementationType { get; private set; }
public ServiceLifetime Lifetime { get; private set; }
public Func<Type, bool> TypeEvaluator { get; set; } = t => true;
public Type MediatorImplementationType { get; set; } = typeof(Mediator);
public ServiceLifetime Lifetime { get; set; } = ServiceLifetime.Transient;
public RequestExceptionActionProcessorStrategy RequestExceptionActionProcessorStrategy { get; set; }
= RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions;

public MediatRServiceConfiguration()
{
MediatorImplementationType = typeof(Mediator);
Lifetime = ServiceLifetime.Transient;
}
internal List<Assembly> AssembliesToRegister { get; } = new();

public MediatRServiceConfiguration Using<TMediator>() where TMediator : IMediator
{
MediatorImplementationType = typeof(TMediator);
return this;
}
public List<ServiceDescriptor> BehaviorsToRegister { get; } = new();

public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining<T>()
=> RegisterServicesFromAssemblyContaining(typeof(T));

public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining(Type type)
=> RegisterServicesFromAssembly(type.Assembly);

public MediatRServiceConfiguration AsSingleton()
public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembly)
{
Lifetime = ServiceLifetime.Singleton;
AssembliesToRegister.Add(assembly);

return this;
}

public MediatRServiceConfiguration AsScoped()
public MediatRServiceConfiguration RegisterServicesFromAssemblies(
params Assembly[] assemblies)
{
Lifetime = ServiceLifetime.Scoped;
AssembliesToRegister.AddRange(assemblies);

return this;
}

public MediatRServiceConfiguration AsTransient()
public MediatRServiceConfiguration AddBehavior<TServiceType, TImplementationType>(
ServiceLifetime serviceLifetime = ServiceLifetime.Transient) =>
AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

public MediatRServiceConfiguration AddBehavior(
Type serviceType,
Type implementationType,
ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
Lifetime = ServiceLifetime.Transient;
BehaviorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime));

return this;
}

public MediatRServiceConfiguration WithEvaluator(Func<Type, bool> evaluator)
public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
TypeEvaluator = evaluator;
var serviceType = typeof(IPipelineBehavior<,>);

BehaviorsToRegister.Add(new ServiceDescriptor(serviceType, openBehaviorType, serviceLifetime));

return this;
}
}
63 changes: 8 additions & 55 deletions src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,71 +23,24 @@ public static class ServiceCollectionExtensions
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="assemblies">Assemblies to scan</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, params Assembly[] assemblies)
=> services.AddMediatR(assemblies, configuration: null);

/// <summary>
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="assemblies">Assemblies to scan</param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, Action<MediatRServiceConfiguration>? configuration, params Assembly[] assemblies)
=> services.AddMediatR(assemblies, configuration);

/// <summary>
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="assemblies">Assemblies to scan</param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, IEnumerable<Assembly> assemblies, Action<MediatRServiceConfiguration>? configuration)
public static IServiceCollection AddMediatR(this IServiceCollection services,
Action<MediatRServiceConfiguration> configuration)
{
if (!assemblies.Any())
var serviceConfig = new MediatRServiceConfiguration();

configuration.Invoke(serviceConfig);

if (!serviceConfig.AssembliesToRegister.Any())
{
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}
var serviceConfig = new MediatRServiceConfiguration();

configuration?.Invoke(serviceConfig);
ServiceRegistrar.AddMediatRClasses(services, serviceConfig);

ServiceRegistrar.AddRequiredServices(services, serviceConfig);

ServiceRegistrar.AddMediatRClasses(services, assemblies, serviceConfig);

return services;
}

/// <summary>
/// Registers handlers and mediator types from the assemblies that contain the specified types
/// </summary>
/// <param name="services"></param>
/// <param name="handlerAssemblyMarkerTypes"></param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, params Type[] handlerAssemblyMarkerTypes)
=> services.AddMediatR(handlerAssemblyMarkerTypes, configuration: null);

/// <summary>
/// Registers handlers and mediator types from the assemblies that contain the specified types
/// </summary>
/// <param name="services"></param>
/// <param name="handlerAssemblyMarkerTypes"></param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, Action<MediatRServiceConfiguration>? configuration, params Type[] handlerAssemblyMarkerTypes)
=> services.AddMediatR(handlerAssemblyMarkerTypes, configuration);

/// <summary>
/// Registers handlers and mediator types from the assemblies that contain the specified types
/// </summary>
/// <param name="services"></param>
/// <param name="handlerAssemblyMarkerTypes"></param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, IEnumerable<Type> handlerAssemblyMarkerTypes, Action<MediatRServiceConfiguration>? configuration)
=> services.AddMediatR(handlerAssemblyMarkerTypes.Select(t => t.GetTypeInfo().Assembly), configuration);
}
2 changes: 1 addition & 1 deletion src/MediatR/Pipeline/RequestExceptionHandlerState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class RequestExceptionHandlerState<TResponse>
public bool Handled { get; private set; }

/// <summary>
/// The response that is returned if <see cref="Handled"/> is <code>true</code>.
/// The response that is returned if <see cref="Handled"/> is <code>true</code>.
/// </summary>
public TResponse? Response { get; private set; }

Expand Down
53 changes: 43 additions & 10 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace MediatR.Registration;

public static class ServiceRegistrar
{
public static void AddMediatRClasses(IServiceCollection services, IEnumerable<Assembly> assembliesToScan, MediatRServiceConfiguration configuration)
public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration)
{
assembliesToScan = assembliesToScan.Distinct().ToArray();
var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();

ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
Expand Down Expand Up @@ -217,20 +217,53 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAdd(new ServiceDescriptor(typeof(ISender), sp => sp.GetRequiredService<IMediator>(), serviceConfiguration.Lifetime));
services.TryAdd(new ServiceDescriptor(typeof(IPublisher), sp => sp.GetRequiredService<IMediator>(), serviceConfiguration.Lifetime));

// Use TryAddTransientExact (see below), we dó want to register our Pre/Post processor behavior, even if (a more concrete)
foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister)
{
services.Add(serviceDescriptor);
}

// Use TryAddTransientExact (see below), we do want to register our Pre/Post processor behavior, even if (a more concrete)
// registration for IPipelineBehavior<,> already exists. But only once.
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>));
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestPostProcessorBehavior<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPreProcessorBehavior<,>),
typeof(IRequestPreProcessor<>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPostProcessorBehavior<,>),
typeof(IRequestPostProcessor<,>));

if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
if (serviceConfiguration.RequestExceptionActionProcessorStrategy ==
RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionActionProcessorBehavior<,>));
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionProcessorBehavior<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>),
typeof(IRequestExceptionAction<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>),
typeof(IRequestExceptionHandler<,,>));
}
else
{
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionProcessorBehavior<,>));
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionActionProcessorBehavior<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>),
typeof(IRequestExceptionHandler<,,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>),
typeof(IRequestExceptionAction<,>));
}
}

private static void RegisterBehaviorIfImplementationsExist(
IServiceCollection services,
Type behaviorType,
Type subBehaviorType
)
{
var hasAnyRegistrationsOfSubBehaviorType = services
.Select(service => service.ImplementationType)
.Where(type => type != null)
.SelectMany(type => type!.GetInterfaces())
.Where(type => type.IsGenericType)
.Select(type => type.GetGenericTypeDefinition())
.Where(type => type != null)
.Any(type => type == subBehaviorType);

if (hasAnyRegistrationsOfSubBehaviorType)
{
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), behaviorType);
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/MediatR.Benchmarks/Benchmarks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public void GlobalSetup()

services.AddSingleton(TextWriter.Null);

services.AddMediatR(typeof(Ping));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining(typeof(Ping)));

services.AddScoped(typeof(IPipelineBehavior<,>), typeof(GenericPipelineBehavior<,>));
services.AddScoped(typeof(IRequestPreProcessor<>), typeof(GenericRequestPreProcessor<>));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public AssemblyResolutionTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(typeof(Ping).GetTypeInfo().Assembly);
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
_provider = services.BuildServiceProvider();
}

Expand Down Expand Up @@ -55,7 +55,7 @@ public void ShouldRequireAtLeastOneAssembly()
{
var services = new ServiceCollection();

Action registration = () => services.AddMediatR(new Type[0]);
Action registration = () => services.AddMediatR(_ => {});

registration.ShouldThrow<ArgumentException>();
}
Expand Down
14 changes: 11 additions & 3 deletions test/MediatR.Tests/MicrosoftExtensionsDI/CustomMediatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ public CustomMediatorTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(cfg => cfg.Using<MyCustomMediator>(), typeof(CustomMediatorTests));
services.AddMediatR(cfg =>
{
cfg.MediatorImplementationType = typeof(MyCustomMediator);
cfg.RegisterServicesFromAssemblyContaining(typeof(CustomMediatorTests));
});
_provider = services.BuildServiceProvider();
}

Expand Down Expand Up @@ -43,10 +47,14 @@ public void Can_Call_AddMediatr_multiple_times()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(cfg => cfg.Using<MyCustomMediator>(), typeof(CustomMediatorTests));
services.AddMediatR(cfg =>
{
cfg.MediatorImplementationType = typeof(MyCustomMediator);
cfg.RegisterServicesFromAssemblyContaining(typeof(CustomMediatorTests));
});

// Call AddMediatr again, this should NOT override our custom mediatr (With MS DI, last registration wins)
services.AddMediatR(typeof(CustomMediatorTests));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining(typeof(CustomMediatorTests)));

var provider = services.BuildServiceProvider();
var mediator = provider.GetRequiredService<IMediator>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public DerivingRequestsTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(typeof(Ping));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining(typeof(Ping)));
_provider = services.BuildServiceProvider();
_mediator = _provider.GetRequiredService<IMediator>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public DuplicateAssemblyResolutionTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(typeof(Ping), typeof(Ping));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblies(typeof(Ping).Assembly, typeof(Ping).Assembly));
_provider = services.BuildServiceProvider();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public async Task Should_not_call_constructor_multiple_times_when_using_a_pipeli

services.AddSingleton(output);
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(ConstructorTestBehavior<,>));
services.AddMediatR(typeof(Ping).GetTypeInfo().Assembly);
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand Down
Loading

0 comments on commit db8f2f5

Please sign in to comment.