diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs index ac8307de6e46..0ac56eed4b4f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs @@ -68,25 +68,21 @@ public void StreamReset() _dataWriteProcessingTask = ProcessDataWrites().Preserve(); } - public void Dispose() + // Called once Application code has exited + // Or on Dispose which also would occur after Application code finished + public void Complete() { lock (_dataWriterLock) { - if (_disposed) - { - return; - } - - _disposed = true; - Stop(); + _pipeWriter.Complete(); + if (_fakeMemoryOwner != null) { _fakeMemoryOwner.Dispose(); _fakeMemoryOwner = null; } - if (_fakeMemory != null) { ArrayPool.Shared.Return(_fakeMemory); @@ -95,6 +91,21 @@ public void Dispose() } } + public void Dispose() + { + lock (_dataWriterLock) + { + if (_disposed) + { + return; + } + + _disposed = true; + + Complete(); + } + } + // In HTTP/1.x, this aborts the entire connection. For HTTP/3 we abort the stream. void IHttpOutputAborter.Abort(ConnectionAbortedException abortReason, ConnectionEndReason reason) { @@ -288,7 +299,9 @@ public void Stop() _streamCompleted = true; - _pipeWriter.Complete(new OperationCanceledException()); + // Application code could be using this PipeWriter, we cancel the next (or in progress) flush so they can observe this Stop + // Additionally, _streamCompleted will cause any future PipeWriter operations to noop + _pipeWriter.CancelPendingFlush(); } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index ccf595d7f89f..795fdc42b521 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -561,6 +561,8 @@ private void CompleteStream(bool errored) TryClose(); } + _http3Output.Complete(); + // Stream will be pooled after app completed. // Wait to signal app completed after any potential aborts on the stream. _appCompletedTaskSource.SetResult(null); diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj b/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj index 9258e26fcba1..055f5f8e297e 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj +++ b/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj @@ -44,5 +44,6 @@ + diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs index 1b7ae1ff0132..8fe339e066d5 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Diagnostics; using System.Diagnostics.Metrics; using System.Net; @@ -1145,6 +1146,137 @@ public async Task POST_Bidirectional_LargeData_Cancellation_Error(HttpProtocols } } + internal class MemoryPoolFeature : IMemoryPoolFeature + { + public MemoryPool MemoryPool { get; set; } + } + + [ConditionalTheory] + [MsQuicSupported] + [InlineData(HttpProtocols.Http3)] + [InlineData(HttpProtocols.Http2)] + public async Task ApplicationWriteWhenConnectionClosesPreservesMemory(HttpProtocols protocol) + { + // Arrange + var memoryPool = new DiagnosticMemoryPool(new PinnedBlockMemoryPool(), allowLateReturn: true); + + var writingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var cancelTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var completionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var builder = CreateHostBuilder(async context => + { + try + { + var requestBody = context.Request.Body; + + await context.Response.BodyWriter.FlushAsync(); + + // Test relies on Htt2Stream/Http3Stream aborting the token after stopping Http2OutputProducer/Http3OutputProducer + // It's very fragile but it is sort of a best effort test anyways + // Additionally, Http2 schedules it's stopping, so doesn't directly do anything to the PipeWriter when calling stop on Http2OutputProducer + context.RequestAborted.Register(() => + { + cancelTcs.SetResult(); + }); + + while (true) + { + var memory = context.Response.BodyWriter.GetMemory(); + + // Unblock client-side to close the connection + writingTcs.TrySetResult(); + + await cancelTcs.Task; + + // Verify memory is still rented from the memory pool after the producer has been stopped + Assert.True(memoryPool.ContainsMemory(memory)); + + context.Response.BodyWriter.Advance(memory.Length); + var flushResult = await context.Response.BodyWriter.FlushAsync(); + + if (flushResult.IsCanceled || flushResult.IsCompleted) + { + break; + } + } + + completionTcs.SetResult(); + } + catch (Exception ex) + { + writingTcs.TrySetException(ex); + // Exceptions annoyingly don't show up on the client side when doing E2E + cancellation testing + // so we need to use a TCS to observe any unexpected errors + completionTcs.TrySetException(ex); + throw; + } + }, protocol: protocol, + configureKestrel: o => + { + o.Listen(IPAddress.Parse("127.0.0.1"), 0, listenOptions => + { + listenOptions.Protocols = protocol; + listenOptions.UseHttps(TestResources.GetTestCertificate()).Use(@delegate => + { + // Connection middleware for Http/1.1 and Http/2 + return (context) => + { + // Set the memory pool used by the connection so we can observe if memory from the PipeWriter is still rented from the pool + context.Features.Set(new MemoryPoolFeature() { MemoryPool = memoryPool }); + return @delegate(context); + }; + }); + + IMultiplexedConnectionBuilder multiplexedConnectionBuilder = listenOptions; + multiplexedConnectionBuilder.Use(@delegate => + { + // Connection middleware for Http/3 + return (context) => + { + // Set the memory pool used by the connection so we can observe if memory from the PipeWriter is still rented from the pool + context.Features.Set(new MemoryPoolFeature() { MemoryPool = memoryPool }); + return @delegate(context); + }; + }); + }); + }); + + var httpClientHandler = new HttpClientHandler(); + httpClientHandler.ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator; + + using (var host = builder.Build()) + using (var client = new HttpClient(httpClientHandler)) + { + await host.StartAsync().DefaultTimeout(); + + var cts = new CancellationTokenSource(); + + var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); + request.Version = GetProtocol(protocol); + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + // Act + var responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + + Logger.LogInformation("Client waiting for headers."); + var response = await responseTask.DefaultTimeout(); + await writingTcs.Task; + + Logger.LogInformation("Client canceled request."); + response.Dispose(); + + // Assert + await host.StopAsync().DefaultTimeout(); + + await completionTcs.Task; + + memoryPool.Dispose(); + + await memoryPool.WhenAllBlocksReturnedAsync(TimeSpan.FromSeconds(15)); + } + } + // Verify HTTP/2 and HTTP/3 match behavior [ConditionalTheory] [MsQuicSupported] diff --git a/src/Shared/Buffers.MemoryPool/DiagnosticMemoryPool.cs b/src/Shared/Buffers.MemoryPool/DiagnosticMemoryPool.cs index 2a35f095e630..2250dd045427 100644 --- a/src/Shared/Buffers.MemoryPool/DiagnosticMemoryPool.cs +++ b/src/Shared/Buffers.MemoryPool/DiagnosticMemoryPool.cs @@ -160,4 +160,27 @@ public async Task WhenAllBlocksReturnedAsync(TimeSpan timeout) await task; } + + public bool ContainsMemory(Memory memory) + { + lock (_syncObj) + { + foreach (var block in _blocks) + { + unsafe + { + fixed (byte* inUseMemoryPtr = memory.Span) + fixed (byte* beginPooledMemoryPtr = block.Memory.Span) + { + byte* endPooledMemoryPtr = beginPooledMemoryPtr + block.Memory.Length; + if (inUseMemoryPtr >= beginPooledMemoryPtr && inUseMemoryPtr < endPooledMemoryPtr) + { + return true; + } + } + } + } + return false; + } + } }