From 14eed746457d124968fbbd5b6af442468892d83c Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Mon, 14 Feb 2022 16:52:24 +0000 Subject: [PATCH 1/5] Add support for endpoint filters in minimal APIs --- .../src/IRouteHandlerFilter.cs | 19 ++ .../src/PublicAPI.Unshipped.txt | 6 + .../src/RouteHandlerFilterContext.cs | 35 +++ .../src/PublicAPI.Unshipped.txt | 4 + .../src/RequestDelegateFactory.cs | 131 ++++++++-- .../test/RequestDelegateFactoryTests.cs | 244 ++++++++++++++++++ .../src/Builder/DelegateRouteHandlerFilter.cs | 19 ++ .../Builder/EndpointRouteBuilderExtensions.cs | 23 +- .../src/Builder/RouteHandlerBuilder.cs | 4 + .../Builder/RouteHandlerFilterExtensions.cs | 49 ++++ src/Http/Routing/src/PublicAPI.Unshipped.txt | 5 + src/Http/Routing/src/RouteEndpointBuilder.cs | 18 ++ 12 files changed, 529 insertions(+), 28 deletions(-) create mode 100644 src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs create mode 100644 src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs create mode 100644 src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs create mode 100644 src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs diff --git a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs new file mode 100644 index 000000000000..14fbc7af950d --- /dev/null +++ b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +/// +/// Provides an interface for implementing a filter targetting a route handler. +/// +public interface IRouteHandlerFilter +{ + /// + /// Implements the core logic associated with the filter given a + /// and the next filter to call in the pipeline. + /// + /// The associated with the current request/response. + /// The next filter in the pipeline. + /// The result of calling the current filter. + abstract ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next); +} diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index b7318f416a16..bee1f68cba6d 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -1,6 +1,8 @@ #nullable enable *REMOVED*abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string! Microsoft.AspNetCore.Http.EndpointMetadataCollection.GetRequiredMetadata() -> T! +Microsoft.AspNetCore.Http.RouteHandlerFilterContext.RouteHandlerFilterContext(Microsoft.AspNetCore.Http.HttpContext! httpContext, params object![]! parameters) -> void +Microsoft.AspNetCore.Http.IRouteHandlerFilter.InvokeAsync(Microsoft.AspNetCore.Http.RouteHandlerFilterContext! context, System.Func>! next) -> System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata.Name.get -> string? Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(Microsoft.AspNetCore.Routing.RouteValueDictionary? dictionary) -> void @@ -8,3 +10,7 @@ Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Co Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Collections.Generic.IEnumerable>? values) -> void abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string? Microsoft.AspNetCore.Http.Metadata.ISkipStatusCodePagesMetadata +Microsoft.AspNetCore.Http.RouteHandlerFilterContext +Microsoft.AspNetCore.Http.RouteHandlerFilterContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext! +Microsoft.AspNetCore.Http.RouteHandlerFilterContext.Parameters.get -> System.Collections.Generic.IList! +Microsoft.AspNetCore.Http.IRouteHandlerFilter diff --git a/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs b/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs new file mode 100644 index 000000000000..558d97cbd06b --- /dev/null +++ b/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +/// +/// Provides an abstraction for wrapping the and parameters +/// provided to a route handler. +/// +public class RouteHandlerFilterContext +{ + /// + /// Creates a new instance of the for a given request. + /// + /// The associated with the current request. + /// A list of parameters provided in the current request. + public RouteHandlerFilterContext(HttpContext httpContext, params object[] parameters) + { + HttpContext = httpContext; + Parameters = parameters; + } + + /// + /// The associated with the current request being processed by the filter. + /// + public HttpContext HttpContext { get; } + + /// + /// A list of parameters provided in the current request to the filter. + /// + /// This list is not read-only to premit modifying of existing parameters by filters. + /// + /// + public IList Parameters { get; } +} diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index 1030d0f0793e..a32059ca51d5 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -1,3 +1,7 @@ #nullable enable Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions +*REMOVED*static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! handler, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! +*REMOVED*static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func? targetFactory = null, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! handler, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null, System.Collections.Generic.IEnumerable? filters = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func? targetFactory = null, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null, System.Collections.Generic.IEnumerable? filters = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! static Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions.ConfigureRouteHandlerJsonOptions(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 954e8c94a082..7fa29589b641 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -39,6 +39,7 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!; + private static readonly MethodInfo WrapObjectAsValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(WrapObjectAsValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; // Call WriteAsJsonAsync() to serialize the runtime return type rather than the declared return type. // https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-polymorphism @@ -71,12 +72,21 @@ public static partial class RequestDelegateFactory private static readonly MemberExpression FormFilesExpr = Expression.Property(FormExpr, typeof(IFormCollection).GetProperty(nameof(IFormCollection.Files))!); private static readonly MemberExpression StatusCodeExpr = Expression.Property(HttpResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!); private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo>(() => Task.CompletedTask)); + private static readonly NewExpression CompletedValueTaskExpr = Expression.New(typeof(ValueTask).GetConstructor(new[] { typeof(Task) })!, CompletedTaskExpr); private static readonly ParameterExpression TempSourceStringExpr = ParameterBindingMethodCache.TempSourceStringExpr; private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); private static readonly UnaryExpression TempSourceStringIsNotNullOrEmptyExpr = Expression.Not(Expression.Call(StringIsNullOrEmptyMethod, TempSourceStringExpr)); + private static readonly ConstructorInfo RouteHandlerFilterContextConstructor = typeof(RouteHandlerFilterContext).GetConstructors().Single(); + private static readonly ParameterExpression FilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "context"); + private static readonly MemberExpression FilterContextParametersExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.Parameters))!); + private static readonly MemberExpression FilterContextHttpContextExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.HttpContext))!); + private static readonly MemberExpression FilterContextHttpContextResponseExpr = Expression.Property(FilterContextHttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Response))!); + private static readonly MemberExpression FilterContextHttpContextStatusCodeExpr = Expression.Property(FilterContextHttpContextResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!); + private static readonly ParameterExpression InvokedFilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "filterContext"); + private static readonly string[] DefaultAcceptsContentType = new[] { "application/json" }; private static readonly string[] FormFileContentType = new[] { "multipart/form-data" }; @@ -85,9 +95,10 @@ public static partial class RequestDelegateFactory /// /// A request handler with any number of custom parameters that often produces a response with its return value. /// The used to configure the behavior of the handler. + /// A collection of s invoked in the endpoint associated with this handler. /// The . #pragma warning disable RS0026 // Do not add multiple public overloads with optional parameters - public static RequestDelegateResult Create(Delegate handler, RequestDelegateFactoryOptions? options = null) + public static RequestDelegateResult Create(Delegate handler, RequestDelegateFactoryOptions? options = null, IEnumerable? filters = null) #pragma warning restore RS0026 // Do not add multiple public overloads with optional parameters { if (handler is null) @@ -102,6 +113,11 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact }; var factoryContext = CreateFactoryContext(options); + if (filters is not null) + { + factoryContext.Filters.AddRange(filters); + } + var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext); return new RequestDelegateResult(httpContext => targetableRequestDelegate(handler.Target, httpContext), factoryContext.Metadata); @@ -113,9 +129,10 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact /// A request handler with any number of custom parameters that often produces a response with its return value. /// Creates the for the non-static method. /// The used to configure the behavior of the handler. + /// A collection of s invoked in the endpoint associated with this handler. /// The . #pragma warning disable RS0026 // Do not add multiple public overloads with optional parameters - public static RequestDelegateResult Create(MethodInfo methodInfo, Func? targetFactory = null, RequestDelegateFactoryOptions? options = null) + public static RequestDelegateResult Create(MethodInfo methodInfo, Func? targetFactory = null, RequestDelegateFactoryOptions? options = null, IEnumerable? filters = null) #pragma warning restore RS0026 // Do not add multiple public overloads with optional parameters { if (methodInfo is null) @@ -129,6 +146,10 @@ public static RequestDelegateResult Create(MethodInfo methodInfo, Func 0) + { + var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext); + Expression>> invokePipeline = (context) => filterPipeline(context); + returnType = typeof(ValueTask); + // var filterContext = new RouteHandlerFilterContext(httpContext, new[] { (object)name_local, (object)int_local }); + // invokePipeline.Invoke(filterContext); + factoryContext.MethodCall = Expression.Block( + new[] { InvokedFilterContextExpr }, + Expression.Assign( + InvokedFilterContextExpr, + Expression.New(RouteHandlerFilterContextConstructor, + new Expression[] { HttpContextExpr, Expression.NewArrayInit(typeof(object), factoryContext.BoxedArgs) })), + Expression.Invoke(invokePipeline, InvokedFilterContextExpr) + ); + } var responseWritingMethodCall = factoryContext.ParamCheckExpressions.Count > 0 ? - CreateParamCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) : - CreateResponseWritingMethodCall(methodInfo, targetExpression, arguments); + CreateParamCheckingResponseWritingMethodCall(returnType, targetExpression, arguments, factoryContext) : + AddResponseWritingToMethodCall(factoryContext.MethodCall, returnType); if (factoryContext.UsingTempSourceString) { @@ -189,6 +231,34 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext); } + private static Func> CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) + { + // httpContext.Response.StatusCode == 400 + // ? Task.CompletedTask + // : handler((string)context.Parameters[0], (int)context.Parameters[1]) + var filteredInvocation = Expression.Lambda>>( + Expression.Condition( + Expression.Equal(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), + CompletedValueTaskExpr, + Expression.Block( + new[] { TargetExpr }, + Expression.Call(WrapObjectAsValueTaskMethod, + target is null + ? Expression.Call(methodInfo, factoryContext.ContextArgAccess) + : Expression.Call(target, methodInfo, factoryContext.ContextArgAccess)) + )), + FilterContextExpr).Compile(); + + for (int i = factoryContext.Filters.Count - 1; i >= 0; i--) + { + var currentFilter = factoryContext.Filters[i]; + var nextFilter = filteredInvocation; + filteredInvocation = (RouteHandlerFilterContext context) => currentFilter.InvokeAsync(context, nextFilter); + + } + return filteredInvocation; + } + private static Expression[] CreateArguments(ParameterInfo[]? parameters, FactoryContext factoryContext) { if (parameters is null || parameters.Length == 0) @@ -201,6 +271,16 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory for (var i = 0; i < parameters.Length; i++) { args[i] = CreateArgument(parameters[i], factoryContext); + // Register expressions containing the boxed and unboxed variants + // of the route handler's arguments for use in RouteHandlerFilterContext + // construction and route handler invocation. + // (string)context.Parameters[0]; + factoryContext.ContextArgAccess.Add( + Expression.Convert( + Expression.Property(FilterContextParametersExpr, "Item", Expression.Constant(i)), + parameters[i].ParameterType)); + // (object)name_local + factoryContext.BoxedArgs.Add(Expression.Convert(args[i], typeof(object))); } if (factoryContext.HasInferredBody && factoryContext.DisableInferredFromBody) @@ -381,16 +461,15 @@ target is null ? Expression.Call(methodInfo, arguments) : Expression.Call(target, methodInfo, arguments); - private static Expression CreateResponseWritingMethodCall(MethodInfo methodInfo, Expression? target, Expression[] arguments) + private static ValueTask WrapObjectAsValueTask(object? obj) { - var callMethod = CreateMethodCall(methodInfo, target, arguments); - return AddResponseWritingToMethodCall(callMethod, methodInfo.ReturnType); + return ValueTask.FromResult(obj); } // If we're calling TryParse or validating parameter optionality and // wasParamCheckFailure indicates it failed, set a 400 StatusCode instead of calling the method. private static Expression CreateParamCheckingResponseWritingMethodCall( - MethodInfo methodInfo, Expression? target, Expression[] arguments, FactoryContext factoryContext) + Type returnType, Expression? target, Expression[] arguments, FactoryContext factoryContext) { // { // string tempSourceString; @@ -440,17 +519,27 @@ private static Expression CreateParamCheckingResponseWritingMethodCall( localVariables[factoryContext.ExtraLocals.Count] = WasParamCheckFailureExpr; - var set400StatusAndReturnCompletedTask = Expression.Block( - Expression.Assign(StatusCodeExpr, Expression.Constant(400)), - CompletedTaskExpr); - - var methodCall = CreateMethodCall(methodInfo, target, arguments); - - var checkWasParamCheckFailure = Expression.Condition(WasParamCheckFailureExpr, - set400StatusAndReturnCompletedTask, - AddResponseWritingToMethodCall(methodCall, methodInfo.ReturnType)); + if (factoryContext.Filters.Count > 0) + { + var checkWasParamCheckFailureWithFilters = Expression.Block( + Expression.IfThen( + WasParamCheckFailureExpr, + Expression.Assign(StatusCodeExpr, Expression.Constant(400))), + AddResponseWritingToMethodCall(factoryContext.MethodCall!, returnType) + ); - checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailure; + checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailureWithFilters; + } + else + { + var checkWasParamCheckFailure = Expression.Condition( + WasParamCheckFailureExpr, + Expression.Block( + Expression.Assign(StatusCodeExpr, Expression.Constant(400)), + CompletedTaskExpr), + AddResponseWritingToMethodCall(factoryContext.MethodCall!, returnType)); + checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailure; + } return Expression.Block(localVariables, checkParamAndCallMethod); } @@ -1596,6 +1685,12 @@ private class FactoryContext public bool ReadForm { get; set; } public ParameterInfo? FirstFormRequestBodyParameter { get; set; } + // Properties for constructing and managing filters + public List ContextArgAccess { get; } = new(); + public Expression? MethodCall { get; set; } + public List BoxedArgs { get; } = new(); + public List Filters { get; set; } = new(); + } private static class RequestDelegateFactoryConstants diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 44f88b26e236..bf8792b5a8b7 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -4200,6 +4200,178 @@ void TestAction(IFormFile file) Assert.Equal(400, badHttpRequestException.StatusCode); } + [Fact] + public async Task RequestDelegateFactory_InvokesFiltersButNotHandler_OnArgumentError() + { + var invoked = false; + // Arrange + string HelloName(string name) + { + invoked = true; + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyStringArgumentFilter() }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + Assert.False(invoked); + Assert.Equal(400, httpContext.Response.StatusCode); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatProvidesCustomErrorMessage() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ProvideCustomErrorMessageFilter() }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var decodedResponseBody = JsonSerializer.Deserialize(responseBodyStream.ToArray()); + Assert.Equal(400, httpContext.Response.StatusCode); + Assert.Equal("New response", decodedResponseBody!.Detail); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeMultipleEndpointFilters_ThatTouchArguments() + { + // Arrange + string HelloName(string name, int age) + { + return $"Hello, {name}! You are {age} years old."; + }; + + var loggerInvoked = 0; + void Log(string arg) => loggerInvoked++; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName", + ["age"] = "25" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyIntArgumentFilter(), new LogArgumentsFilter(Log) }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Hello, TestName! You are 27 years old.", responseBody); + Assert.Equal(2, loggerInvoked); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesBodyParameter() + { + // Arrange + Todo todo = new Todo() { Name = "Write tests", IsComplete = true }; + string PrintTodo(Todo todo) + { + return $"{todo.Name} is {(todo.IsComplete ? "done" : "not done")}."; + }; + + var httpContext = CreateHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo); + var stream = new MemoryStream(requestBodyBytes); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(CultureInfo.InvariantCulture); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(PrintTodo, null, new List() { new ModifyTodoArgumentFilter() }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Write tests is not done.", responseBody); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesResult() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyStringResultFilter() }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("HELLO, TESTNAME!", responseBody); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeMultipleEndpointFilters_ThatModifyArgumentsAndResult() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyStringResultFilter(), new ModifyStringArgumentFilter() }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("HELLO, TESTNAMEPREFIX!", responseBody); + } + private DefaultHttpContext CreateHttpContext() { var responseFeature = new TestHttpResponseFeature(); @@ -4559,6 +4731,78 @@ public TlsConnectionFeature(X509Certificate2 clientCertificate) throw new NotImplementedException(); } } + + private class ModifyStringArgumentFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + context.Parameters[0] = context.Parameters[0] != null ? $"{((string)context.Parameters[0]!)}Prefix" : "NULL"; + return await next(context); + } + } + + private class ModifyIntArgumentFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + context.Parameters[1] = ((int)context.Parameters[1]!) + 2; + return await next(context); + } + } + + private class ModifyTodoArgumentFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + Todo originalTodo = (Todo)context.Parameters[0]!; + originalTodo!.IsComplete = !originalTodo.IsComplete; + context.Parameters[0] = originalTodo; + return await next(context); + } + } + + private class ProvideCustomErrorMessageFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + if (context.HttpContext.Response.StatusCode == 400) + { + return Results.Problem("New response", statusCode: 400); + } + return await next(context); + } + } + + private class LogArgumentsFilter : IRouteHandlerFilter + { + private Action _logger; + + public LogArgumentsFilter(Action logger) + { + _logger = logger; + } + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + foreach (var parameter in context.Parameters) + { + _logger(parameter!.ToString() ?? "no arg"); + } + return await next(context); + } + } + + private class ModifyStringResultFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + var previousResult = await next(context); + if (previousResult is string stringResult) + { + return stringResult.ToUpperInvariant(); + } + return previousResult; + } + } } internal static class TestExtensionResults diff --git a/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs new file mode 100644 index 000000000000..9872915904a8 --- /dev/null +++ b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +internal class DelegateRouteHandlerFilter : IRouteHandlerFilter +{ + private readonly Func>, ValueTask> _routeHandlerFilter; + + internal DelegateRouteHandlerFilter(Func>, ValueTask> routeHandlerFilter) + { + _routeHandlerFilter = routeHandlerFilter; + } + + public ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + return _routeHandlerFilter(context, next); + } +} diff --git a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs index 4c1cb9b09904..e43bf6145869 100644 --- a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs @@ -492,10 +492,7 @@ private static RouteHandlerBuilder Map( DisableInferBodyFromParameters = disableInferBodyFromParameters, }; - var requestDelegateResult = RequestDelegateFactory.Create(handler, options); - var builder = new RouteEndpointBuilder( - requestDelegateResult.RequestDelegate, pattern, defaultOrder) { @@ -521,12 +518,6 @@ private static RouteHandlerBuilder Map( // Add delegate attributes as metadata var attributes = handler.Method.GetCustomAttributes(); - // Add add request delegate metadata - foreach (var metadata in requestDelegateResult.EndpointMetadata) - { - builder.Metadata.Add(metadata); - } - // This can be null if the delegate is a dynamic method or compiled from an expression tree if (attributes is not null) { @@ -543,6 +534,18 @@ private static RouteHandlerBuilder Map( endpoints.DataSources.Add(dataSource); } - return new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder)); + var routeHandlerBuilder = new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder)); + routeHandlerBuilder.Add(endpointBuilder => + { + var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options, routeHandlerBuilder.RouteHandlerFilters); + // Add add request delegate metadata + foreach (var metadata in filteredRequestDelegateResult.EndpointMetadata) + { + endpointBuilder.Metadata.Add(metadata); + } + endpointBuilder.RequestDelegate = filteredRequestDelegateResult.RequestDelegate; + }); + + return routeHandlerBuilder; } } diff --git a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs index 40b896f2db24..b42e22cc3d8d 100644 --- a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs +++ b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Http; + namespace Microsoft.AspNetCore.Builder; /// @@ -11,6 +13,8 @@ public sealed class RouteHandlerBuilder : IEndpointConventionBuilder private readonly IEnumerable? _endpointConventionBuilders; private readonly IEndpointConventionBuilder? _endpointConventionBuilder; + internal List RouteHandlerFilters { get; } = new(); + /// /// Instantiates a new given a single /// . diff --git a/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs new file mode 100644 index 000000000000..ffec088f3e73 --- /dev/null +++ b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.AspNetCore.Builder; + +namespace Microsoft.AspNetCore.Http; + +/// +/// Extension methods for adding to a route handler. +/// +public static class RouteHandlerFilterExtensions +{ + /// + /// Registers a filter onto the route handler. + /// + /// The . + /// The to register. + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, IRouteHandlerFilter filter) + { + builder.RouteHandlerFilters.Add(filter); + return builder; + } + + /// + /// Registers a filter of type onto the route handler. + /// + /// The type of the to register. + /// The . + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TFilterType>(this RouteHandlerBuilder builder) where TFilterType : IRouteHandlerFilter, new() + { + builder.RouteHandlerFilters.Add(new TFilterType()); + return builder; + } + + /// + /// Registers a filter given a delegate onto the route handler. + /// + /// The . + /// A representing the core logic of the filter. + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, Func>, ValueTask> routeHandlerFilter) + { + builder.RouteHandlerFilters.Add(new DelegateRouteHandlerFilter(routeHandlerFilter)); + return builder; + } +} diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 341f2446a1cf..1aa4f87f6a32 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1,4 +1,6 @@ #nullable enable +Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions +Microsoft.AspNetCore.Routing.RouteEndpointBuilder.RouteEndpointBuilder(Microsoft.AspNetCore.Routing.Patterns.RoutePattern! routePattern, int order) -> void Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token, System.Type! type) -> void Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token) -> void static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! handler) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! @@ -6,3 +8,6 @@ static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this override Microsoft.AspNetCore.Routing.RouteValuesAddress.ToString() -> string? *REMOVED*~Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, Microsoft.AspNetCore.Http.IRouteHandlerFilter! filter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, System.Func>!, System.Threading.Tasks.ValueTask>! routeHandlerFilter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! diff --git a/src/Http/Routing/src/RouteEndpointBuilder.cs b/src/Http/Routing/src/RouteEndpointBuilder.cs index bb27282fa97e..f58c3906aaa4 100644 --- a/src/Http/Routing/src/RouteEndpointBuilder.cs +++ b/src/Http/Routing/src/RouteEndpointBuilder.cs @@ -38,6 +38,24 @@ public RouteEndpointBuilder( Order = order; } + /// + /// Constructs a new instance. + /// + /// The to use in URL matching. + /// The order assigned to the endpoint. + /// + /// This constructor allows the to be added to the + /// after construction but before + /// is invoked. + /// + public RouteEndpointBuilder( + RoutePattern routePattern, + int order) + { + RoutePattern = routePattern; + Order = order; + } + /// public override Endpoint Build() { From c981930119732cf098a8416ed545076894d6e313 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Wed, 2 Mar 2022 18:54:49 +0000 Subject: [PATCH 2/5] Fix build and react to peer review --- .../src/IRouteHandlerFilter.cs | 2 +- .../src/PublicAPI.Unshipped.txt | 6 +-- .../src/RequestDelegateFactory.cs | 44 ++++++++++--------- .../src/RequestDelegateFactoryOptions.cs | 5 +++ .../test/RequestDelegateFactoryTests.cs | 30 ++++++++++--- .../Builder/EndpointRouteBuilderExtensions.cs | 18 ++++---- src/Http/Routing/src/PublicAPI.Unshipped.txt | 1 - src/Http/Routing/src/RouteEndpointBuilder.cs | 2 +- 8 files changed, 65 insertions(+), 43 deletions(-) diff --git a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs index 14fbc7af950d..83d0173b756f 100644 --- a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs +++ b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs @@ -15,5 +15,5 @@ public interface IRouteHandlerFilter /// The associated with the current request/response. /// The next filter in the pipeline. /// The result of calling the current filter. - abstract ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next); + ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next); } diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index a32059ca51d5..8385166aa4d8 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -1,7 +1,5 @@ #nullable enable Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions -*REMOVED*static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! handler, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! -*REMOVED*static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func? targetFactory = null, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! handler, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null, System.Collections.Generic.IEnumerable? filters = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func? targetFactory = null, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null, System.Collections.Generic.IEnumerable? filters = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! static Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions.ConfigureRouteHandlerJsonOptions(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.get -> System.Collections.Generic.IEnumerable? +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.init -> void diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 7fa29589b641..0bd9d34c3f27 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -95,10 +95,9 @@ public static partial class RequestDelegateFactory /// /// A request handler with any number of custom parameters that often produces a response with its return value. /// The used to configure the behavior of the handler. - /// A collection of s invoked in the endpoint associated with this handler. /// The . #pragma warning disable RS0026 // Do not add multiple public overloads with optional parameters - public static RequestDelegateResult Create(Delegate handler, RequestDelegateFactoryOptions? options = null, IEnumerable? filters = null) + public static RequestDelegateResult Create(Delegate handler, RequestDelegateFactoryOptions? options = null) #pragma warning restore RS0026 // Do not add multiple public overloads with optional parameters { if (handler is null) @@ -113,10 +112,6 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact }; var factoryContext = CreateFactoryContext(options); - if (filters is not null) - { - factoryContext.Filters.AddRange(filters); - } var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext); @@ -129,10 +124,9 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact /// A request handler with any number of custom parameters that often produces a response with its return value. /// Creates the for the non-static method. /// The used to configure the behavior of the handler. - /// A collection of s invoked in the endpoint associated with this handler. /// The . #pragma warning disable RS0026 // Do not add multiple public overloads with optional parameters - public static RequestDelegateResult Create(MethodInfo methodInfo, Func? targetFactory = null, RequestDelegateFactoryOptions? options = null, IEnumerable? filters = null) + public static RequestDelegateResult Create(MethodInfo methodInfo, Func? targetFactory = null, RequestDelegateFactoryOptions? options = null) #pragma warning restore RS0026 // Do not add multiple public overloads with optional parameters { if (methodInfo is null) @@ -146,10 +140,6 @@ public static RequestDelegateResult Create(MethodInfo methodInfo, Func CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext) @@ -202,7 +193,7 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions // If there are filters registered on the route handler, then we update the method call and // return type associated with the request to allow for the filter invocation pipeline. - if (factoryContext.Filters.Count > 0) + if (factoryContext.Filters is { Count: > 0 }) { var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext); Expression>> invokePipeline = (context) => filterPipeline(context); @@ -220,7 +211,7 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions } var responseWritingMethodCall = factoryContext.ParamCheckExpressions.Count > 0 ? - CreateParamCheckingResponseWritingMethodCall(returnType, targetExpression, arguments, factoryContext) : + CreateParamCheckingResponseWritingMethodCall(returnType, factoryContext) : AddResponseWritingToMethodCall(factoryContext.MethodCall, returnType); if (factoryContext.UsingTempSourceString) @@ -249,9 +240,9 @@ target is null )), FilterContextExpr).Compile(); - for (int i = factoryContext.Filters.Count - 1; i >= 0; i--) + for (int i = factoryContext.Filters!.Count - 1; i >= 0; i--) { - var currentFilter = factoryContext.Filters[i]; + var currentFilter = factoryContext.Filters![i]; var nextFilter = filteredInvocation; filteredInvocation = (RouteHandlerFilterContext context) => currentFilter.InvokeAsync(context, nextFilter); @@ -468,8 +459,7 @@ target is null ? // If we're calling TryParse or validating parameter optionality and // wasParamCheckFailure indicates it failed, set a 400 StatusCode instead of calling the method. - private static Expression CreateParamCheckingResponseWritingMethodCall( - Type returnType, Expression? target, Expression[] arguments, FactoryContext factoryContext) + private static Expression CreateParamCheckingResponseWritingMethodCall(Type returnType, FactoryContext factoryContext) { // { // string tempSourceString; @@ -519,8 +509,15 @@ private static Expression CreateParamCheckingResponseWritingMethodCall( localVariables[factoryContext.ExtraLocals.Count] = WasParamCheckFailureExpr; - if (factoryContext.Filters.Count > 0) + // If filters have been registered, we set the `wasParamCheckFailure` property + // but do not return from the invocation to allow the filters to run. + if (factoryContext.Filters is { Count: > 0 }) { + // if (wasParamCheckFailure) + // { + // httpContext.Response.StatusCode = 400; + // } + // return RequestDelegateFactory.ExecuteObjectReturn(invocationPipeline.Invoke(context) as object); var checkWasParamCheckFailureWithFilters = Expression.Block( Expression.IfThen( WasParamCheckFailureExpr, @@ -532,6 +529,12 @@ private static Expression CreateParamCheckingResponseWritingMethodCall( } else { + // wasParamCheckFailure ? { + // httpContext.Response.StatusCode = 400; + // return Task.CompletedTask; + // } : { + // return RequestDelegateFactory.ExecuteObjectReturn(invocationPipeline.Invoke(context) as object); + // } var checkWasParamCheckFailure = Expression.Condition( WasParamCheckFailureExpr, Expression.Block( @@ -1689,8 +1692,7 @@ private class FactoryContext public List ContextArgAccess { get; } = new(); public Expression? MethodCall { get; set; } public List BoxedArgs { get; } = new(); - public List Filters { get; set; } = new(); - + public List? Filters { get; init; } } private static class RequestDelegateFactoryConstants diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs index 892cbd2c7efe..35f88f4db8cd 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs @@ -31,4 +31,9 @@ public sealed class RequestDelegateFactoryOptions /// Prevent the from inferring a parameter should be bound from the request body without an attribute that implements . /// public bool DisableInferBodyFromParameters { get; init; } + + /// + /// The list of filters that must run in the pipeline for a given route handler. + /// + public IEnumerable? RouteHandlerFilters { get; init; } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index bf8792b5a8b7..1af56e02660e 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -4214,7 +4214,10 @@ string HelloName(string name) var httpContext = CreateHttpContext(); // Act - var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyStringArgumentFilter() }); + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyStringArgumentFilter() } + }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4238,7 +4241,10 @@ string HelloName(string name) httpContext.Response.Body = responseBodyStream; // Act - var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ProvideCustomErrorMessageFilter() }); + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ProvideCustomErrorMessageFilter() } + }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4272,7 +4278,10 @@ string HelloName(string name, int age) }); // Act - var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyIntArgumentFilter(), new LogArgumentsFilter(Log) }); + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyIntArgumentFilter(), new LogArgumentsFilter(Log) } + }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4305,7 +4314,10 @@ string PrintTodo(Todo todo) httpContext.Response.Body = responseBodyStream; // Act - var factoryResult = RequestDelegateFactory.Create(PrintTodo, null, new List() { new ModifyTodoArgumentFilter() }); + var factoryResult = RequestDelegateFactory.Create(PrintTodo, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyTodoArgumentFilter() } + }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4334,7 +4346,10 @@ string HelloName(string name) }); // Act - var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyStringResultFilter() }); + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyStringResultFilter() } + }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4363,7 +4378,10 @@ string HelloName(string name) }); // Act - var factoryResult = RequestDelegateFactory.Create(HelloName, null, new List() { new ModifyStringResultFilter(), new ModifyStringArgumentFilter() }); + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyStringResultFilter(), new ModifyStringArgumentFilter() } + }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); diff --git a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs index e43bf6145869..6f896ce0394b 100644 --- a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs @@ -484,14 +484,6 @@ private static RouteHandlerBuilder Map( var routeHandlerOptions = endpoints.ServiceProvider?.GetService>(); - var options = new RequestDelegateFactoryOptions - { - ServiceProvider = endpoints.ServiceProvider, - RouteParameterNames = routeParams, - ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false, - DisableInferBodyFromParameters = disableInferBodyFromParameters, - }; - var builder = new RouteEndpointBuilder( pattern, defaultOrder) @@ -537,7 +529,15 @@ private static RouteHandlerBuilder Map( var routeHandlerBuilder = new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder)); routeHandlerBuilder.Add(endpointBuilder => { - var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options, routeHandlerBuilder.RouteHandlerFilters); + var options = new RequestDelegateFactoryOptions + { + ServiceProvider = endpoints.ServiceProvider, + RouteParameterNames = routeParams, + ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false, + DisableInferBodyFromParameters = disableInferBodyFromParameters, + RouteHandlerFilters = routeHandlerBuilder.RouteHandlerFilters + }; + var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options); // Add add request delegate metadata foreach (var metadata in filteredRequestDelegateResult.EndpointMetadata) { diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 1aa4f87f6a32..4cf74e9056fd 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1,6 +1,5 @@ #nullable enable Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions -Microsoft.AspNetCore.Routing.RouteEndpointBuilder.RouteEndpointBuilder(Microsoft.AspNetCore.Routing.Patterns.RoutePattern! routePattern, int order) -> void Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token, System.Type! type) -> void Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token) -> void static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! handler) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! diff --git a/src/Http/Routing/src/RouteEndpointBuilder.cs b/src/Http/Routing/src/RouteEndpointBuilder.cs index f58c3906aaa4..add1f849a4a4 100644 --- a/src/Http/Routing/src/RouteEndpointBuilder.cs +++ b/src/Http/Routing/src/RouteEndpointBuilder.cs @@ -48,7 +48,7 @@ public RouteEndpointBuilder( /// after construction but before /// is invoked. /// - public RouteEndpointBuilder( + internal RouteEndpointBuilder( RoutePattern routePattern, int order) { From be5a2ac8e1a658aa654c1758e9269b294ba44cbc Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Wed, 2 Mar 2022 21:49:44 +0000 Subject: [PATCH 3/5] Fixing failing tests and address feedback --- .../src/PublicAPI.Unshipped.txt | 2 +- .../src/RequestDelegateFactory.cs | 5 ++-- .../src/RequestDelegateFactoryOptions.cs | 2 +- .../Builder/EndpointRouteBuilderExtensions.cs | 27 ++++++++++--------- ...ndlerEndpointRouteBuilderExtensionsTest.cs | 21 ++++++++++----- 5 files changed, 34 insertions(+), 23 deletions(-) diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index 8385166aa4d8..1d4c624f9113 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -1,5 +1,5 @@ #nullable enable Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions static Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions.ConfigureRouteHandlerJsonOptions(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! -Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.get -> System.Collections.Generic.IEnumerable? +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.get -> System.Collections.Generic.IReadOnlyList? Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.init -> void diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 0bd9d34c3f27..ee194aac8435 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -79,7 +79,7 @@ public static partial class RequestDelegateFactory private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); private static readonly UnaryExpression TempSourceStringIsNotNullOrEmptyExpr = Expression.Not(Expression.Call(StringIsNullOrEmptyMethod, TempSourceStringExpr)); - private static readonly ConstructorInfo RouteHandlerFilterContextConstructor = typeof(RouteHandlerFilterContext).GetConstructors().Single(); + private static readonly ConstructorInfo RouteHandlerFilterContextConstructor = typeof(RouteHandlerFilterContext).GetConstructor(new[] { typeof(HttpContext), typeof(object[]) })!; private static readonly ParameterExpression FilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "context"); private static readonly MemberExpression FilterContextParametersExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.Parameters))!); private static readonly MemberExpression FilterContextHttpContextExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.HttpContext))!); @@ -224,6 +224,7 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions private static Func> CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) { + Debug.Assert(factoryContext.Filters is not null); // httpContext.Response.StatusCode == 400 // ? Task.CompletedTask // : handler((string)context.Parameters[0], (int)context.Parameters[1]) @@ -240,7 +241,7 @@ target is null )), FilterContextExpr).Compile(); - for (int i = factoryContext.Filters!.Count - 1; i >= 0; i--) + for (var i = factoryContext.Filters.Count - 1; i >= 0; i--) { var currentFilter = factoryContext.Filters![i]; var nextFilter = filteredInvocation; diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs index 35f88f4db8cd..870c2a06158e 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs @@ -35,5 +35,5 @@ public sealed class RequestDelegateFactoryOptions /// /// The list of filters that must run in the pipeline for a given route handler. /// - public IEnumerable? RouteHandlerFilters { get; init; } + public IReadOnlyList? RouteHandlerFilters { get; init; } } diff --git a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs index 6f896ce0394b..6ce2d6c2c7ea 100644 --- a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs @@ -507,18 +507,6 @@ private static RouteHandlerBuilder Map( builder.DisplayName = $"{builder.DisplayName} => {endpointName}"; } - // Add delegate attributes as metadata - var attributes = handler.Method.GetCustomAttributes(); - - // This can be null if the delegate is a dynamic method or compiled from an expression tree - if (attributes is not null) - { - foreach (var attribute in attributes) - { - builder.Metadata.Add(attribute); - } - } - var dataSource = endpoints.DataSources.OfType().FirstOrDefault(); if (dataSource is null) { @@ -538,11 +526,24 @@ private static RouteHandlerBuilder Map( RouteHandlerFilters = routeHandlerBuilder.RouteHandlerFilters }; var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options); - // Add add request delegate metadata + // Add request delegate metadata foreach (var metadata in filteredRequestDelegateResult.EndpointMetadata) { endpointBuilder.Metadata.Add(metadata); } + + // We add attributes on the handler after those automatically generated by the + // RDF since they have a higher specificity. + var attributes = handler.Method.GetCustomAttributes(); + + // This can be null if the delegate is a dynamic method or compiled from an expression tree + if (attributes is not null) + { + foreach (var attribute in attributes) + { + endpointBuilder.Metadata.Add(attribute); + } + } endpointBuilder.RequestDelegate = filteredRequestDelegateResult.RequestDelegate; }); diff --git a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs index 5636c10cd940..1c3450159801 100644 --- a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs @@ -208,7 +208,9 @@ public async Task MapGetWithoutRouteParameter_BuildsEndpointWithQuerySpecificBin public void MapGet_ThrowsWithImplicitFromBody() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapGet("/", (Todo todo) => { })); + _ = builder.MapGet("/", (Todo todo) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message); Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message); } @@ -217,7 +219,9 @@ public void MapGet_ThrowsWithImplicitFromBody() public void MapDelete_ThrowsWithImplicitFromBody() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapDelete("/", (Todo todo) => { })); + _ = builder.MapDelete("/", (Todo todo) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message); Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message); } @@ -243,7 +247,9 @@ public static object[][] NonImplicitFromBodyMethods public void MapVerb_ThrowsWithImplicitFromBody(string method) { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapMethods("/", new[] { method }, (Todo todo) => { })); + _ = builder.MapMethods("/", new[] { method }, (Todo todo) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message); Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message); } @@ -581,7 +587,9 @@ public async Task MapVerbWithRouteParameterDoesNotFallbackToQuery(Func(() => builder.MapGet("/", ([FromRoute] int id) => { })); + _ = builder.MapGet("/", ([FromRoute] int id) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Equal("'id' is not a route parameter.", ex.Message); } @@ -637,7 +645,9 @@ public async Task MapGetWithNamedFromRouteParameter_FailsForParameterName() public void MapGetWithNamedFromRouteParameter_ThrowsForMismatchedPattern() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapGet("/{id}", ([FromRoute(Name = "value")] int id, HttpContext httpContext) => { })); + _ = builder.MapGet("/{id}", ([FromRoute(Name = "value")] int id, HttpContext httpContext) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Equal("'value' is not a route parameter.", ex.Message); } @@ -677,7 +687,6 @@ public void MapPost_BuildsEndpointWithCorrectEndpointMetadata() Assert.False(endpointMetadata!.IsOptional); Assert.Equal(typeof(Todo), endpointMetadata.RequestType); Assert.Equal(new[] { "application/xml" }, endpointMetadata.ContentTypes); - } [Fact] From 2f8a5b6a3d9a27998ae382f716b7e75aac399956 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Thu, 3 Mar 2022 06:12:03 +0000 Subject: [PATCH 4/5] Don't execute handler after filter if StatusCode >= 400 --- src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs | 3 ++- src/Http/Http.Extensions/src/RequestDelegateFactory.cs | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs index 83d0173b756f..4d3e583eaa5d 100644 --- a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs +++ b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs @@ -14,6 +14,7 @@ public interface IRouteHandlerFilter /// /// The associated with the current request/response. /// The next filter in the pipeline. - /// The result of calling the current filter. + /// An awaitable result of calling the handler and apply + /// any modifications made by filters in the pipeline. ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next); } diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index ee194aac8435..5b79735f356c 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -225,12 +225,12 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions private static Func> CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) { Debug.Assert(factoryContext.Filters is not null); - // httpContext.Response.StatusCode == 400 + // httpContext.Response.StatusCode >= 400 // ? Task.CompletedTask // : handler((string)context.Parameters[0], (int)context.Parameters[1]) var filteredInvocation = Expression.Lambda>>( Expression.Condition( - Expression.Equal(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), + Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), CompletedValueTaskExpr, Expression.Block( new[] { TargetExpr }, From 70fbdabdfe467265775cbd93e6b3811ae92d8039 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Thu, 3 Mar 2022 19:22:34 +0000 Subject: [PATCH 5/5] Rebase and fix new test --- .../RequestDelegateEndpointRouteBuilderExtensionsTest.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs index a76b96a19be5..df21fbb1d7f1 100644 --- a/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs @@ -89,8 +89,9 @@ public async Task MapEndpoint_ReturnGenericTypeTask_GeneratedDelegate() var endpointBuilder = builder.MapGet("/", GenericTypeTaskDelegate); // Assert - var endpointBuilder1 = GetRouteEndpointBuilder(builder); - var requestDelegate = endpointBuilder1.RequestDelegate; + var dataSource = GetBuilderEndpointDataSource(builder); + var endpoint = Assert.Single(dataSource.Endpoints); // Triggers build and construction of delegate + var requestDelegate = endpoint.RequestDelegate; await requestDelegate(httpContext); var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());