Skip to content

Commit

Permalink
Add null check on Task, Task<IResult>, ValueTask, ValueTask<IResult> (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wcontayon authored Jun 9, 2021
1 parent 43241a3 commit 0ecf4c1
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 14 deletions.
66 changes: 52 additions & 14 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
Expand All @@ -30,7 +31,7 @@ public static class RequestDelegateFactory
private static readonly MethodInfo ExecuteTaskResultOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueResultTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!;
private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(IResult).GetMethod(nameof(IResult.ExecuteAsync), BindingFlags.Public | BindingFlags.Instance)!;
private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = GetMethodInfo<Func<HttpResponse, string, Task>>((response, text) => HttpResponseWritingExtensions.WriteAsync(response, text, default));
private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo<Func<HttpResponse, object, Task>>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default));
private static readonly MethodInfo EnumTryParseMethod = GetEnumTryParseMethod();
Expand Down Expand Up @@ -393,7 +394,7 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
}
else if (typeof(IResult).IsAssignableFrom(returnType))
{
return Expression.Call(methodCall, ResultWriteResponseAsyncMethod, HttpContextExpr);
return Expression.Call(ResultWriteResponseAsyncMethod, methodCall, HttpContextExpr);
}
else if (returnType == typeof(string))
{
Expand Down Expand Up @@ -679,6 +680,8 @@ private static MemberInfo GetMemberInfo<T>(Expression<T> expr)

private static Task ExecuteTask<T>(Task<T> task, HttpContext httpContext)
{
EnsureRequestTaskNotNull(task);

static async Task ExecuteAwaited(Task<T> task, HttpContext httpContext)
{
await httpContext.Response.WriteAsJsonAsync(await task);
Expand All @@ -692,19 +695,21 @@ static async Task ExecuteAwaited(Task<T> task, HttpContext httpContext)
return ExecuteAwaited(task, httpContext);
}

private static Task ExecuteTaskOfString(Task<string> task, HttpContext httpContext)
private static Task ExecuteTaskOfString(Task<string?> task, HttpContext httpContext)
{
EnsureRequestTaskNotNull(task);

static async Task ExecuteAwaited(Task<string> task, HttpContext httpContext)
{
await httpContext.Response.WriteAsync(await task);
}

if (task.IsCompletedSuccessfully)
{
return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult());
return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()!);
}

return ExecuteAwaited(task, httpContext);
return ExecuteAwaited(task!, httpContext);
}

private static Task ExecuteValueTask(ValueTask task)
Expand Down Expand Up @@ -737,7 +742,7 @@ static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext)
return ExecuteAwaited(task, httpContext);
}

private static Task ExecuteValueTaskOfString(ValueTask<string> task, HttpContext httpContext)
private static Task ExecuteValueTaskOfString(ValueTask<string?> task, HttpContext httpContext)
{
static async Task ExecuteAwaited(ValueTask<string> task, HttpContext httpContext)
{
Expand All @@ -746,30 +751,37 @@ static async Task ExecuteAwaited(ValueTask<string> task, HttpContext httpContext

if (task.IsCompletedSuccessfully)
{
return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult());
return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()!);
}

return ExecuteAwaited(task, httpContext);
return ExecuteAwaited(task!, httpContext);
}

private static Task ExecuteValueTaskResult<T>(ValueTask<T> task, HttpContext httpContext) where T : IResult
private static Task ExecuteValueTaskResult<T>(ValueTask<T?> task, HttpContext httpContext) where T : IResult
{
static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext)
{
await (await task).ExecuteAsync(httpContext);
await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext);
}

if (task.IsCompletedSuccessfully)
{
return task.GetAwaiter().GetResult().ExecuteAsync(httpContext);
return EnsureRequestResultNotNull(task.GetAwaiter().GetResult())!.ExecuteAsync(httpContext);
}

return ExecuteAwaited(task, httpContext);
return ExecuteAwaited(task!, httpContext);
}

private static async Task ExecuteTaskResult<T>(Task<T> task, HttpContext httpContext) where T : IResult
private static async Task ExecuteTaskResult<T>(Task<T?> task, HttpContext httpContext) where T : IResult
{
await (await task).ExecuteAsync(httpContext);
EnsureRequestTaskOfNotNull(task);

await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext);
}

private static async Task ExecuteResultWriteResponse(IResult result, HttpContext httpContext)
{
await EnsureRequestResultNotNull(result)!.ExecuteAsync(httpContext);
}

