Skip to content

Commit

Permalink
Keep Kestrel's connection PipeReader in a consistent state (#16725)
Browse files Browse the repository at this point in the history
- When the request body PipeReader.ReadAsync throws, the connection-level
pipe should be advanced, so subsequent attempts to read from the
connection-level pipe don't fail unnecessarily
  • Loading branch information
halter73 authored Nov 4, 2019
1 parent da20a12 commit e3b971a
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ private async Task PumpAsync()
}

// Read() will have already have greedily consumed the entire request body if able.
CheckCompletedReadResult(result);
if (result.IsCompleted)
{
ThrowUnexpectedEndOfRequestContent();
}
}
finally
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public override async ValueTask<ReadResult> ReadAsyncInternal(CancellationToken
if (_readCompleted)
{
_isReading = true;
return _readResult;
return new ReadResult(_readResult.Buffer, Interlocked.Exchange(ref _userCanceled, 0) == 1, _readResult.IsCompleted);
}

TryStart();
Expand All @@ -70,44 +70,47 @@ public override async ValueTask<ReadResult> ReadAsyncInternal(CancellationToken
}
catch (ConnectionAbortedException ex)
{
_isReading = false;
throw new TaskCanceledException("The request was aborted", ex);
}

void ResetReadingState()
{
_isReading = false;
// Reset the timing read here for the next call to read.
StopTimingRead(0);
_context.Input.AdvanceTo(_readResult.Buffer.Start);
}

if (_context.RequestTimedOut)
{
ResetReadingState();
BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout);
}

// Make sure to handle when this is canceled here.
if (_readResult.IsCanceled)
if (_readResult.IsCompleted)
{
if (Interlocked.Exchange(ref _userCanceled, 0) == 1)
{
// Ignore the readResult if it wasn't by the user.
CreateReadResultFromConnectionReadResult();

break;
}
else
{
// Reset the timing read here for the next call to read.
StopTimingRead(0);
continue;
}
ResetReadingState();
ThrowUnexpectedEndOfRequestContent();
}

var readableBuffer = _readResult.Buffer;
var readableBufferLength = readableBuffer.Length;
StopTimingRead(readableBufferLength);
// Ignore the canceled readResult if it wasn't canceled by the user.
if (!_readResult.IsCanceled || Interlocked.Exchange(ref _userCanceled, 0) == 1)
{
var returnedReadResultLength = CreateReadResultFromConnectionReadResult();

CheckCompletedReadResult(_readResult);
// Don't count bytes belonging to the next request, since read rate timeouts are done on a per-request basis.
StopTimingRead(returnedReadResultLength);

if (readableBufferLength > 0)
{
CreateReadResultFromConnectionReadResult();
if (_readResult.IsCompleted)
{
TryStop();
}

break;
}

ResetReadingState();
}

return _readResult;
Expand All @@ -129,66 +132,69 @@ public override bool TryReadInternal(out ReadResult readResult)
if (_readCompleted)
{
_isReading = true;
readResult = _readResult;
readResult = new ReadResult(_readResult.Buffer, Interlocked.Exchange(ref _userCanceled, 0) == 1, _readResult.IsCompleted);
return true;
}

TryStart();

if (!_context.Input.TryRead(out _readResult))
{
readResult = default;
return false;
}

if (_readResult.IsCanceled)
// The while(true) because we don't want to return a canceled ReadResult if the user themselves didn't cancel it.
while (true)
{
if (Interlocked.Exchange(ref _userCanceled, 0) == 0)
if (!_context.Input.TryRead(out _readResult))
{
// Cancellation wasn't by the user, return default ReadResult
readResult = default;
return false;
}
}

// Only set _isReading if we are returing true.
_isReading = true;
if (!_readResult.IsCanceled || Interlocked.Exchange(ref _userCanceled, 0) == 1)
{
break;
}

CreateReadResultFromConnectionReadResult();
_context.Input.AdvanceTo(_readResult.Buffer.Start);
}

readResult = _readResult;
CountBytesRead(readResult.Buffer.Length);
if (_readResult.IsCompleted)
{
_context.Input.AdvanceTo(_readResult.Buffer.Start);
ThrowUnexpectedEndOfRequestContent();
}

return true;
}
var returnedReadResultLength = CreateReadResultFromConnectionReadResult();

public override Task ConsumeAsync()
{
TryStart();
// Don't count bytes belonging to the next request, since read rate timeouts are done on a per-request basis.
CountBytesRead(returnedReadResultLength);

if (!_readResult.Buffer.IsEmpty && _inputLength == 0)
// Only set _isReading if we are returning true.
_isReading = true;
readResult = _readResult;

if (readResult.IsCompleted)
{
_context.Input.AdvanceTo(_readResult.Buffer.End);
TryStop();
}

return OnConsumeAsync();
return true;
}

