Skip to content

Commit

Permalink
Support optional input for MapAction (#30434)
Browse files Browse the repository at this point in the history
* Support optional input for MapAction

* Revert passing HttpContext.RequestAborted

* Fix tests
  • Loading branch information
Kahbazi authored Feb 26, 2021
1 parent 1256a3b commit 1735db4
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 44 deletions.
88 changes: 60 additions & 28 deletions src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ internal static class MapActionExpressionTreeBuilder
private static readonly MethodInfo ChangeTypeMethodInfo = GetMethodInfo<Func<object, Type, object>>((value, type) => Convert.ChangeType(value, type, CultureInfo.InvariantCulture));
private static readonly MethodInfo ExecuteTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskOfStringMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueTaskOfStringMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskResultOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueResultTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!;
Expand Down Expand Up @@ -71,28 +72,31 @@ public static RequestDelegate BuildRequestDelegate(Delegate action)
// This argument represents the deserialized body returned from IHttpRequestReader
// when the method has a FromBody attribute declared

var args = new List<Expression>();
var methodParameters = method.GetParameters();
var args = new List<Expression>(methodParameters.Length);

foreach (var parameter in method.GetParameters())
foreach (var parameter in methodParameters)
{
Expression paramterExpression = Expression.Default(parameter.ParameterType);

if (parameter.GetCustomAttributes().OfType<IFromRouteMetadata>().FirstOrDefault() is { } routeAttribute)
var parameterCustomAttributes = parameter.GetCustomAttributes();

if (parameterCustomAttributes.OfType<IFromRouteMetadata>().FirstOrDefault() is { } routeAttribute)
{
var routeValuesProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.RouteValues));
paramterExpression = BindParamenter(routeValuesProperty, parameter, routeAttribute.Name);
}
else if (parameter.GetCustomAttributes().OfType<IFromQueryMetadata>().FirstOrDefault() is { } queryAttribute)
else if (parameterCustomAttributes.OfType<IFromQueryMetadata>().FirstOrDefault() is { } queryAttribute)
{
var queryProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Query));
paramterExpression = BindParamenter(queryProperty, parameter, queryAttribute.Name);
}
else if (parameter.GetCustomAttributes().OfType<IFromHeaderMetadata>().FirstOrDefault() is { } headerAttribute)
else if (parameterCustomAttributes.OfType<IFromHeaderMetadata>().FirstOrDefault() is { } headerAttribute)
{
var headersProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Headers));
paramterExpression = BindParamenter(headersProperty, parameter, headerAttribute.Name);
}
else if (parameter.GetCustomAttributes().OfType<IFromBodyMetadata>().FirstOrDefault() is { } bodyAttribute)
else if (parameterCustomAttributes.OfType<IFromBodyMetadata>().FirstOrDefault() is { } bodyAttribute)
{
if (consumeBodyDirectly)
{
Expand All @@ -109,7 +113,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action)
bodyType = parameter.ParameterType;
paramterExpression = Expression.Convert(DeserializedBodyArg, bodyType);
}
else if (parameter.GetCustomAttributes().OfType<IFromFormMetadata>().FirstOrDefault() is { } formAttribute)
else if (parameterCustomAttributes.OfType<IFromFormMetadata>().FirstOrDefault() is { } formAttribute)
{
if (consumeBodyDirectly)
{
Expand All @@ -125,27 +129,24 @@ public static RequestDelegate BuildRequestDelegate(Delegate action)
{
paramterExpression = Expression.Call(GetRequiredServiceMethodInfo.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr);
}
else
else if (parameter.ParameterType == typeof(IFormCollection))
{
if (parameter.ParameterType == typeof(IFormCollection))
if (consumeBodyDirectly)
{
if (consumeBodyDirectly)
{
ThrowCannotReadBodyDirectlyAndAsForm();
}
ThrowCannotReadBodyDirectlyAndAsForm();
}

consumeBodyAsForm = true;
consumeBodyAsForm = true;

paramterExpression = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Form));
}
else if (parameter.ParameterType == typeof(HttpContext))
{
paramterExpression = HttpContextParameter;
}
else if (parameter.ParameterType == typeof(CancellationToken))
{
paramterExpression = RequestAbortedExpr;
}
paramterExpression = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Form));
}
else if (parameter.ParameterType == typeof(HttpContext))
{
paramterExpression = HttpContextParameter;
}
else if (parameter.ParameterType == typeof(CancellationToken))
{
paramterExpression = RequestAbortedExpr;
}

