From f27a2336e06957a82e15e47a889f774efc34f262 Mon Sep 17 00:00:00 2001 From: Chris R Date: Thu, 13 Jun 2019 14:23:07 -0700 Subject: [PATCH 1/4] Implement Http/2 CompleteAsync #10886 --- .../src/IHttpResponseCompletionFeature.cs | 20 + .../Http/HttpProtocol.FeatureCollection.cs | 1 + .../Internal/Http/HttpProtocol.Generated.cs | 23 + .../Core/src/Internal/Http/HttpProtocol.cs | 35 +- .../Internal/Http/RequestProcessingStatus.cs | 3 +- .../src/Internal/Http2/Http2OutputProducer.cs | 7 +- .../Http2/Http2Stream.FeatureCollection.cs | 38 +- .../Http2/Http2StreamTests.cs | 719 ++++++++++++++++++ .../HttpProtocolFeatureCollection.cs | 1 + 9 files changed, 829 insertions(+), 18 deletions(-) create mode 100644 src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs diff --git a/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs b/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs new file mode 100644 index 000000000000..eed45e40364b --- /dev/null +++ b/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Http.Features +{ + /// + /// A feature to gracefully end a response. + /// + public interface IHttpResponseCompletionFeature + { + /// + /// Flush any remaining response headers, data, or trailers. + /// This may throw if the response is in an invalid state such as a Content-Length mismatch. + /// + /// + Task CompleteAsync(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs index 82b1ebfe6b1a..f922de8a995f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -277,6 +277,7 @@ protected void ResetHttp1Features() protected void ResetHttp2Features() { _currentIHttp2StreamIdFeature = this; + _currentIHttpResponseCompletionFeature = this; _currentIHttpResponseTrailersFeature = this; } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs index b9b3e26905e7..f594feed0fa3 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs @@ -29,6 +29,7 @@ internal partial class HttpProtocol : IFeatureCollection private static readonly Type IFormFeatureType = typeof(IFormFeature); private static readonly Type IHttpUpgradeFeatureType = typeof(IHttpUpgradeFeature); private static readonly Type IHttp2StreamIdFeatureType = typeof(IHttp2StreamIdFeature); + private static readonly Type IHttpResponseCompletionFeatureType = typeof(IHttpResponseCompletionFeature); private static readonly Type IHttpResponseTrailersFeatureType = typeof(IHttpResponseTrailersFeature); private static readonly Type IResponseCookiesFeatureType = typeof(IResponseCookiesFeature); private static readonly Type IItemsFeatureType = typeof(IItemsFeature); @@ -58,6 +59,7 @@ internal partial class HttpProtocol : IFeatureCollection private object _currentIFormFeature; private object _currentIHttpUpgradeFeature; private object _currentIHttp2StreamIdFeature; + private object _currentIHttpResponseCompletionFeature; private object _currentIHttpResponseTrailersFeature; private object _currentIResponseCookiesFeature; private object _currentIItemsFeature; @@ -98,6 +100,7 @@ private void FastReset() _currentIQueryFeature = null; _currentIFormFeature = null; _currentIHttp2StreamIdFeature = null; + _currentIHttpResponseCompletionFeature = null; _currentIHttpResponseTrailersFeature = null; _currentIResponseCookiesFeature = null; _currentIItemsFeature = null; @@ -224,6 +227,10 @@ object IFeatureCollection.this[Type key] { feature = _currentIHttp2StreamIdFeature; } + else if (key == IHttpResponseCompletionFeatureType) + { + feature = _currentIHttpResponseCompletionFeature; + } else if (key == IHttpResponseTrailersFeatureType) { feature = _currentIHttpResponseTrailersFeature; @@ -348,6 +355,10 @@ object IFeatureCollection.this[Type key] { _currentIHttp2StreamIdFeature = value; } + else if (key == IHttpResponseCompletionFeatureType) + { + _currentIHttpResponseCompletionFeature = value; + } else if (key == IHttpResponseTrailersFeatureType) { _currentIHttpResponseTrailersFeature = value; @@ -470,6 +481,10 @@ TFeature IFeatureCollection.Get() { feature = (TFeature)_currentIHttp2StreamIdFeature; } + else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature)) + { + feature = (TFeature)_currentIHttpResponseCompletionFeature; + } else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature)) { feature = (TFeature)_currentIHttpResponseTrailersFeature; @@ -598,6 +613,10 @@ void IFeatureCollection.Set(TFeature feature) { _currentIHttp2StreamIdFeature = feature; } + else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature)) + { + _currentIHttpResponseCompletionFeature = feature; + } else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature)) { _currentIHttpResponseTrailersFeature = feature; @@ -718,6 +737,10 @@ private IEnumerable> FastEnumerable() { yield return new KeyValuePair(IHttp2StreamIdFeatureType, _currentIHttp2StreamIdFeature); } + if (_currentIHttpResponseCompletionFeature != null) + { + yield return new KeyValuePair(IHttpResponseCompletionFeatureType, _currentIHttpResponseCompletionFeature); + } if (_currentIHttpResponseTrailersFeature != null) { yield return new KeyValuePair(IHttpResponseTrailersFeatureType, _currentIHttpResponseTrailersFeature); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index 41a4f2e6aa54..37ad4d6dc4bd 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -210,6 +210,7 @@ private void HttpVersionSetSlow(string value) public bool RequestTrailersAvailable { get; set; } public Stream RequestBody { get; set; } public PipeReader RequestBodyPipeReader { get; set; } + public HttpResponseTrailers ResponseTrailers { get; set; } private int _statusCode; public int StatusCode @@ -287,7 +288,9 @@ public CancellationToken RequestAborted public bool HasResponseStarted => _requestProcessingStatus >= RequestProcessingStatus.HeadersCommitted; - public bool HasFlushedHeaders => _requestProcessingStatus == RequestProcessingStatus.HeadersFlushed; + public bool HasFlushedHeaders => _requestProcessingStatus >= RequestProcessingStatus.HeadersFlushed; + + public bool HasResponseCompleted => _requestProcessingStatus == RequestProcessingStatus.ResponseCompleted; protected HttpRequestHeaders HttpRequestHeaders { get; } @@ -632,9 +635,9 @@ private async Task ProcessRequests(IHttpApplication applicat // Run the application code for this request await application.ProcessRequestAsync(context); - if (!_connectionAborted) + if (!_connectionAborted && !VerifyResponseContentLength(out var lengthException)) { - VerifyResponseContentLength(); + ReportApplicationError(lengthException); } } catch (BadHttpRequestException ex) @@ -898,7 +901,7 @@ private void CheckLastWrite() } } - protected void VerifyResponseContentLength() + protected bool VerifyResponseContentLength(out Exception ex) { var responseHeaders = HttpResponseHeaders; @@ -915,9 +918,13 @@ protected void VerifyResponseContentLength() _keepAlive = false; } - ReportApplicationError(new InvalidOperationException( - CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value))); + ex = new InvalidOperationException( + CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value)); + return false; } + + ex = null; + return true; } public void ProduceContinue() @@ -935,7 +942,7 @@ public void ProduceContinue() } } - public Task InitializeResponseAsync(int firstWriteByteCount) + public Task InitializeResponseAsync(int firstWriteByteCount, bool appCompleted = false) { var startingTask = FireOnStarting(); // If return is Task.CompletedTask no awaiting is required @@ -946,7 +953,7 @@ public Task InitializeResponseAsync(int firstWriteByteCount) VerifyInitializeState(firstWriteByteCount); - ProduceStart(appCompleted: false); + ProduceStart(appCompleted: appCompleted); return Task.CompletedTask; } @@ -1043,8 +1050,13 @@ protected Task ProduceEnd() return WriteSuffix(); } - private Task WriteSuffix() + protected Task WriteSuffix() { + if (HasResponseCompleted) + { + return Task.CompletedTask; + } + // _autoChunk should be checked after we are sure ProduceStart() has been called // since ProduceStart() may set _autoChunk to true. if (_autoChunk || _httpVersion == Http.HttpVersion.Http2) @@ -1064,7 +1076,7 @@ private Task WriteSuffix() if (!HasFlushedHeaders) { - _requestProcessingStatus = RequestProcessingStatus.HeadersFlushed; + _requestProcessingStatus = RequestProcessingStatus.ResponseCompleted; return FlushAsyncInternal(); } @@ -1080,6 +1092,8 @@ private async Task WriteSuffixAwaited() await Output.WriteStreamSuffixAsync(); + _requestProcessingStatus = RequestProcessingStatus.ResponseCompleted; + if (_keepAlive) { Log.ConnectionKeepAlive(ConnectionId); @@ -1244,6 +1258,7 @@ private void SetErrorResponseHeaders(int statusCode) var responseHeaders = HttpResponseHeaders; responseHeaders.Reset(); + ResponseTrailers?.Reset(); var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues(); responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs index 61832dc34bdf..6e27fb5dc807 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs @@ -10,6 +10,7 @@ internal enum RequestProcessingStatus ParsingHeaders, AppStarted, HeadersCommitted, - HeadersFlushed + HeadersFlushed, + ResponseCompleted } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs index 4f481850d703..5f0d80a37217 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs @@ -151,7 +151,7 @@ public void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpRespon // 2. There is no trailing HEADERS frame. Http2HeadersFrameFlags http2HeadersFrame; - if (appCompleted && !_startedWritingDataFrames && (_stream.Trailers == null || _stream.Trailers.Count == 0)) + if (appCompleted && !_startedWritingDataFrames && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0)) { _streamEnded = true; http2HeadersFrame = Http2HeadersFrameFlags.END_STREAM; @@ -313,7 +313,7 @@ private async ValueTask ProcessDataWrites() { readResult = await _dataPipe.Reader.ReadAsync(); - if (readResult.IsCompleted && _stream.Trailers?.Count > 0) + if (readResult.IsCompleted && _stream.ResponseTrailers?.Count > 0) { // Output is ending and there are trailers to write // Write any remaining content then write trailers @@ -322,7 +322,8 @@ private async ValueTask ProcessDataWrites() flushResult = await _frameWriter.WriteDataAsync(_streamId, _flowControl, readResult.Buffer, endStream: false); } - flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.Trailers); + _stream.ResponseTrailers.SetReadOnly(); + flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.ResponseTrailers); } else if (readResult.IsCompleted && _streamEnded) { diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs index 7187fc846240..7ff3c4308b52 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -11,21 +12,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 { internal partial class Http2Stream : IHttp2StreamIdFeature, IHttpMinRequestBodyDataRateFeature, + IHttpResponseCompletionFeature, IHttpResponseTrailersFeature { - internal HttpResponseTrailers Trailers { get; set; } private IHeaderDictionary _userTrailers; IHeaderDictionary IHttpResponseTrailersFeature.Trailers { get { - if (Trailers == null) + if (ResponseTrailers == null) { - Trailers = new HttpResponseTrailers(); + ResponseTrailers = new HttpResponseTrailers(); + if (HasResponseCompleted) + { + ResponseTrailers.SetReadOnly(); + } } - return _userTrailers ?? Trailers; + return _userTrailers ?? ResponseTrailers; } set { @@ -48,5 +53,30 @@ MinDataRate IHttpMinRequestBodyDataRateFeature.MinDataRate MinRequestBodyDataRate = value; } } + + async Task IHttpResponseCompletionFeature.CompleteAsync() + { + // Finalize headers + if (!HasResponseStarted) + { + if (!VerifyResponseContentLength(out var lengthException)) + { + throw lengthException; + } + + await InitializeResponseAsync(0, appCompleted: true); + } + + // Flush headers, body, trailers... + if (!HasResponseCompleted) + { + if (!VerifyResponseContentLength(out var lengthException)) + { + throw lengthException; + } + + await WriteSuffix(); + } + } } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index 09c9dafae6ff..810b3359c796 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -1839,6 +1839,32 @@ await InitializeConnectionAsync(context => Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); } + [Fact] + public async Task ResponseTrailers_WithExeption500_Cleared() + { + await InitializeConnectionAsync(context => + { + context.Response.AppendTrailer("CustomName", "Custom Value"); + throw new NotImplementedException("Test Exception"); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM | Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("500", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + [Fact] public async Task ResponseTrailers_WithData_Sent() { @@ -3307,5 +3333,698 @@ await InitializeConnectionAsync(async context => Assert.Contains(TestSink.Writes, w => w.EventId.Id == 13 && w.LogLevel == LogLevel.Error && w.Exception is ConnectionAbortedException && w.Exception.InnerException == expectedException); } + + [Fact] + public async Task CompleteAsync_BeforeBodyStarted_SendsHeadersWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await completionFeature.CompleteAsync().DefaultTimeout(); + + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders["content-length"]); + } + + [Fact] + public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_SendsHeadersAndTrailersWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + context.Response.AppendTrailer("CustomName", "Custom Value"); + + await completionFeature.CompleteAsync().DefaultTimeout(); + await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops + + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 25, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders["content-length"]); + + _decodedHeaders.Clear(); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); + } + + [Fact] + public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAnd500() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + context.Response.ContentLength = 25; + context.Response.AppendTrailer("CustomName", "Custom Value"); + + var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout()); + Assert.Equal(CoreStrings.FormatTooFewBytesWritten(0, 25), ex.Message); + + Assert.False(startingTcs.Task.IsCompleted); // OnStarting did not get called. + Assert.False(context.Response.Headers.IsReadOnly); + Assert.False(context.Features.Get().Trailers.IsReadOnly); + + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("500", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + + [Fact] + public async Task CompleteAsync_AfterBodyStarted_SendsBodyWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await context.Response.WriteAsync("Hello World"); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + await completionFeature.CompleteAsync().DefaultTimeout(); + await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + } + + [Fact] + public async Task CompleteAsync_WriteAfterComplete_Throws() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await completionFeature.CompleteAsync().DefaultTimeout(); + + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + var ex = await Assert.ThrowsAsync(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout()); + Assert.Equal("Writing is not allowed after writer was completed.", ex.Message); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + + [Fact] + public async Task CompleteAsync_WriteAgainAfterComplete_Throws() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await context.Response.WriteAsync("Hello World").DefaultTimeout(); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + await completionFeature.CompleteAsync().DefaultTimeout(); + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + var ex = await Assert.ThrowsAsync(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout()); + Assert.Equal("Writing is not allowed after writer was completed.", ex.Message); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + } + + [Fact] + public async Task CompleteAsync_AfterBodyStarted_WithTrailers_SendsBodyAndTrailersWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await context.Response.WriteAsync("Hello World"); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + context.Response.AppendTrailer("CustomName", "Custom Value"); + + await completionFeature.CompleteAsync().DefaultTimeout(); + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 25, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + + _decodedHeaders.Clear(); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); + } + + [Fact] + public async Task CompleteAsync_AfterBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAndReset() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + context.Response.ContentLength = 25; + await context.Response.WriteAsync("Hello World"); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + context.Response.AppendTrailer("CustomName", "Custom Value"); + + var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout()); + Assert.Equal(CoreStrings.FormatTooFewBytesWritten(11, 25), ex.Message); + + Assert.False(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 56, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + + clientTcs.SetResult(0); + + await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, + expectedErrorMessage: CoreStrings.FormatTooFewBytesWritten(11, 25)); + + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("25", _decodedHeaders[HeaderNames.ContentLength]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + } + + [Fact] + public async Task AbortAfterCompleteAsync_GETWithResponseBodyAndTrailers_ResetsAfterResponse() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await context.Response.WriteAsync("Hello World"); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + context.Response.AppendTrailer("CustomName", "Custom Value"); + + await completionFeature.CompleteAsync().DefaultTimeout(); + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // RequestAborted will no longer fire after CompleteAsync. + Assert.False(context.RequestAborted.CanBeCanceled); + context.Abort(); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 25, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + + _decodedHeaders.Clear(); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); + } + + [Fact] + public async Task AbortAfterCompleteAsync_POSTWithResponseBodyAndTrailers_RequestBodyThrows() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + var requestBodyTask = context.Request.BodyReader.ReadAsync(); + + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + await context.Response.WriteAsync("Hello World"); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + context.Response.AppendTrailer("CustomName", "Custom Value"); + + await completionFeature.CompleteAsync().DefaultTimeout(); + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // RequestAborted will no longer fire after CompleteAsync. + Assert.False(context.RequestAborted.CanBeCanceled); + context.Abort(); + + await Assert.ThrowsAsync(async () => await requestBodyTask); + await Assert.ThrowsAsync(async () => await context.Request.BodyReader.ReadAsync()); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: false); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 25, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + + _decodedHeaders.Clear(); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); + } } } diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs index 826677071924..d30a30a9eb34 100644 --- a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs @@ -35,6 +35,7 @@ public static string GenerateFile() { "IHttpUpgradeFeature", "IHttp2StreamIdFeature", + "IHttpResponseCompletionFeature", "IHttpResponseTrailersFeature", "IResponseCookiesFeature", "IItemsFeature", From 99cde9133385a3270ae382dd0a028c65c9149a67 Mon Sep 17 00:00:00 2001 From: Chris R Date: Thu, 13 Jun 2019 15:36:05 -0700 Subject: [PATCH 2/4] Use ProduceEnd --- .../Core/src/Internal/Http/HttpProtocol.cs | 24 +++++++++---------- .../Http2/Http2Stream.FeatureCollection.cs | 9 ++----- .../Http2/Http2StreamTests.cs | 2 +- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index 37ad4d6dc4bd..bc053825c40c 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -635,6 +635,15 @@ private async Task ProcessRequests(IHttpApplication applicat // Run the application code for this request await application.ProcessRequestAsync(context); + // Trigger OnStarting if it hasn't been called yet and the app hasn't + // already failed. If an OnStarting callback throws we can go through + // our normal error handling in ProduceEnd. + // https://github.com/aspnet/KestrelHttpServer/issues/43 + if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0) + { + await FireOnStarting(); + } + if (!_connectionAborted && !VerifyResponseContentLength(out var lengthException)) { ReportApplicationError(lengthException); @@ -655,15 +664,6 @@ private async Task ProcessRequests(IHttpApplication applicat KestrelEventSource.Log.RequestStop(this); - // Trigger OnStarting if it hasn't been called yet and the app hasn't - // already failed. If an OnStarting callback throws we can go through - // our normal error handling in ProduceEnd. - // https://github.com/aspnet/KestrelHttpServer/issues/43 - if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0) - { - await FireOnStarting(); - } - // At this point all user code that needs use to the request or response streams has completed. // Using these streams in the OnCompleted callback is not allowed. StopBodies(); @@ -942,7 +942,7 @@ public void ProduceContinue() } } - public Task InitializeResponseAsync(int firstWriteByteCount, bool appCompleted = false) + public Task InitializeResponseAsync(int firstWriteByteCount) { var startingTask = FireOnStarting(); // If return is Task.CompletedTask no awaiting is required @@ -953,7 +953,7 @@ public Task InitializeResponseAsync(int firstWriteByteCount, bool appCompleted = VerifyInitializeState(firstWriteByteCount); - ProduceStart(appCompleted: appCompleted); + ProduceStart(appCompleted: false); return Task.CompletedTask; } @@ -1050,7 +1050,7 @@ protected Task ProduceEnd() return WriteSuffix(); } - protected Task WriteSuffix() + private Task WriteSuffix() { if (HasResponseCompleted) { diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs index 7ff3c4308b52..fb27e387d3a4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs @@ -59,12 +59,7 @@ async Task IHttpResponseCompletionFeature.CompleteAsync() // Finalize headers if (!HasResponseStarted) { - if (!VerifyResponseContentLength(out var lengthException)) - { - throw lengthException; - } - - await InitializeResponseAsync(0, appCompleted: true); + await FireOnStarting(); } // Flush headers, body, trailers... @@ -75,7 +70,7 @@ async Task IHttpResponseCompletionFeature.CompleteAsync() throw lengthException; } - await WriteSuffix(); + await ProduceEnd(); } } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index 810b3359c796..7614f84016da 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -3484,7 +3484,7 @@ await InitializeConnectionAsync(async context => var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout()); Assert.Equal(CoreStrings.FormatTooFewBytesWritten(0, 25), ex.Message); - Assert.False(startingTcs.Task.IsCompleted); // OnStarting did not get called. + Assert.True(startingTcs.Task.IsCompletedSuccessfully); Assert.False(context.Response.Headers.IsReadOnly); Assert.False(context.Features.Get().Trailers.IsReadOnly); From 992dacf8e917b1ef19449b5cdafd295cf0235abc Mon Sep 17 00:00:00 2001 From: Chris R Date: Thu, 13 Jun 2019 16:35:45 -0700 Subject: [PATCH 3/4] Refs ergriosergiuha4g3qai9u43089a34g --- .../ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs b/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs index 30f567851c78..6f53a07297ef 100644 --- a/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs +++ b/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs @@ -200,6 +200,10 @@ public partial interface IHttpRequestTrailersFeature bool Available { get; } Microsoft.AspNetCore.Http.IHeaderDictionary Trailers { get; } } + public partial interface IHttpResponseCompletionFeature + { + System.Threading.Tasks.Task CompleteAsync(); + } public partial interface IHttpResponseFeature { System.IO.Stream Body { get; set; } From 8ca6917725d9365c93034682b603b40635e8bd9a Mon Sep 17 00:00:00 2001 From: Chris R Date: Fri, 14 Jun 2019 12:52:49 -0700 Subject: [PATCH 4/4] Add pipe test --- .../Http2/Http2StreamTests.cs | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index 7614f84016da..7bdaa893a90b 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -3712,6 +3712,87 @@ await ExpectAsync(Http2FrameType.DATA, Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); } + [Fact] + public async Task CompleteAsync_AfterPipeWrite_WithTrailers_SendsBodyAndTrailersWithEndStream() + { + var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + try + { + context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; }); + var completionFeature = context.Features.Get(); + Assert.NotNull(completionFeature); + + var buffer = context.Response.BodyWriter.GetMemory(); + var length = Encoding.UTF8.GetBytes("Hello World", buffer.Span); + context.Response.BodyWriter.Advance(length); + + Assert.False(startingTcs.Task.IsCompletedSuccessfully); // OnStarting did not get called. + Assert.False(context.Response.Headers.IsReadOnly); + + context.Response.AppendTrailer("CustomName", "Custom Value"); + + await completionFeature.CompleteAsync().DefaultTimeout(); + Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called. + Assert.True(context.Response.Headers.IsReadOnly); + + Assert.True(context.Features.Get().Trailers.IsReadOnly); + + // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting. + await clientTcs.Task.DefaultTimeout(); + appTcs.SetResult(0); + } + catch (Exception ex) + { + appTcs.SetException(ex); + } + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + var bodyFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 11, + withFlags: (byte)(Http2HeadersFrameFlags.NONE), + withStreamId: 1); + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 25, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + clientTcs.SetResult(0); + await appTcs.Task; + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span)); + + _decodedHeaders.Clear(); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("Custom Value", _decodedHeaders["CustomName"]); + } + [Fact] public async Task CompleteAsync_AfterBodyStarted_WithTrailers_SendsBodyAndTrailersWithEndStream() {