private void CreateReadResultFromConnectionReadResult()
private long CreateReadResultFromConnectionReadResult()
{
if (_readResult.Buffer.Length >= _inputLength + _examinedUnconsumedBytes)
{
_readCompleted = true;
_readResult = new ReadResult(
_readResult.Buffer.Slice(0, _inputLength + _examinedUnconsumedBytes),
_readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 1,
_readCompleted);
}
var initialLength = _readResult.Buffer.Length;
var maxLength = _inputLength + _examinedUnconsumedBytes;

if (_readResult.IsCompleted)
if (initialLength < maxLength)
{
TryStop();
return initialLength;
}

_readCompleted = true;
_readResult = new ReadResult(
_readResult.Buffer.Slice(0, maxLength),
_readResult.IsCanceled,
isCompleted: true);

return maxLength;
}

public override void AdvanceTo(SequencePosition consumed)
Expand All @@ -207,9 +213,10 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami

if (_readCompleted)
{
_readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), Interlocked.Exchange(ref _userCanceled, 0) == 1, _readCompleted);
// If the old stored _readResult was canceled, it's already been observed. Do not store a canceled read result permanently.
_readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), isCanceled: false, _readCompleted);

if (_readResult.Buffer.Length == 0 && !_finalAdvanceCalled)
if (!_finalAdvanceCalled && _readResult.Buffer.Length == 0)
{
_context.Input.AdvanceTo(consumed);
_finalAdvanceCalled = true;
Expand Down
23 changes: 11 additions & 12 deletions src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -21,19 +22,17 @@ protected Http1MessageBody(Http1Connection context)
_context = context;
}

protected void CheckCompletedReadResult(ReadResult result)
[StackTraceHidden]
protected void ThrowUnexpectedEndOfRequestContent()
{
if (result.IsCompleted)
{
// OnInputOrOutputCompleted() is an idempotent method that closes the connection. Sometimes
// input completion is observed here before the Input.OnWriterCompleted() callback is fired,
// so we call OnInputOrOutputCompleted() now to prevent a race in our tests where a 400
// response is written after observing the unexpected end of request content instead of just
// closing the connection without a response as expected.
_context.OnInputOrOutputCompleted();

BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent);
}
// OnInputOrOutputCompleted() is an idempotent method that closes the connection. Sometimes
// input completion is observed here before the Input.OnWriterCompleted() callback is fired,
// so we call OnInputOrOutputCompleted() now to prevent a race in our tests where a 400
// response is written after observing the unexpected end of request content instead of just
// closing the connection without a response as expected.
_context.OnInputOrOutputCompleted();

BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent);
}

public abstract bool TryReadInternal(out ReadResult readResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,13 @@ public Http1UpgradeMessageBody(Http1Connection context)

public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
{
if (_completed)
{
throw new InvalidOperationException("Reading is not allowed after the reader was completed.");
}
ThrowIfCompleted();
return _context.Input.ReadAsync(cancellationToken);
}

public override bool TryRead(out ReadResult result)
{
if (_completed)
{
throw new InvalidOperationException("Reading is not allowed after the reader was completed.");
}
ThrowIfCompleted();
return _context.Input.TryRead(out result);
}

Expand Down
50 changes: 50 additions & 0 deletions src/Servers/Kestrel/Core/test/MessageBodyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,56 @@ public async Task CompleteForContentLengthDoesNotCompleteConnectionPipeMakesRead
}
}

[Fact]
public async Task UnexpectedEndOfRequestContentIsRepeatedlyThrownForContentLengthBody()
{
using (var input = new TestInput())
{
var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection);
var reader = new HttpRequestPipeReader();
reader.StartAcceptingReads(body);

input.Application.Output.Complete();

var ex0 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex1 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex2 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());
var ex3 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());

Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex0.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex1.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex2.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex3.Reason);

await body.StopAsync();
}
}

[Fact]
public async Task UnexpectedEndOfRequestContentIsRepeatedlyThrownForChunkedBody()
{
using (var input = new TestInput())
{
var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection);
var reader = new HttpRequestPipeReader();
reader.StartAcceptingReads(body);

input.Application.Output.Complete();

var ex0 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex1 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex2 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());
var ex3 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());

Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex0.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex1.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex2.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex3.Reason);

await body.StopAsync();
}
}

[Fact]
public async Task CompleteForChunkedAllowsConsumeToWork()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Serilog;
using Xunit;

namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
{
public class RequestTests : LoggedTest
public class RequestTests : TestApplicationErrorLoggerLoggedTest
{
[Fact]
public async Task StreamsAreNotPersistedAcrossRequests()
Expand Down Expand Up @@ -1440,6 +1441,39 @@ await connection.Receive(
}
}

[Fact]
public async Task ContentLengthSwallowedUnexpectedEndOfRequestContentDoesNotResultInWarnings()
{
var testContext = new TestServiceContext(LoggerFactory);

await using (var server = new TestServer(async httpContext =>
{
try
{
await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1);
}
catch
{
}
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
"Content-Length: 5",
"",
"");
connection.ShutdownSend();

await connection.ReceiveEnd();
}
}

Assert.Empty(TestApplicationErrorLogger.Messages.Where(m => m.LogLevel >= LogLevel.Warning));
}

[Fact]
public async Task ContentLengthRequestCallCancelPendingReadWorks()
{
Expand Down

0 comments on commit e3b971a

Please sign in to comment.