diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 3aa5b989baf8a..7c3e4d4cbb9e8 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -147,15 +147,12 @@ public async Task SendAsync(CancellationToken cancellationT // Ideally, headers will be sent out in a gathered write inside of SendContentAsync(). // If we don't have content, or we are doing Expect 100 Continue, then we can't rely on // this and must send our headers immediately. + await FlushSendBufferAsync(requestCancellationSource.Token).ConfigureAwait(false); - await _stream.WriteAsync(_sendBuffer.ActiveMemory, endStream: _expect100ContinueCompletionSource == null, requestCancellationSource.Token).ConfigureAwait(false); - _sendBuffer.Discard(_sendBuffer.ActiveLength); - - if (_expect100ContinueCompletionSource != null) + // End the stream writing if there's no content to send. + if (_request.Content == null) { - // Flush to ensure we get a response. - // TODO: MsQuic may not need any flushing. - await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); + _stream.Shutdown(); } } @@ -375,8 +372,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance { // Our initial send buffer, which has our headers, are normally sent out on the first write to the Http3WriteStream. // If we get here, it means the content didn't actually do any writing. Send out the headers now. - await _stream.WriteAsync(_sendBuffer.ActiveMemory, cancellationToken).ConfigureAwait(false); - _sendBuffer.Discard(_sendBuffer.ActiveLength); + await FlushSendBufferAsync(cancellationToken).ConfigureAwait(false); } _stream.Shutdown(); @@ -436,6 +432,14 @@ private async ValueTask WriteRequestContentAsync(ReadOnlyMemory buffer, Ca } } + private async ValueTask FlushSendBufferAsync(CancellationToken cancellationToken) + { + await _stream.WriteAsync(_sendBuffer.ActiveMemory, cancellationToken).ConfigureAwait(false); + _sendBuffer.Discard(_sendBuffer.ActiveLength); + + await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); + } + private async ValueTask DrainContentLength0Frames(CancellationToken cancellationToken) { Http3FrameType? frameType; @@ -1317,6 +1321,16 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo return _stream.WriteRequestContentAsync(buffer, cancellationToken); } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (_stream == null) + { + return Task.FromException(new ObjectDisposedException(nameof(Http3WriteStream))); + } + + return _stream.FlushSendBufferAsync(cancellationToken).AsTask(); + } } private enum HeaderState diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 0e1274d502f5f..ec8dad21da959 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -379,6 +379,56 @@ public async Task ServerCertificateCustomValidationCallback_Succeeds() await serverTask; } + [Fact] + public async Task EmptyCustomContent_FlushHeaders() + { + using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + TaskCompletionSource headersReceived = new TaskCompletionSource(); + + Task serverTask = Task.Run(async () => + { + using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + + // Receive headers and unblock the client. + await stream.ReadRequestDataAsync(false); + headersReceived.SetResult(); + + await stream.ReadRequestBodyAsync(); + await stream.SendResponseAsync(); + }); + + Task clientTask = Task.Run(async () => + { + StreamingHttpContent requestContent = new StreamingHttpContent(); + + using HttpClient client = CreateHttpClient(); + using HttpRequestMessage request = new() + { + Method = HttpMethod.Post, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact, + Content = requestContent + }; + + Task responseTask = client.SendAsync(request); + + Stream requestStream = await requestContent.GetStreamAsync(); + await requestStream.FlushAsync(); + + await headersReceived.Task; + + requestContent.CompleteStream(); + + using HttpResponseMessage response = await responseTask; + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); + } + [Fact] public async Task DisposeHttpClient_Http3ConnectionIsClosed() { @@ -826,4 +876,41 @@ public static TheoryData InteropUrisWithContent() => { "https://pgjones.dev/" }, // aioquic with content }; } + + internal class StreamingHttpContent : HttpContent + { + private readonly TaskCompletionSource _completeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _getStreamTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) + { + throw new NotSupportedException(); + } + + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) + { + _getStreamTcs.TrySetResult(stream); + + var cancellationTcs = new TaskCompletionSource(); + cancellationToken.Register(() => cancellationTcs.TrySetCanceled()); + + await Task.WhenAny(_completeTcs.Task, cancellationTcs.Task); + } + + protected override bool TryComputeLength(out long length) + { + length = -1; + return false; + } + + public Task GetStreamAsync() + { + return _getStreamTcs.Task; + } + + public void CompleteStream() + { + _completeTcs.TrySetResult(); + } + } }