Skip to content

Commit

Permalink
[HTTP/3] Flush sends buffered data (#57234)
Browse files Browse the repository at this point in the history
HTTP/3 write stream will send buffered headers if flush is called.

Fixes #56969
  • Loading branch information
ManickaP authored Aug 12, 2021
1 parent 60f1105 commit dcd1e03
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,12 @@ public async Task<HttpResponseMessage> SendAsync(CancellationToken cancellationT
// Ideally, headers will be sent out in a gathered write inside of SendContentAsync().
// If we don't have content, or we are doing Expect 100 Continue, then we can't rely on
// this and must send our headers immediately.
await FlushSendBufferAsync(requestCancellationSource.Token).ConfigureAwait(false);

await _stream.WriteAsync(_sendBuffer.ActiveMemory, endStream: _expect100ContinueCompletionSource == null, requestCancellationSource.Token).ConfigureAwait(false);
_sendBuffer.Discard(_sendBuffer.ActiveLength);

if (_expect100ContinueCompletionSource != null)
// End the stream writing if there's no content to send.
if (_request.Content == null)
{
// Flush to ensure we get a response.
// TODO: MsQuic may not need any flushing.
await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
_stream.Shutdown();
}
}

Expand Down Expand Up @@ -375,8 +372,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance
{
// Our initial send buffer, which has our headers, are normally sent out on the first write to the Http3WriteStream.
// If we get here, it means the content didn't actually do any writing. Send out the headers now.
await _stream.WriteAsync(_sendBuffer.ActiveMemory, cancellationToken).ConfigureAwait(false);
_sendBuffer.Discard(_sendBuffer.ActiveLength);
await FlushSendBufferAsync(cancellationToken).ConfigureAwait(false);
}

_stream.Shutdown();
Expand Down Expand Up @@ -436,6 +432,14 @@ private async ValueTask WriteRequestContentAsync(ReadOnlyMemory<byte> buffer, Ca
}
}

private async ValueTask FlushSendBufferAsync(CancellationToken cancellationToken)
{
await _stream.WriteAsync(_sendBuffer.ActiveMemory, cancellationToken).ConfigureAwait(false);
_sendBuffer.Discard(_sendBuffer.ActiveLength);

await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
}

private async ValueTask DrainContentLength0Frames(CancellationToken cancellationToken)
{
Http3FrameType? frameType;
Expand Down Expand Up @@ -1317,6 +1321,16 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo

return _stream.WriteRequestContentAsync(buffer, cancellationToken);
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
if (_stream == null)
{
return Task.FromException(new ObjectDisposedException(nameof(Http3WriteStream)));
}

return _stream.FlushSendBufferAsync(cancellationToken).AsTask();
}
}

private enum HeaderState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,56 @@ public async Task ServerCertificateCustomValidationCallback_Succeeds()
await serverTask;
}

[Fact]
public async Task EmptyCustomContent_FlushHeaders()
{
using Http3LoopbackServer server = CreateHttp3LoopbackServer();
TaskCompletionSource headersReceived = new TaskCompletionSource();

Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
// Receive headers and unblock the client.
await stream.ReadRequestDataAsync(false);
headersReceived.SetResult();
await stream.ReadRequestBodyAsync();
await stream.SendResponseAsync();
});

Task clientTask = Task.Run(async () =>
{
StreamingHttpContent requestContent = new StreamingHttpContent();
using HttpClient client = CreateHttpClient();
using HttpRequestMessage request = new()
{
Method = HttpMethod.Post,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact,
Content = requestContent
};
Task<HttpResponseMessage> responseTask = client.SendAsync(request);
Stream requestStream = await requestContent.GetStreamAsync();
await requestStream.FlushAsync();
await headersReceived.Task;
requestContent.CompleteStream();
using HttpResponseMessage response = await responseTask;
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
});

await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

[Fact]
public async Task DisposeHttpClient_Http3ConnectionIsClosed()
{
Expand Down Expand Up @@ -826,4 +876,41 @@ public static TheoryData<string> InteropUrisWithContent() =>
{ "https://pgjones.dev/" }, // aioquic with content
};
}

internal class StreamingHttpContent : HttpContent
{
private readonly TaskCompletionSource _completeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<Stream> _getStreamTcs = new TaskCompletionSource<Stream>(TaskCreationOptions.RunContinuationsAsynchronously);

protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
{
throw new NotSupportedException();
}

protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
{
_getStreamTcs.TrySetResult(stream);

var cancellationTcs = new TaskCompletionSource();
cancellationToken.Register(() => cancellationTcs.TrySetCanceled());

await Task.WhenAny(_completeTcs.Task, cancellationTcs.Task);
}

protected override bool TryComputeLength(out long length)
{
length = -1;
return false;
}

public Task<Stream> GetStreamAsync()
{
return _getStreamTcs.Task;
}

public void CompleteStream()
{
_completeTcs.TrySetResult();
}
}
}

0 comments on commit dcd1e03

Please sign in to comment.