From 0ecf4c1bd9526c36356a0f259d90aafdf0756fcd Mon Sep 17 00:00:00 2001 From: wcontayon Date: Wed, 9 Jun 2021 20:27:37 +0200 Subject: [PATCH] Add null check on Task, Task, ValueTask, ValueTask (#33079) --- .../src/RequestDelegateFactory.cs | 66 +++++++++++---- .../test/RequestDelegateFactoryTests.cs | 83 +++++++++++++++++++ 2 files changed, 135 insertions(+), 14 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 7762e2e1467f..6232a5479a00 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Globalization; using System.IO; using System.Linq; using System.Linq.Expressions; @@ -30,7 +31,7 @@ public static class RequestDelegateFactory private static readonly MethodInfo ExecuteTaskResultOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueResultTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!; - private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(IResult).GetMethod(nameof(IResult.ExecuteAsync), BindingFlags.Public | BindingFlags.Instance)!; + private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = GetMethodInfo>((response, text) => HttpResponseWritingExtensions.WriteAsync(response, text, default)); private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default)); private static readonly MethodInfo EnumTryParseMethod = GetEnumTryParseMethod(); @@ -393,7 +394,7 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, } else if (typeof(IResult).IsAssignableFrom(returnType)) { - return Expression.Call(methodCall, ResultWriteResponseAsyncMethod, HttpContextExpr); + return Expression.Call(ResultWriteResponseAsyncMethod, methodCall, HttpContextExpr); } else if (returnType == typeof(string)) { @@ -679,6 +680,8 @@ private static MemberInfo GetMemberInfo(Expression expr) private static Task ExecuteTask(Task task, HttpContext httpContext) { + EnsureRequestTaskNotNull(task); + static async Task ExecuteAwaited(Task task, HttpContext httpContext) { await httpContext.Response.WriteAsJsonAsync(await task); @@ -692,8 +695,10 @@ static async Task ExecuteAwaited(Task task, HttpContext httpContext) return ExecuteAwaited(task, httpContext); } - private static Task ExecuteTaskOfString(Task task, HttpContext httpContext) + private static Task ExecuteTaskOfString(Task task, HttpContext httpContext) { + EnsureRequestTaskNotNull(task); + static async Task ExecuteAwaited(Task task, HttpContext httpContext) { await httpContext.Response.WriteAsync(await task); @@ -701,10 +706,10 @@ static async Task ExecuteAwaited(Task task, HttpContext httpContext) if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()!); } - return ExecuteAwaited(task, httpContext); + return ExecuteAwaited(task!, httpContext); } private static Task ExecuteValueTask(ValueTask task) @@ -737,7 +742,7 @@ static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) return ExecuteAwaited(task, httpContext); } - private static Task ExecuteValueTaskOfString(ValueTask task, HttpContext httpContext) + private static Task ExecuteValueTaskOfString(ValueTask task, HttpContext httpContext) { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) { @@ -746,30 +751,37 @@ static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()!); } - return ExecuteAwaited(task, httpContext); + return ExecuteAwaited(task!, httpContext); } - private static Task ExecuteValueTaskResult(ValueTask task, HttpContext httpContext) where T : IResult + private static Task ExecuteValueTaskResult(ValueTask task, HttpContext httpContext) where T : IResult { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) { - await (await task).ExecuteAsync(httpContext); + await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext); } if (task.IsCompletedSuccessfully) { - return task.GetAwaiter().GetResult().ExecuteAsync(httpContext); + return EnsureRequestResultNotNull(task.GetAwaiter().GetResult())!.ExecuteAsync(httpContext); } - return ExecuteAwaited(task, httpContext); + return ExecuteAwaited(task!, httpContext); } - private static async Task ExecuteTaskResult(Task task, HttpContext httpContext) where T : IResult + private static async Task ExecuteTaskResult(Task task, HttpContext httpContext) where T : IResult { - await (await task).ExecuteAsync(httpContext); + EnsureRequestTaskOfNotNull(task); + + await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext); + } + + private static async Task ExecuteResultWriteResponse(IResult result, HttpContext httpContext) + { + await EnsureRequestResultNotNull(result)!.ExecuteAsync(httpContext); } private class FactoryContext @@ -819,5 +831,31 @@ private static ILogger GetLogger(HttpContext httpContext) return loggerFactory.CreateLogger(typeof(RequestDelegateFactory)); } } + + private static void EnsureRequestTaskOfNotNull(Task task) where T : IResult + { + if (task is null) + { + throw new InvalidOperationException("The IResult in Task response must not be null."); + } + } + + private static void EnsureRequestTaskNotNull(Task? task) + { + if (task is null) + { + throw new InvalidOperationException("The Task returned by the Delegate must not be null."); + } + } + + private static IResult EnsureRequestResultNotNull(IResult? result) + { + if (result is null) + { + throw new InvalidOperationException("The IResult returned by the Delegate must not be null."); + } + + return result; + } } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 32e115a80038..ae85b27e3ee3 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -1081,6 +1081,89 @@ public async Task RequestDelegateWritesBoolReturnValue(Delegate @delegate) Assert.Equal("true", responseBody); } + public static IEnumerable NullResult + { + get + { + IResult? TestAction() => null; + Task? TaskBoolAction() => null; + Task? TaskNullAction() => null; + Task TaskTestAction() => Task.FromResult(null); + ValueTask ValueTaskTestAction() => ValueTask.FromResult(null); + + return new List + { + new object[] { (Func)TestAction, "The IResult returned by the Delegate must not be null." }, + new object[] { (Func?>)TaskNullAction, "The IResult in Task response must not be null." }, + new object[] { (Func?>)TaskBoolAction, "The Task returned by the Delegate must not be null." }, + new object[] { (Func>)TaskTestAction, "The IResult returned by the Delegate must not be null." }, + new object[] { (Func>)ValueTaskTestAction, "The IResult returned by the Delegate must not be null." }, + }; + } + } + + [Theory] + [MemberData(nameof(NullResult))] + public async Task RequestDelegateThrowsInvalidOperationExceptionOnNullDelegate(Delegate @delegate, string message) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + var exception = await Assert.ThrowsAnyAsync(async () => await requestDelegate(httpContext)); + Assert.Contains(message, exception.Message); + } + + public static IEnumerable NullContentResult + { + get + { + bool? TestBoolAction() => null; + Task TaskTestBoolAction() => Task.FromResult(null); + ValueTask ValueTaskTestBoolAction() => ValueTask.FromResult(null); + + int? TestIntAction() => null; + Task TaskTestIntAction() => Task.FromResult(null); + ValueTask ValueTaskTestIntAction() => ValueTask.FromResult(null); + + Todo? TestTodoAction() => null; + Task TaskTestTodoAction() => Task.FromResult(null); + ValueTask ValueTaskTestTodoAction() => ValueTask.FromResult(null); + + return new List + { + new object[] { (Func)TestBoolAction }, + new object[] { (Func>)TaskTestBoolAction }, + new object[] { (Func>)ValueTaskTestBoolAction }, + new object[] { (Func)TestIntAction }, + new object[] { (Func>)TaskTestIntAction }, + new object[] { (Func>)ValueTaskTestIntAction }, + new object[] { (Func)TestTodoAction }, + new object[] { (Func>)TaskTestTodoAction }, + new object[] { (Func>)ValueTaskTestTodoAction }, + }; + } + } + + [Theory] + [MemberData(nameof(NullContentResult))] + public async Task RequestDelegateWritesNullReturnNullValue(Delegate @delegate) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + + Assert.Equal("null", responseBody); + } + private class Todo { public int Id { get; set; }