Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep Kestrel's connection PipeReader in a consistent state #16725

Merged
merged 3 commits into from
Nov 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ private async Task PumpAsync()
}

// Read() will have already have greedily consumed the entire request body if able.
CheckCompletedReadResult(result);
if (result.IsCompleted)
{
ThrowUnexpectedEndOfRequestContent();
}
}
finally
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public override async ValueTask<ReadResult> ReadAsyncInternal(CancellationToken
if (_readCompleted)
{
_isReading = true;
return _readResult;
return new ReadResult(_readResult.Buffer, Interlocked.Exchange(ref _userCanceled, 0) == 1, _readResult.IsCompleted);
}

TryStart();
Expand All @@ -70,44 +70,47 @@ public override async ValueTask<ReadResult> ReadAsyncInternal(CancellationToken
}
catch (ConnectionAbortedException ex)
{
_isReading = false;
throw new TaskCanceledException("The request was aborted", ex);
}

void ResetReadingState()
{
_isReading = false;
// Reset the timing read here for the next call to read.
StopTimingRead(0);
_context.Input.AdvanceTo(_readResult.Buffer.Start);
}

if (_context.RequestTimedOut)
{
ResetReadingState();
BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout);
}

// Make sure to handle when this is canceled here.
if (_readResult.IsCanceled)
if (_readResult.IsCompleted)
{
if (Interlocked.Exchange(ref _userCanceled, 0) == 1)
{
// Ignore the readResult if it wasn't by the user.
CreateReadResultFromConnectionReadResult();

break;
}
else
{
// Reset the timing read here for the next call to read.
StopTimingRead(0);
continue;
}
ResetReadingState();
ThrowUnexpectedEndOfRequestContent();
}

var readableBuffer = _readResult.Buffer;
var readableBufferLength = readableBuffer.Length;
StopTimingRead(readableBufferLength);
// Ignore the canceled readResult if it wasn't canceled by the user.
if (!_readResult.IsCanceled || Interlocked.Exchange(ref _userCanceled, 0) == 1)
{
var returnedReadResultLength = CreateReadResultFromConnectionReadResult();

CheckCompletedReadResult(_readResult);
// Don't count bytes belonging to the next request, since read rate timeouts are done on a per-request basis.
StopTimingRead(returnedReadResultLength);

if (readableBufferLength > 0)
{
CreateReadResultFromConnectionReadResult();
if (_readResult.IsCompleted)
{
TryStop();
}

break;
}

ResetReadingState();
}

return _readResult;
Expand All @@ -129,66 +132,69 @@ public override bool TryReadInternal(out ReadResult readResult)
if (_readCompleted)
{
_isReading = true;
readResult = _readResult;
readResult = new ReadResult(_readResult.Buffer, Interlocked.Exchange(ref _userCanceled, 0) == 1, _readResult.IsCompleted);
return true;
}

TryStart();

if (!_context.Input.TryRead(out _readResult))
{
readResult = default;
return false;
}

if (_readResult.IsCanceled)
// The while(true) because we don't want to return a canceled ReadResult if the user themselves didn't cancel it.
while (true)
{
if (Interlocked.Exchange(ref _userCanceled, 0) == 0)
if (!_context.Input.TryRead(out _readResult))
{
// Cancellation wasn't by the user, return default ReadResult
readResult = default;
return false;
}
}

// Only set _isReading if we are returing true.
_isReading = true;
if (!_readResult.IsCanceled || Interlocked.Exchange(ref _userCanceled, 0) == 1)
{
break;
}

CreateReadResultFromConnectionReadResult();
_context.Input.AdvanceTo(_readResult.Buffer.Start);
}

readResult = _readResult;
CountBytesRead(readResult.Buffer.Length);
if (_readResult.IsCompleted)
{
_context.Input.AdvanceTo(_readResult.Buffer.Start);
ThrowUnexpectedEndOfRequestContent();
}

return true;
}
var returnedReadResultLength = CreateReadResultFromConnectionReadResult();

public override Task ConsumeAsync()
{
TryStart();
// Don't count bytes belonging to the next request, since read rate timeouts are done on a per-request basis.
CountBytesRead(returnedReadResultLength);

if (!_readResult.Buffer.IsEmpty && _inputLength == 0)
// Only set _isReading if we are returning true.
_isReading = true;
readResult = _readResult;

if (readResult.IsCompleted)
{
_context.Input.AdvanceTo(_readResult.Buffer.End);
TryStop();
}

return OnConsumeAsync();
return true;
}

private void CreateReadResultFromConnectionReadResult()
private long CreateReadResultFromConnectionReadResult()
{
if (_readResult.Buffer.Length >= _inputLength + _examinedUnconsumedBytes)
{
_readCompleted = true;
_readResult = new ReadResult(
_readResult.Buffer.Slice(0, _inputLength + _examinedUnconsumedBytes),
_readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 1,
_readCompleted);
}
var initialLength = _readResult.Buffer.Length;
var maxLength = _inputLength + _examinedUnconsumedBytes;

if (_readResult.IsCompleted)
if (initialLength < maxLength)
{
TryStop();
return initialLength;
}

_readCompleted = true;
_readResult = new ReadResult(
_readResult.Buffer.Slice(0, maxLength),
_readResult.IsCanceled,
isCompleted: true);

return maxLength;
}

