diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj index 1757202c4a178..c6421ec4e1986 100644 --- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -1,4 +1,4 @@ - + true $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent) @@ -46,8 +46,6 @@ - - - - @@ -192,8 +188,6 @@ - - diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs index 1cb1f60789995..a60cbb975dafe 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs @@ -21,12 +21,5 @@ internal partial class BaseOverlappedAsyncResult : ContextAwareResult _numBytes = numBytes; return s_resultObjectSentinel; // return sentinel rather than boxing numBytes } - - // Used instead of the base InternalWaitForCompletion when storing an Int32 result - internal int InternalWaitForCompletionInt32Result() - { - base.InternalWaitForCompletion(); - return _numBytes; - } } } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.Unix.cs deleted file mode 100644 index 49cb1262b3d1b..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.Unix.cs +++ /dev/null @@ -1,32 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; - -namespace System.Net.Sockets -{ - // OverlappedAsyncResult - // - // This class is used to take care of storage for async Socket operation - // from the BeginSend, BeginSendTo, BeginReceive, BeginReceiveFrom calls. - internal partial class OverlappedAsyncResult : BaseOverlappedAsyncResult - { - private int _socketAddressSize; - - internal int GetSocketAddressSize() - { - return _socketAddressSize; - } - - public void CompletionCallback(int numBytes, byte[]? socketAddress, int socketAddressSize, SocketFlags receivedFlags, SocketError errorCode) - { - if (_socketAddress != null) - { - Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socket address: {socketAddress}"); - _socketAddressSize = socketAddressSize; - } - - base.CompletionCallback(numBytes, errorCode); - } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.Windows.cs deleted file mode 100644 index dced861318f9f..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.Windows.cs +++ /dev/null @@ -1,136 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics; -using System.Runtime.InteropServices; - -namespace System.Net.Sockets -{ - // OverlappedAsyncResult - // - // This class is used to take care of storage for async Socket operation - // from the BeginSend, BeginSendTo, BeginReceive, BeginReceiveFrom calls. - internal partial class OverlappedAsyncResult : BaseOverlappedAsyncResult - { - internal WSABuffer _singleBuffer; - internal WSABuffer[]? _wsaBuffers; - - internal IntPtr GetSocketAddressPtr() - { - return Marshal.UnsafeAddrOfPinnedArrayElement(_socketAddress!.Buffer, 0); - } - - internal IntPtr GetSocketAddressSizePtr() - { - return Marshal.UnsafeAddrOfPinnedArrayElement(_socketAddress!.Buffer, _socketAddress.GetAddressSizeOffset()); - } - - internal unsafe int GetSocketAddressSize() - { - return *(int*)GetSocketAddressSizePtr(); - } - - // SetUnmanagedStructures - // - // Fills in overlapped structures used in an async overlapped Winsock call. - // These calls are outside the runtime and are unmanaged code, so we need - // to prepare specific structures and ints that lie in unmanaged memory - // since the overlapped calls may complete asynchronously. - internal void SetUnmanagedStructures(byte[] buffer, int offset, int size, Internals.SocketAddress? socketAddress) - { - // Fill in Buffer Array structure that will be used for our send/recv Buffer - _socketAddress = socketAddress; - if (_socketAddress != null) - { - object[] objectsToPin = new object[2]; - objectsToPin[0] = buffer; - - _socketAddress.CopyAddressSizeIntoBuffer(); - objectsToPin[1] = _socketAddress.Buffer; - - base.SetUnmanagedStructures(objectsToPin); - } - else - { - base.SetUnmanagedStructures(buffer); - } - - _singleBuffer.Length = size; - _singleBuffer.Pointer = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset); - } - - internal void SetUnmanagedStructures(IList> buffers) - { - // Fill in Buffer Array structure that will be used for our send/recv Buffer. - // Make sure we don't let the app mess up the buffer array enough to cause - // corruption. - int count = buffers.Count; - ArraySegment[] buffersCopy = new ArraySegment[count]; - - for (int i = 0; i < count; i++) - { - buffersCopy[i] = buffers[i]; - RangeValidationHelpers.ValidateSegment(buffersCopy[i]); - } - - _wsaBuffers = new WSABuffer[count]; - - object[] objectsToPin = new object[count]; - for (int i = 0; i < count; i++) - { - objectsToPin[i] = buffersCopy[i].Array!; - } - - base.SetUnmanagedStructures(objectsToPin); - - for (int i = 0; i < count; i++) - { - _wsaBuffers[i].Length = buffersCopy[i].Count; - _wsaBuffers[i].Pointer = Marshal.UnsafeAddrOfPinnedArrayElement(buffersCopy[i].Array!, buffersCopy[i].Offset); - } - } - - // This method is called after an asynchronous call is made for the user. - // It checks and acts accordingly if the IO: - // 1) completed synchronously. - // 2) was pended. - // 3) failed. - internal override object? PostCompletion(int numBytes) - { - if (ErrorCode == 0 && NetEventSource.Log.IsEnabled()) - { - LogBuffer(numBytes); - } - - return base.PostCompletion(numBytes); - } - - private void LogBuffer(int size) - { - // This should only be called if tracing is enabled. However, there is the potential for a race - // condition where tracing is disabled between a calling check and here, in which case the assert - // may fire erroneously. - Debug.Assert(NetEventSource.Log.IsEnabled()); - - if (size > -1) - { - if (_wsaBuffers != null) - { - foreach (WSABuffer wsaBuffer in _wsaBuffers) - { - NetEventSource.DumpBuffer(this, wsaBuffer.Pointer, Math.Min(wsaBuffer.Length, size)); - if ((size -= wsaBuffer.Length) <= 0) - { - break; - } - } - } - else - { - NetEventSource.DumpBuffer(this, _singleBuffer.Pointer, Math.Min(_singleBuffer.Length, size)); - } - } - } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.cs deleted file mode 100644 index 8ca9bca32929f..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/OverlappedAsyncResult.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Net.Sockets -{ - // OverlappedAsyncResult - // - // This class is used to take care of storage for async Socket operation - // from the BeginSend, BeginSendTo, BeginReceive, BeginReceiveFrom calls. - internal partial class OverlappedAsyncResult : BaseOverlappedAsyncResult - { - private Internals.SocketAddress? _socketAddress; - - internal OverlappedAsyncResult(Socket socket, object? asyncState, AsyncCallback? asyncCallback) : - base(socket, asyncState, asyncCallback) - { - } - - internal Internals.SocketAddress? SocketAddress - { - get { return _socketAddress; } - set { _socketAddress = value; } - } - } - - internal sealed class OriginalAddressOverlappedAsyncResult : OverlappedAsyncResult - { - internal OriginalAddressOverlappedAsyncResult(Socket socket, object? asyncState, AsyncCallback? asyncCallback) : - base(socket, asyncState, asyncCallback) - { - } - - internal Internals.SocketAddress? SocketAddressOriginal { get; set; } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.Unix.cs deleted file mode 100644 index aa596b291b015..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.Unix.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; - -namespace System.Net.Sockets -{ - internal unsafe sealed partial class ReceiveMessageOverlappedAsyncResult : BaseOverlappedAsyncResult - { - private int _socketAddressSize; - - internal int GetSocketAddressSize() - { - return _socketAddressSize; - } - - public void CompletionCallback(int numBytes, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) - { - Debug.Assert(_socketAddress != null, "_socketAddress was null"); - Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socketAddress: {socketAddress}"); - - _socketAddressSize = socketAddressSize; - _socketFlags = receivedFlags; - _ipPacketInformation = ipPacketInformation; - - base.CompletionCallback(numBytes, errorCode); - } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.Windows.cs deleted file mode 100644 index 5aa6bf3289876..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.Windows.cs +++ /dev/null @@ -1,138 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.Runtime.InteropServices; - -namespace System.Net.Sockets -{ - internal unsafe sealed partial class ReceiveMessageOverlappedAsyncResult : BaseOverlappedAsyncResult - { - private Interop.Winsock.WSAMsg* _message; - private WSABuffer* _wsaBuffer; - private byte[]? _wsaBufferArray; - private byte[]? _controlBuffer; - internal byte[]? _messageBuffer; - - private IntPtr GetSocketAddressSizePtr() - { - return Marshal.UnsafeAddrOfPinnedArrayElement(_socketAddress!.Buffer, _socketAddress.GetAddressSizeOffset()); - } - - internal unsafe int GetSocketAddressSize() - { - return *(int*)GetSocketAddressSizePtr(); - } - - // SetUnmanagedStructures - // - // Fills in overlapped Structures used in an async overlapped Winsock call. - // These calls are outside the runtime and are unmanaged code, so we need - // to prepare specific structures and ints that lie in unmanaged memory - // since the overlapped calls may complete asynchronously. - internal unsafe void SetUnmanagedStructures(byte[] buffer, int offset, int size, Internals.SocketAddress socketAddress, SocketFlags socketFlags) - { - _messageBuffer = new byte[sizeof(Interop.Winsock.WSAMsg)]; - _wsaBufferArray = new byte[sizeof(WSABuffer)]; - - bool ipv4, ipv6; - Socket.GetIPProtocolInformation(((Socket)AsyncObject!).AddressFamily, socketAddress, out ipv4, out ipv6); - - // Prepare control buffer. - if (ipv4) - { - _controlBuffer = new byte[sizeof(Interop.Winsock.ControlData)]; - } - else if (ipv6) - { - _controlBuffer = new byte[sizeof(Interop.Winsock.ControlDataIPv6)]; - } - - // Pin buffers. - object[] objectsToPin = new object[(_controlBuffer != null) ? 5 : 4]; - objectsToPin[0] = buffer; - objectsToPin[1] = _messageBuffer; - objectsToPin[2] = _wsaBufferArray; - - // Prepare socketaddress buffer. - _socketAddress = socketAddress; - _socketAddress.CopyAddressSizeIntoBuffer(); - objectsToPin[3] = _socketAddress.Buffer; - - if (_controlBuffer != null) - { - objectsToPin[4] = _controlBuffer; - } - - base.SetUnmanagedStructures(objectsToPin); - - // Prepare data buffer. - _wsaBuffer = (WSABuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_wsaBufferArray, 0); - _wsaBuffer->Length = size; - _wsaBuffer->Pointer = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset); - - - // Setup structure. - _message = (Interop.Winsock.WSAMsg*)Marshal.UnsafeAddrOfPinnedArrayElement(_messageBuffer, 0); - _message->socketAddress = Marshal.UnsafeAddrOfPinnedArrayElement(_socketAddress.Buffer, 0); - _message->addressLength = (uint)_socketAddress.Size; - _message->buffers = Marshal.UnsafeAddrOfPinnedArrayElement(_wsaBufferArray, 0); - _message->count = 1; - - if (_controlBuffer != null) - { - _message->controlBuffer.Pointer = Marshal.UnsafeAddrOfPinnedArrayElement(_controlBuffer, 0); - _message->controlBuffer.Length = _controlBuffer.Length; - } - - _message->flags = socketFlags; - } - - private unsafe void InitIPPacketInformation() - { - int? controlBufferLength = _controlBuffer?.Length; - if (controlBufferLength == sizeof(Interop.Winsock.ControlData)) - { - // IPv4 - _ipPacketInformation = SocketPal.GetIPPacketInformation((Interop.Winsock.ControlData*)_message->controlBuffer.Pointer); - } - else if (controlBufferLength == sizeof(Interop.Winsock.ControlDataIPv6)) - { - // IPv6 - _ipPacketInformation = SocketPal.GetIPPacketInformation((Interop.Winsock.ControlDataIPv6*)_message->controlBuffer.Pointer); - } - else - { - // Other - _ipPacketInformation = default; - } - } - - protected override void ForceReleaseUnmanagedStructures() - { - _socketFlags = _message->flags; - base.ForceReleaseUnmanagedStructures(); - } - - internal override object? PostCompletion(int numBytes) - { - InitIPPacketInformation(); - if (ErrorCode == 0 && NetEventSource.Log.IsEnabled()) - { - LogBuffer(numBytes); - } - - return base.PostCompletion(numBytes); - } - - private void LogBuffer(int size) - { - // This should only be called if tracing is enabled. However, there is the potential for a race - // condition where tracing is disabled between a calling check and here, in which case the assert - // may fire erroneously. - Debug.Assert(NetEventSource.Log.IsEnabled()); - - NetEventSource.DumpBuffer(this, _wsaBuffer->Pointer, Math.Min(_wsaBuffer->Length, size)); - } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.cs deleted file mode 100644 index 7f330a6f726bc..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ReceiveMessageOverlappedAsyncResult.cs +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Net.Sockets -{ - internal unsafe sealed partial class ReceiveMessageOverlappedAsyncResult : BaseOverlappedAsyncResult - { - private Internals.SocketAddress? _socketAddressOriginal; - private Internals.SocketAddress? _socketAddress; - - private SocketFlags _socketFlags; - private IPPacketInformation _ipPacketInformation; - - internal ReceiveMessageOverlappedAsyncResult(Socket socket, object? asyncState, AsyncCallback? asyncCallback) : - base(socket, asyncState, asyncCallback) - { } - - internal Internals.SocketAddress? SocketAddress - { - get - { - return _socketAddress; - } - set - { - _socketAddress = value; - } - } - - internal Internals.SocketAddress? SocketAddressOriginal - { - get - { - return _socketAddressOriginal; - } - set - { - _socketAddressOriginal = value; - } - } - - internal SocketFlags SocketFlags - { - get - { - return _socketFlags; - } - } - - internal IPPacketInformation IPPacketInformation - { - get - { - return _ipPacketInformation; - } - } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index f5432d56e0009..461122cfd09a6 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -350,18 +350,7 @@ public Task ReceiveFromAsync(ArraySegment buffer, /// An asynchronous task that completes with a containing the number of bytes received and the endpoint of the sending host. public ValueTask ReceiveFromAsync(Memory buffer, SocketFlags socketFlags, EndPoint remoteEndPoint, CancellationToken cancellationToken = default) { - if (remoteEndPoint is null) - { - throw new ArgumentNullException(nameof(remoteEndPoint)); - } - if (!CanTryAddressFamily(remoteEndPoint.AddressFamily)) - { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEndPoint.AddressFamily, _addressFamily), nameof(remoteEndPoint)); - } - if (_rightEndPoint == null) - { - throw new InvalidOperationException(SR.net_sockets_mustbind); - } + ValidateReceiveFromEndpointAndState(remoteEndPoint, nameof(remoteEndPoint)); if (cancellationToken.IsCancellationRequested) { @@ -403,18 +392,7 @@ public Task ReceiveMessageFromAsync(ArraySegment /// An asynchronous task that completes with a containing the number of bytes received and additional information about the sending host. public ValueTask ReceiveMessageFromAsync(Memory buffer, SocketFlags socketFlags, EndPoint remoteEndPoint, CancellationToken cancellationToken = default) { - if (remoteEndPoint is null) - { - throw new ArgumentNullException(nameof(remoteEndPoint)); - } - if (!CanTryAddressFamily(remoteEndPoint.AddressFamily)) - { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEndPoint.AddressFamily, _addressFamily), nameof(remoteEndPoint)); - } - if (_rightEndPoint == null) - { - throw new InvalidOperationException(SR.net_sockets_mustbind); - } + ValidateReceiveFromEndpointAndState(remoteEndPoint, nameof(remoteEndPoint)); if (cancellationToken.IsCancellationRequested) { return ValueTask.FromCanceled(cancellationToken); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index d82934aadc5cd..eb5277e8e01a2 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1540,18 +1540,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla { ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, size); - if (remoteEP == null) - { - throw new ArgumentNullException(nameof(remoteEP)); - } - if (!CanTryAddressFamily(remoteEP.AddressFamily)) - { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); - } - if (_rightEndPoint == null) - { - throw new InvalidOperationException(SR.net_sockets_mustbind); - } + ValidateReceiveFromEndpointAndState(remoteEP, nameof(remoteEP)); SocketPal.CheckDualModeReceiveSupport(this); ValidateBlockingMode(); @@ -1703,19 +1692,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl { ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, size); - if (remoteEP == null) - { - throw new ArgumentNullException(nameof(remoteEP)); - } - if (!CanTryAddressFamily(remoteEP.AddressFamily)) - { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, - remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); - } - if (_rightEndPoint == null) - { - throw new InvalidOperationException(SR.net_sockets_mustbind); - } + ValidateReceiveFromEndpointAndState(remoteEP, nameof(remoteEP)); SocketPal.CheckDualModeReceiveSupport(this); @@ -2324,26 +2301,6 @@ public void EndSendFile(IAsyncResult asyncResult) EndSendFileInternal(asyncResult); } - // Routine Description: - // - // BeginSendTo - Async implementation of SendTo, - // - // This routine may go pending at which time, - // but any case the callback Delegate will be called upon completion - // - // Arguments: - // - // WriteBuffer - Buffer to transmit - // Index - Offset into WriteBuffer to begin sending from - // Size - Size of Buffer to transmit - // Flags - Specific Socket flags to pass to winsock - // remoteEP - EndPoint to transmit To - // Callback - Delegate function that holds callback, called on completion of I/O - // State - State used to track callback, set by caller, not required - // - // Return Value: - // - // IAsyncResult - Async result used to retrieve result public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint remoteEP, AsyncCallback? callback, object? state) { ThrowIfDisposed(); @@ -2353,113 +2310,14 @@ public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags throw new ArgumentNullException(nameof(remoteEP)); } - Internals.SocketAddress socketAddress = Serialize(ref remoteEP); - - // Set up the async result and indicate to flow the context. - OverlappedAsyncResult asyncResult = new OverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Post the send. - DoBeginSendTo(buffer, offset, size, socketFlags, remoteEP, socketAddress, asyncResult); - - // Finish, possibly posting the callback. The callback won't be posted before this point is reached. - asyncResult.FinishPostingAsyncOp(ref Caches.SendClosureCache); - - return asyncResult; - } - - private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint endPointSnapshot, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); - - EndPoint? oldEndPoint = _rightEndPoint; - - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try - { - if (_rightEndPoint == null) - { - _rightEndPoint = endPointSnapshot; - } - - errorCode = SocketPal.SendToAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SendToAsync returns:{errorCode} size:{size} returning AsyncResult:{asyncResult}"); - } - catch (ObjectDisposedException) - { - _rightEndPoint = oldEndPoint; - _localEndPoint = null; - throw; - } - - // Throw an appropriate SocketException if the native call fails synchronously. - if (!CheckErrorAndUpdateStatus(errorCode)) - { - UpdateSendSocketErrorForDisposed(ref errorCode); - // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; - _localEndPoint = null; - - throw new SocketException((int)errorCode); - } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size} returning AsyncResult:{asyncResult}"); + Task t = SendToAsync(buffer.AsMemory(offset, size), socketFlags, remoteEP).AsTask(); + return TaskToApm.Begin(t, callback, state); } - // Routine Description: - // - // EndSendTo - Called by user code after I/O is done or the user wants to wait. - // until Async completion, needed to retrieve error result from call - // - // Arguments: - // - // AsyncResult - the AsyncResult Returned from BeginSend call - // - // Return Value: - // - // int - Number of bytes transferred public int EndSendTo(IAsyncResult asyncResult) { ThrowIfDisposed(); - - // Validate input parameters. - if (asyncResult == null) - { - throw new ArgumentNullException(nameof(asyncResult)); - } - - OverlappedAsyncResult? castedAsyncResult = asyncResult as OverlappedAsyncResult; - if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this) - { - throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); - } - if (castedAsyncResult.EndCalled) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndSendTo")); - } - - int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result(); - castedAsyncResult.EndCalled = true; - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"bytesTransferred:{bytesTransferred}"); - - // Throw an appropriate SocketException if the native call failed asynchronously. - SocketError errorCode = (SocketError)castedAsyncResult.ErrorCode; - if (errorCode != SocketError.Success) - { - UpdateSendSocketErrorForDisposed(ref errorCode); - UpdateStatusAfterSocketErrorAndThrowException(errorCode); - } - else if (SocketsTelemetry.Log.IsEnabled()) - { - SocketsTelemetry.Log.BytesSent(bytesTransferred); - if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent(); - } - - return bytesTransferred; + return TaskToApm.End(asyncResult); } public IAsyncResult BeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state) @@ -2546,97 +2404,17 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, size); - if (remoteEP == null) - { - throw new ArgumentNullException(nameof(remoteEP)); - } - if (!CanTryAddressFamily(remoteEP.AddressFamily)) - { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); - } - if (_rightEndPoint == null) - { - throw new InvalidOperationException(SR.net_sockets_mustbind); - } + ValidateReceiveFromEndpointAndState(remoteEP, nameof(remoteEP)); - SocketPal.CheckDualModeReceiveSupport(this); - - // Set up the result and set it to collect the context. - ReceiveMessageOverlappedAsyncResult asyncResult = new ReceiveMessageOverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Start the ReceiveFrom. - EndPoint oldEndPoint = _rightEndPoint; - - // We don't do a CAS demand here because the contents of remoteEP aren't used by - // WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress - // with the right address family - Internals.SocketAddress socketAddress = Serialize(ref remoteEP); - - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try + Task t = ReceiveMessageFromAsync(buffer.AsMemory(offset, size), socketFlags, remoteEP).AsTask(); + // In case of synchronous completion, ReceiveMessageFromAsync() returns a completed task. + // When this happens, we need to update 'remoteEP' in order to conform to the historical behavior of BeginReceiveMessageFrom(). + if (t.IsCompletedSuccessfully) { - // Save a copy of the original EndPoint in the asyncResult. - asyncResult.SocketAddressOriginal = IPEndPointExtensions.Serialize(remoteEP); - - SetReceivingPacketInformation(); - - if (_rightEndPoint == null) - { - _rightEndPoint = remoteEP; - } - - errorCode = SocketPal.ReceiveMessageFromAsync(this, _handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); - - if (errorCode != SocketError.Success) - { - // WSARecvMsg() will never return WSAEMSGSIZE directly, since a completion is queued in this case. We wouldn't be able - // to handle this easily because of assumptions OverlappedAsyncResult makes about whether there would be a completion - // or not depending on the error code. If WSAEMSGSIZE would have been normally returned, it returns WSA_IO_PENDING instead. - // That same map is implemented here just in case. - if (errorCode == SocketError.MessageSize) - { - Debug.Fail("Returned WSAEMSGSIZE!"); - errorCode = SocketError.IOPending; - } - } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"ReceiveMessageFromAsync returns:{errorCode} size:{size} returning AsyncResult:{asyncResult}"); + EndPoint resultEp = t.Result.RemoteEndPoint; + if (!remoteEP.Equals(resultEp)) remoteEP = resultEp; } - catch (ObjectDisposedException) - { - _rightEndPoint = oldEndPoint; - _localEndPoint = null; - throw; - } - - // Throw an appropriate SocketException if the native call fails synchronously. - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred: 0); - if (!CheckErrorAndUpdateStatus(errorCode)) - { - // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; - _localEndPoint = null; - - throw new SocketException((int)errorCode); - } - - // Capture the context, maybe call the callback, and return. - asyncResult.FinishPostingAsyncOp(ref Caches.ReceiveClosureCache); - - if (asyncResult.CompletedSynchronously && !asyncResult.SocketAddressOriginal.Equals(asyncResult.SocketAddress)) - { - try - { - remoteEP = remoteEP.Create(asyncResult.SocketAddress!); - } - catch - { - } - } - + IAsyncResult asyncResult = TaskToApm.Begin(t, callback, state); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size} returning AsyncResult:{asyncResult}"); return asyncResult; } @@ -2652,190 +2430,35 @@ public int EndReceiveMessageFrom(IAsyncResult asyncResult, ref SocketFlags socke { throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, endPoint.AddressFamily, _addressFamily), nameof(endPoint)); } - if (asyncResult == null) - { - throw new ArgumentNullException(nameof(asyncResult)); - } - - ReceiveMessageOverlappedAsyncResult? castedAsyncResult = asyncResult as ReceiveMessageOverlappedAsyncResult; - if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this) - { - throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); - } - if (castedAsyncResult.EndCalled) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndReceiveMessageFrom")); - } - - Internals.SocketAddress socketAddressOriginal = Serialize(ref endPoint); - - int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result(); - castedAsyncResult.EndCalled = true; - - // Update socket address size. - castedAsyncResult.SocketAddress!.InternalSize = castedAsyncResult.GetSocketAddressSize(); - - if (!socketAddressOriginal.Equals(castedAsyncResult.SocketAddress)) - { - try - { - endPoint = endPoint.Create(castedAsyncResult.SocketAddress); - } - catch - { - } - } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"bytesTransferred:{bytesTransferred}"); - SocketError errorCode = (SocketError)castedAsyncResult.ErrorCode; - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); - // Throw an appropriate SocketException if the native call failed asynchronously. - if (errorCode != SocketError.Success && errorCode != SocketError.MessageSize) + SocketReceiveMessageFromResult result = TaskToApm.End(asyncResult); + if (!endPoint.Equals(result.RemoteEndPoint)) { - UpdateStatusAfterSocketErrorAndThrowException(errorCode); + endPoint = result.RemoteEndPoint; } - else if (SocketsTelemetry.Log.IsEnabled()) - { - SocketsTelemetry.Log.BytesReceived(bytesTransferred); - if (errorCode == SocketError.Success && SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); - } - - socketFlags = castedAsyncResult.SocketFlags; - ipPacketInformation = castedAsyncResult.IPPacketInformation; - - return bytesTransferred; + socketFlags = result.SocketFlags; + ipPacketInformation = result.PacketInformation; + return result.ReceivedBytes; } - // Routine Description: - // - // BeginReceiveFrom - Async implementation of RecvFrom call, - // - // Called when we want to start an async receive. - // We kick off the receive, and if it completes synchronously we'll - // call the callback. Otherwise we'll return an IASyncResult, which - // the caller can use to wait on or retrieve the final status, as needed. - // - // Uses Winsock 2 overlapped I/O. - // - // Arguments: - // - // ReadBuffer - status line that we wish to parse - // Index - Offset into ReadBuffer to begin reading from - // Request - Size of Buffer to recv - // Flags - Additional Flags that may be passed to the underlying winsock call - // remoteEP - EndPoint that are to receive from - // Callback - Delegate function that holds callback, called on completion of I/O - // State - State used to track callback, set by caller, not required - // - // Return Value: - // - // IAsyncResult - Async result used to retrieve result public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP, AsyncCallback? callback, object? state) { ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, size); - if (remoteEP == null) - { - throw new ArgumentNullException(nameof(remoteEP)); - } - if (!CanTryAddressFamily(remoteEP.AddressFamily)) - { - throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); - } - if (_rightEndPoint == null) - { - throw new InvalidOperationException(SR.net_sockets_mustbind); - } - - SocketPal.CheckDualModeReceiveSupport(this); - - // We don't do a CAS demand here because the contents of remoteEP aren't used by - // WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress - // with the right address family - Internals.SocketAddress socketAddress = Serialize(ref remoteEP); - - // Set up the result and set it to collect the context. - var asyncResult = new OriginalAddressOverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Start the ReceiveFrom. - DoBeginReceiveFrom(buffer, offset, size, socketFlags, remoteEP, socketAddress, asyncResult); - - // Capture the context, maybe call the callback, and return. - asyncResult.FinishPostingAsyncOp(ref Caches.ReceiveClosureCache); - - if (asyncResult.CompletedSynchronously && !asyncResult.SocketAddressOriginal!.Equals(asyncResult.SocketAddress)) - { - try - { - remoteEP = remoteEP.Create(asyncResult.SocketAddress!); - } - catch - { - } - } - - return asyncResult; - } - - private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint endPointSnapshot, Internals.SocketAddress socketAddress, OriginalAddressOverlappedAsyncResult asyncResult) - { - EndPoint? oldEndPoint = _rightEndPoint; - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); - - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try - { - // Save a copy of the original EndPoint in the asyncResult. - asyncResult.SocketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); - - if (_rightEndPoint == null) - { - _rightEndPoint = endPointSnapshot; - } - - errorCode = SocketPal.ReceiveFromAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"ReceiveFromAsync returns:{errorCode} size:{size} returning AsyncResult:{asyncResult}"); - } - catch (ObjectDisposedException) - { - _rightEndPoint = oldEndPoint; - _localEndPoint = null; - throw; - } + ValidateReceiveFromEndpointAndState(remoteEP, nameof(remoteEP)); - // Throw an appropriate SocketException if the native call fails synchronously. - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred: 0); - if (!CheckErrorAndUpdateStatus(errorCode)) + Task t = ReceiveFromAsync(buffer.AsMemory(offset, size), socketFlags, remoteEP).AsTask(); + // In case of synchronous completion, ReceiveFromAsync() returns a completed task. + // When this happens, we need to update 'remoteEP' in order to conform to the historical behavior of BeginReceiveFrom(). + if (t.IsCompletedSuccessfully) { - // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; - _localEndPoint = null; - - throw new SocketException((int)errorCode); + EndPoint resultEp = t.Result.RemoteEndPoint; + if (!remoteEP.Equals(resultEp)) remoteEP = resultEp; } - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size} return AsyncResult:{asyncResult}"); + return TaskToApm.Begin(t, callback, state); } - // Routine Description: - // - // EndReceiveFrom - Called when I/O is done or the user wants to wait. If - // the I/O isn't done, we'll wait for it to complete, and then we'll return - // the bytes of I/O done. - // - // Arguments: - // - // AsyncResult - the AsyncResult Returned from BeginReceiveFrom call - // - // Return Value: - // - // int - Number of bytes transferred public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint) { ThrowIfDisposed(); @@ -2849,55 +2472,13 @@ public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint) { throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, endPoint.AddressFamily, _addressFamily), nameof(endPoint)); } - if (asyncResult == null) - { - throw new ArgumentNullException(nameof(asyncResult)); - } - OverlappedAsyncResult? castedAsyncResult = asyncResult as OverlappedAsyncResult; - if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this) + SocketReceiveFromResult result = TaskToApm.End(asyncResult); + if (!endPoint.Equals(result.RemoteEndPoint)) { - throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); - } - if (castedAsyncResult.EndCalled) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndReceiveFrom")); - } - - Internals.SocketAddress socketAddressOriginal = Serialize(ref endPoint); - - int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result(); - castedAsyncResult.EndCalled = true; - - // Update socket address size. - castedAsyncResult.SocketAddress!.InternalSize = castedAsyncResult.GetSocketAddressSize(); - - if (!socketAddressOriginal.Equals(castedAsyncResult.SocketAddress)) - { - try - { - endPoint = endPoint.Create(castedAsyncResult.SocketAddress); - } - catch - { - } + endPoint = result.RemoteEndPoint; } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"bytesTransferred:{bytesTransferred}"); - - // Throw an appropriate SocketException if the native call failed asynchronously. - SocketError errorCode = (SocketError)castedAsyncResult.ErrorCode; - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); - if (errorCode != SocketError.Success) - { - UpdateStatusAfterSocketErrorAndThrowException(errorCode); - } - else if (SocketsTelemetry.Log.IsEnabled()) - { - SocketsTelemetry.Log.BytesReceived(bytesTransferred); - if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); - } - return bytesTransferred; + return result.ReceivedBytes; } // Routine Description: @@ -4185,6 +3766,24 @@ private bool CheckErrorAndUpdateStatus(SocketError errorCode) return false; } + // Called in Receive(Message)From variants to validate 'remoteEndPoint', + // and check whether the socket is bound. + private void ValidateReceiveFromEndpointAndState(EndPoint remoteEndPoint, string remoteEndPointArgumentName) + { + if (remoteEndPoint == null) + { + throw new ArgumentNullException(remoteEndPointArgumentName); + } + if (!CanTryAddressFamily(remoteEndPoint.AddressFamily)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEndPoint.AddressFamily, _addressFamily), remoteEndPointArgumentName); + } + if (_rightEndPoint == null) + { + throw new InvalidOperationException(SR.net_sockets_mustbind); + } + } + // ValidateBlockingMode - called before synchronous calls to validate // the fact that we are in blocking mode (not in non-blocking mode) so the // call will actually be synchronous. diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 8b1f87071194a..6f800064abad7 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1884,54 +1884,6 @@ public static async void SendPacketsAsync( } } - public static SocketError SendToAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) - { - asyncResult.SocketAddress = socketAddress; - - int bytesSent; - int socketAddressLen = socketAddress.Size; - SocketError socketError = handle.AsyncContext.SendToAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, ref socketAddressLen, out bytesSent, asyncResult.CompletionCallback); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesSent, socketAddress.Buffer, socketAddressLen, SocketFlags.None, SocketError.Success); - } - return socketError; - } - - public static SocketError ReceiveFromAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) - { - asyncResult.SocketAddress = socketAddress; - - int socketAddressSize = socketAddress.InternalSize; - int bytesReceived; - SocketFlags receivedFlags; - SocketError socketError = handle.AsyncContext.ReceiveFromAsync(new Memory(buffer, offset, count), socketFlags, socketAddress.Buffer, ref socketAddressSize, out bytesReceived, out receivedFlags, asyncResult.CompletionCallback); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesReceived, socketAddress.Buffer, socketAddressSize, receivedFlags, SocketError.Success); - } - return socketError; - } - - public static SocketError ReceiveMessageFromAsync(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, ReceiveMessageOverlappedAsyncResult asyncResult) - { - asyncResult.SocketAddress = socketAddress; - - bool isIPv4, isIPv6; - Socket.GetIPProtocolInformation(((Socket)asyncResult.AsyncObject!).AddressFamily, socketAddress, out isIPv4, out isIPv6); - - int socketAddressSize = socketAddress.InternalSize; - int bytesReceived; - SocketFlags receivedFlags; - IPPacketInformation ipPacketInformation; - SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(new Memory(buffer, offset, count), null, socketFlags, socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, asyncResult.CompletionCallback); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesReceived, socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation, SocketError.Success); - } - return socketError; - } - public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, int receiveSize, int socketAddressSize, AcceptOverlappedAsyncResult asyncResult) { Debug.Assert(acceptHandle == null, $"Unexpected acceptHandle: {acceptHandle}"); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 09858566bc28e..cf6e12e66dbba 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -1099,107 +1099,6 @@ public static unsafe SocketError SendFileAsync(SafeSocketHandle handle, FileStre } } - public static unsafe SocketError SendToAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) - { - // Set up asyncResult for overlapped WSASendTo. - asyncResult.SetUnmanagedStructures(buffer, offset, count, socketAddress); - try - { - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSASendTo( - handle, - ref asyncResult._singleBuffer, - 1, // There is only ever 1 buffer being sent. - out bytesTransferred, - socketFlags, - asyncResult.GetSocketAddressPtr(), - asyncResult.SocketAddress!.Size, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - - public static unsafe SocketError ReceiveAsync(SafeSocketHandle handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - // Set up asyncResult for overlapped WSASend. - asyncResult.SetUnmanagedStructures(buffers); - try - { - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSARecv( - handle, - asyncResult._wsaBuffers, - asyncResult._wsaBuffers!.Length, - out bytesTransferred, - ref socketFlags, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - - public static unsafe SocketError ReceiveFromAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) - { - // Set up asyncResult for overlapped WSARecvFrom. - asyncResult.SetUnmanagedStructures(buffer, offset, count, socketAddress); - try - { - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSARecvFrom( - handle, - ref asyncResult._singleBuffer, - 1, - out bytesTransferred, - ref socketFlags, - asyncResult.GetSocketAddressPtr(), - asyncResult.GetSocketAddressSizePtr(), - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - - public static unsafe SocketError ReceiveMessageFromAsync(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, ReceiveMessageOverlappedAsyncResult asyncResult) - { - asyncResult.SetUnmanagedStructures(buffer, offset, count, socketAddress, socketFlags); - try - { - int bytesTransfered; - SocketError errorCode = (SocketError)socket.WSARecvMsg( - handle, - Marshal.UnsafeAddrOfPinnedArrayElement(asyncResult._messageBuffer!, 0), - out bytesTransfered, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransfered); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - public static unsafe SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle, int receiveSize, int socketAddressSize, AcceptOverlappedAsyncResult asyncResult) { // The buffer needs to contain the requested data plus room for two sockaddrs and 16 bytes diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs index 470ff065748bf..7cbb830e2f528 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs @@ -263,35 +263,6 @@ public void Send_Buffers_EmptyBuffers_Throws_Argument() AssertExtensions.Throws("buffers", () => GetSocket().Send(new List>(), SocketFlags.None, out errorCode)); } - [Fact] - public void SendTo_NullBuffer_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().SendTo(null, 0, 0, SocketFlags.None, new IPEndPoint(IPAddress.Loopback, 1))); - } - - [Fact] - public void SendTo_NullEndPoint_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().SendTo(s_buffer, 0, 0, SocketFlags.None, null)); - } - - [Fact] - public void SendTo_InvalidOffset_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().SendTo(s_buffer, -1, s_buffer.Length, SocketFlags.None, endpoint)); - Assert.Throws(() => GetSocket().SendTo(s_buffer, s_buffer.Length + 1, s_buffer.Length, SocketFlags.None, endpoint)); - } - - [Fact] - public void SendTo_InvalidSize_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().SendTo(s_buffer, 0, -1, SocketFlags.None, endpoint)); - Assert.Throws(() => GetSocket().SendTo(s_buffer, 0, s_buffer.Length + 1, SocketFlags.None, endpoint)); - Assert.Throws(() => GetSocket().SendTo(s_buffer, s_buffer.Length, 1, SocketFlags.None, endpoint)); - } - [Fact] public void Receive_Buffer_NullBuffer_Throws_ArgumentNull() { @@ -330,114 +301,6 @@ public void Receive_Buffers_EmptyBuffers_Throws_Argument() AssertExtensions.Throws("buffers", () => GetSocket().Receive(new List>(), SocketFlags.None, out errorCode)); } - [Fact] - public void ReceiveFrom_NullBuffer_Throws_ArgumentNull() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().ReceiveFrom(null, 0, 0, SocketFlags.None, ref endpoint)); - } - - [Fact] - public void ReceiveFrom_NullEndPoint_Throws_ArgumentNull() - { - EndPoint endpoint = null; - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint)); - } - - [Fact] - public void ReceiveFrom_AddressFamily_Throws_Argument() - { - EndPoint endpoint = new IPEndPoint(IPAddress.IPv6Loopback, 1); - AssertExtensions.Throws("remoteEP", () => GetSocket(AddressFamily.InterNetwork).ReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint)); - } - - [Fact] - public void ReceiveFrom_InvalidOffset_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, -1, s_buffer.Length, SocketFlags.None, ref endpoint)); - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, s_buffer.Length + 1, s_buffer.Length, SocketFlags.None, ref endpoint)); - } - - [Fact] - public void ReceiveFrom_InvalidSize_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, 0, -1, SocketFlags.None, ref endpoint)); - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, 0, s_buffer.Length + 1, SocketFlags.None, ref endpoint)); - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, s_buffer.Length, 1, SocketFlags.None, ref endpoint)); - } - - [Fact] - public void ReceiveFrom_NotBound_Throws_InvalidOperation() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().ReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint)); - } - - [Fact] - public void ReceiveMessageFrom_NullBuffer_Throws_ArgumentNull() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().ReceiveMessageFrom(null, 0, 0, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void ReceiveMessageFrom_NullEndPoint_Throws_ArgumentNull() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = null; - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, 0, 0, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void ReceiveMessageFrom_AddressFamily_Throws_Argument() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.IPv6Loopback, 1); - IPPacketInformation packetInfo; - - AssertExtensions.Throws("remoteEP", () => GetSocket(AddressFamily.InterNetwork).ReceiveMessageFrom(s_buffer, 0, 0, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void ReceiveMessageFrom_InvalidOffset_Throws_ArgumentOutOfRange() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, -1, s_buffer.Length, ref flags, ref remote, out packetInfo)); - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, s_buffer.Length + 1, s_buffer.Length, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void ReceiveMessageFrom_InvalidSize_Throws_ArgumentOutOfRange() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, 0, -1, ref flags, ref remote, out packetInfo)); - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, 0, s_buffer.Length + 1, ref flags, ref remote, out packetInfo)); - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, s_buffer.Length, 1, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void ReceiveMessageFrom_NotBound_Throws_InvalidOperation() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().ReceiveMessageFrom(s_buffer, 0, 0, ref flags, ref remote, out packetInfo)); - } - [Fact] public void SetSocketOption_Object_ObjectNull_Throws_ArgumentNull() { @@ -612,50 +475,6 @@ public void ReceiveAsync_NullAsyncEventArgs_Throws_ArgumentNull() Assert.Throws(() => GetSocket().ReceiveAsync(null)); } - [Fact] - public void ReceiveFromAsync_NullAsyncEventArgs_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().ReceiveFromAsync(null)); - } - - [Fact] - public void ReceiveFromAsync_NullRemoteEndPoint_Throws_ArgumentException() - { - Assert.Throws("e", () => GetSocket().ReceiveFromAsync(s_eventArgs)); - } - - [Fact] - public void ReceiveFromAsync_AddressFamily_Throws_Argument() - { - var eventArgs = new SocketAsyncEventArgs { - RemoteEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 1) - }; - - AssertExtensions.Throws("e", () => GetSocket(AddressFamily.InterNetwork).ReceiveFromAsync(eventArgs)); - } - - [Fact] - public void ReceiveMessageFromAsync_NullAsyncEventArgs_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().ReceiveMessageFromAsync(null)); - } - - [Fact] - public void ReceiveMessageFromAsync_NullRemoteEndPoint_Throws_ArgumentException() - { - Assert.Throws("e", () => GetSocket().ReceiveMessageFromAsync(s_eventArgs)); - } - - [Fact] - public void ReceiveMessageFromAsync_AddressFamily_Throws_Argument() - { - var eventArgs = new SocketAsyncEventArgs { - RemoteEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 1) - }; - - AssertExtensions.Throws("e", () => GetSocket(AddressFamily.InterNetwork).ReceiveMessageFromAsync(eventArgs)); - } - [Fact] public void SendAsync_NullAsyncEventArgs_Throws_ArgumentNull() { @@ -684,18 +503,6 @@ public void SendPacketsAsync_NotConnected_Throws_NotSupported() Assert.Throws(() => GetSocket().SendPacketsAsync(eventArgs)); } - [Fact] - public void SendToAsync_NullAsyncEventArgs_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().SendToAsync(null)); - } - - [Fact] - public void SendToAsync_NullRemoteEndPoint_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().SendToAsync(s_eventArgs)); - } - [Theory] [InlineData(true)] [InlineData(false)] @@ -1212,58 +1019,6 @@ public void EndSend_UnrelatedAsyncResult_Throws_Argument() AssertExtensions.Throws("asyncResult", () => GetSocket().EndSend(Task.CompletedTask)); } - [Fact] - public void BeginSendTo_NullBuffer_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().BeginSendTo(null, 0, 0, SocketFlags.None, new IPEndPoint(IPAddress.Loopback, 1), TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().SendToAsync(new ArraySegment(null, 0, 0), SocketFlags.None, new IPEndPoint(IPAddress.Loopback, 1)); }); - } - - [Fact] - public void BeginSendTo_NullEndPoint_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().BeginSendTo(s_buffer, 0, 0, SocketFlags.None, null, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().SendToAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, null); }); - } - - [Fact] - public void BeginSendTo_InvalidOffset_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginSendTo(s_buffer, -1, s_buffer.Length, SocketFlags.None, endpoint, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginSendTo(s_buffer, s_buffer.Length + 1, s_buffer.Length, SocketFlags.None, endpoint, TheAsyncCallback, null)); - - Assert.Throws(() => { GetSocket().SendToAsync(new ArraySegment(s_buffer, -1, s_buffer.Length), SocketFlags.None, endpoint); }); - Assert.ThrowsAny(() => { GetSocket().SendToAsync(new ArraySegment(s_buffer, s_buffer.Length + 1, s_buffer.Length), SocketFlags.None, endpoint); }); - } - - [Fact] - public void BeginSendTo_InvalidSize_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginSendTo(s_buffer, 0, -1, SocketFlags.None, endpoint, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginSendTo(s_buffer, 0, s_buffer.Length + 1, SocketFlags.None, endpoint, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginSendTo(s_buffer, s_buffer.Length, 1, SocketFlags.None, endpoint, TheAsyncCallback, null)); - - Assert.Throws(() => { GetSocket().SendToAsync(new ArraySegment(s_buffer, 0, -1), SocketFlags.None, endpoint); }); - Assert.ThrowsAny(() => { GetSocket().SendToAsync(new ArraySegment(s_buffer, 0, s_buffer.Length + 1), SocketFlags.None, endpoint); }); - Assert.ThrowsAny(() => { GetSocket().SendToAsync(new ArraySegment(s_buffer, s_buffer.Length, 1), SocketFlags.None, endpoint); }); - } - - [Fact] - public void EndSendTo_NullAsyncResult_Throws_ArgumentNull() - { - Assert.Throws(() => GetSocket().EndSendTo(null)); - } - - [Fact] - public void EndSendto_UnrelatedAsyncResult_Throws_Argument() - { - AssertExtensions.Throws("asyncResult", () => GetSocket().EndSendTo(Task.CompletedTask)); - } - [Fact] public void BeginReceive_Buffer_NullBuffer_Throws_ArgumentNull() { @@ -1313,163 +1068,6 @@ public void EndReceive_NullAsyncResult_Throws_ArgumentNull() Assert.Throws(() => GetSocket().EndReceive(null)); } - [Fact] - public void BeginReceiveFrom_NullBuffer_Throws_ArgumentNull() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().BeginReceiveFrom(null, 0, 0, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ReceiveFromAsync(new ArraySegment(null, 0, 0), SocketFlags.None, endpoint); }); - } - - [Fact] - public void BeginReceiveFrom_NullEndPoint_Throws_ArgumentNull() - { - EndPoint endpoint = null; - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, endpoint); }); - } - - [Fact] - public void BeginReceiveFrom_AddressFamily_Throws_Argument() - { - EndPoint endpoint = new IPEndPoint(IPAddress.IPv6Loopback, 1); - AssertExtensions.Throws("remoteEP", () => GetSocket(AddressFamily.InterNetwork).BeginReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - AssertExtensions.Throws("remoteEndPoint", () => { GetSocket(AddressFamily.InterNetwork).ReceiveFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, endpoint); }); - } - - [Fact] - public void BeginReceiveFrom_InvalidOffset_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, -1, s_buffer.Length, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, s_buffer.Length + 1, s_buffer.Length, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - - Assert.Throws(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, -1, s_buffer.Length), SocketFlags.None, endpoint); }); - Assert.ThrowsAny(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, s_buffer.Length + 1, s_buffer.Length), SocketFlags.None, endpoint); }); - } - - [Fact] - public void BeginReceiveFrom_InvalidSize_Throws_ArgumentOutOfRange() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, 0, -1, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, 0, s_buffer.Length + 1, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, s_buffer.Length, 1, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - - Assert.Throws(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, 0, -1), SocketFlags.None, endpoint); }); - Assert.ThrowsAny(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, 0, s_buffer.Length + 1), SocketFlags.None, endpoint); }); - Assert.ThrowsAny(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, s_buffer.Length, 1), SocketFlags.None, endpoint); }); - } - - [Fact] - public void BeginReceiveFrom_NotBound_Throws_InvalidOperation() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().BeginReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ReceiveFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, endpoint); }); - } - - [Fact] - public void EndReceiveFrom_NullAsyncResult_Throws_ArgumentNull() - { - EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); - Assert.Throws(() => GetSocket().EndReceiveFrom(null, ref endpoint)); - } - - [Fact] - public void BeginReceiveMessageFrom_NullBuffer_Throws_ArgumentNull() - { - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(null, 0, 0, SocketFlags.None, ref remote, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(null, 0, 0), SocketFlags.None, remote); }); - } - - [Fact] - public void BeginReceiveMessageFrom_NullEndPoint_Throws_ArgumentNull() - { - EndPoint remote = null; - - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, 0, 0, SocketFlags.None, ref remote, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, remote); }); - } - - [Fact] - public void BeginReceiveMessageFrom_AddressFamily_Throws_Argument() - { - EndPoint remote = new IPEndPoint(IPAddress.IPv6Loopback, 1); - - AssertExtensions.Throws("remoteEP", () => GetSocket(AddressFamily.InterNetwork).BeginReceiveMessageFrom(s_buffer, 0, 0, SocketFlags.None, ref remote, TheAsyncCallback, null)); - AssertExtensions.Throws("remoteEndPoint", () => { GetSocket(AddressFamily.InterNetwork).ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, remote); }); - } - - [Fact] - public void BeginReceiveMessageFrom_InvalidOffset_Throws_ArgumentOutOfRange() - { - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, -1, s_buffer.Length, SocketFlags.None, ref remote, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, s_buffer.Length + 1, s_buffer.Length, SocketFlags.None, ref remote, TheAsyncCallback, null)); - - Assert.Throws(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, -1, s_buffer.Length), SocketFlags.None, remote); }); - Assert.ThrowsAny(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, s_buffer.Length + 1, s_buffer.Length), SocketFlags.None, remote); }); - } - - [Fact] - public void BeginReceiveMessageFrom_InvalidSize_Throws_ArgumentOutOfRange() - { - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, 0, -1, SocketFlags.None, ref remote, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, 0, s_buffer.Length + 1, SocketFlags.None, ref remote, TheAsyncCallback, null)); - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, s_buffer.Length, 1, SocketFlags.None, ref remote, TheAsyncCallback, null)); - - Assert.Throws(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, -1), SocketFlags.None, remote); }); - Assert.ThrowsAny(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, s_buffer.Length + 1), SocketFlags.None, remote); }); - Assert.ThrowsAny(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, s_buffer.Length, 1), SocketFlags.None, remote); }); - } - - [Fact] - public void BeginReceiveMessageFrom_NotBound_Throws_InvalidOperation() - { - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - - Assert.Throws(() => GetSocket().BeginReceiveMessageFrom(s_buffer, 0, 0, SocketFlags.None, ref remote, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, remote); }); - } - - [Fact] - public void EndReceiveMessageFrom_NullEndPoint_Throws_ArgumentNull() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = null; - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().EndReceiveMessageFrom(null, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void EndReceiveMessageFrom_AddressFamily_Throws_Argument() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.IPv6Loopback, 1); - IPPacketInformation packetInfo; - - AssertExtensions.Throws("endPoint", () => GetSocket(AddressFamily.InterNetwork).EndReceiveMessageFrom(null, ref flags, ref remote, out packetInfo)); - } - - [Fact] - public void EndReceiveMessageFrom_NullAsyncResult_Throws_ArgumentNull() - { - SocketFlags flags = SocketFlags.None; - EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1); - IPPacketInformation packetInfo; - - Assert.Throws(() => GetSocket().EndReceiveMessageFrom(null, ref flags, ref remote, out packetInfo)); - } - [Fact] public void CancelConnectAsync_NullEventArgs_Throws_ArgumentNull() { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs index e957790e8c590..d7d037abba193 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DualModeSocketTest.cs @@ -1121,7 +1121,11 @@ public void Socket_BeginSendToV4IPEndPointToV4Host_Throws() Assert.Throws(() => { - socket.BeginSendTo(new byte[1], 0, 1, SocketFlags.None, new IPEndPoint(IPAddress.Loopback, UnusedPort), null, null); + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47905")] + // TODO: When fixing the issue above, revert this test to check that the exception is being thrown in BeginSendTo + // without the need to call EndSendTo. + IAsyncResult result = socket.BeginSendTo(new byte[1], 0, 1, SocketFlags.None, new IPEndPoint(IPAddress.Loopback, UnusedPort), null, null); + socket.EndSendTo(result); }); } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index e686dd48f79d7..524b81f67d2d5 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -11,6 +12,8 @@ namespace System.Net.Sockets.Tests { public abstract class ReceiveFrom : SocketTestHelperBase where T : SocketHelperBase, new() { + protected static Socket CreateSocket(AddressFamily addressFamily = AddressFamily.InterNetwork) => new Socket(addressFamily, SocketType.Dgram, ProtocolType.Udp); + protected static IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily = AddressFamily.InterNetwork) => addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234) : new IPEndPoint(IPAddress.Parse("1:2:3::4"), 1234); @@ -22,10 +25,11 @@ protected ReceiveFrom(ITestOutputHelper output) : base(output) { } [InlineData(1, -1, 0)] // offset low [InlineData(1, 2, 0)] // offset high [InlineData(1, 0, -1)] // count low - [InlineData(1, 1, 2)] // count high - public async Task OutOfRange_Throws(int length, int offset, int count) + [InlineData(1, 0, 2)] // count high + [InlineData(1, 1, 1)] // count high + public async Task OutOfRange_Throws_ArgumentOutOfRangeException(int length, int offset, int count) { - using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + using Socket socket = CreateSocket(); ArraySegment buffer = new FakeArraySegment { @@ -34,24 +38,48 @@ public async Task OutOfRange_Throws(int length, int offset, int count) Offset = offset }.ToActual(); - await Assert.ThrowsAnyAsync(() => ReceiveFromAsync(socket, buffer, GetGetDummyTestEndpoint())); + await AssertThrowsSynchronously(() => ReceiveFromAsync(socket, buffer, GetGetDummyTestEndpoint())); } [Fact] - public async Task NullBuffer_Throws() + public async Task NullBuffer_Throws_ArgumentNullException() { if (!ValidatesArrayArguments) return; - using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + using Socket socket = CreateSocket(); + await AssertThrowsSynchronously(() => ReceiveFromAsync(socket, null, GetGetDummyTestEndpoint())); + } - await Assert.ThrowsAsync(() => ReceiveFromAsync(socket, null, GetGetDummyTestEndpoint())); + [Fact] + public async Task NullEndpoint_Throws_ArgumentException() + { + using Socket socket = CreateSocket(); + if (UsesEap) + { + await AssertThrowsSynchronously(() => ReceiveFromAsync(socket, new byte[1], null)); + } + else + { + await AssertThrowsSynchronously(() => ReceiveFromAsync(socket, new byte[1], null)); + } } [Fact] - public async Task NullEndpoint_Throws() + public async Task AddressFamilyDoesNotMatch_Throws_ArgumentException() { - using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + using var ipv4Socket = CreateSocket(); + EndPoint ipV6Endpoint = GetGetDummyTestEndpoint(AddressFamily.InterNetworkV6); + await AssertThrowsSynchronously(() => ReceiveFromAsync(ipv4Socket, new byte[1], ipV6Endpoint)); + } - await Assert.ThrowsAnyAsync(() => ReceiveFromAsync(socket, new byte[1], null)); + [Fact] + public async Task NotBound_Throws_InvalidOperationException() + { + // ReceiveFromAsync(saea) does not throw. + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47714")] + if (UsesEap) return; + + using Socket socket = CreateSocket(); + await AssertThrowsSynchronously(() => ReceiveFromAsync(socket, new byte[1], GetGetDummyTestEndpoint())); } [Theory] @@ -223,6 +251,70 @@ public ReceiveFrom_SyncForceNonBlocking(ITestOutputHelper output) : base(output) public sealed class ReceiveFrom_Apm : ReceiveFrom { public ReceiveFrom_Apm(ITestOutputHelper output) : base(output) { } + + [Fact] + public void EndReceiveFrom_NullAsyncResult_Throws_ArgumentNullException() + { + EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); + using Socket socket = CreateSocket(); + Assert.Throws(() => socket.EndReceiveFrom(null, ref endpoint)); + } + + [Fact] + public void EndReceiveFrom_UnrelatedAsyncResult_Throws_ArgumentException() + { + EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); + using Socket socket = CreateSocket(); + + Assert.Throws(() => socket.EndReceiveFrom(Task.CompletedTask, ref endpoint)); + } + + [Fact] + public void EndReceiveFrom_NullEndPoint_Throws_ArgumentNullException() + { + EndPoint validEndPoint = new IPEndPoint(IPAddress.Loopback, 1); + EndPoint invalidEndPoint = null; + using Socket socket = CreateSocket(); + socket.BindToAnonymousPort(IPAddress.Loopback); + IAsyncResult iar = socket.BeginReceiveFrom(new byte[1], 0, 1, SocketFlags.None, ref validEndPoint, null, null); + Assert.Throws("endPoint", () => socket.EndReceiveFrom(iar, ref invalidEndPoint)); + } + + [Fact] + public void EndReceiveFrom_AddressFamilyDoesNotMatch_Throws_ArgumentException() + { + EndPoint validEndPoint = new IPEndPoint(IPAddress.Loopback, 1); + EndPoint invalidEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 1); + using Socket socket = CreateSocket(); + socket.BindToAnonymousPort(IPAddress.Loopback); + IAsyncResult iar = socket.BeginReceiveFrom(new byte[1], 0, 1, SocketFlags.None, ref validEndPoint, null, null); + Assert.Throws("endPoint", () => socket.EndReceiveFrom(iar, ref invalidEndPoint)); + } + + [Fact] + public void BeginReceiveFrom_RemoteEpIsReturnedWhenCompletedSynchronously() + { + EndPoint anyEp = new IPEndPoint(IPAddress.Any, 0); + EndPoint remoteEp = anyEp; + using Socket receiver = CreateSocket(); + receiver.BindToAnonymousPort(IPAddress.Loopback); + using Socket sender = CreateSocket(); + sender.BindToAnonymousPort(IPAddress.Loopback); + + sender.SendTo(new byte[1], receiver.LocalEndPoint); + + IAsyncResult iar = receiver.BeginReceiveFrom(new byte[1], 0, 1, SocketFlags.None, ref remoteEp, null, null); + if (iar.CompletedSynchronously) + { + _output.WriteLine("Completed synchronously, updated endpoint."); + Assert.Equal(sender.LocalEndPoint, remoteEp); + } + else + { + _output.WriteLine("Completed asynchronously, did not update endPoint"); + Assert.Equal(anyEp, remoteEp); + } + } } public sealed class ReceiveFrom_Task : ReceiveFrom @@ -238,7 +330,7 @@ public ReceiveFrom_CancellableTask(ITestOutputHelper output) : base(output) { } [MemberData(nameof(LoopbacksAndBuffers))] public async Task WhenCanceled_Throws(IPAddress loopback, bool precanceled) { - using var socket = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket socket = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); using var dummy = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); socket.BindToAnonymousPort(loopback); dummy.BindToAnonymousPort(loopback); @@ -258,6 +350,13 @@ public async Task WhenCanceled_Throws(IPAddress loopback, bool precanceled) public sealed class ReceiveFrom_Eap : ReceiveFrom { public ReceiveFrom_Eap(ITestOutputHelper output) : base(output) { } + + [Fact] + public void ReceiveFromAsync_NullAsyncEventArgs_Throws_ArgumentNullException() + { + using Socket socket = CreateSocket(); + Assert.Throws(() => socket.ReceiveFromAsync(null)); + } } public sealed class ReceiveFrom_SpanSync : ReceiveFrom diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs index a9443803773d4..0384715f16e75 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs @@ -12,6 +12,8 @@ namespace System.Net.Sockets.Tests { public abstract class ReceiveMessageFrom : SocketTestHelperBase where T : SocketHelperBase, new() { + protected static Socket CreateSocket(AddressFamily addressFamily = AddressFamily.InterNetwork) => new Socket(addressFamily, SocketType.Dgram, ProtocolType.Udp); + protected static IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily = AddressFamily.InterNetwork) => addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234) : new IPEndPoint(IPAddress.Parse("1:2:3::4"), 1234); @@ -19,6 +21,67 @@ protected static IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily protected ReceiveMessageFrom(ITestOutputHelper output) : base(output) { } + [Theory] + [InlineData(1, -1, 0)] // offset low + [InlineData(1, 2, 0)] // offset high + [InlineData(1, 0, -1)] // count low + [InlineData(1, 0, 2)] // count high + [InlineData(1, 1, 1)] // count high + public async Task OutOfRange_Throws_ArgumentOutOfRangeException(int length, int offset, int count) + { + using Socket socket = CreateSocket(); + + ArraySegment buffer = new FakeArraySegment + { + Array = new byte[length], + Count = count, + Offset = offset + }.ToActual(); + + await AssertThrowsSynchronously(() => ReceiveMessageFromAsync(socket, buffer, GetGetDummyTestEndpoint())); + } + + [Fact] + public async Task NullBuffer_Throws_ArgumentNullException() + { + if (!ValidatesArrayArguments) return; + using Socket socket = CreateSocket(); + await AssertThrowsSynchronously(() => ReceiveMessageFromAsync(socket, null, GetGetDummyTestEndpoint())); + } + + [Fact] + public async Task NullEndpoint_Throws_ArgumentException() + { + using Socket socket = CreateSocket(); + if (UsesEap) + { + await AssertThrowsSynchronously(() => ReceiveMessageFromAsync(socket, new byte[1], null)); + } + else + { + await AssertThrowsSynchronously(() => ReceiveMessageFromAsync(socket, new byte[1], null)); + } + } + + [Fact] + public async Task AddressFamilyDoesNotMatch_Throws_ArgumentException() + { + using var ipv4Socket = CreateSocket(); + EndPoint ipV6Endpoint = GetGetDummyTestEndpoint(AddressFamily.InterNetworkV6); + await AssertThrowsSynchronously(() => ReceiveMessageFromAsync(ipv4Socket, new byte[1], ipV6Endpoint)); + } + + [Fact] + public async Task NotBound_Throws_InvalidOperationException() + { + // ReceiveFromAsync(saea) fails on a Debug.Assert(): + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47714")] + if (UsesEap) return; + + using Socket socket = CreateSocket(); + await AssertThrowsSynchronously(() => ReceiveMessageFromAsync(socket, new byte[1], GetGetDummyTestEndpoint())); + } + [PlatformSpecific(TestPlatforms.AnyUnix)] [Theory] [InlineData(false)] @@ -192,6 +255,77 @@ public ReceiveMessageFrom_SyncForceNonBlocking(ITestOutputHelper output) : base( public sealed class ReceiveMessageFrom_Apm : ReceiveMessageFrom { public ReceiveMessageFrom_Apm(ITestOutputHelper output) : base(output) { } + + [Fact] + public void EndReceiveMessageFrom_NullAsyncResult_Throws_ArgumentNullException() + { + SocketFlags socketFlags = SocketFlags.None; + EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); + using Socket socket = CreateSocket(); + + Assert.Throws(() => socket.EndReceiveMessageFrom(null, ref socketFlags, ref endpoint, out _)); + } + + [Fact] + public void EndReceiveMessageFrom_UnrelatedAsyncResult_Throws_ArgumentException() + { + SocketFlags socketFlags = SocketFlags.None; + EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); + using Socket socket = CreateSocket(); + + Assert.Throws(() => socket.EndReceiveMessageFrom(Task.CompletedTask, ref socketFlags, ref endpoint, out _)); + } + + [Fact] + public void EndReceiveMessageFrom_NullEndPoint_Throws_ArgumentNullException() + { + SocketFlags socketFlags = SocketFlags.None; + EndPoint validEndPoint = new IPEndPoint(IPAddress.Loopback, 1); + EndPoint invalidEndPoint = null; + using Socket socket = CreateSocket(); + socket.BindToAnonymousPort(IPAddress.Loopback); + IAsyncResult iar = socket.BeginReceiveMessageFrom(new byte[1], 0, 1, SocketFlags.None, ref validEndPoint, null, null); + + Assert.Throws("endPoint", () => socket.EndReceiveMessageFrom(iar, ref socketFlags, ref invalidEndPoint, out _)); + } + + [Fact] + public void EndReceiveMessageFrom_AddressFamilyDoesNotMatch_Throws_ArgumentException() + { + SocketFlags socketFlags = SocketFlags.None; + EndPoint validEndPoint = new IPEndPoint(IPAddress.Loopback, 1); + EndPoint invalidEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 1); + using Socket socket = CreateSocket(); + socket.BindToAnonymousPort(IPAddress.Loopback); + IAsyncResult iar = socket.BeginReceiveMessageFrom(new byte[1], 0, 1, SocketFlags.None, ref validEndPoint, null, null); + + Assert.Throws("endPoint", () => socket.EndReceiveMessageFrom(iar, ref socketFlags, ref invalidEndPoint, out _)); + } + + [Fact] + public void BeginReceiveMessageFrom_RemoteEpIsReturnedWhenCompletedSynchronously() + { + EndPoint anyEp = new IPEndPoint(IPAddress.Any, 0); + EndPoint remoteEp = anyEp; + using Socket receiver = CreateSocket(); + receiver.BindToAnonymousPort(IPAddress.Loopback); + using Socket sender = CreateSocket(); + sender.BindToAnonymousPort(IPAddress.Loopback); + + sender.SendTo(new byte[1], receiver.LocalEndPoint); + + IAsyncResult iar = receiver.BeginReceiveMessageFrom(new byte[1], 0, 1, SocketFlags.None, ref remoteEp, null, null); + if (iar.CompletedSynchronously) + { + _output.WriteLine("Completed synchronously, updated endpoint."); + Assert.Equal(sender.LocalEndPoint, remoteEp); + } + else + { + _output.WriteLine("Completed asynchronously, did not update endPoint"); + Assert.Equal(anyEp, remoteEp); + } + } } public sealed class ReceiveMessageFrom_Task : ReceiveMessageFrom @@ -228,6 +362,13 @@ public sealed class ReceiveMessageFrom_Eap : ReceiveMessageFrom { public ReceiveMessageFrom_Eap(ITestOutputHelper output) : base(output) { } + [Fact] + public void ReceiveFromAsync_NullAsyncEventArgs_Throws_ArgumentNullException() + { + using Socket socket = CreateSocket(); + Assert.Throws(() => socket.ReceiveMessageFromAsync(null)); + } + [Theory] [InlineData(false, 0)] [InlineData(false, 1)] diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs index 1159d40f2e4a3..70205f8033cc4 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs @@ -13,7 +13,10 @@ namespace System.Net.Sockets.Tests { public abstract class SendTo : SocketTestHelperBase where T : SocketHelperBase, new() { - protected static readonly IPEndPoint ValidUdpRemoteEndpoint = new IPEndPoint(IPAddress.Parse("10.20.30.40"), 1234); + protected static Socket CreateSocket(AddressFamily addressFamily = AddressFamily.InterNetwork) => new Socket(addressFamily, SocketType.Dgram, ProtocolType.Udp); + + protected static IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily = AddressFamily.InterNetwork) => + addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234) : new IPEndPoint(IPAddress.Parse("1:2:3::4"), 1234); protected SendTo(ITestOutputHelper output) : base(output) { @@ -23,34 +26,41 @@ protected SendTo(ITestOutputHelper output) : base(output) [InlineData(1, -1, 0)] // offset low [InlineData(1, 2, 0)] // offset high [InlineData(1, 0, -1)] // count low - [InlineData(1, 1, 2)] // count high - public async Task OutOfRange_Throws(int length, int offset, int count) + [InlineData(1, 0, 2)] // count high + [InlineData(1, 1, 1)] // count high + public async Task OutOfRange_Throws_ArgumentOutOfRangeException(int length, int offset, int count) { - using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + using var socket = CreateSocket(); ArraySegment buffer = new FakeArraySegment { Array = new byte[length], Count = count, Offset = offset }.ToActual(); - await Assert.ThrowsAnyAsync(() => SendToAsync(socket, buffer, ValidUdpRemoteEndpoint)); + await AssertThrowsSynchronously(() => SendToAsync(socket, buffer, GetGetDummyTestEndpoint())); } [Fact] - public async Task NullBuffer_Throws() + public async Task NullBuffer_Throws_ArgumentNullException() { if (!ValidatesArrayArguments) return; - using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + using var socket = CreateSocket(); - await Assert.ThrowsAsync(() => SendToAsync(socket, null, ValidUdpRemoteEndpoint)); + await AssertThrowsSynchronously(() => SendToAsync(socket, null, GetGetDummyTestEndpoint())); } [Fact] - public async Task NullEndpoint_Throws() + public async Task NullEndpoint_Throws_ArgumentException() { - using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); - - await Assert.ThrowsAnyAsync(() => SendToAsync(socket, new byte[1], null)); + using Socket socket = CreateSocket(); + if (UsesEap) + { + await AssertThrowsSynchronously(() => SendToAsync(socket, new byte[1], null)); + } + else + { + await AssertThrowsSynchronously(() => SendToAsync(socket, new byte[1], null)); + } } [Fact] @@ -59,7 +69,7 @@ public async Task Datagram_UDP_ShouldImplicitlyBindLocalEndpoint() using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); byte[] buffer = new byte[32]; - Task sendTask = SendToAsync(socket, new ArraySegment(buffer), ValidUdpRemoteEndpoint); + Task sendTask = SendToAsync(socket, new ArraySegment(buffer), GetGetDummyTestEndpoint()); // Asynchronous calls shall alter the property immediately: if (!UsesSync) @@ -91,7 +101,7 @@ public async Task Disposed_Throws() using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); socket.Dispose(); - await Assert.ThrowsAsync(() => SendToAsync(socket, new byte[1], ValidUdpRemoteEndpoint)); + await Assert.ThrowsAsync(() => SendToAsync(socket, new byte[1], GetGetDummyTestEndpoint())); } } @@ -118,11 +128,35 @@ public SendTo_SyncForceNonBlocking(ITestOutputHelper output) : base(output) {} public sealed class SendTo_Apm : SendTo { public SendTo_Apm(ITestOutputHelper output) : base(output) {} + + [Fact] + public void EndSendTo_NullAsyncResult_Throws_ArgumentNullException() + { + EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); + using Socket socket = CreateSocket(); + Assert.Throws(() => socket.EndSendTo(null)); + } + + [Fact] + public void EndSendTo_UnrelatedAsyncResult_Throws_ArgumentException() + { + EndPoint endpoint = new IPEndPoint(IPAddress.Loopback, 1); + using Socket socket = CreateSocket(); + + Assert.Throws(() => socket.EndSendTo(Task.CompletedTask)); + } } public sealed class SendTo_Eap : SendTo { public SendTo_Eap(ITestOutputHelper output) : base(output) {} + + [Fact] + public void SendToAsync_NullAsyncEventArgs_Throws_ArgumentNullException() + { + using Socket socket = CreateSocket(); + Assert.Throws(() => socket.SendToAsync(null)); + } } public sealed class SendTo_Task : SendTo @@ -142,7 +176,7 @@ public async Task PreCanceled_Throws() cts.Cancel(); OperationCanceledException ex = await Assert.ThrowsAnyAsync( - () => sender.SendToAsync(new byte[1], SocketFlags.None, ValidUdpRemoteEndpoint, cts.Token).AsTask()); + () => sender.SendToAsync(new byte[1], SocketFlags.None, GetGetDummyTestEndpoint(), cts.Token).AsTask()); Assert.Equal(cts.Token, ex.CancellationToken); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index 615cca73aca0f..bdaf68ffb497e 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -39,6 +39,7 @@ public abstract Task ReceiveMessageFromAsync( public virtual bool ValidatesArrayArguments => true; public virtual bool UsesSync => false; public virtual bool UsesApm => false; + public virtual bool UsesEap => false; public virtual bool DisposeDuringOperationResultsInDisposedException => false; public virtual bool ConnectAfterDisconnectResultsInInvalidOperationException => false; public virtual bool SupportsMultiConnect => true; @@ -275,6 +276,7 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo public sealed class SocketHelperEap : SocketHelperBase { + public override bool UsesEap => true; public override bool ValidatesArrayArguments => false; public override bool SupportsAcceptReceive => true; @@ -423,6 +425,7 @@ public Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffe public bool ValidatesArrayArguments => _socketHelper.ValidatesArrayArguments; public bool UsesSync => _socketHelper.UsesSync; public bool UsesApm => _socketHelper.UsesApm; + public bool UsesEap => _socketHelper.UsesEap; public bool DisposeDuringOperationResultsInDisposedException => _socketHelper.DisposeDuringOperationResultsInDisposedException; public bool ConnectAfterDisconnectResultsInInvalidOperationException => _socketHelper.ConnectAfterDisconnectResultsInInvalidOperationException; public bool SupportsMultiConnect => _socketHelper.SupportsMultiConnect; @@ -431,6 +434,22 @@ public Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffe public bool SupportsSendFileSlicing => _socketHelper.SupportsSendFileSlicing; public void Listen(Socket s, int backlog) => _socketHelper.Listen(s, backlog); public void ConfigureNonBlocking(Socket s) => _socketHelper.ConfigureNonBlocking(s); + + // A helper method to observe exceptions on sync paths of async variants. + // In that case, exceptions should be seen without awaiting completion. + // Synchronous variants are started on a separate thread using Task.Run(), therefore we should await the task. + protected async Task AssertThrowsSynchronously(Func testCode) + where TException : Exception + { + if (UsesSync) + { + return await Assert.ThrowsAsync(testCode); + } + else + { + return Assert.Throws(() => { _ = testCode(); }); + } + } } public class SocketHelperSpanSync : SocketHelperArraySync