Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IAsyncEnumerable controller methods to allow setting headers #57924

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 22 additions & 80 deletions src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,12 @@ public static Task WriteAsJsonAsync<TValue>(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, options,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response.BodyWriter, value, options, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, options, cancellationToken);
}

Expand All @@ -131,33 +121,22 @@ public static Task WriteAsJsonAsync<TValue>(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response, value, jsonTypeInfo,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response, value, jsonTypeInfo, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);

static async Task WriteAsJsonAsyncSlow(Task startTask, HttpResponse response, TValue value, JsonTypeInfo<TValue> jsonTypeInfo,
bool ignoreOCE, CancellationToken cancellationToken)
static async Task WriteAsJsonAsyncSlow(HttpResponse response, TValue value, JsonTypeInfo<TValue> jsonTypeInfo,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}
}

Expand All @@ -184,52 +163,38 @@ public static Task WriteAsJsonAsync(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response, value, jsonTypeInfo,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response, value, jsonTypeInfo, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);

static async Task WriteAsJsonAsyncSlow(Task startTask, HttpResponse response, object? value, JsonTypeInfo jsonTypeInfo,
bool ignoreOCE, CancellationToken cancellationToken)
static async Task WriteAsJsonAsyncSlow(HttpResponse response, object? value, JsonTypeInfo jsonTypeInfo,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(response.BodyWriter, value, jsonTypeInfo, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}
}

[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
[RequiresDynamicCode(RequiresDynamicCodeMessage)]
private static async Task WriteAsJsonAsyncSlow<TValue>(
Task startTask,
PipeWriter body,
TValue value,
JsonSerializerOptions? options,
bool ignoreOCE,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(body, value, options, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}

/// <summary>
Expand Down Expand Up @@ -304,42 +269,30 @@ public static Task WriteAsJsonAsync(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, type, options,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response.BodyWriter, value, type, options,
response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, type, options, cancellationToken);
}

[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
[RequiresDynamicCode(RequiresDynamicCodeMessage)]
private static async Task WriteAsJsonAsyncSlow(
Task startTask,
PipeWriter body,
object? value,
Type type,
JsonSerializerOptions? options,
bool ignoreOCE,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(body, value, type, options, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}

/// <summary>
Expand Down Expand Up @@ -367,33 +320,22 @@ public static Task WriteAsJsonAsync(

response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
}

// if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
if (!startTask.IsCompleted || !cancellationToken.CanBeCanceled)
if (!cancellationToken.CanBeCanceled)
{
return WriteAsJsonAsyncSlow(startTask, response.BodyWriter, value, type, context,
ignoreOCE: !cancellationToken.CanBeCanceled,
cancellationToken.CanBeCanceled ? cancellationToken : response.HttpContext.RequestAborted);
return WriteAsJsonAsyncSlow(response.BodyWriter, value, type, context, response.HttpContext.RequestAborted);
}

startTask.GetAwaiter().GetResult();
return JsonSerializer.SerializeAsync(response.BodyWriter, value, type, context, cancellationToken);

static async Task WriteAsJsonAsyncSlow(Task startTask, PipeWriter body, object? value, Type type, JsonSerializerContext context,
bool ignoreOCE, CancellationToken cancellationToken)
static async Task WriteAsJsonAsyncSlow(PipeWriter body, object? value, Type type, JsonSerializerContext context,
CancellationToken cancellationToken)
{
try
{
await startTask;
await JsonSerializer.SerializeAsync(body, value, type, context, cancellationToken);
}
catch (OperationCanceledException) when (ignoreOCE) { }
catch (OperationCanceledException) { }
}
}

Expand Down
70 changes: 69 additions & 1 deletion src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.TestHost;

#nullable enable

Expand Down Expand Up @@ -481,6 +484,71 @@ public async Task WriteAsJsonAsync_NullValue_WithJsonTypeInfo_JsonResponse()
Assert.Equal("null", data);
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task AsyncEnumerableCanSetHeader()
{
var builder = WebApplication.CreateBuilder();
builder.WebHost.UseTestServer();

await using var app = builder.Build();

app.MapGet("/", IAsyncEnumerable<int> (HttpContext httpContext) =>
{
return AsyncEnum();

async IAsyncEnumerable<int> AsyncEnum()
{
await Task.Yield();
httpContext.Response.Headers["Test"] = "t";
yield return 1;
}
});

await app.StartAsync();

var client = app.GetTestClient();

var result = await client.GetAsync("/");
result.EnsureSuccessStatusCode();
var headerValue = Assert.Single(result.Headers.GetValues("Test"));
Assert.Equal("t", headerValue);

await app.StopAsync();
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task EnumerableCanSetHeader()
{
var builder = WebApplication.CreateBuilder();
builder.WebHost.UseTestServer();

await using var app = builder.Build();

app.MapGet("/", IEnumerable<int> (HttpContext httpContext) =>
{
return Enum();

IEnumerable<int> Enum()
{
httpContext.Response.Headers["Test"] = "t";
yield return 1;
}
});

await app.StartAsync();

var client = app.GetTestClient();

var result = await client.GetAsync("/");
result.EnsureSuccessStatusCode();
var headerValue = Assert.Single(result.Headers.GetValues("Test"));
Assert.Equal("t", headerValue);

await app.StopAsync();
}

public class TestObject
{
public string? StringProperty { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<Reference Include="Microsoft.AspNetCore.Http.Results" />
<Reference Include="Microsoft.AspNetCore.Http.Extensions" />
<Reference Include="Microsoft.AspNetCore.Mvc.Core" />
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Microsoft.Extensions.DependencyInjection" />
<Reference Include="Microsoft.Extensions.DependencyModel" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ public sealed override async Task WriteResponseBodyAsync(OutputFormatterWriteCon
try
{
var responseWriter = httpContext.Response.BodyWriter;
if (!httpContext.Response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await httpContext.Response.StartAsync();
}

if (jsonTypeInfo is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ public async Task ExecuteAsync(ActionContext context, JsonResult result)
try
{
var responseWriter = response.BodyWriter;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await response.StartAsync();
}

await JsonSerializer.SerializeAsync(responseWriter, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted);
}
catch (OperationCanceledException) when (context.HttpContext.RequestAborted.IsCancellationRequested) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,21 @@ public async Task Formatting_PolymorphicModel_WithJsonPolymorphism()
await response.AssertStatusCodeAsync(HttpStatusCode.OK);
Assert.Equal(expected, await response.Content.ReadAsStringAsync());
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task CanSetHeaderWithAsyncEnumerable()
{
// Arrange
var expected = "[1]";

// Act
var response = await Client.GetAsync($"/SystemTextJsonOutputFormatter/{nameof(SystemTextJsonOutputFormatterController.AsyncEnumerable)}");

// Assert
await response.AssertStatusCodeAsync(HttpStatusCode.OK);
Assert.Equal(expected, await response.Content.ReadAsStringAsync());
var headerValue = Assert.Single(response.Headers.GetValues("Test"));
Assert.Equal("t", headerValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ public class SystemTextJsonOutputFormatterController : ControllerBase
Address = "Some address",
};

[HttpGet]
public async IAsyncEnumerable<int> AsyncEnumerable()
{
await Task.Yield();
HttpContext.Response.Headers["Test"] = "t";
yield return 1;
}

[JsonPolymorphic]
[JsonDerivedType(typeof(DerivedModel), nameof(DerivedModel))]
public class SimpleModel
Expand Down
Loading