public override void AdvanceTo(SequencePosition consumed)
Expand All @@ -207,9 +213,10 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami

if (_readCompleted)
{
_readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), Interlocked.Exchange(ref _userCanceled, 0) == 1, _readCompleted);
// If the old stored _readResult was canceled, it's already been observed. Do not store a canceled read result permanently.
_readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), isCanceled: false, _readCompleted);

if (_readResult.Buffer.Length == 0 && !_finalAdvanceCalled)
if (!_finalAdvanceCalled && _readResult.Buffer.Length == 0)
{
_context.Input.AdvanceTo(consumed);
_finalAdvanceCalled = true;
Expand Down
23 changes: 11 additions & 12 deletions src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -21,19 +22,17 @@ protected Http1MessageBody(Http1Connection context)
_context = context;
}

protected void CheckCompletedReadResult(ReadResult result)
[StackTraceHidden]
protected void ThrowUnexpectedEndOfRequestContent()
{
if (result.IsCompleted)
{
// OnInputOrOutputCompleted() is an idempotent method that closes the connection. Sometimes
// input completion is observed here before the Input.OnWriterCompleted() callback is fired,
// so we call OnInputOrOutputCompleted() now to prevent a race in our tests where a 400
// response is written after observing the unexpected end of request content instead of just
// closing the connection without a response as expected.
_context.OnInputOrOutputCompleted();

BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent);
}
// OnInputOrOutputCompleted() is an idempotent method that closes the connection. Sometimes
// input completion is observed here before the Input.OnWriterCompleted() callback is fired,
// so we call OnInputOrOutputCompleted() now to prevent a race in our tests where a 400
// response is written after observing the unexpected end of request content instead of just
// closing the connection without a response as expected.
_context.OnInputOrOutputCompleted();

BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent);
}

public abstract bool TryReadInternal(out ReadResult readResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,13 @@ public Http1UpgradeMessageBody(Http1Connection context)

public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
{
if (_completed)
{
throw new InvalidOperationException("Reading is not allowed after the reader was completed.");
}
ThrowIfCompleted();
return _context.Input.ReadAsync(cancellationToken);
}

public override bool TryRead(out ReadResult result)
{
if (_completed)
{
throw new InvalidOperationException("Reading is not allowed after the reader was completed.");
}
ThrowIfCompleted();
return _context.Input.TryRead(out result);
}

Expand Down
50 changes: 50 additions & 0 deletions src/Servers/Kestrel/Core/test/MessageBodyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,56 @@ public async Task CompleteForContentLengthDoesNotCompleteConnectionPipeMakesRead
}
}

[Fact]
public async Task UnexpectedEndOfRequestContentIsRepeatedlyThrownForContentLengthBody()
{
using (var input = new TestInput())
{
var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection);
var reader = new HttpRequestPipeReader();
reader.StartAcceptingReads(body);

input.Application.Output.Complete();

var ex0 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex1 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex2 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());
var ex3 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());

Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex0.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex1.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex2.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex3.Reason);

await body.StopAsync();
}
}

[Fact]
public async Task UnexpectedEndOfRequestContentIsRepeatedlyThrownForChunkedBody()
{
using (var input = new TestInput())
{
var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection);
var reader = new HttpRequestPipeReader();
reader.StartAcceptingReads(body);

input.Application.Output.Complete();

var ex0 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex1 = Assert.Throws<BadHttpRequestException>(() => reader.TryRead(out var readResult));
var ex2 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());
var ex3 = await Assert.ThrowsAsync<BadHttpRequestException>(() => reader.ReadAsync().AsTask());

Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex0.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex1.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex2.Reason);
Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, ex3.Reason);

await body.StopAsync();
}
}

[Fact]
public async Task CompleteForChunkedAllowsConsumeToWork()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Serilog;
using Xunit;

namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
{
public class RequestTests : LoggedTest
public class RequestTests : TestApplicationErrorLoggerLoggedTest
{
[Fact]
public async Task StreamsAreNotPersistedAcrossRequests()
Expand Down Expand Up @@ -1440,6 +1441,39 @@ await connection.Receive(
}
}

[Fact]
public async Task ContentLengthSwallowedUnexpectedEndOfRequestContentDoesNotResultInWarnings()
{
var testContext = new TestServiceContext(LoggerFactory);

await using (var server = new TestServer(async httpContext =>
{
try
{
await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1);
}
catch
{
}
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
"Content-Length: 5",
"",
"");
connection.ShutdownSend();

await connection.ReceiveEnd();
}
}

Assert.Empty(TestApplicationErrorLogger.Messages.Where(m => m.LogLevel >= LogLevel.Warning));
}

[Fact]
public async Task ContentLengthRequestCallCancelPendingReadWorks()
{
Expand Down