private class FactoryContext
Expand Down Expand Up @@ -819,5 +831,31 @@ private static ILogger GetLogger(HttpContext httpContext)
return loggerFactory.CreateLogger(typeof(RequestDelegateFactory));
}
}

private static void EnsureRequestTaskOfNotNull<T>(Task<T?> task) where T : IResult
{
if (task is null)
{
throw new InvalidOperationException("The IResult in Task<IResult> response must not be null.");
}
}

private static void EnsureRequestTaskNotNull(Task? task)
{
if (task is null)
{
throw new InvalidOperationException("The Task returned by the Delegate must not be null.");
}
}

private static IResult EnsureRequestResultNotNull(IResult? result)
{
if (result is null)
{
throw new InvalidOperationException("The IResult returned by the Delegate must not be null.");
}

return result;
}
}
}
83 changes: 83 additions & 0 deletions src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,89 @@ public async Task RequestDelegateWritesBoolReturnValue(Delegate @delegate)
Assert.Equal("true", responseBody);
}

public static IEnumerable<object[]> NullResult
{
get
{
IResult? TestAction() => null;
Task<bool?>? TaskBoolAction() => null;
Task<IResult?>? TaskNullAction() => null;
Task<IResult?> TaskTestAction() => Task.FromResult<IResult?>(null);
ValueTask<IResult?> ValueTaskTestAction() => ValueTask.FromResult<IResult?>(null);

return new List<object[]>
{
new object[] { (Func<IResult?>)TestAction, "The IResult returned by the Delegate must not be null." },
new object[] { (Func<Task<IResult?>?>)TaskNullAction, "The IResult in Task<IResult> response must not be null." },
new object[] { (Func<Task<bool?>?>)TaskBoolAction, "The Task returned by the Delegate must not be null." },
new object[] { (Func<Task<IResult?>>)TaskTestAction, "The IResult returned by the Delegate must not be null." },
new object[] { (Func<ValueTask<IResult?>>)ValueTaskTestAction, "The IResult returned by the Delegate must not be null." },
};
}
}

[Theory]
[MemberData(nameof(NullResult))]
public async Task RequestDelegateThrowsInvalidOperationExceptionOnNullDelegate(Delegate @delegate, string message)
{
var httpContext = new DefaultHttpContext();
var responseBodyStream = new MemoryStream();
httpContext.Response.Body = responseBodyStream;

var requestDelegate = RequestDelegateFactory.Create(@delegate);

var exception = await Assert.ThrowsAnyAsync<InvalidOperationException>(async () => await requestDelegate(httpContext));
Assert.Contains(message, exception.Message);
}

public static IEnumerable<object[]> NullContentResult
{
get
{
bool? TestBoolAction() => null;
Task<bool?> TaskTestBoolAction() => Task.FromResult<bool?>(null);
ValueTask<bool?> ValueTaskTestBoolAction() => ValueTask.FromResult<bool?>(null);

int? TestIntAction() => null;
Task<int?> TaskTestIntAction() => Task.FromResult<int?>(null);
ValueTask<int?> ValueTaskTestIntAction() => ValueTask.FromResult<int?>(null);

Todo? TestTodoAction() => null;
Task<Todo?> TaskTestTodoAction() => Task.FromResult<Todo?>(null);
ValueTask<Todo?> ValueTaskTestTodoAction() => ValueTask.FromResult<Todo?>(null);

return new List<object[]>
{
new object[] { (Func<bool?>)TestBoolAction },
new object[] { (Func<Task<bool?>>)TaskTestBoolAction },
new object[] { (Func<ValueTask<bool?>>)ValueTaskTestBoolAction },
new object[] { (Func<int?>)TestIntAction },
new object[] { (Func<Task<int?>>)TaskTestIntAction },
new object[] { (Func<ValueTask<int?>>)ValueTaskTestIntAction },
new object[] { (Func<Todo?>)TestTodoAction },
new object[] { (Func<Task<Todo?>>)TaskTestTodoAction },
new object[] { (Func<ValueTask<Todo?>>)ValueTaskTestTodoAction },
};
}
}

[Theory]
[MemberData(nameof(NullContentResult))]
public async Task RequestDelegateWritesNullReturnNullValue(Delegate @delegate)
{
var httpContext = new DefaultHttpContext();
var responseBodyStream = new MemoryStream();
httpContext.Response.Body = responseBodyStream;

var requestDelegate = RequestDelegateFactory.Create(@delegate);

await requestDelegate(httpContext);

var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());

Assert.Equal("null", responseBody);
}

private class Todo
{
public int Id { get; set; }
Expand Down

0 comments on commit 0ecf4c1

Please sign in to comment.