Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing IRequest to just be a normal Task instead of Task<Unit> #835

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions samples/MediatR.Examples/JingHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace MediatR.Examples;

public class JingHandler : AsyncRequestHandler<Jing>
public class JingHandler : IRequestHandler<Jing>
{
private readonly TextWriter _writer;

Expand All @@ -13,7 +13,7 @@ public JingHandler(TextWriter writer)
_writer = writer;
}

protected override Task Handle(Jing request, CancellationToken cancellationToken)
public Task Handle(Jing request, CancellationToken cancellationToken)
{
return _writer.WriteLineAsync($"--- Handled Jing: {request.Message}, no Jong");
}
Expand Down
2 changes: 1 addition & 1 deletion src/MediatR.Contracts/IRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ namespace MediatR;
/// <summary>
/// Marker interface to represent a request with a void response
/// </summary>
public interface IRequest : IRequest<Unit> { }
public interface IRequest : IBaseRequest { }

/// <summary>
/// Marker interface to represent a request with a response
Expand Down
2 changes: 1 addition & 1 deletion src/MediatR.Contracts/MediatR.Contracts.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<SymbolPackageFormat>snupkg</SymbolPackageFormat>
<EmbedUntrackedSources>true</EmbedUntrackedSources>
<Deterministic>true</Deterministic>
<Version>1.0.1</Version>
<Version>2.0.0</Version>
<RootNamespace>MediatR</RootNamespace>

</PropertyGroup>
Expand Down
65 changes: 7 additions & 58 deletions src/MediatR/IRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,68 +21,17 @@ public interface IRequestHandler<in TRequest, TResponse>
}

/// <summary>
/// Defines a handler for a request with a void (<see cref="Unit" />) response.
/// You do not need to register this interface explicitly with a container as it inherits from the base <see cref="IRequestHandler{TRequest, TResponse}" /> interface.
/// Defines a handler for a request with a void response.
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
public interface IRequestHandler<in TRequest> : IRequestHandler<TRequest, Unit>
where TRequest : IRequest<Unit>
{
}

/// <summary>
/// Wrapper class for a handler that asynchronously handles a request and does not return a response
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
public abstract class AsyncRequestHandler<TRequest> : IRequestHandler<TRequest>
public interface IRequestHandler<in TRequest>
where TRequest : IRequest
{
async Task<Unit> IRequestHandler<TRequest, Unit>.Handle(TRequest request, CancellationToken cancellationToken)
{
await Handle(request, cancellationToken).ConfigureAwait(false);
return Unit.Value;
}

/// <summary>
/// Override in a derived class for the handler logic
/// </summary>
/// <param name="request">Request</param>
/// <param name="cancellationToken"></param>
/// <returns>Response</returns>
protected abstract Task Handle(TRequest request, CancellationToken cancellationToken);
}

/// <summary>
/// Wrapper class for a handler that synchronously handles a request and returns a response
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
/// <typeparam name="TResponse">The type of response from the handler</typeparam>
public abstract class RequestHandler<TRequest, TResponse> : IRequestHandler<TRequest, TResponse>
where TRequest : IRequest<TResponse>
{
Task<TResponse> IRequestHandler<TRequest, TResponse>.Handle(TRequest request, CancellationToken cancellationToken)
=> Task.FromResult(Handle(request));

/// <summary>
/// Override in a derived class for the handler logic
/// Handles a request
/// </summary>
/// <param name="request">Request</param>
/// <returns>Response</returns>
protected abstract TResponse Handle(TRequest request);
}

