diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index 7614f84016da..7bdaa893a90b 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -3712,6 +3712,87 @@ await ExpectAsync(Http2FrameType.DATA, Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); } + [Fact] + public async Task CompleteAsync_AfterPipeWrite_WithTrailers_SendsBodyAndTrailersWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + var buffer = context.Response.BodyWriter.GetMemory(); + var length = Encoding.UTF8.GetBytes("Hello World", buffer.Span); + context.Response.BodyWriter.Advance(length); + + Assert.False(startingTcs.Task.IsCompletedSuccessfully); // OnStarting did not get called. + Assert.False(context.Response.Headers.IsReadOnly); + + context.Response.AppendTrailer("CustomName", "Custom Value"); + + await completionFeature.CompleteAsync().DefaultTimeout(); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 25, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + + _decodedHeaders.Clear(); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); + } + [Fact] public async Task CompleteAsync_AfterBodyStarted_WithTrailers_SendsBodyAndTrailersWithEndStream() {