args.Add(paramterExpression);
Expand Down Expand Up @@ -182,6 +183,12 @@ public static RequestDelegate BuildRequestDelegate(Delegate action)
{
body = methodCall;
}
else if (method.ReturnType == typeof(ValueTask))
{
body = Expression.Call(
ExecuteValueTaskMethodInfo,
methodCall);
}
else if (method.ReturnType.IsGenericType &&
method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>))
{
Expand Down Expand Up @@ -263,7 +270,7 @@ public static RequestDelegate BuildRequestDelegate(Delegate action)
var box = Expression.TypeAs(methodCall, typeof(object));
body = Expression.Call(JsonResultWriteResponseAsync, HttpResponseExpr, box, Expression.Constant(CancellationToken.None));
}
else
else
{
body = Expression.Call(JsonResultWriteResponseAsync, HttpResponseExpr, methodCall, Expression.Constant(CancellationToken.None));
}
Expand Down Expand Up @@ -398,10 +405,20 @@ private static Expression BindParamenter(Expression sourceExpression, ParameterI
expr = Expression.Convert(expr, parameter.ParameterType);
}

Expression defaultExpression;
if (parameter.HasDefaultValue)
{
defaultExpression = Expression.Constant(parameter.DefaultValue);
}
else
{
defaultExpression = Expression.Default(parameter.ParameterType);
}

// property[key] == null ? default : (ParameterType){Type}.Parse(property[key]);
expr = Expression.Condition(
Expression.Equal(valueArg, Expression.Constant(null)),
Expression.Default(parameter.ParameterType),
defaultExpression,
expr);

return expr;
Expand Down Expand Up @@ -449,7 +466,22 @@ static async Task ExecuteAwaited(Task<string> task, HttpContext httpContext)
return ExecuteAwaited(task, httpContext);
}

private static Task ExecuteValueTask<T>(ValueTask<T> task, HttpContext httpContext)
private static Task ExecuteValueTask(ValueTask task)
{
static async Task ExecuteAwaited(ValueTask task)
{
await task;
}

if (task.IsCompletedSuccessfully)
{
task.GetAwaiter().GetResult();
}

return ExecuteAwaited(task);
}

