diff --git a/src/libraries/Common/src/System/Net/StreamBuffer.cs b/src/libraries/Common/src/System/Net/StreamBuffer.cs index 6759fcdd8e20b..5ccab63f1f2e1 100644 --- a/src/libraries/Common/src/System/Net/StreamBuffer.cs +++ b/src/libraries/Common/src/System/Net/StreamBuffer.cs @@ -18,6 +18,7 @@ internal sealed class StreamBuffer : IDisposable private bool _readAborted; private readonly ResettableValueTaskSource _readTaskSource; private readonly ResettableValueTaskSource _writeTaskSource; + private readonly TaskCompletionSource _shutdownTaskSource; public const int DefaultInitialBufferSize = 4 * 1024; public const int DefaultMaxBufferSize = 32 * 1024; @@ -28,10 +29,13 @@ public StreamBuffer(int initialBufferSize = DefaultInitialBufferSize, int maxBuf _maxBufferSize = maxBufferSize; _readTaskSource = new ResettableValueTaskSource(); _writeTaskSource = new ResettableValueTaskSource(); + _shutdownTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } private object SyncObject => _readTaskSource; + public Task Completed => _shutdownTaskSource.Task; + public bool IsComplete { get @@ -187,6 +191,11 @@ public void EndWrite() _writeEnded = true; _readTaskSource.SignalWaiter(); + + if (_buffer.IsEmpty) + { + _shutdownTaskSource.TrySetResult(); + } } } @@ -210,10 +219,16 @@ public void EndWrite() _writeTaskSource.SignalWaiter(); + if (_buffer.IsEmpty && _writeEnded) + { + _shutdownTaskSource.TrySetResult(); + } + return (false, bytesRead); } else if (_writeEnded) { + _shutdownTaskSource.TrySetResult(); return (false, 0); } @@ -280,6 +295,7 @@ public void AbortRead() _readTaskSource.SignalWaiter(); _writeTaskSource.SignalWaiter(); + _shutdownTaskSource.TrySetResult(); } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index faccfab64b27e..6c6a06c85291c 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -15,7 +15,7 @@ namespace System.Net.Test.Common { - internal sealed class Http3LoopbackStream : IDisposable + internal sealed class Http3LoopbackStream : IDisposable, IAsyncDisposable { private const int MaximumVarIntBytes = 8; private const long VarIntMax = (1L << 62) - 1; @@ -43,6 +43,10 @@ public void Dispose() { _stream.Dispose(); } + + public ValueTask DisposeAsync() => + _stream.DisposeAsync(); + public async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { HttpRequestData request = await ReadRequestDataAsync().ConfigureAwait(false); @@ -116,12 +120,6 @@ public async Task SendFrameAsync(long frameType, ReadOnlyMemory framePaylo await _stream.WriteAsync(framePayload).ConfigureAwait(false); } - public async Task ShutdownSendAsync() - { - _stream.Shutdown(); - await _stream.ShutdownWriteCompleted().ConfigureAwait(false); - } - static int EncodeHttpInteger(long longToEncode, Span buffer) { Debug.Assert(longToEncode >= 0); @@ -226,9 +224,8 @@ public async Task SendResponseBodyAsync(byte[] content, bool isFinal = true) if (isFinal) { - await ShutdownSendAsync().ConfigureAwait(false); - await _stream.ShutdownCompleted().ConfigureAwait(false); - Dispose(); + _stream.CompleteWrites(); + await DisposeAsync(); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs index c091127350068..d2cc4d5879f43 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs @@ -519,7 +519,7 @@ private async Task ProcessServerStreamAsync(QuicStream stream) NetEventSource.Info(this, $"Ignoring server-initiated stream of unknown type {unknownStreamType}."); } - stream.AbortWrite((long)Http3ErrorCode.StreamCreationError); + stream.Abort((long)Http3ErrorCode.StreamCreationError, QuicAbortDirection.Read); return; } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 3d469f06c3b5f..d93d200f586b3 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -91,7 +91,10 @@ public async ValueTask DisposeAsync() if (!_disposed) { _disposed = true; + + // TODO: use CloseAsync() with a cancellation token to prevent a DoS await _stream.DisposeAsync().ConfigureAwait(false); + DisposeSyncHelper(); } } @@ -151,7 +154,7 @@ public async Task SendAsync(CancellationToken cancellationT } else { - _stream.Shutdown(); + _stream.CompleteWrites(); } } @@ -262,7 +265,7 @@ public async Task SendAsync(CancellationToken cancellationT if (cancellationToken.IsCancellationRequested) { - _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + _stream.Abort((long)Http3ErrorCode.RequestCancelled); throw new OperationCanceledException(ex.Message, ex, cancellationToken); } else @@ -279,7 +282,7 @@ public async Task SendAsync(CancellationToken cancellationT } catch (Exception ex) { - _stream.AbortWrite((long)Http3ErrorCode.InternalError); + _stream.Abort((long)Http3ErrorCode.InternalError); if (ex is HttpRequestException) { throw; @@ -371,7 +374,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance _sendBuffer.Discard(_sendBuffer.ActiveLength); } - _stream.Shutdown(); + _stream.CompleteWrites(); } private async ValueTask WriteRequestContentAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) @@ -776,7 +779,7 @@ private async ValueTask ReadHeadersAsync(long headersLength, CancellationToken c // https://tools.ietf.org/html/draft-ietf-quic-http-24#section-4.1.1 if (headersLength > _headerBudgetRemaining) { - _stream.AbortWrite((long)Http3ErrorCode.ExcessiveLoad); + _stream.Abort((long)Http3ErrorCode.ExcessiveLoad); throw new HttpRequestException(SR.Format(SR.net_http_response_headers_exceeded_length, _connection.Pool.Settings._maxResponseHeadersLength * 1024L)); } @@ -1113,11 +1116,11 @@ private void HandleReadResponseContentException(Exception ex, CancellationToken _connection.Abort(ex); throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex)); case OperationCanceledException oce when oce.CancellationToken == cancellationToken: - _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + _stream.Abort((long)Http3ErrorCode.RequestCancelled); ExceptionDispatchInfo.Throw(ex); // Rethrow. return; // Never reached. default: - _stream.AbortWrite((long)Http3ErrorCode.InternalError); + _stream.Abort((long)Http3ErrorCode.InternalError); throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex)); } } diff --git a/src/libraries/System.Net.Quic/System.Net.Quic.sln b/src/libraries/System.Net.Quic/System.Net.Quic.sln index 3d6fa4fc85246..3ad2d96bdcfbc 100644 --- a/src/libraries/System.Net.Quic/System.Net.Quic.sln +++ b/src/libraries/System.Net.Quic/System.Net.Quic.sln @@ -1,4 +1,8 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.31220.234 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{55C933AA-2735-4B38-A1DD-01A27467AB18}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Win32.Registry", "..\Microsoft.Win32.Registry\ref\Microsoft.Win32.Registry.csproj", "{69CDCFD5-AA35-40D8-A437-ED1C06E9CA95}" @@ -23,18 +27,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{4BABFE90-C81 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "StreamConformanceTests", "..\Common\tests\StreamConformanceTests\StreamConformanceTests.csproj", "{CCE2D0B0-BDBE-4750-B215-2517286510EB}" +EndProject Global - GlobalSection(NestedProjects) = preSolution - {55C933AA-2735-4B38-A1DD-01A27467AB18} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} - {E8E7DD3A-EC3F-4472-9F70-B515A3D11038} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} - {69CDCFD5-AA35-40D8-A437-ED1C06E9CA95} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {D7A52855-C6DE-4FD0-9CAF-E55F292C69E5} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {7BB8C50D-4770-42CB-BE15-76AD623A5AE8} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {833418C5-FEC9-482F-A0D6-69DFC332C1B6} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {E1CABA2F-48AD-49FA-B872-BEED78C51980} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} - {4F87758B-D1AF-4DE3-A9A2-68B1558C02B7} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} - {9D56BA9E-1B0D-4320-9FE9-A2D326A32BE0} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} - EndGlobalSection GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU Release|Any CPU = Release|Any CPU @@ -76,10 +71,26 @@ Global {E1CABA2F-48AD-49FA-B872-BEED78C51980}.Debug|Any CPU.Build.0 = Debug|Any CPU {E1CABA2F-48AD-49FA-B872-BEED78C51980}.Release|Any CPU.ActiveCfg = Release|Any CPU {E1CABA2F-48AD-49FA-B872-BEED78C51980}.Release|Any CPU.Build.0 = Release|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CCE2D0B0-BDBE-4750-B215-2517286510EB}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {55C933AA-2735-4B38-A1DD-01A27467AB18} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} + {69CDCFD5-AA35-40D8-A437-ED1C06E9CA95} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {D7A52855-C6DE-4FD0-9CAF-E55F292C69E5} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {4F87758B-D1AF-4DE3-A9A2-68B1558C02B7} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} + {E8E7DD3A-EC3F-4472-9F70-B515A3D11038} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} + {7BB8C50D-4770-42CB-BE15-76AD623A5AE8} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {9D56BA9E-1B0D-4320-9FE9-A2D326A32BE0} = {DAC0D00A-6EB0-4A72-94BB-EB90B3EE72A9} + {833418C5-FEC9-482F-A0D6-69DFC332C1B6} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {E1CABA2F-48AD-49FA-B872-BEED78C51980} = {4BABFE90-C818-4772-9D2E-B92F69E1FCDF} + {CCE2D0B0-BDBE-4750-B215-2517286510EB} = {BDA10542-BE94-4A73-9B5B-6BE5CE57F883} + EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {4B59ACCA-7F0C-4062-AA79-B3D75EFACCCD} EndGlobalSection diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index be6df9a7c479b..5dc88d442442e 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -6,6 +6,14 @@ namespace System.Net.Quic { + [System.FlagsAttribute] + public enum QuicAbortDirection + { + Read = 1, + Write = 2, + Both = 3, + Immediate = 7 + } public partial class QuicClientConnectionOptions : System.Net.Quic.QuicOptions { public QuicClientConnectionOptions() { } @@ -87,11 +95,13 @@ internal QuicStream() { } public override long Length { get { throw null; } } public override long Position { get { throw null; } set { } } public long StreamId { get { throw null; } } - public void AbortRead(long errorCode) { } - public void AbortWrite(long errorCode) { } + public void Abort(long errorCode, System.Net.Quic.QuicAbortDirection abortDirection = System.Net.Quic.QuicAbortDirection.Immediate) { } public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } + public System.Threading.Tasks.ValueTask CloseAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public void CompleteWrites() { } protected override void Dispose(bool disposing) { } + public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } public override int EndRead(System.IAsyncResult asyncResult) { throw null; } public override void EndWrite(System.IAsyncResult asyncResult) { } public override void Flush() { } @@ -102,9 +112,6 @@ public override void Flush() { } public override System.Threading.Tasks.ValueTask ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; } public override void SetLength(long value) { } - public void Shutdown() { } - public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void Write(byte[] buffer, int offset, int count) { } public override void Write(System.ReadOnlySpan buffer) { } public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx index a29352a0578f5..a702b67aeb9ea 100644 --- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx @@ -1,17 +1,17 @@  - @@ -150,5 +150,4 @@ Writing is not allowed on stream. - - + \ No newline at end of file diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index bd814f690d952..860de214b4309 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -169,45 +169,28 @@ internal override Task FlushAsync(CancellationToken cancellationToken) return Task.CompletedTask; } - internal override void AbortRead(long errorCode) + internal override void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both) { - throw new NotImplementedException(); - } + // TODO: support abort read direction. - internal override void AbortWrite(long errorCode) - { - if (_isInitiator) - { - _streamState._outboundErrorCode = errorCode; - } - else + if (abortDirection.HasFlag(QuicAbortDirection.Write)) { - _streamState._inboundErrorCode = errorCode; - } - - WriteStreamBuffer?.EndWrite(); - } - - internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) - { - CheckDisposed(); - - return default; - } - - - internal override ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) - { - CheckDisposed(); + if (_isInitiator) + { + _streamState._outboundErrorCode = errorCode; + } + else + { + _streamState._inboundErrorCode = errorCode; + } - return default; + WriteStreamBuffer?.EndWrite(); + } } - internal override void Shutdown() + public override void CompleteWrites() { CheckDisposed(); - - // This seems to mean shutdown send, in particular, not both. WriteStreamBuffer?.EndWrite(); if (_streamState._inboundStreamBuffer is null) // unidirectional stream @@ -232,29 +215,38 @@ public override void Dispose() { if (!_disposed) { - Shutdown(); + CompleteWrites(); + + _streamState._outboundStreamBuffer.Completed.GetAwaiter().GetResult(); _disposed = true; } } - public override ValueTask DisposeAsync() + public override async ValueTask DisposeAsync(CancellationToken cancellationToken) { if (!_disposed) { - Shutdown(); + CompleteWrites(); + + if (ReadStreamBuffer is StreamBuffer readStreamBuffer) + { + await ReadStreamBuffer.Completed.WaitAsync(cancellationToken).ConfigureAwait(false); + } + else + { + cancellationToken.ThrowIfCancellationRequested(); + } _disposed = true; } - - return default; } internal sealed class StreamState { public readonly long _streamId; - public StreamBuffer _outboundStreamBuffer; - public StreamBuffer? _inboundStreamBuffer; + public readonly StreamBuffer _outboundStreamBuffer; + public readonly StreamBuffer? _inboundStreamBuffer; public long _outboundErrorCode; public long _inboundErrorCode; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs index 50bc0612ae3c7..e75096e4e9483 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs @@ -6,35 +6,35 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal internal static partial class MsQuicStatusCodes { // TODO return better error messages here. - public static string GetError(uint status) - { - return status switch + public static string GetError(uint status) { - Success => "SUCCESS", - Pending => "PENDING", - Continue => "CONTINUE", - OutOfMemory => "OUT_OF_MEMORY", - InvalidParameter => "INVALID_PARAMETER", - InvalidState => "INVALID_STATE", - NotSupported => "NOT_SUPPORTED", - NotFound => "NOT_FOUND", - BufferTooSmall => "BUFFER_TOO_SMALL", - HandshakeFailure => "HANDSHAKE_FAILURE", - Aborted => "ABORTED", - AddressInUse => "ADDRESS_IN_USE", - ConnectionTimeout => "CONNECTION_TIMEOUT", - ConnectionIdle => "CONNECTION_IDLE", - HostUnreachable => "UNREACHABLE", - InternalError => "INTERNAL_ERROR", - ConnectionRefused => "CONNECTION_REFUSED", - ProtocolError => "PROTOCOL_ERROR", - VerNegError => "VER_NEG_ERROR", - TlsError => "TLS_ERROR", - UserCanceled => "USER_CANCELED", - AlpnNegotiationFailure => "ALPN_NEG_FAILURE", + return status switch + { + Success => "SUCCESS", + Pending => "PENDING", + Continue => "CONTINUE", + OutOfMemory => "OUT_OF_MEMORY", + InvalidParameter => "INVALID_PARAMETER", + InvalidState => "INVALID_STATE", + NotSupported => "NOT_SUPPORTED", + NotFound => "NOT_FOUND", + BufferTooSmall => "BUFFER_TOO_SMALL", + HandshakeFailure => "HANDSHAKE_FAILURE", + Aborted => "ABORTED", + AddressInUse => "ADDRESS_IN_USE", + ConnectionTimeout => "CONNECTION_TIMEOUT", + ConnectionIdle => "CONNECTION_IDLE", + HostUnreachable => "UNREACHABLE", + InternalError => "INTERNAL_ERROR", + ConnectionRefused => "CONNECTION_REFUSED", + ProtocolError => "PROTOCOL_ERROR", + VerNegError => "VER_NEG_ERROR", + TlsError => "TLS_ERROR", + UserCanceled => "USER_CANCELED", + AlpnNegotiationFailure => "ALPN_NEG_FAILURE", StreamLimit => "STREAM_LIMIT_REACHED", - _ => $"0x{status:X8}" - }; + _ => $"0x{status:X8}" + }; + } } - } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index e796496bcf1f5..8ece072ec156e 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; using System.Net.Quic.Implementations.MsQuic.Internal; +using System.Reflection; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Threading; @@ -41,11 +41,20 @@ private sealed class State public MsQuicConnection.State ConnectionState = null!; // set in ctor. public ReadState ReadState; + + // set when ReadState.Aborted: public long ReadErrorCode = -1; - public readonly List ReceiveQuicBuffers = new List(); - // Resettable completions to be used for multiple calls to receive. - public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); + // filled when ReadState.BuffersAvailable: + public QuicBuffer[] ReceiveQuicBuffers = Array.Empty(); + public int ReceiveQuicBuffersCount; + public int ReceiveQuicBuffersTotalBytes; + + // set when ReadState.PendingRead: + public Memory ReceiveUserBuffer; + public CancellationTokenRegistration ReceiveCancellationRegistration; + public MsQuicStream? RootedReceiveStream; // roots the stream in the pinned state to prevent GC during an async read I/O. + public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); public SendState SendState; public long SendErrorCode = -1; @@ -56,17 +65,13 @@ private sealed class State public int SendBufferMaxCount; public int SendBufferCount; - // Resettable completions to be used for multiple calls to send, start, and shutdown. - public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); - - public ShutdownWriteState ShutdownWriteState; + // Roots the stream in the pinned state to prevent GC during an async dispose. + public MsQuicStream? RootedDisposeStream; - // Set once writes have been shutdown. - public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - public ShutdownState ShutdownState; + // Resettable completions to be used for multiple calls to send, start. + public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); - // Set once stream have been shutdown. + // Set once both peers have fully shut down their side of the stream. public readonly TaskCompletionSource ShutdownCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } @@ -96,7 +101,7 @@ internal MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHa { _stateHandle.Free(); throw new ObjectDisposedException(nameof(QuicConnection)); - } + } _state.ConnectionState = connectionState; @@ -128,6 +133,9 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F QuicExceptionHelpers.ThrowIfFailed(status, "Failed to open stream to peer."); + // TODO: StreamStart is blocking on another thread here. + // We should refactor this to use the ASYNC flag. + status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.FAIL_BLOCKED); QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream."); } @@ -138,6 +146,9 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F throw; } + // TODO: our callback starts getting called as soon as we call StreamStart. + // Should this stuff be moved before that call? + if (!connectionState.TryAddStream(this)) { _state.Handle?.Dispose(); @@ -234,7 +245,7 @@ private async ValueTask HandleWriteStartState(Can { await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false); _started = true; - } + } // if token was already cancelled, this would execute syncronously CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => @@ -263,7 +274,7 @@ private async ValueTask HandleWriteStartState(Can { cancellationToken.ThrowIfCancellationRequested(); throw new OperationCanceledException(SR.net_quic_sending_aborted); - } + } else if (_state.SendState == SendState.ConnectionClosed) { throw GetConnectionAbortedException(_state); @@ -295,7 +306,7 @@ private void HandleWriteFailedState() } } - internal override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + internal override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) { ThrowIfDisposed(); @@ -309,204 +320,186 @@ internal override async ValueTask ReadAsync(Memory destination, Cance NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] reading into Memory of '{destination.Length}' bytes."); } + ReadState readState; + long abortError = -1; + bool canceledSynchronously = false; + lock (_state) { - if (_state.ReadState == ReadState.ReadsCompleted) - { - return 0; - } - else if (_state.ReadState == ReadState.Aborted) + readState = _state.ReadState; + abortError = _state.ReadErrorCode; + + if (readState != ReadState.PendingRead && cancellationToken.IsCancellationRequested) { - throw ThrowHelper.GetStreamAbortedException(_state.ReadErrorCode); + readState = ReadState.StreamAborted; + _state.ReadState = ReadState.StreamAborted; + canceledSynchronously = true; } - else if (_state.ReadState == ReadState.ConnectionClosed) + else if (readState == ReadState.None) { - throw GetConnectionAbortedException(_state); - } - } + Debug.Assert(_state.RootedReceiveStream is null); - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ReadState == ReadState.None) + _state.ReceiveUserBuffer = destination; + _state.RootedReceiveStream = this; + _state.ReadState = ReadState.PendingRead; + + if (cancellationToken.CanBeCanceled) { - shouldComplete = true; + _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) => + { + var state = (State)obj!; + bool completePendingRead; + + lock (state) + { + completePendingRead = state.ReadState == ReadState.PendingRead; + state.RootedReceiveStream = null; + state.ReadState = ReadState.StreamAborted; + } + + if (completePendingRead) + { + state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(token))); + } + }, _state); + } + else + { + _state.ReceiveCancellationRegistration = default; } - state.ReadState = ReadState.Aborted; - } - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Read was canceled", token))); + return _state.ReceiveResettableCompletionSource.GetValueTask(); } - }, _state); - - // TODO there could potentially be a perf gain by storing the buffer from the initial read - // This reduces the amount of async calls, however it makes it so MsQuic holds onto the buffers - // longer than it needs to. We will need to benchmark this. - int length = (int)await _state.ReceiveResettableCompletionSource.GetValueTask().ConfigureAwait(false); + else if (readState == ReadState.BuffersAvailable) + { + _state.ReadState = ReadState.None; - int actual = Math.Min(length, destination.Length); + int taken = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span); + ReceiveComplete(taken); - static unsafe void CopyToBuffer(Span destinationBuffer, List sourceBuffers) - { - Span slicedBuffer = destinationBuffer; - for (int i = 0; i < sourceBuffers.Count; i++) - { - QuicBuffer nativeBuffer = sourceBuffers[i]; - int length = Math.Min((int)nativeBuffer.Length, slicedBuffer.Length); - new Span(nativeBuffer.Buffer, length).CopyTo(slicedBuffer); - if (length < nativeBuffer.Length) + if (taken != _state.ReceiveQuicBuffersTotalBytes) { - // The buffer passed in was larger that the received data, return - return; + // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer. + EnableReceive(); } - slicedBuffer = slicedBuffer.Slice(length); - } - } - CopyToBuffer(destination.Span, _state.ReceiveQuicBuffers); - - lock (_state) - { - if (_state.ReadState == ReadState.IndividualReadComplete) - { - _state.ReceiveQuicBuffers.Clear(); - ReceiveComplete(actual); - EnableReceive(); - _state.ReadState = ReadState.None; + return new ValueTask(taken); } } - return actual; - } + Exception? ex = null; - // TODO do we want this to be a synchronization mechanism to cancel a pending read - // If so, we need to complete the read here as well. - internal override void AbortRead(long errorCode) - { - ThrowIfDisposed(); - - lock (_state) + switch (readState) { - _state.ReadState = ReadState.Aborted; + case ReadState.EndOfReadStream: + return new ValueTask(0); + case ReadState.PendingRead: + ex = new InvalidOperationException("Only one read is supported at a time."); + break; + case ReadState.StreamAborted: + ex = + canceledSynchronously ? new OperationCanceledException(cancellationToken) : // aborted by token being canceled before the async op started. + abortError == -1 ? new QuicOperationAbortedException() : // aborted by user via some other operation. + new QuicStreamAbortedException(abortError); // aborted by peer. + + break; + case ReadState.ConnectionAborted: + default: + Debug.Assert(readState == ReadState.ConnectionAborted, $"{nameof(ReadState)} of '{readState}' is unaccounted for in {nameof(ReadAsync)}."); + ex = GetConnectionAbortedException(_state); + break; } - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, errorCode); + return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(ex!)); } - internal override void AbortWrite(long errorCode) + /// The number of bytes copied. + private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan sourceBuffers, Span destinationBuffer) { - ThrowIfDisposed(); + Debug.Assert(sourceBuffers.Length != 0); - bool shouldComplete = false; + int originalDestinationLength = destinationBuffer.Length; + QuicBuffer nativeBuffer; + int takeLength = 0; + int i = 0; - lock (_state) + do { - if (_state.ShutdownWriteState == ShutdownWriteState.None) - { - _state.ShutdownWriteState = ShutdownWriteState.Canceled; - shouldComplete = true; - } - } + nativeBuffer = sourceBuffers[i]; + takeLength = Math.Min((int)nativeBuffer.Length, destinationBuffer.Length); - if (shouldComplete) - { - _state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException("Shutdown was aborted.", errorCode))); + new Span(nativeBuffer.Buffer, takeLength).CopyTo(destinationBuffer); + destinationBuffer = destinationBuffer.Slice(takeLength); } + while (destinationBuffer.Length != 0 && ++i < sourceBuffers.Length); - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode); + return originalDestinationLength - destinationBuffer.Length; } - private void StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) + // We don't wait for QUIC_STREAM_EVENT_SEND_SHUTDOWN_COMPLETE event here, + // because it is only sent to us once the peer has acknowledged the shutdown. + // Instead, this method acts more like shutdown(SD_SEND) in that it only "queues" + // the shutdown packet to be sent without any waiting for completion. + public override void CompleteWrites() { - uint status = MsQuicApi.Api.StreamShutdownDelegate(_state.Handle, flags, errorCode); - QuicExceptionHelpers.ThrowIfFailed(status, "StreamShutdown failed."); + ThrowIfDisposed(); + + // Error code is ignored for graceful shutdown. + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } - internal override async ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) + internal override void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both) { ThrowIfDisposed(); + QUIC_STREAM_SHUTDOWN_FLAGS flags = QUIC_STREAM_SHUTDOWN_FLAGS.NONE; + bool completeWrites = false; + bool completeReads = false; + lock (_state) { - if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed) + if (abortDirection.HasFlag(QuicAbortDirection.Write)) { - throw GetConnectionAbortedException(_state); + completeWrites = _state.SendState is SendState.None or SendState.Pending; + _state.SendState = SendState.Aborted; + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND; } - } - // TODO do anything to stop writes? - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - lock (state) + if (abortDirection.HasFlag(QuicAbortDirection.Read)) { - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Canceled; // TODO: should we separate states for cancelling here vs calling Abort? - shouldComplete = true; - } + completeReads = _state.ReadState == ReadState.PendingRead; + _state.RootedReceiveStream = null; + _state.ReadState = ReadState.StreamAborted; + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; } + } - if (shouldComplete) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown write was canceled", token))); - } - }, _state); - - await _state.ShutdownWriteCompletionSource.Task.ConfigureAwait(false); - } + if ((abortDirection & QuicAbortDirection.Immediate) == QuicAbortDirection.Immediate) + { + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE; + } - internal override async ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) - { - ThrowIfDisposed(); + StartShutdownOrAbort(flags, errorCode); - lock (_state) + if (completeWrites) { - if (_state.ShutdownState == ShutdownState.ConnectionClosed) - { - throw GetConnectionAbortedException(_state); - } + _state.SendResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); } - // TODO do anything to stop writes? - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => + if (completeReads) { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ShutdownState == ShutdownState.None) - { - state.ShutdownState = ShutdownState.Canceled; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown was canceled", token))); - } - }, _state); - - await _state.ShutdownCompletionSource.Task.ConfigureAwait(false); + _state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); + } } - internal override void Shutdown() + /// + /// For abortive flags, the error code sent to peer. Otherwise, ignored. + private void StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { - ThrowIfDisposed(); + Debug.Assert(!_disposed); - // it is ok to send shutdown several times, MsQuic will ignore it - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + uint status = MsQuicApi.Api.StreamShutdownDelegate(_state.Handle, flags, errorCode); + QuicExceptionHelpers.ThrowIfFailed(status, "StreamShutdown failed."); } // TODO consider removing sync-over-async with blocking calls. @@ -519,7 +512,7 @@ internal override int Read(Span buffer) int readLength = ReadAsync(new Memory(rentedBuffer, 0, buffer.Length)).AsTask().GetAwaiter().GetResult(); rentedBuffer.AsSpan(0, readLength).CopyTo(buffer); return readLength; - } + } finally { ArrayPool.Shared.Return(rentedBuffer); @@ -548,32 +541,61 @@ internal override Task FlushAsync(CancellationToken cancellationToken = default) return Task.CompletedTask; } - public override ValueTask DisposeAsync() - { - // TODO: perform a graceful shutdown and wait for completion? - - Dispose(true); - return default; - } + public override ValueTask DisposeAsync(CancellationToken cancellationToken) => + DisposeAsync(cancellationToken, async: true); public override void Dispose() { - Dispose(true); - GC.SuppressFinalize(this); + ValueTask t = DisposeAsync(cancellationToken: default, async: false); + Debug.Assert(t.IsCompleted); + t.GetAwaiter().GetResult(); } + // TODO: there's a bug here where the safe handle is no longer valid. + // This shouldn't happen because the safe handle *should be* rooted + // until after our disposal completes. ~MsQuicStream() { - Dispose(false); + DisposeAsyncThrowaway(this); + + static async void DisposeAsyncThrowaway(MsQuicStream stream) + { + await stream.DisposeAsync(cancellationToken: default, async: true).ConfigureAwait(false); + } } - private void Dispose(bool disposing) + private async ValueTask DisposeAsync(CancellationToken cancellationToken, bool async) { if (_disposed) { return; } + // MsQuic will ignore this call if it was already shutdown elsewhere. + // PERF TODO: update write loop to make it so we don't need to call this. it queues an event to the MsQuic thread pool. + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + + // MsQuic will continue sending us events, so we need to wait for shutdown + // completion (the final event) before freeing _stateHandle's GCHandle. + // If Abort() wasn't called with "immediate", this will wait for peer to shut down their write side. + + if (async) + { + _state.RootedDisposeStream = this; + try + { + await _state.ShutdownCompletionSource.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + _state.RootedDisposeStream = null; + } + } + else + { + _state.ShutdownCompletionSource.Task.GetAwaiter().GetResult(); + } + _disposed = true; _state.Handle.Dispose(); Marshal.FreeHGlobal(_state.SendQuicBuffers); @@ -581,6 +603,8 @@ private void Dispose(bool disposing) CleanupSendState(_state); _state.ConnectionState?.RemoveStream(this); + GC.SuppressFinalize(this); + if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] disposed"); @@ -589,7 +613,8 @@ private void Dispose(bool disposing) private void EnableReceive() { - MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + uint status = MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + QuicExceptionHelpers.ThrowIfFailed(status, "StreamReceiveSetEnabled failed."); } private static uint NativeCallbackHandler( @@ -632,11 +657,6 @@ private static uint HandleEvent(State state, ref StreamEvent evt) // Peer has stopped receiving data, don't send anymore. case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED: return HandleEventPeerRecvAborted(state, ref evt); - // Occurs when shutdown is completed for the send side. - // This only happens for shutdown on sending, not receiving - // Receive shutdown can only be abortive. - case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE: - return HandleEventSendShutdownComplete(state, ref evt); // Shutdown for both sending and receiving is completed. case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE: return HandleEventShutdownComplete(state, ref evt); @@ -652,50 +672,91 @@ private static uint HandleEvent(State state, ref StreamEvent evt) private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) { - StreamEventDataReceive receiveEvent = evt.Data.Receive; - for (int i = 0; i < receiveEvent.BufferCount; i++) + ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive; + + if (receiveEvent.BufferCount == 0) { - state.ReceiveQuicBuffers.Add(receiveEvent.Buffers[i]); + // This is a 0-length receive that happens once reads are finished (via abort or otherwise). + // State changes for this are handled elsewhere. + return MsQuicStatusCodes.Success; } - bool shouldComplete = false; + int readLength; + lock (state) { - if (state.ReadState == ReadState.None) - { - shouldComplete = true; - } - if (state.ReadState != ReadState.ConnectionClosed) + switch (state.ReadState) { - state.ReadState = ReadState.IndividualReadComplete; + case ReadState.None: + // ReadAsync() hasn't been called yet. Stash the buffer so the next ReadAsync call completes synchronously. + + if ((uint)state.ReceiveQuicBuffers.Length < receiveEvent.BufferCount) + { + QuicBuffer[] oldReceiveBuffers = state.ReceiveQuicBuffers; + state.ReceiveQuicBuffers = ArrayPool.Shared.Rent((int)receiveEvent.BufferCount); + + if (oldReceiveBuffers.Length != 0) // don't return Array.Empty. + { + ArrayPool.Shared.Return(oldReceiveBuffers); + } + } + + for (uint i = 0; i < receiveEvent.BufferCount; ++i) + { + state.ReceiveQuicBuffers[i] = receiveEvent.Buffers[i]; + } + + state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount; + state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength); + state.ReadState = ReadState.BuffersAvailable; + return MsQuicStatusCodes.Pending; + case ReadState.PendingRead: + // There is a pending ReadAsync(). + + state.ReceiveCancellationRegistration.Unregister(); + state.RootedReceiveStream = null; + state.ReadState = ReadState.None; + + readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); + break; + default: + Debug.Assert(state.ReadState is ReadState.StreamAborted or ReadState.ConnectionAborted, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); + + // There was a race between a user aborting the read stream and the callback being ran. + // This will eat any received data. + return MsQuicStatusCodes.Success; } } - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.Complete((uint)receiveEvent.TotalBufferLength); - } + // We're completing a pending read. - return MsQuicStatusCodes.Pending; + state.ReceiveResettableCompletionSource.Complete(readLength); + + // Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called. + // Returning Continue will cause a second receive event to fire immediately after this returns, but allows MsQuic to clean up its buffers. + + uint ret = (uint)readLength == receiveEvent.TotalBufferLength + ? MsQuicStatusCodes.Success + : MsQuicStatusCodes.Continue; + + receiveEvent.TotalBufferLength = (uint)readLength; + return ret; } private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) { - bool shouldComplete = false; + bool shouldComplete; + lock (state) { - if (state.SendState == SendState.None || state.SendState == SendState.Pending) - { - shouldComplete = true; - } + shouldComplete = state.SendState == SendState.None || state.SendState == SendState.Pending; state.SendState = SendState.Aborted; - state.SendErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; + state.SendErrorCode = evt.Data.PeerSendAborted.ErrorCode; } if (shouldComplete) { - state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode))); + state.SendResettableCompletionSource.CompleteException(new QuicStreamAbortedException(state.SendErrorCode)); } return MsQuicStatusCodes.Success; @@ -721,82 +782,9 @@ private static uint HandleEventStartComplete(State state) return MsQuicStatusCodes.Success; } - private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt) - { - bool shouldComplete = false; - lock (state) - { - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Finished; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.ShutdownWriteCompletionSource.SetResult(); - } - - return MsQuicStatusCodes.Success; - } - private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt) { - StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete; - - if (shutdownCompleteEvent.ConnectionShutdown != 0) - { - return HandleEventConnectionClose(state); - } - - bool shouldReadComplete = false; - bool shouldShutdownWriteComplete = false; - bool shouldShutdownComplete = false; - - lock (state) - { - // This event won't occur within the middle of a receive. - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source."); - - if (state.ReadState == ReadState.None) - { - shouldReadComplete = true; - } - - if (state.ReadState != ReadState.ConnectionClosed) - { - state.ReadState = ReadState.ReadsCompleted; - } - - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Finished; - shouldShutdownWriteComplete = true; - } - - if (state.ShutdownState == ShutdownState.None) - { - state.ShutdownState = ShutdownState.Finished; - shouldShutdownComplete = true; - } - } - - if (shouldReadComplete) - { - state.ReceiveResettableCompletionSource.Complete(0); - } - - if (shouldShutdownWriteComplete) - { - state.ShutdownWriteCompletionSource.SetResult(); - } - - if (shouldShutdownComplete) - { - state.ShutdownCompletionSource.SetResult(); - } - + state.ShutdownCompletionSource.TrySetResult(); return MsQuicStatusCodes.Success; } @@ -809,7 +797,7 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) { shouldComplete = true; } - state.ReadState = ReadState.Aborted; + state.ReadState = ReadState.StreamAborted; state.ReadErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; } @@ -824,25 +812,27 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) private static uint HandleEventPeerSendShutdown(State state) { - bool shouldComplete = false; + bool completePendingRead = false; lock (state) { // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source."); - if (state.ReadState == ReadState.None) + + if (state.ReadState == ReadState.PendingRead) { - shouldComplete = true; + completePendingRead = true; + state.RootedReceiveStream = null; + state.ReadState = ReadState.EndOfReadStream; } - - if (state.ReadState != ReadState.ConnectionClosed) + else if (state.ReadState == ReadState.None) { - state.ReadState = ReadState.ReadsCompleted; + state.ReadState = ReadState.EndOfReadStream; } } - if (shouldComplete) + if (completePendingRead) { state.ReceiveResettableCompletionSource.Complete(0); } @@ -868,7 +858,7 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt) if (canceled) { state.SendState = SendState.Aborted; - } + } } if (complete) @@ -877,8 +867,8 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt) if (!canceled) { - state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); - } + state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + } else { state.SendResettableCompletionSource.CompleteException( @@ -919,7 +909,7 @@ private unsafe ValueTask SendReadOnlyMemoryAsync( if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) { // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } return default; } @@ -974,7 +964,7 @@ private unsafe ValueTask SendReadOnlySequenceAsync( if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) { // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } return default; } @@ -1043,7 +1033,7 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync( if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) { // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); + StartShutdownOrAbort(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } return default; } @@ -1117,6 +1107,7 @@ private void ThrowIfDisposed() private static uint HandleEventConnectionClose(State state) { long errorCode = state.ConnectionState.AbortErrorCode; + if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] handling Connection#{state.ConnectionState.GetHashCode()} close" + @@ -1125,34 +1116,19 @@ private static uint HandleEventConnectionClose(State state) bool shouldCompleteRead = false; bool shouldCompleteSend = false; - bool shouldCompleteShutdownWrite = false; bool shouldCompleteShutdown = false; lock (state) { - if (state.ReadState == ReadState.None) - { - shouldCompleteRead = true; - } - state.ReadState = ReadState.ConnectionClosed; - - if (state.SendState == SendState.None || state.SendState == SendState.Pending) - { - shouldCompleteSend = true; - } - state.SendState = SendState.ConnectionClosed; + shouldCompleteRead = state.ReadState == ReadState.PendingRead; + shouldCompleteSend = state.SendState is SendState.None or SendState.Pending; - if (state.ShutdownWriteState == ShutdownWriteState.None) + if (state.ReadState is not ReadState.EndOfReadStream or ReadState.StreamAborted) { - shouldCompleteShutdownWrite = true; + state.ReadState = ReadState.ConnectionAborted; } - state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed; - if (state.ShutdownState == ShutdownState.None) - { - shouldCompleteShutdown = true; - } - state.ShutdownState = ShutdownState.ConnectionClosed; + state.SendState = SendState.ConnectionClosed; } if (shouldCompleteRead) @@ -1167,12 +1143,6 @@ private static uint HandleEventConnectionClose(State state) ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); } - if (shouldCompleteShutdownWrite) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - if (shouldCompleteShutdown) { state.ShutdownCompletionSource.SetException( @@ -1188,45 +1158,34 @@ private static Exception GetConnectionAbortedException(State state) => private enum ReadState { /// - /// The stream is open, but there is no data available. + /// The stream is open, but there is no pending operation and no data available. /// None, /// - /// Data is available in . + /// There is a pending operation on the stream. /// - IndividualReadComplete, + PendingRead, /// - /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. + /// There is data available. /// - ReadsCompleted, + BuffersAvailable, /// - /// User has aborted the stream, either via a cancellation token on ReadAsync(), or via AbortRead(). + /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. /// - Aborted, + EndOfReadStream, /// - /// Connection was closed, either by user or by the peer. + /// The stream has been aborted, either by user or by peer. /// - ConnectionClosed - } + StreamAborted, - private enum ShutdownWriteState - { - None, - Canceled, - Finished, - ConnectionClosed - } - - private enum ShutdownState - { - None, - Canceled, - Finished, - ConnectionClosed + /// + /// The connection has been aborted, either by user or by peer. + /// + ConnectionAborted } private enum SendState diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs index 2be277a61252a..e97cfc1613cbe 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs @@ -7,7 +7,7 @@ namespace System.Net.Quic.Implementations { - internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable + internal abstract class QuicStreamProvider { internal abstract long StreamId { get; } @@ -17,9 +17,7 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable internal abstract ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default); - internal abstract void AbortRead(long errorCode); - - internal abstract void AbortWrite(long errorCode); + internal abstract void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Both); internal abstract bool CanWrite { get; } @@ -37,18 +35,14 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable internal abstract ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default); - internal abstract ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default); - - internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default); - - internal abstract void Shutdown(); - internal abstract void Flush(); internal abstract Task FlushAsync(CancellationToken cancellationToken); + public abstract void CompleteWrites(); + public abstract void Dispose(); - public abstract ValueTask DisposeAsync(); + public abstract ValueTask DisposeAsync(CancellationToken cancellationToken); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs new file mode 100644 index 0000000000000..788f6db7d01ff --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Quic +{ + [Flags] + public enum QuicAbortDirection + { + /// + /// Aborts the read direction of the stream. + /// + Read = 1, + + /// + /// Aborts the write direction of the stream. + /// + Write = 2, + + /// + /// Aborts both the read and write direction of the stream. + /// + Both = Read | Write, + + /// + /// Aborts both the read and write direction of the stream, without waiting for the peer to acknowledge the shutdown. + /// + Immediate = Both | 4 + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index e1724eee53575..a7ddd4d181205 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -85,9 +85,17 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override Task FlushAsync(CancellationToken cancellationToken) => _provider.FlushAsync(cancellationToken); - public void AbortRead(long errorCode) => _provider.AbortRead(errorCode); + /// + /// Completes the write direction of the stream, notifying the peer of end-of-stream. + /// + public void CompleteWrites() => _provider.CompleteWrites(); - public void AbortWrite(long errorCode) => _provider.AbortWrite(errorCode); + /// + /// Aborts the . + /// + /// The error code to abort with. + /// The direction of the abort. + public void Abort(long errorCode, QuicAbortDirection abortDirection = QuicAbortDirection.Immediate) => _provider.Abort(errorCode, abortDirection); public ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, endStream, cancellationToken); @@ -99,12 +107,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public ValueTask WriteAsync(ReadOnlyMemory> buffers, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, endStream, cancellationToken); - public ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownWriteCompleted(cancellationToken); - - public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken); - - public void Shutdown() => _provider.Shutdown(); - + /// protected override void Dispose(bool disposing) { if (disposing) @@ -112,5 +115,19 @@ protected override void Dispose(bool disposing) _provider.Dispose(); } } + + /// + public override ValueTask DisposeAsync() => CloseAsync(); + + /// + /// Shuts down and closes the , leaving it in a disposed state. + /// + /// If triggered, an will be thrown and the stream will be left undisposed. + /// A representing the asynchronous closure of the . + /// + /// When the stream has been been aborted with , this will complete independent of the peer. + /// Otherwise, this will wait for the peer to complete their write side (gracefully or abortive) and drain any bytes received in the mean time. + /// + public ValueTask CloseAsync(CancellationToken cancellationToken = default) => _provider.DisposeAsync(cancellationToken); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 277265a17a53a..d19e339d58792 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -230,9 +230,6 @@ await RunClientServer( break; } } - - stream.Shutdown(); - await stream.ShutdownCompleted(); }, async serverConnection => { @@ -248,9 +245,6 @@ await RunClientServer( int expectedTotalBytes = writes.SelectMany(x => x).Sum(); Assert.Equal(expectedTotalBytes, totalBytes); - - stream.Shutdown(); - await stream.ShutdownCompleted(); }); } @@ -434,8 +428,6 @@ await RunClientServer( await stream.WriteAsync(data[pos..(pos + writeSize)]); } await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); }, clientFunction: async connection => { @@ -451,8 +443,6 @@ await RunClientServer( int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); AssertArrayEqual(data, buffer); - - await stream.ShutdownCompleted(); } ); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 4eee9b459d9fb..e1693565d32d9 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -1,9 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; using System.Threading; @@ -20,12 +20,10 @@ public abstract class QuicStreamTests : QuicTestBase [Fact] public async Task BasicTest() { - await RunClientServer( + await RunBidirectionalClientServer( iterations: 100, - serverFunction: async connection => + serverFunction: async stream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[s_data.Length]; int bytesRead = await ReadAll(stream, buffer); @@ -33,12 +31,9 @@ await RunClientServer( Assert.Equal(s_data, buffer); await stream.WriteAsync(s_data, endStream: true); - await stream.ShutdownCompleted(); }, - clientFunction: async connection => + clientFunction: async stream => { - await using QuicStream stream = connection.OpenBidirectionalStream(); - await stream.WriteAsync(s_data, endStream: true); byte[] buffer = new byte[s_data.Length]; @@ -46,8 +41,6 @@ await RunClientServer( Assert.Equal(s_data.Length, bytesRead); Assert.Equal(s_data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -65,12 +58,10 @@ public async Task MultipleReadsAndWrites() m = m[s_data.Length..]; } - await RunClientServer( + await RunBidirectionalClientServer( iterations: 100, - serverFunction: async connection => + serverFunction: async stream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[expectedBytesCount]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(expectedBytesCount, bytesRead); @@ -81,13 +72,9 @@ await RunClientServer( await stream.WriteAsync(s_data); } await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); }, - clientFunction: async connection => + clientFunction: async stream => { - await using QuicStream stream = connection.OpenBidirectionalStream(); - for (int i = 0; i < sendCount; i++) { await stream.WriteAsync(s_data); @@ -98,8 +85,6 @@ await RunClientServer( int bytesRead = await ReadAll(stream, buffer); Assert.Equal(expectedBytesCount, bytesRead); Assert.Equal(expected, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -126,9 +111,6 @@ await RunClientServer( await stream.WriteAsync(s_data, endStream: true); await stream2.WriteAsync(s_data, endStream: true); - - await stream.ShutdownCompleted(); - await stream2.ShutdownCompleted(); }, clientFunction: async connection => { @@ -148,9 +130,6 @@ await RunClientServer( int bytesRead2 = await ReadAll(stream2, buffer2); Assert.Equal(s_data.Length, bytesRead2); Assert.Equal(s_data, buffer2); - - await stream.ShutdownCompleted(); - await stream2.ShutdownCompleted(); } ); } @@ -158,20 +137,24 @@ await RunClientServer( [Fact] public async Task GetStreamIdWithoutStartWorks() { - using QuicListener listener = CreateQuicListener(); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - - ValueTask clientTask = clientConnection.ConnectAsync(); - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - await clientTask; + using SemaphoreSlim sem = new SemaphoreSlim(0); + await RunClientServer( + async clientConnection => + { + await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); + Assert.Equal(0, clientStream.StreamId); + sem.Release(); - using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); - Assert.Equal(0, clientStream.StreamId); + }, + async serverConnection => + { + await sem.WaitAsync(); - // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer - // explicitly closing the connections seems to help, but the problem should still be investigated, we should have a meaningful - // exception instead of AccessViolationException - await clientConnection.CloseAsync(0); + // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer + // explicitly closing the connections seems to help, but the problem should still be investigated, we should have a meaningful + // exception instead of AccessViolationException + await serverConnection.CloseAsync(0); + }); } [Fact] @@ -181,12 +164,10 @@ public async Task LargeDataSentAndReceived() const int NumberOfWrites = 256; // total sent = 16M byte[] data = Enumerable.Range(0, writeSize * NumberOfWrites).Select(x => (byte)x).ToArray(); - await RunClientServer( + await RunBidirectionalClientServer( iterations: 5, - serverFunction: async connection => + serverFunction: async stream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[data.Length]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); @@ -197,13 +178,9 @@ await RunClientServer( await stream.WriteAsync(data[pos..(pos + writeSize)]); } await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); }, - clientFunction: async connection => + clientFunction: async stream => { - await using QuicStream stream = connection.OpenBidirectionalStream(); - for (int pos = 0; pos < data.Length; pos += writeSize) { await stream.WriteAsync(data[pos..(pos + writeSize)]); @@ -214,8 +191,6 @@ await RunClientServer( int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); AssertArrayEqual(data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -292,9 +267,6 @@ private static async Task TestBidirectionalStream(QuicStream s1, QuicStream s2) await SendAndReceiveEOFAsync(s1, s2); await SendAndReceiveEOFAsync(s2, s1); - - await s1.ShutdownCompleted(); - await s2.ShutdownCompleted(); } private static async Task TestUnidirectionalStream(QuicStream s1, QuicStream s2) @@ -309,9 +281,6 @@ private static async Task TestUnidirectionalStream(QuicStream s1, QuicStream s2) await SendAndReceiveDataAsync(s_data, s1, s2); await SendAndReceiveEOFAsync(s1, s2); - - await s1.ShutdownCompleted(); - await s2.ShutdownCompleted(); } private static async Task SendAndReceiveDataAsync(byte[] data, QuicStream s1, QuicStream s2) @@ -355,11 +324,9 @@ public async Task ReadWrite_Random_Success(int readSize, int writeSize) byte[] testBuffer = new byte[8192]; Random.Shared.NextBytes(testBuffer); - await RunClientServer( - async clientConnection => + await RunUnidirectionalClientServer( + async clientStream => { - await using QuicStream clientStream = clientConnection.OpenUnidirectionalStream(); - ReadOnlyMemory sendBuffer = testBuffer; while (sendBuffer.Length != 0) { @@ -368,17 +335,14 @@ await RunClientServer( sendBuffer = sendBuffer.Slice(chunk.Length); } - await clientStream.WriteAsync(Memory.Empty, endStream: true); - await clientStream.ShutdownCompleted(); + clientStream.CompleteWrites(); }, - async serverConnection => + async serverStream => { - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); - byte[] receiveBuffer = new byte[testBuffer.Length]; int totalBytesRead = 0; - while (true) // TODO: if you don't read until 0-byte read, ShutdownCompleted sometimes may not trigger - why? + while (true) { Memory recieveChunkBuffer = receiveBuffer.AsMemory(totalBytesRead, Math.Min(receiveBuffer.Length - totalBytesRead, readSize)); int bytesRead = await serverStream.ReadAsync(recieveChunkBuffer); @@ -392,8 +356,6 @@ await RunClientServer( Assert.Equal(testBuffer.Length, totalBytesRead); AssertArrayEqual(testBuffer, receiveBuffer); - - await serverStream.ShutdownCompleted(); }); } @@ -408,32 +370,116 @@ from writeSize in sizes } [Fact] - public async Task Read_StreamAborted_Throws() + public async Task Read_WriteAborted_Throws() { const int ExpectedErrorCode = 0xfffffff; - await Task.Run(async () => - { - using QuicListener listener = CreateQuicListener(); - ValueTask serverConnectionTask = listener.AcceptConnectionAsync(); + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + + await sem.WaitAsync(); + clientStream.Abort(ExpectedErrorCode, QuicAbortDirection.Write); + }, + async serverStream => + { + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + + sem.Release(); + + byte[] buffer = new byte[100]; + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); + } + + [Fact] + public async Task Read_SynchronousCompletion_Success() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + sem.Release(); + clientStream.CompleteWrites(); + sem.Release(); + }, + async serverStream => + { + await sem.WaitAsync(); + await Task.Delay(1000); + + ValueTask task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + int received = await task; + Assert.Equal(1, received); + + await sem.WaitAsync(); + await Task.Delay(1000); + + task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + received = await task; + Assert.Equal(0, received); + }); + } + + [Fact] + public async Task ReadOutstanding_ReadAborted_Throws() + { + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + Task exTask = Assert.ThrowsAsync(() => serverStream.ReadAsync(new byte[1]).AsTask()); + + Assert.False(exTask.IsCompleted); + + serverStream.Abort(ExpectedErrorCode, QuicAbortDirection.Read); + + await exTask; - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); + sem.Release(); + }); + } - using QuicConnection serverConnection = await serverConnectionTask; + [Fact] + public async Task Read_ConcurrentReads_Throws() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); - await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); - await clientStream.WriteAsync(new byte[1]); + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + ValueTask readTask = serverStream.ReadAsync(new byte[1]); + Assert.False(readTask.IsCompleted); - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); - await serverStream.ReadAsync(new byte[1]); + await Assert.ThrowsAsync(async () => await serverStream.ReadAsync(new byte[1])); - clientStream.AbortWrite(ExpectedErrorCode); + sem.Release(); - byte[] buffer = new byte[100]; - QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); - Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - }).WaitAsync(TimeSpan.FromSeconds(15)); + int res = await readTask; + Assert.Equal(0, res); + }); } [ActiveIssue("https://github.com/dotnet/runtime/issues/53530")] @@ -442,24 +488,18 @@ public async Task StreamAbortedWithoutWriting_ReadThrows() { long expectedErrorCode = 1234; - await RunClientServer( - clientFunction: async connection => + await RunUnidirectionalClientServer( + clientStream => { - await using QuicStream stream = connection.OpenUnidirectionalStream(); - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); + clientStream.Abort(expectedErrorCode, QuicAbortDirection.Write); + return Task.CompletedTask; }, - serverFunction: async connection => + async serverStream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[1]; - QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(stream, buffer)); + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(serverStream, buffer)); Assert.Equal(expectedErrorCode, ex.ErrorCode); - - await stream.ShutdownCompleted(); } ); } @@ -469,39 +509,31 @@ public async Task WritePreCanceled_Throws() { long expectedErrorCode = 1234; - await RunClientServer( - clientFunction: async connection => + await RunUnidirectionalClientServer( + async clientStream => { - await using QuicStream stream = connection.OpenUnidirectionalStream(); - CancellationTokenSource cts = new CancellationTokenSource(); cts.Cancel(); - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], cts.Token).AsTask()); + await Assert.ThrowsAsync(() => clientStream.WriteAsync(new byte[1], cts.Token).AsTask()); // next write would also throw - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1]).AsTask()); + await Assert.ThrowsAsync(() => clientStream.WriteAsync(new byte[1]).AsTask()); // manual write abort is still required - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); + clientStream.Abort(expectedErrorCode, QuicAbortDirection.Write); }, - serverFunction: async connection => + async serverStream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - byte[] buffer = new byte[1024 * 1024]; // TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530 //QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(stream, buffer)); try { - await ReadAll(stream, buffer); + await ReadAll(serverStream, buffer); } catch (QuicStreamAbortedException) { } - - await stream.ShutdownCompleted(); } ); } @@ -511,11 +543,9 @@ public async Task WriteCanceled_NextWriteThrows() { long expectedErrorCode = 1234; - await RunClientServer( - clientFunction: async connection => + await RunUnidirectionalClientServer( + async clientStream => { - await using QuicStream stream = connection.OpenUnidirectionalStream(); - CancellationTokenSource cts = new CancellationTokenSource(500); async Task WriteUntilCanceled() @@ -523,7 +553,7 @@ async Task WriteUntilCanceled() var buffer = new byte[64 * 1024]; while (true) { - await stream.WriteAsync(buffer, cancellationToken: cts.Token); + await clientStream.WriteAsync(buffer, cancellationToken: cts.Token); } } @@ -531,23 +561,19 @@ async Task WriteUntilCanceled() await Assert.ThrowsAsync(() => WriteUntilCanceled().WaitAsync(TimeSpan.FromSeconds(3))); // next write would also throw - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1]).AsTask()); + await Assert.ThrowsAsync(() => clientStream.WriteAsync(new byte[1]).AsTask()); // manual write abort is still required - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); + clientStream.Abort(expectedErrorCode, QuicAbortDirection.Write); }, - serverFunction: async connection => + async serverStream => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - async Task ReadUntilAborted() { var buffer = new byte[1024]; while (true) { - int res = await stream.ReadAsync(buffer); + int res = await serverStream.ReadAsync(buffer); if (res == 0) { break; @@ -562,11 +588,123 @@ async Task ReadUntilAborted() await ReadUntilAborted().WaitAsync(TimeSpan.FromSeconds(3)); } catch (QuicStreamAbortedException) { } - - await stream.ShutdownCompleted(); } ); } + + [Fact] + public async Task CloseAsync_Cancelled_Then_CloseAsync_Success() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + // Make sure the first task throws an OCE. + + using var cts = new CancellationTokenSource(500); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => + { + await clientStream.CloseAsync(cts.Token); + }); + + Assert.Equal(cts.Token, oce.CancellationToken); + + // Release before closing the stream, to allow the server to close its write stream. + + sem.Release(); + }, + async serverStream => + { + // Wait before closing the stream, which would otherwise cause the client's CloseAsync to finish. + + await sem.WaitAsync(); + }); + } + + // This tests the pattern needed to safely control shutdown of a QuicStream. + // 1. Normal stream usage happens inside try. + // 2. Call Abort(Both) in the catch. + // 3. Call Close() with a cancellation token in the finally. + // 4. If that Close() fails, call Abort(Immediate). + // + // This is important to avoid a DoS if the peer doesn't shutdown their sends but otherwise leaves the connection open. + // TODO: we should rework the API to make this a lot more foolproof. + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task QuicStream_ClosePattern_Success(bool abortive) + { + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + try + { + // All the usual stream usage happens inside a try block. + // Just a dummy throw here to demonstrate the pattern... + + if (abortive) + { + throw new Exception(); + } + } + catch + { + // Abort here. The CloseAsync that follows will still wait for an ACK of the shutdown. + clientStream.Abort(ExpectedErrorCode, QuicAbortDirection.Both); + } + finally + { + // Call CloseAsync() with a cancellation token to allow it to time out when peer doesn't shutdown. + + using var shutdownCts = new CancellationTokenSource(500); + try + { + await clientStream.CloseAsync(shutdownCts.Token); + } + catch + { + // Abort (possibly again, which will ignore error code and not queue any new I/O). + // This time, Immediate is used which will cause CloseAsync() to not wait for a shutdown ACK. + clientStream.Abort(ExpectedErrorCode, QuicAbortDirection.Immediate); + } + } + + // Either the CloseAsync above worked, in which case this is a no-op, + // or the stream has been re-aborted with Immediate, in which case this will complete "immediately" but not synchronously. + await clientStream.CloseAsync(); + + // Only allow the other side to close its stream after the dispose completes. + sem.Release(); + }, + async serverStream => + { + // Don't shutdown client side until server side has 100% completed. + await sem.WaitAsync(); + + // Wait for server's abort to reach us. + await Task.Delay(500); + + QuicStreamAbortedException ex = await Assert.ThrowsAsync(async () => + { + await serverStream.WriteAsync(new byte[1]); + }); + + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + + ex = await Assert.ThrowsAsync(async () => + { + await serverStream.ReadAsync(new byte[1]); + }); + + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); + } } public sealed class QuicStreamTests_MockProvider : QuicStreamTests { } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index ee7501868beba..9f2e367e56fc5 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -77,26 +77,80 @@ internal QuicListener CreateQuicListener(IPEndPoint endpoint) private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options); + internal Task RunUnidirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunClientServerStream(clientFunction, serverFunction, iterations, millisecondsTimeout, bidi: false); + + internal Task RunBidirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunClientServerStream(clientFunction, serverFunction, iterations, millisecondsTimeout, bidi: true); + + private async Task RunClientServerStream(Func clientFunction, Func serverFunction, int iterations, int millisecondsTimeout, bool bidi) + { + const long ClientThrewAbortCode = 1234567890; + const long ServerThrewAbortCode = 2345678901; + + await RunClientServer( + async clientConnection => + { + await using QuicStream clientStream = bidi ? clientConnection.OpenBidirectionalStream() : clientConnection.OpenUnidirectionalStream(); + try + { + await clientFunction(clientStream); + } + catch + { + try + { + // abort the stream to give the peer a chance to tear down. + clientStream.Abort(ClientThrewAbortCode); + } + catch(ObjectDisposedException) + { + // do nothing. + } + + throw; + } + }, + async serverConnection => + { + await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + try + { + await serverFunction(serverStream); + } + catch + { + try + { + // abort the stream to give the peer a chance to tear down. + serverStream.Abort(ServerThrewAbortCode); + } + catch (ObjectDisposedException) + { + // do nothing. + } + throw; + } + }, iterations, millisecondsTimeout); + } + internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) { using QuicListener listener = CreateQuicListener(); - var serverFinished = new ManualResetEventSlim(); - var clientFinished = new ManualResetEventSlim(); + using var serverFinished = new SemaphoreSlim(0); + using var clientFinished = new SemaphoreSlim(0); for (int i = 0; i < iterations; ++i) { - serverFinished.Reset(); - clientFinished.Reset(); - await new[] { Task.Run(async () => { using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); await serverFunction(serverConnection); - serverFinished.Set(); - clientFinished.Wait(); + serverFinished.Release(); + await clientFinished.WaitAsync(); await serverConnection.CloseAsync(0); }), Task.Run(async () => @@ -104,8 +158,8 @@ await new[] using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); await clientConnection.ConnectAsync(); await clientFunction(clientConnection); - clientFinished.Set(); - serverFinished.Wait(); + clientFinished.Release(); + await serverFinished.WaitAsync(); await clientConnection.CloseAsync(0); }) }.WhenAllOrAnyFailed(millisecondsTimeout);