/// <summary>
/// Wrapper class for a handler that synchronously handles a request and does not return a response
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
public abstract class RequestHandler<TRequest> : IRequestHandler<TRequest>
where TRequest : IRequest
{
Task<Unit> IRequestHandler<TRequest, Unit>.Handle(TRequest request, CancellationToken cancellationToken)
{
Handle(request);
return Unit.Task;
}

protected abstract void Handle(TRequest request);
/// <param name="request">The request</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Response from the request</returns>
Task Handle(TRequest request, CancellationToken cancellationToken);
}
9 changes: 9 additions & 0 deletions src/MediatR/ISender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ public interface ISender
/// <returns>A task that represents the send operation. The task result contains the handler response</returns>
Task<TResponse> Send<TResponse>(IRequest<TResponse> request, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously send a request to a single handler with no response
/// </summary>
/// <param name="request">Request object</param>
/// <param name="cancellationToken">Optional cancellation token</param>
/// <returns>A task that represents the send operation.</returns>
Task Send<TRequest>(TRequest request, CancellationToken cancellationToken = default)
where TRequest : IRequest;

/// <summary>
/// Asynchronously send an object request to a single handler via dynamic dispatch
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion src/MediatR/MediatR.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="MediatR.Contracts" Version="1.0.1" />
<PackageReference Include="MediatR.Contracts" Version="[2.0.0, 3.0.0)" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="7.0.0" />
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.1.1" PrivateAssets="All" />
<PackageReference Include="MinVer" Version="4.2.0" PrivateAssets="All" />
</ItemGroup>

</Project>
41 changes: 37 additions & 4 deletions src/MediatR/Mediator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ public Task<TResponse> Send<TResponse>(IRequest<TResponse> request, Cancellation
return handler.Handle(request, _serviceProvider, cancellationToken);
}

public Task Send<TRequest>(TRequest request, CancellationToken cancellationToken = default)
where TRequest : IRequest
{
if (request == null)
{
throw new ArgumentNullException(nameof(request));
}

var requestType = typeof(TRequest);

var handler = (RequestHandlerWrapper)_requestHandlers.GetOrAdd(requestType,
static t => (RequestHandlerBase)(Activator.CreateInstance(typeof(RequestHandlerWrapperImpl<>).MakeGenericType(t))
?? throw new InvalidOperationException($"Could not create wrapper type for {t}")));

return handler.Handle(request, _serviceProvider, cancellationToken);
}

public Task<object?> Send(object request, CancellationToken cancellationToken = default)
{
if (request == null)
Expand All @@ -55,13 +72,29 @@ public Task<TResponse> Send<TResponse>(IRequest<TResponse> request, Cancellation
.GetInterfaces()
.FirstOrDefault(static i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IRequest<>));

Type wrapperType;

if (requestInterfaceType is null)
{
throw new ArgumentException($"{requestTypeKey.Name} does not implement {nameof(IRequest)}", nameof(request));
requestInterfaceType = requestTypeKey
.GetInterfaces()
.FirstOrDefault(static i => i == typeof(IRequest));

if (requestInterfaceType is null)
{
throw new ArgumentException($"{requestTypeKey.Name} does not implement {nameof(IRequest)}",
nameof(request));
}

wrapperType =
typeof(RequestHandlerWrapperImpl<>).MakeGenericType(requestTypeKey);
}
else
{
var responseType = requestInterfaceType.GetGenericArguments()[0];
wrapperType =
typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType);
}

var responseType = requestInterfaceType.GetGenericArguments()[0];
var wrapperType = typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType);

return (RequestHandlerBase)(Activator.CreateInstance(wrapperType)
?? throw new InvalidOperationException($"Could not create wrapper for type {wrapperType}"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using MediatR;
using MediatR.Pipeline;
using MediatR.Registration;
Expand Down
1 change: 1 addition & 0 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public static void AddMediatRClasses(IServiceCollection services, MediatRService
var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();

ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestPreProcessor<>), services, assembliesToScan, true, configuration);
Expand Down
48 changes: 40 additions & 8 deletions src/MediatR/Wrappers/RequestHandlerWrapper.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
using System;
using Microsoft.Extensions.DependencyInjection;

namespace MediatR.Wrappers;

using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;

namespace MediatR.Wrappers;

public abstract class RequestHandlerBase
{
public abstract Task<object?> Handle(object request, IServiceProvider serviceProvider,
CancellationToken cancellationToken);

}

public abstract class RequestHandlerWrapper<TResponse> : RequestHandlerBase
Expand All @@ -20,21 +18,55 @@ public abstract Task<TResponse> Handle(IRequest<TResponse> request, IServiceProv
CancellationToken cancellationToken);
}

public abstract class RequestHandlerWrapper : RequestHandlerBase
{
public abstract Task<Unit> Handle(IRequest request, IServiceProvider serviceProvider,
CancellationToken cancellationToken);
}

public class RequestHandlerWrapperImpl<TRequest, TResponse> : RequestHandlerWrapper<TResponse>
where TRequest : IRequest<TResponse>
{
public override async Task<object?> Handle(object request, IServiceProvider serviceProvider,
CancellationToken cancellationToken) =>
await Handle((IRequest<TResponse>)request, serviceProvider, cancellationToken).ConfigureAwait(false);
await Handle((IRequest<TResponse>) request, serviceProvider, cancellationToken).ConfigureAwait(false);

public override Task<TResponse> Handle(IRequest<TResponse> request, IServiceProvider serviceProvider,
CancellationToken cancellationToken)
{
Task<TResponse> Handler() => serviceProvider.GetRequiredService<IRequestHandler<TRequest, TResponse>>().Handle((TRequest) request, cancellationToken);
Task<TResponse> Handler() => serviceProvider.GetRequiredService<IRequestHandler<TRequest, TResponse>>()
.Handle((TRequest) request, cancellationToken);

return serviceProvider
.GetServices<IPipelineBehavior<TRequest, TResponse>>()
.Reverse()
.Aggregate((RequestHandlerDelegate<TResponse>) Handler, (next, pipeline) => () => pipeline.Handle((TRequest)request, next, cancellationToken))();
.Aggregate((RequestHandlerDelegate<TResponse>) Handler,
(next, pipeline) => () => pipeline.Handle((TRequest) request, next, cancellationToken))();
}
}

