From 72a4e466ce30d9751ce89dc350e00f5067646e8a Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 6 Sep 2022 20:09:20 +0200 Subject: [PATCH] [QUIC] Fix native crashes and heap corruption via "generated-like" interop (#74669) * Send buffers and handles crash fixes * Add generated-like interop * Apply PR feedback from #74611 * Change asserts * Feedback + moved native methods to their own file * PR feedback Co-authored-by: ManickaP --- .../Quic/Internal/MsQuicApi.NativeMethods.cs | 378 ++++++++++++++++++ .../src/System/Net/Quic/Internal/MsQuicApi.cs | 2 +- .../Net/Quic/Internal/MsQuicConfiguration.cs | 12 +- .../System/Net/Quic/Internal/MsQuicHelpers.cs | 8 +- .../Net/Quic/Internal/MsQuicSafeHandle.cs | 21 +- .../src/System/Net/Quic/QuicConnection.cs | 34 +- .../src/System/Net/Quic/QuicListener.cs | 12 +- .../src/System/Net/Quic/QuicStream.cs | 84 ++-- 8 files changed, 484 insertions(+), 67 deletions(-) create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs new file mode 100644 index 0000000000000..206eac76ac787 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs @@ -0,0 +1,378 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Quic; + +namespace System.Net.Quic; + +internal sealed unsafe partial class MsQuicApi +{ + public void SetContext(MsQuicSafeHandle handle, void* context) + { + bool success = false; + try + { + handle.DangerousAddRef(ref success); + ApiTable->SetContext(handle.QuicHandle, context); + } + finally + { + if (success) + { + handle.DangerousRelease(); + } + } + } + + public void* GetContext(MsQuicSafeHandle handle) + { + bool success = false; + try + { + handle.DangerousAddRef(ref success); + return ApiTable->GetContext(handle.QuicHandle); + } + finally + { + if (success) + { + handle.DangerousRelease(); + } + } + } + + public void SetCallbackHandler(MsQuicSafeHandle handle, void* callback, void* context) + { + bool success = false; + try + { + handle.DangerousAddRef(ref success); + ApiTable->SetCallbackHandler(handle.QuicHandle, callback, context); + } + finally + { + if (success) + { + handle.DangerousRelease(); + } + } + } + + public int SetParam(MsQuicSafeHandle handle, uint param, uint bufferLength, void* buffer) + { + bool success = false; + try + { + handle.DangerousAddRef(ref success); + return ApiTable->SetParam(handle.QuicHandle, param, bufferLength, buffer); + } + finally + { + if (success) + { + handle.DangerousRelease(); + } + } + } + + public int GetParam(MsQuicSafeHandle handle, uint param, uint* bufferLength, void* buffer) + { + bool success = false; + try + { + handle.DangerousAddRef(ref success); + return ApiTable->GetParam(handle.QuicHandle, param, bufferLength, buffer); + } + finally + { + if (success) + { + handle.DangerousRelease(); + } + } + } + + public void RegistrationShutdown(MsQuicSafeHandle registration, QUIC_CONNECTION_SHUTDOWN_FLAGS flags, ulong code) + { + bool success = false; + try + { + registration.DangerousAddRef(ref success); + ApiTable->RegistrationShutdown(registration.QuicHandle, flags, code); + } + finally + { + if (success) + { + registration.DangerousRelease(); + } + } + } + + public int ConfigurationOpen(MsQuicSafeHandle registration, QUIC_BUFFER* alpnBuffers, uint alpnBuffersCount, QUIC_SETTINGS* settings, uint settingsSize, void* context, QUIC_HANDLE** configuration) + { + bool success = false; + try + { + registration.DangerousAddRef(ref success); + return ApiTable->ConfigurationOpen(registration.QuicHandle, alpnBuffers, alpnBuffersCount, settings, settingsSize, context, configuration); + } + finally + { + if (success) + { + registration.DangerousRelease(); + } + } + } + + public int ConfigurationLoadCredential(MsQuicSafeHandle configuration, QUIC_CREDENTIAL_CONFIG* config) + { + bool success = false; + try + { + configuration.DangerousAddRef(ref success); + return ApiTable->ConfigurationLoadCredential(configuration.QuicHandle, config); + } + finally + { + if (success) + { + configuration.DangerousRelease(); + } + } + } + + public int ListenerOpen(MsQuicSafeHandle registration, delegate* unmanaged[Cdecl] callback, void* context, QUIC_HANDLE** listener) + { + bool success = false; + try + { + registration.DangerousAddRef(ref success); + return ApiTable->ListenerOpen(registration.QuicHandle, callback, context, listener); + } + finally + { + if (success) + { + registration.DangerousRelease(); + } + } + } + + public int ListenerStart(MsQuicSafeHandle listener, QUIC_BUFFER* alpnBuffers, uint alpnBuffersCount, QuicAddr* localAddress) + { + bool success = false; + try + { + listener.DangerousAddRef(ref success); + return ApiTable->ListenerStart(listener.QuicHandle, alpnBuffers, alpnBuffersCount, localAddress); + } + finally + { + if (success) + { + listener.DangerousRelease(); + } + } + } + + public void ListenerStop(MsQuicSafeHandle listener) + { + bool success = false; + try + { + listener.DangerousAddRef(ref success); + ApiTable->ListenerStop(listener.QuicHandle); + } + finally + { + if (success) + { + listener.DangerousRelease(); + } + } + } + + public int ConnectionOpen(MsQuicSafeHandle registration, delegate* unmanaged[Cdecl] callback, void* context, QUIC_HANDLE** connection) + { + bool success = false; + try + { + registration.DangerousAddRef(ref success); + return ApiTable->ConnectionOpen(registration.QuicHandle, callback, context, connection); + } + finally + { + if (success) + { + registration.DangerousRelease(); + } + } + } + + public void ConnectionShutdown(MsQuicSafeHandle connection, QUIC_CONNECTION_SHUTDOWN_FLAGS flags, ulong code) + { + bool success = false; + try + { + connection.DangerousAddRef(ref success); + ApiTable->ConnectionShutdown(connection.QuicHandle, flags, code); + } + finally + { + if (success) + { + connection.DangerousRelease(); + } + } + } + + public int ConnectionStart(MsQuicSafeHandle connection, MsQuicSafeHandle configuration, ushort family, sbyte* serverName, ushort serverPort) + { + bool connectionSuccess = false; + bool configurationSuccess = false; + try + { + connection.DangerousAddRef(ref connectionSuccess); + configuration.DangerousAddRef(ref configurationSuccess); + return ApiTable->ConnectionStart(connection.QuicHandle, configuration.QuicHandle, family, serverName, serverPort); + } + finally + { + if (connectionSuccess) + { + connection.DangerousRelease(); + } + if (configurationSuccess) + { + configuration.DangerousRelease(); + } + } + } + + public int ConnectionSetConfiguration(MsQuicSafeHandle connection, MsQuicSafeHandle configuration) + { + bool connectionSuccess = false; + bool configurationSuccess = false; + try + { + connection.DangerousAddRef(ref connectionSuccess); + configuration.DangerousAddRef(ref configurationSuccess); + return ApiTable->ConnectionSetConfiguration(connection.QuicHandle, configuration.QuicHandle); + } + finally + { + if (connectionSuccess) + { + connection.DangerousRelease(); + } + if (configurationSuccess) + { + configuration.DangerousRelease(); + } + } + } + + public int StreamOpen(MsQuicSafeHandle connection, QUIC_STREAM_OPEN_FLAGS flags, delegate* unmanaged[Cdecl] callback, void* context, QUIC_HANDLE** stream) + { + bool success = false; + try + { + connection.DangerousAddRef(ref success); + return ApiTable->StreamOpen(connection.QuicHandle, flags, callback, context, stream); + } + finally + { + if (success) + { + connection.DangerousRelease(); + } + } + } + + public int StreamStart(MsQuicSafeHandle stream, QUIC_STREAM_START_FLAGS flags) + { + bool success = false; + try + { + stream.DangerousAddRef(ref success); + return ApiTable->StreamStart(stream.QuicHandle, flags); + } + finally + { + if (success) + { + stream.DangerousRelease(); + } + } + } + + public int StreamShutdown(MsQuicSafeHandle stream, QUIC_STREAM_SHUTDOWN_FLAGS flags, ulong code) + { + bool success = false; + try + { + stream.DangerousAddRef(ref success); + return ApiTable->StreamShutdown(stream.QuicHandle, flags, code); + } + finally + { + if (success) + { + stream.DangerousRelease(); + } + } + } + + public int StreamSend(MsQuicSafeHandle stream, QUIC_BUFFER* buffers, uint buffersCount, QUIC_SEND_FLAGS flags, void* context) + { + bool success = false; + try + { + stream.DangerousAddRef(ref success); + return ApiTable->StreamSend(stream.QuicHandle, buffers, buffersCount, flags, context); + } + finally + { + if (success) + { + stream.DangerousRelease(); + } + } + } + + public void StreamReceiveComplete(MsQuicSafeHandle stream, ulong length) + { + bool success = false; + try + { + stream.DangerousAddRef(ref success); + ApiTable->StreamReceiveComplete(stream.QuicHandle, length); + } + finally + { + if (success) + { + stream.DangerousRelease(); + } + } + } + + public int StreamReceiveSetEnabled(MsQuicSafeHandle stream, byte enabled) + { + bool success = false; + try + { + stream.DangerousAddRef(ref success); + return ApiTable->StreamReceiveSetEnabled(stream.QuicHandle, enabled); + } + finally + { + if (success) + { + stream.DangerousRelease(); + } + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs index e2866454356dd..e28134ea4b6f5 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs @@ -13,7 +13,7 @@ namespace System.Net.Quic; -internal sealed unsafe class MsQuicApi +internal sealed unsafe partial class MsQuicApi { private static readonly Version MinWindowsVersion = new Version(10, 0, 20145, 1000); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs index b84cb6cce4267..ddec979aade1b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs @@ -131,8 +131,8 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI using MsQuicBuffers msquicBuffers = new MsQuicBuffers(); msquicBuffers.Initialize(alpnProtocols, alpnProtocol => alpnProtocol.Protocol); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConfigurationOpen( - MsQuicApi.Api.Registration.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ConfigurationOpen( + MsQuicApi.Api.Registration, msquicBuffers.Buffers, (uint)alpnProtocols.Count, &settings, @@ -140,7 +140,7 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI (void*)IntPtr.Zero, &handle), "ConfigurationOpen failed"); - MsQuicSafeHandle configurationHandle = new MsQuicSafeHandle(handle, MsQuicApi.Api.ApiTable->ConfigurationClose, SafeHandleType.Configuration); + MsQuicSafeHandle configurationHandle = new MsQuicSafeHandle(handle, SafeHandleType.Configuration); try { @@ -157,13 +157,13 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI if (certificate is null) { config.Type = QUIC_CREDENTIAL_TYPE.NONE; - status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); + status = MsQuicApi.Api.ConfigurationLoadCredential(configurationHandle, &config); } else if (MsQuicApi.UsesSChannelBackend) { config.Type = QUIC_CREDENTIAL_TYPE.CERTIFICATE_CONTEXT; config.CertificateContext = (void*)certificate.Handle; - status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); + status = MsQuicApi.Api.ConfigurationLoadCredential(configurationHandle, &config); } else { @@ -192,7 +192,7 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI PrivateKeyPassword = (sbyte*)IntPtr.Zero }; config.CertificatePkcs12 = &pkcs12Certificate; - status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); + status = MsQuicApi.Api.ConfigurationLoadCredential(configurationHandle, &config); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs index 04beebc6a2fc1..683e8bb62473e 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -61,8 +61,8 @@ internal static unsafe T GetMsQuicParameter(MsQuicSafeHandle handle, uint par T value; uint length = (uint)sizeof(T); - int status = MsQuicApi.Api.ApiTable->GetParam( - handle.QuicHandle, + int status = MsQuicApi.Api.GetParam( + handle, parameter, &length, (byte*)&value); @@ -78,8 +78,8 @@ internal static unsafe T GetMsQuicParameter(MsQuicSafeHandle handle, uint par internal static unsafe void SetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, T value) where T : unmanaged { - int status = MsQuicApi.Api.ApiTable->SetParam( - handle.QuicHandle, + int status = MsQuicApi.Api.SetParam( + handle, parameter, (uint)sizeof(T), (byte*)&value); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs index 6e247ef993737..58b41443694a1 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs @@ -39,10 +39,25 @@ public MsQuicSafeHandle(QUIC_HANDLE* handle, delegate* unmanaged[Cdecl] MsQuicApi.Api.ApiTable->RegistrationClose, + SafeHandleType.Configuration => MsQuicApi.Api.ApiTable->ConfigurationClose, + SafeHandleType.Listener => MsQuicApi.Api.ApiTable->ListenerClose, + SafeHandleType.Connection => MsQuicApi.Api.ApiTable->ConnectionClose, + SafeHandleType.Stream => MsQuicApi.Api.ApiTable->StreamClose, + _ => throw new ArgumentException($"Unexpected value: {safeHandleType}", nameof(safeHandleType)) + }, + safeHandleType) { } + protected override bool ReleaseHandle() { - _releaseAction(QuicHandle); + QUIC_HANDLE* quicHandle = QuicHandle; SetHandle(IntPtr.Zero); + _releaseAction(quicHandle); if (NetEventSource.Log.IsEnabled()) { @@ -77,8 +92,8 @@ internal sealed class MsQuicContextSafeHandle : MsQuicSafeHandle /// private readonly MsQuicSafeHandle? _parent; - public unsafe MsQuicContextSafeHandle(QUIC_HANDLE* handle, GCHandle context, delegate* unmanaged[Cdecl] releaseAction, SafeHandleType safeHandleType, MsQuicSafeHandle? parent = null) - : base(handle, releaseAction, safeHandleType) + public unsafe MsQuicContextSafeHandle(QUIC_HANDLE* handle, GCHandle context, SafeHandleType safeHandleType, MsQuicSafeHandle? parent = null) + : base(handle, safeHandleType) { _context = context; if (parent is not null) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index cd23a611ff06f..83073bcfab502 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -179,13 +179,13 @@ private unsafe QuicConnection() try { QUIC_HANDLE* handle; - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionOpen( - MsQuicApi.Api.Registration.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ConnectionOpen( + MsQuicApi.Api.Registration, &NativeCallback, (void*)GCHandle.ToIntPtr(context), &handle), "ConnectionOpen failed"); - _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->ConnectionClose, SafeHandleType.Connection); + _handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Connection); } catch { @@ -204,12 +204,12 @@ internal unsafe QuicConnection(QUIC_HANDLE* handle, QUIC_NEW_CONNECTION_INFO* in GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); try { + _handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Connection); delegate* unmanaged[Cdecl] nativeCallback = &NativeCallback; - MsQuicApi.Api.ApiTable->SetCallbackHandler( - handle, + MsQuicApi.Api.SetCallbackHandler( + _handle, nativeCallback, (void*)GCHandle.ToIntPtr(context)); - _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->ConnectionClose, SafeHandleType.Connection); } catch { @@ -294,9 +294,9 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options, { unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionStart( - _handle.QuicHandle, - _configuration.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ConnectionStart( + _handle, + _configuration, (ushort)addressFamily, (sbyte*)targetHostPtr, (ushort)port), @@ -334,9 +334,9 @@ internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, str unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionSetConfiguration( - _handle.QuicHandle, - _configuration.QuicHandle), + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ConnectionSetConfiguration( + _handle, + _configuration), "ConnectionSetConfiguration failed"); } } @@ -430,8 +430,8 @@ public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken { unsafe { - MsQuicApi.Api.ApiTable->ConnectionShutdown( - _handle.QuicHandle, + MsQuicApi.Api.ConnectionShutdown( + _handle, QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, (ulong)errorCode); } @@ -474,8 +474,8 @@ private unsafe int HandleEventShutdownInitiatedByPeer(ref SHUTDOWN_INITIATED_BY_ } private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data) { - _shutdownTcs.TrySetResult(); _acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException())); + _shutdownTcs.TrySetResult(); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data) @@ -582,8 +582,8 @@ public async ValueTask DisposeAsync() { unsafe { - MsQuicApi.Api.ApiTable->ConnectionShutdown( - _handle.QuicHandle, + MsQuicApi.Api.ConnectionShutdown( + _handle, QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, (ulong)_defaultCloseErrorCode); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index 37d0e0d2079d7..a99e82159eceb 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -106,13 +106,13 @@ private unsafe QuicListener(QuicListenerOptions options) try { QUIC_HANDLE* handle; - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ListenerOpen( - MsQuicApi.Api.Registration.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ListenerOpen( + MsQuicApi.Api.Registration, &NativeCallback, (void*)GCHandle.ToIntPtr(context), &handle), "ListenerOpen failed"); - _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->ListenerClose, SafeHandleType.Listener); + _handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Listener); } catch { @@ -135,8 +135,8 @@ private unsafe QuicListener(QuicListenerOptions options) // Using the Unspecified family makes MsQuic handle connections from all IP addresses. address.Family = QUIC_ADDRESS_FAMILY_UNSPEC; } - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ListenerStart( - _handle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ListenerStart( + _handle, alpnBuffers.Buffers, (uint)alpnBuffers.Count, &address), @@ -266,7 +266,7 @@ public async ValueTask DisposeAsync() { unsafe { - MsQuicApi.Api.ApiTable->ListenerStop(_handle.QuicHandle); + MsQuicApi.Api.ListenerStop(_handle); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index c11f3029a2136..c2cf3f23bf4de 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -90,6 +90,7 @@ public sealed partial class QuicStream } }; private MsQuicBuffers _sendBuffers = new MsQuicBuffers(); + private object _sendBuffersLock = new object(); private readonly long _defaultErrorCode; @@ -141,14 +142,14 @@ internal unsafe QuicStream(MsQuicContextSafeHandle connectionHandle, QuicStreamT try { QUIC_HANDLE* handle; - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamOpen( - connectionHandle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamOpen( + connectionHandle, type == QuicStreamType.Unidirectional ? QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL : QUIC_STREAM_OPEN_FLAGS.NONE, &NativeCallback, (void*)GCHandle.ToIntPtr(context), &handle), "StreamOpen failed"); - _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->StreamClose, SafeHandleType.Stream, connectionHandle); + _handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle); } catch { @@ -179,12 +180,12 @@ internal unsafe QuicStream(MsQuicContextSafeHandle connectionHandle, QUIC_HANDLE GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); try { + _handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle); delegate* unmanaged[Cdecl] nativeCallback = &NativeCallback; - MsQuicApi.Api.ApiTable->SetCallbackHandler( - handle, + MsQuicApi.Api.SetCallbackHandler( + _handle, nativeCallback, (void*)GCHandle.ToIntPtr(context)); - _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->StreamClose, SafeHandleType.Stream, connectionHandle); } catch { @@ -220,8 +221,8 @@ internal ValueTask StartAsync(CancellationToken cancellationToken = default) { unsafe { - int status = MsQuicApi.Api.ApiTable->StreamStart( - _handle.QuicHandle, + int status = MsQuicApi.Api.StreamStart( + _handle, QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT); if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) { @@ -297,8 +298,8 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamReceiveSetEnabled( - _handle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamReceiveSetEnabled( + _handle, 1), "StreamReceivedSetEnabled failed"); } @@ -360,19 +361,34 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca return valueTask; } - _sendBuffers.Initialize(buffer); - unsafe + lock (_sendBuffersLock) { - int status = MsQuicApi.Api.ApiTable->StreamSend( - _handle.QuicHandle, - _sendBuffers.Buffers, - (uint)_sendBuffers.Count, - completeWrites ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, - null); - if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + ObjectDisposedException.ThrowIf(_disposed == 1, this); // TODO: valueTask is left unobserved + unsafe { - _sendBuffers.Reset(); - _sendTcs.TrySetException(exception, final: true); + if (_sendBuffers.Count > 0 && _sendBuffers.Buffers[0].Buffer != null) + { + // _sendBuffers are not reset, meaning SendComplete for the previous WriteAsync call didn't arrive yet. + // In case of cancellation, the task from _sendTcs is finished before the aborting. It is technically possible for subsequent + // WriteAsync to grab the next task from _sendTcs and start executing before SendComplete event occurs for the previous (canceled) write. + // This is not an "invalid nested call", because the previous task has finished. Best guess is to mimic OperationAborted as it will be from Abort + // that would execute soon enough, if not already. Not final, because Abort should be the one to set final exception. + _sendTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted), final: false); + return valueTask; + } + + _sendBuffers.Initialize(buffer); + int status = MsQuicApi.Api.StreamSend( + _handle, + _sendBuffers.Buffers, + (uint)_sendBuffers.Count, + completeWrites ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, + null); + if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + { + _sendBuffers.Reset(); + _sendTcs.TrySetException(exception, final: true); + } } } @@ -419,8 +435,8 @@ public void Abort(QuicAbortDirection abortDirection, long errorCode) unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( - _handle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamShutdown( + _handle, flags, (ulong)errorCode), "StreamShutdown failed"); @@ -442,8 +458,8 @@ public void CompleteWrites() { unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( - _handle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamShutdown( + _handle, QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, default), "StreamShutdown failed"); @@ -490,7 +506,12 @@ private unsafe int HandleEventReceive(ref RECEIVE data) } private unsafe int HandleEventSendComplete(ref SEND_COMPLETE data) { - _sendBuffers.Reset(); + // In case of cancellation, the task from _sendTcs is finished before the aborting. It is technically possible for subsequent WriteAsync to grab the next task + // from _sendTcs and start executing before SendComplete event occurs for the previous (canceled) write + lock (_sendBuffersLock) + { + _sendBuffers.Reset(); + } if (data.Canceled == 0) { _sendTcs.TrySetResult(); @@ -653,13 +674,16 @@ public override async ValueTask DisposeAsync() await valueTask.ConfigureAwait(false); _handle.Dispose(); - // TODO: memory leak if not disposed - _sendBuffers.Dispose(); + lock (_sendBuffersLock) + { + // TODO: memory leak if not disposed + _sendBuffers.Dispose(); + } unsafe void StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { - int status = MsQuicApi.Api.ApiTable->StreamShutdown( - _handle.QuicHandle, + int status = MsQuicApi.Api.StreamShutdown( + _handle, flags, (ulong)errorCode); if (StatusFailed(status))