private static Task ExecuteValueTaskOfT<T>(ValueTask<T> task, HttpContext httpContext)
{
static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,186 @@ namespace Microsoft.AspNetCore.Routing.Internal
{
public class MapActionExpressionTreeBuilderTest
{
[Fact]
public async Task RequestDelegateInvokesAction()
public static IEnumerable<object[]> NoResult
{
var invoked = false;

void TestAction()
get
{
invoked = true;
void TestAction(HttpContext httpContext)
{
MarkAsInvoked(httpContext);
}

Task TaskTestAction(HttpContext httpContext)
{
MarkAsInvoked(httpContext);
return Task.CompletedTask;
}

ValueTask ValueTaskTestAction(HttpContext httpContext)
{
MarkAsInvoked(httpContext);
return ValueTask.CompletedTask;
}

void StaticTestAction(HttpContext httpContext)
{
MarkAsInvoked(httpContext);
}

Task StaticTaskTestAction(HttpContext httpContext)
{
MarkAsInvoked(httpContext);
return Task.CompletedTask;
}

ValueTask StaticValueTaskTestAction(HttpContext httpContext)
{
MarkAsInvoked(httpContext);
return ValueTask.CompletedTask;
}

void MarkAsInvoked(HttpContext httpContext)
{
httpContext.Items.Add("invoked", true);
}

return new List<object[]>
{
new object[] { (Action<HttpContext>)TestAction },
new object[] { (Func<HttpContext, Task>)TaskTestAction },
new object[] { (Func<HttpContext, ValueTask>)ValueTaskTestAction },
new object[] { (Action<HttpContext>)StaticTestAction },
new object[] { (Func<HttpContext, Task>)StaticTaskTestAction },
new object[] { (Func<HttpContext, ValueTask>)StaticValueTaskTestAction },
};
}
}

var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate((Action)TestAction);
[Theory]
[MemberData(nameof(NoResult))]
public async Task RequestDelegateInvokesAction(Delegate @delegate)
{
var httpContext = new DefaultHttpContext();

var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate);

await requestDelegate(null!);
await requestDelegate(httpContext);

Assert.True(invoked);
Assert.True(httpContext.Items["invoked"] as bool?);
}

[Fact]
public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName()
public static IEnumerable<object[]> FromRouteResult
{
get
{
void TestAction(HttpContext httpContext, [FromRoute] int value)
{
StoreInput(httpContext, value);
};

Task TaskTestAction(HttpContext httpContext, [FromRoute] int value)
{
StoreInput(httpContext, value);
return Task.CompletedTask;
}

ValueTask ValueTaskTestAction(HttpContext httpContext, [FromRoute] int value)
{
StoreInput(httpContext, value);
return ValueTask.CompletedTask;
}



return new List<object[]>
{
new object[] { (Action<HttpContext, int>)TestAction },
new object[] { (Func<HttpContext, int, Task>)TaskTestAction },
new object[] { (Func<HttpContext, int, ValueTask>)ValueTaskTestAction },
};
}
}
private static void StoreInput(HttpContext httpContext, object value)
{
httpContext.Items.Add("input", value);
}

[Theory]
[MemberData(nameof(FromRouteResult))]
public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName(Delegate @delegate)
{
const string paramName = "value";
const int originalRouteParam = 42;

int? deserializedRouteParam = null;
var httpContext = new DefaultHttpContext();
httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo);

var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate);

await requestDelegate(httpContext);

Assert.Equal(originalRouteParam, httpContext.Items["input"] as int?);
}

void TestAction([FromRoute] int value)
public static IEnumerable<object[]> FromRouteOptionalResult
{
get
{
deserializedRouteParam = value;
return new List<object[]>
{
new object[] { (Action<HttpContext, int>)TestAction },
new object[] { (Func<HttpContext, int, Task>)TaskTestAction },
new object[] { (Func<HttpContext, int, ValueTask>)ValueTaskTestAction }
};
}
}

private static void TestAction(HttpContext httpContext, [FromRoute] int value = 42)
{
StoreInput(httpContext, value);
}

private static Task TaskTestAction(HttpContext httpContext, [FromRoute] int value = 42)
{
StoreInput(httpContext, value);
return Task.CompletedTask;
}

private static ValueTask ValueTaskTestAction(HttpContext httpContext, [FromRoute] int value = 42)
{
StoreInput(httpContext, value);
return ValueTask.CompletedTask;
}

[Theory]
[MemberData(nameof(FromRouteOptionalResult))]
public async Task RequestDelegatePopulatesFromRouteOptionalParameter(Delegate @delegate)
{
var httpContext = new DefaultHttpContext();

var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate);

await requestDelegate(httpContext);

Assert.Equal(42, httpContext.Items["input"] as int?);
}

[Theory]
[MemberData(nameof(FromRouteOptionalResult))]
public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParameterName(Delegate @delegate)
{
const string paramName = "value";
const int originalRouteParam = 47;

var httpContext = new DefaultHttpContext();

httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo);

var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate((Action<int>)TestAction);
var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate);

await requestDelegate(httpContext);

Assert.Equal(originalRouteParam, deserializedRouteParam);
Assert.Equal(47, httpContext.Items["input"] as int?);
}

[Fact]
Expand Down

0 comments on commit 1735db4

Please sign in to comment.