public class RequestHandlerWrapperImpl<TRequest> : RequestHandlerWrapper
where TRequest : IRequest
{
public override async Task<object?> Handle(object request, IServiceProvider serviceProvider,
CancellationToken cancellationToken) =>
await Handle((IRequest) request, serviceProvider, cancellationToken).ConfigureAwait(false);

public override Task<Unit> Handle(IRequest request, IServiceProvider serviceProvider,
CancellationToken cancellationToken)
{
async Task<Unit> Handler()
{
await serviceProvider.GetRequiredService<IRequestHandler<TRequest>>()
.Handle((TRequest) request, cancellationToken);

return Unit.Value;
}

return serviceProvider
.GetServices<IPipelineBehavior<TRequest, Unit>>()
.Reverse()
.Aggregate((RequestHandlerDelegate<Unit>) Handler,
(next, pipeline) => () => pipeline.Handle((TRequest) request, next, cancellationToken))();
}
}
10 changes: 6 additions & 4 deletions test/MediatR.Tests/ExceptionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public Task<Pong> Handle(NullPing request, CancellationToken cancellationToken)
}
}

public class VoidNullPingHandler : IRequestHandler<VoidNullPing, Unit>
public class VoidNullPingHandler : IRequestHandler<VoidNullPing>
{
public Task<Unit> Handle(VoidNullPing request, CancellationToken cancellationToken)
public Task Handle(VoidNullPing request, CancellationToken cancellationToken)
{
return Unit.Task;
return Task.CompletedTask;
}
}

Expand Down Expand Up @@ -244,7 +244,7 @@ public class PingException : IRequest

public class PingExceptionHandler : IRequestHandler<PingException>
{
public Task<Unit> Handle(PingException request, CancellationToken cancellationToken)
public Task Handle(PingException request, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand All @@ -261,6 +261,7 @@ public async Task Should_throw_exception_for_non_generic_send_when_exception_occ
scanner.IncludeNamespaceContainingType<Ping>();
scanner.WithDefaultConventions();
scanner.AddAllTypesOf(typeof(IRequestHandler<,>));
scanner.AddAllTypesOf(typeof(IRequestHandler<>));
});
cfg.For<IMediator>().Use<Mediator>();
});
Expand Down Expand Up @@ -309,6 +310,7 @@ public async Task Should_throw_exception_for_generic_send_when_exception_occurs(
scanner.IncludeNamespaceContainingType<Ping>();
scanner.WithDefaultConventions();
scanner.AddAllTypesOf(typeof(IRequestHandler<,>));
scanner.AddAllTypesOf(typeof(IRequestHandler<>));
});
cfg.For<IMediator>().Use<Mediator>();
});
Expand Down
15 changes: 8 additions & 7 deletions test/MediatR.Tests/GenericTypeConstraintsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ public class Jing : IRequest
public string? Message { get; set; }
}

public class JingHandler : IRequestHandler<Jing, Unit>
public class JingHandler : IRequestHandler<Jing>
{
public Task<Unit> Handle(Jing request, CancellationToken cancellationToken)
public Task Handle(Jing request, CancellationToken cancellationToken)
{
// empty handle
return Unit.Task;
return Task.CompletedTask;
}
}

Expand Down Expand Up @@ -98,6 +98,7 @@ public GenericTypeConstraintsTests()
scanner.IncludeNamespaceContainingType<Jing>();
scanner.WithDefaultConventions();
scanner.AddAllTypesOf(typeof(IRequestHandler<,>));
scanner.AddAllTypesOf(typeof(IRequestHandler<>));
});
cfg.For<IMediator>().Use<Mediator>();
});
Expand All @@ -119,15 +120,15 @@ public async Task Should_Resolve_Void_Return_Request()

// Assert it is of type IRequest and IRequest<T>
Assert.True(genericTypeConstraintsVoidReturn.IsIRequest);
Assert.True(genericTypeConstraintsVoidReturn.IsIRequestT);
Assert.False(genericTypeConstraintsVoidReturn.IsIRequestT);
Assert.True(genericTypeConstraintsVoidReturn.IsIBaseRequest);

// Verify it is of IRequest and IBaseRequest and IRequest<Unit>
// Verify it is of IRequest and IBaseRequest
var results = genericTypeConstraintsVoidReturn.Handle(jing);

Assert.Equal(3, results.Length);
Assert.Equal(2, results.Length);

results.ShouldContain(typeof(IRequest<Unit>));
results.ShouldNotContain(typeof(IRequest<Unit>));
results.ShouldContain(typeof(IBaseRequest));
results.ShouldContain(typeof(IRequest));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void ShouldResolveRequestHandler()
[Fact]
public void ShouldResolveInternalHandler()
{
_provider.GetService<IRequestHandler<InternalPing, Unit>>().ShouldNotBeNull();
_provider.GetService<IRequestHandler<InternalPing>>().ShouldNotBeNull();
}

[Fact]
Expand Down
Loading