Skip to content
This repository has been archived by the owner on Dec 18, 2018. It is now read-only.

Call OnStarting before verifying response length (#1289) #1302

Merged
merged 2 commits into from
Jan 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -533,23 +533,29 @@ protected async Task FireOnCompleted()

public void Flush()
{
ProduceStartAndFireOnStarting().GetAwaiter().GetResult();
InitializeResponse(0).GetAwaiter().GetResult();
Output.Flush();
}

public async Task FlushAsync(CancellationToken cancellationToken)
{
await ProduceStartAndFireOnStarting();
await InitializeResponse(0);
await Output.FlushAsync(cancellationToken);
}

public void Write(ArraySegment<byte> data)
{
// For the first write, ensure headers are flushed if Write(Chunked)isn't called.
// For the first write, ensure headers are flushed if Write(Chunked) isn't called.
var firstWrite = !HasResponseStarted;

VerifyAndUpdateWrite(data.Count);
ProduceStartAndFireOnStarting().GetAwaiter().GetResult();
if (firstWrite)
{
InitializeResponse(data.Count).GetAwaiter().GetResult();
}
else
{
VerifyAndUpdateWrite(data.Count);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change InitializeResponse to not call VerifyAndUpdateWrite, and just call VerifyAndUpdateWrite afterwards?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the check in InitializeResponse so we can return a 500 if the first write exceeds Content-Length. Without the check there a 200 response still started, despite the code failing later.

}

if (_canHaveBody)
{
Expand Down Expand Up @@ -616,9 +622,7 @@ public Task WriteAsync(ArraySegment<byte> data, CancellationToken cancellationTo

public async Task WriteAsyncAwaited(ArraySegment<byte> data, CancellationToken cancellationToken)
{
VerifyAndUpdateWrite(data.Count);

await ProduceStartAndFireOnStarting();
await InitializeResponseAwaited(data.Count);

// WriteAsyncAwaited is only called for the first write to the body.
// Ensure headers are flushed if Write(Chunked)Async isn't called.
Expand Down Expand Up @@ -691,7 +695,13 @@ protected void VerifyResponseContentLength()
responseHeaders.HeaderContentLengthValue.HasValue &&
_responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value)
{
_keepAlive = false;
// We need to close the connection if any bytes were written since the client
// cannot be certain of how many bytes it will receive.
if (_responseBytesWritten > 0)
{
_keepAlive = false;
}

ReportApplicationError(new InvalidOperationException(
$"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value})."));
}
Expand Down Expand Up @@ -734,7 +744,7 @@ public void ProduceContinue()
}
}

public Task ProduceStartAndFireOnStarting()
public Task InitializeResponse(int firstWriteByteCount)
{
if (HasResponseStarted)
{
Expand All @@ -743,19 +753,21 @@ public Task ProduceStartAndFireOnStarting()

if (_onStarting != null)
{
return ProduceStartAndFireOnStartingAwaited();
return InitializeResponseAwaited(firstWriteByteCount);
}

if (_applicationException != null)
{
ThrowResponseAbortedException();
}

VerifyAndUpdateWrite(firstWriteByteCount);
ProduceStart(appCompleted: false);

return TaskCache.CompletedTask;
}

private async Task ProduceStartAndFireOnStartingAwaited()
private async Task InitializeResponseAwaited(int firstWriteByteCount)
{
await FireOnStarting();

Expand All @@ -764,6 +776,7 @@ private async Task ProduceStartAndFireOnStartingAwaited()
ThrowResponseAbortedException();
}

VerifyAndUpdateWrite(firstWriteByteCount);
ProduceStart(appCompleted: false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ await connection.Receive(
}

[Fact]
public async Task WhenAppWritesMoreThanContentLengthWriteThrowsAndConnectionCloses()
public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWrite()
{
var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
Expand Down Expand Up @@ -520,7 +520,7 @@ await connection.ReceiveEnd(
}

[Fact]
public async Task WhenAppWritesMoreThanContentLengthWriteAsyncThrowsAndConnectionCloses()
public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWriteAsync()
{
var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
Expand Down Expand Up @@ -554,15 +554,52 @@ await connection.ReceiveForcedEnd(
}

[Fact]
public async Task WhenAppWritesMoreThanContentLengthAndResponseNotStarted500ResponseSentAndConnectionCloses()
public async Task InternalServerErrorAndConnectionClosedOnWriteWithMoreThanContentLengthAndResponseNotStarted()
{
var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };

using (var server = new TestServer(async httpContext =>
using (var server = new TestServer(httpContext =>
{
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = 5;
await httpContext.Response.WriteAsync("hello, world");
httpContext.Response.Body.Write(response, 0, response.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no longer async?

Copy link
Contributor Author

@cesarblum cesarblum Jan 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The WriteAsync version of this test is just below. Need to verify both calls.

return TaskCache.CompletedTask;
}, serviceContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.ReceiveForcedEnd(
$"HTTP/1.1 500 Internal Server Error",
"Connection: close",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
}

var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error);
Assert.Equal(
$"Response Content-Length mismatch: too many bytes written (12 of 5).",
logMessage.Exception.Message);
}

[Fact]
public async Task InternalServerErrorAndConnectionClosedOnWriteAsyncWithMoreThanContentLengthAndResponseNotStarted()
{
var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };

using (var server = new TestServer(httpContext =>
{
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = 5;
return httpContext.Response.Body.WriteAsync(response, 0, response.Length);
}, serviceContext))
{
using (var connection = server.CreateConnection())
Expand Down Expand Up @@ -631,7 +668,7 @@ await connection.ReceiveEnd(
}

[Fact]
public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionCloses()
public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionDoesNotClose()
{
var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
Expand All @@ -645,23 +682,27 @@ public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndC
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"GET / HTTP/1.1",
"",
"");
await connection.ReceiveForcedEnd(
$"HTTP/1.1 500 Internal Server Error",
"Connection: close",
await connection.Receive(
"HTTP/1.1 500 Internal Server Error",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"HTTP/1.1 500 Internal Server Error",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
}

var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error);
Assert.Equal(
$"Response Content-Length mismatch: too few bytes written (0 of 5).",
errorMessage.Exception.Message);
var error = testLogger.Messages.Where(message => message.LogLevel == LogLevel.Error);
Assert.Equal(2, error.Count());
Assert.All(error, message => message.Equals("Response Content-Length mismatch: too few bytes written (0 of 5)."));
}

[Theory]
Expand Down Expand Up @@ -1065,6 +1106,170 @@ await connection.ReceiveEnd(
}
}

[Fact]
public async Task FirstWriteVerifiedAfterOnStarting()
{
using (var server = new TestServer(httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});

var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;

// If OnStarting is not run before verifying writes, an error response will be sent.
httpContext.Response.Body.Write(response, 0, response.Length);
return TaskCache.CompletedTask;
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"c",
"hello, world",
"0",
"",
"");
}
}
}

[Fact]
public async Task SubsequentWriteVerifiedAfterOnStarting()
{
using (var server = new TestServer(httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});

var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;

// If OnStarting is not run before verifying writes, an error response will be sent.
httpContext.Response.Body.Write(response, 0, response.Length / 2);
httpContext.Response.Body.Write(response, response.Length / 2, response.Length - response.Length / 2);
return TaskCache.CompletedTask;
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"6",
"hello,",
"6",
" world",
"0",
"",
"");
}
}
}

[Fact]
public async Task FirstWriteAsyncVerifiedAfterOnStarting()
{
using (var server = new TestServer(httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});

var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;

// If OnStarting is not run before verifying writes, an error response will be sent.
return httpContext.Response.Body.WriteAsync(response, 0, response.Length);
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"c",
"hello, world",
"0",
"",
"");
}
}
}

[Fact]
public async Task SubsequentWriteAsyncVerifiedAfterOnStarting()
{
using (var server = new TestServer(async httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});

var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;

// If OnStarting is not run before verifying writes, an error response will be sent.
await httpContext.Response.Body.WriteAsync(response, 0, response.Length / 2);
await httpContext.Response.Body.WriteAsync(response, response.Length / 2, response.Length - response.Length / 2);
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"6",
"hello,",
"6",
" world",
"0",
"",
"");
}
}
}

public static TheoryData<string, StringValues, string> NullHeaderData
{
get
Expand Down
Loading