diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.Unix.cs index 0993a0088b4041..5231124073f595 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.Unix.cs @@ -130,6 +130,8 @@ internal bool IsNonBlocking // If transitioning from non-blocking to blocking, we keep the native socket in non-blocking mode, and emulate // blocking operations within SocketAsyncContext on top of epoll/kqueue. // This avoids problems with switching to native blocking while there are pending operations. + // Note: After ConnectAsync completes, we may restore the native socket to blocking mode + // to optimize subsequent synchronous operations (see RestoreBlocking/SetHandleBlocking). if (value) { AsyncContext.SetHandleNonBlocking(); @@ -139,6 +141,20 @@ internal bool IsNonBlocking internal bool IsUnderlyingHandleBlocking => !AsyncContext.IsHandleNonBlocking; + /// + /// Restores the underlying socket to blocking mode after ConnectAsync completes. + /// Only restores blocking if the user hasn't explicitly set Blocking = false (i.e., IsNonBlocking is false). + /// This is only safe to call when the socket is guaranteed by construction to not be used concurrently + /// with any other operation, such as at the completion of ConnectAsync. + /// + internal void RestoreBlocking() + { + if (!IsNonBlocking && !IsClosed) + { + AsyncContext.SetHandleBlocking(); + } + } + internal int ReceiveTimeout { get diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 4e2e117984084c..770d46c8967b18 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -678,6 +678,8 @@ public override void InvokeCallback(bool allowPooling) if (buffer.Length == 0) { + AssociatedContext._socket.RestoreBlocking(); + // Invoke callback only when we are completely done. // In case data were provided for Connect we may or may not send them all. // If we did not we will need follow-up with Send operation @@ -1350,8 +1352,9 @@ public void SetHandleNonBlocking() // // Our sockets may start as blocking, and later transition to non-blocking, either because the user // explicitly requested non-blocking mode, or because we need non-blocking mode to support async - // operations. We never transition back to blocking mode, to avoid problems synchronizing that - // transition with the async infrastructure. + // operations. After a successful ConnectAsync, we may transition back to blocking mode to optimize + // subsequent synchronous operations (see SetHandleBlocking). The socket will be set back to + // non-blocking when another async operation is performed. // // Note that there's no synchronization here, so we may set the non-blocking option multiple times // in a race. This should be fine. @@ -1369,6 +1372,23 @@ public void SetHandleNonBlocking() public bool IsHandleNonBlocking => _isHandleNonBlocking; + public void SetHandleBlocking() + { + if (OperatingSystem.IsWasi()) + { + // WASI sockets are always non-blocking + return; + } + + if (_isHandleNonBlocking) + { + if (Interop.Sys.Fcntl.SetIsNonBlocking(_socket, 0) == 0) + { + _isHandleNonBlocking = false; + } + } + } + private void PerformSyncOperation(ref OperationQueue queue, TOperation operation, int timeout, int observedSequenceNumber) where TOperation : AsyncOperation { @@ -1563,6 +1583,12 @@ public SocketError ConnectAsync(Memory socketAddress, Action.Empty, ref sentBytes, callback!, default); } + + // Only restore blocking when there's no follow-up async send. + if (buffer.Length == 0) + { + _socket.RestoreBlocking(); + } return errorCode; } @@ -1580,6 +1606,11 @@ public SocketError ConnectAsync(Memory socketAddress, Action(async () => + await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, 1))); + + Assert.False(IsSocketNonBlocking(client)); + } + + [Fact] + public async Task AcceptAsync_AcceptedSocketIsBlockingByDefault() + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + client.Connect((IPEndPoint)listener.LocalEndPoint!); + + using Socket accepted = await listener.AcceptAsync(); + + Assert.True(accepted.Blocking); + Assert.False(IsSocketNonBlocking(accepted)); + } + + [Fact] + public async Task AcceptAsync_AcceptedSocketSyncReceiveWorks() + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + client.Connect((IPEndPoint)listener.LocalEndPoint!); + + using Socket accepted = await listener.AcceptAsync(); + + client.Send(new byte[] { 1, 2, 3 }); + + byte[] buffer = new byte[10]; + int received = accepted.Receive(buffer); + + Assert.Equal(3, received); + Assert.True(accepted.Blocking); + Assert.False(IsSocketNonBlocking(accepted)); + } + + [Fact] + public async Task AcceptAsync_ConcurrentAccepts_DoNotCorruptListenerState() + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(5); + + Task accept1 = listener.AcceptAsync(); + Task accept2 = listener.AcceptAsync(); + + using Socket client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + using Socket client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + client1.Connect((IPEndPoint)listener.LocalEndPoint!); + client2.Connect((IPEndPoint)listener.LocalEndPoint!); + + using Socket accepted1 = await accept1; + using Socket accepted2 = await accept2; + + Assert.True(accepted1.Blocking); + Assert.False(IsSocketNonBlocking(accepted1)); + Assert.True(accepted2.Blocking); + Assert.False(IsSocketNonBlocking(accepted2)); + } + + [Fact] + public async Task ConnectAsync_WithBuffer_SocketStaysNonBlocking() + { + using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + var saea = new SocketAsyncEventArgs(); + saea.RemoteEndPoint = (IPEndPoint)listener.LocalEndPoint!; + saea.SetBuffer(new byte[] { 1, 2, 3 }, 0, 3); + + var tcs = new TaskCompletionSource(); + saea.Completed += (_, _) => tcs.SetResult(); + + if (!client.ConnectAsync(saea)) + { + tcs.SetResult(); + } + + await tcs.Task; + + Assert.Equal(SocketError.Success, saea.SocketError); + + // When buffer > 0, the socket stays non-blocking because SendToAsync + // may have been used to send the initial data after connect. + Assert.True(IsSocketNonBlocking(client)); + + saea.Dispose(); + } + } +} diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index 43844aea397681..6fdc8aa36a6974 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -22,6 +22,7 @@ + @@ -105,6 +106,12 @@ + + + +