diff --git a/src/libraries/Common/tests/System/Net/Sockets/SocketTestExtensions.cs b/src/libraries/Common/tests/System/Net/Sockets/SocketTestExtensions.cs index 1e252f0d6d6dbb..8315b8ce990ae2 100644 --- a/src/libraries/Common/tests/System/Net/Sockets/SocketTestExtensions.cs +++ b/src/libraries/Common/tests/System/Net/Sockets/SocketTestExtensions.cs @@ -102,6 +102,9 @@ internal class PortBlocker : IDisposable private const int MaxAttempts = 16; private Socket _shadowSocket; public Socket MainSocket { get; } + public Socket SecondarySocket => _shadowSocket; + + public int Port; public PortBlocker(Func socketFactory) { @@ -126,7 +129,11 @@ public PortBlocker(Func socketFactory) _shadowSocket = new Socket(shadowAddress.AddressFamily, MainSocket.SocketType, MainSocket.ProtocolType); success = TryBindWithoutReuseAddress(_shadowSocket, shadowEndPoint, out _); - if (success) break; + if (success) + { + Port = port; + break; + } } catch (SocketException) { diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index cac94cac46a8b1..30e9a0cd47db46 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -6,6 +6,11 @@ namespace System.Net.Sockets { + public enum ConnectAlgorithm + { + Default = 0, + Parallel = 1, + } public enum IOControlCode : long { [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")] @@ -343,6 +348,7 @@ public void Connect(string host, int port) { } public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; } public bool ConnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } + public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e, System.Net.Sockets.ConnectAlgorithm connectAlgorithm) { throw null; } public System.Threading.Tasks.Task ConnectAsync(string host, int port) { throw null; } public System.Threading.Tasks.ValueTask ConnectAsync(string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; } public void Disconnect(bool reuseSocket) { } diff --git a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx index 15eb73c240c87c..ea60531946954e 100644 --- a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx @@ -315,4 +315,7 @@ Provided SocketAddress is too small for given AddressFamily. + + Provided ConnectAlgorithm {0} is not valid. + diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj index 283d47ee4feda7..2426a84e8a225c 100644 --- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -19,6 +19,7 @@ + diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ConnectAlgorithm.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ConnectAlgorithm.cs new file mode 100644 index 00000000000000..e0db2090c47676 --- /dev/null +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/ConnectAlgorithm.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Sockets +{ + /// + /// Specifies the algorithm used to establish a socket connection. + /// + public enum ConnectAlgorithm + { + /// + /// The default connection mechanism, typically sequential processing. + /// + Default = 0, + + /// + /// Uses a Happy Eyeballs-like algorithm to connect, attempting connections in parallel to improve speed and reliability. + /// + Parallel = 1, + } +} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 4f5598eac020d0..7835529bd74585 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -2918,7 +2918,7 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaMul e.StartOperationConnect(saeaMultiConnectCancelable, userSocket); try { - pending = e.DnsConnectAsync(dnsEP, default, default, cancellationToken); + pending = e.DnsConnectAsync(dnsEP, default, default, default, cancellationToken); } catch { @@ -2981,9 +2981,16 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaMul return pending; } - public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e) + public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e) => + ConnectAsync(socketType, protocolType, e, ConnectAlgorithm.Default); + public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e, ConnectAlgorithm connectAlgorithm) { ArgumentNullException.ThrowIfNull(e); + if (connectAlgorithm != ConnectAlgorithm.Default && + connectAlgorithm != ConnectAlgorithm.Parallel) + { + throw new ArgumentException(SR.Format(SR.net_sockets_invalid_connect_algorithm, connectAlgorithm), nameof(connectAlgorithm)); + } if (e.HasMultipleBuffers) { @@ -3005,7 +3012,7 @@ public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType e.StartOperationConnect(saeaMultiConnectCancelable: true, userSocket: false); try { - pending = e.DnsConnectAsync(dnsEP, socketType, protocolType, cancellationToken: default); + pending = e.DnsConnectAsync(dnsEP, socketType, protocolType, connectAlgorithm, cancellationToken: default); } catch { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index ac90675f515aa2..5658a14a37d88c 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -676,9 +676,10 @@ internal void FinishOperationAsyncFailure(SocketError socketError, int bytesTran /// The DNS end point to which to connect. /// The SocketType to use to construct new sockets, if necessary. /// The ProtocolType to use to construct new sockets, if necessary. + /// Connect strategy. /// The CancellationToken. /// true if the operation is pending; otherwise, false if it's already completed. - internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, ProtocolType protocolType, CancellationToken cancellationToken) + internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, ProtocolType protocolType, ConnectAlgorithm connectAlgorithm, CancellationToken cancellationToken) { Debug.Assert(endPoint.AddressFamily == AddressFamily.Unspecified || endPoint.AddressFamily == AddressFamily.InterNetwork || @@ -691,9 +692,15 @@ internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, Proto cancellationToken = _multipleConnectCancellation.Token; } + // We can do parallel connect only if socket was not specified and when there is at least one address of each AF. + bool parallelConnect = connectAlgorithm == ConnectAlgorithm.Parallel && + _currentSocket == null && + endPoint.AddressFamily == AddressFamily.Unspecified && + Socket.OSSupportsIPv6 && Socket.OSSupportsIPv4; + // In .NET 5 and earlier, the APM implementation allowed for synchronous exceptions from this to propagate // synchronously. This call is made here rather than in the Core async method below to preserve that behavior. - Task addressesTask = Dns.GetHostAddressesAsync(endPoint.Host, endPoint.AddressFamily, cancellationToken); + Task addressesTask = Dns.GetHostAddressesAsync(endPoint.Host, parallelConnect ? AddressFamily.InterNetwork : endPoint.AddressFamily, cancellationToken); // Initialize the internal event args instance. It needs to be initialized with `this` instance's buffer // so that it may be used as part of receives during a connect. @@ -705,7 +712,21 @@ internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, Proto // by a try/catch. Thus we ignore the result. We avoid an "async void" method so as to skip the implicit SynchronizationContext // interactions async void methods entail. #pragma warning disable CA2025 - _ = Core(internalArgs, addressesTask, endPoint.Port, socketType, protocolType, cancellationToken); + if (parallelConnect) + { + var state = new ParallelMultiConnectSocketState(this); + var internalArgsV6 = new MultiConnectSocketAsyncEventArgs(); + internalArgsV6.CopyBufferFrom(this); + + Task addressesTask6 = Dns.GetHostAddressesAsync(endPoint.Host, AddressFamily.InterNetworkV6, cancellationToken); + _ = Core(internalArgs, addressesTask, endPoint.Port, socketType, protocolType, state, cancellationToken); + _ = Core(internalArgsV6, addressesTask6, endPoint.Port, socketType, protocolType, state, cancellationToken); + return true; + } + else + { + _ = Core(internalArgs, addressesTask, endPoint.Port, socketType, protocolType, null, cancellationToken); + } #pragma warning restore // Determine whether the async operation already completed and stored the results into `this`. @@ -714,7 +735,7 @@ internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, Proto // The callback won't invoke the Completed event if it gets there first. return internalArgs.ReachedCoordinationPointFirst(); - async Task Core(MultiConnectSocketAsyncEventArgs internalArgs, Task addressesTask, int port, SocketType socketType, ProtocolType protocolType, CancellationToken cancellationToken) + async Task Core(MultiConnectSocketAsyncEventArgs internalArgs, Task addressesTask, int port, SocketType socketType, ProtocolType protocolType, ParallelMultiConnectSocketState? parallelState, CancellationToken cancellationToken) { Socket? tempSocketIPv4 = null, tempSocketIPv6 = null; Exception? caughtException = null; @@ -843,35 +864,51 @@ caughtException is OperationCanceledException || } } - // Store the results. - if (caughtException != null) + if (parallelState != null) { - SetResults(caughtException, 0, SocketFlags.None); - _currentSocket?.UpdateStatusAfterSocketError(_socketError); + // If we do parallel connect use SetResults from there to arbiter competing results. + if (caughtException != null) + { + parallelState.SetResults(null, internalArgs.SocketError, 0, SocketFlags.None, caughtException); + } + else + { + parallelState.SetResults(internalArgs.ConnectSocket, internalArgs.SocketError, internalArgs.BytesTransferred, internalArgs.SocketFlags, null); + } + internalArgs.Dispose(); } else { - SetResults(SocketError.Success, internalArgs.BytesTransferred, internalArgs.SocketFlags); - _connectSocket = _currentSocket = internalArgs.ConnectSocket!; - } + // Store the results. + if (caughtException != null) + { + SetResults(caughtException, 0, SocketFlags.None); + _currentSocket?.UpdateStatusAfterSocketError(_socketError); + } + else + { + SetResults(SocketError.Success, internalArgs.BytesTransferred, internalArgs.SocketFlags); + _connectSocket = _currentSocket = internalArgs.ConnectSocket!; + } - // Complete the operation. - if (SocketsTelemetry.Log.IsEnabled()) LogBytesTransferEvents(_connectSocket?.SocketType, SocketAsyncOperation.Connect, internalArgs.BytesTransferred); + // Complete the operation. + if (SocketsTelemetry.Log.IsEnabled()) LogBytesTransferEvents(_connectSocket?.SocketType, SocketAsyncOperation.Connect, internalArgs.BytesTransferred); - Complete(); + Complete(); - // Clean up after our temporary arguments. - internalArgs.Dispose(); + // Clean up after our temporary arguments. + internalArgs.Dispose(); - // If the caller is treating this operation as pending, own the completion. - if (!internalArgs.ReachedCoordinationPointFirst()) - { - // Regardless of _flowExecutionContext, context will have been flown through this async method, as that's part - // of what async methods do. As such, we're already on whatever ExecutionContext is the right one to invoke - // the completion callback. This method may have even mutated the ExecutionContext, in which case for telemetry - // we need those mutations to be surfaced as part of this callback, so that logging performed here sees those - // mutations (e.g. to the current Activity). - OnCompleted(this); + // If the caller is treating this operation as pending, own the completion. + if (!internalArgs.ReachedCoordinationPointFirst()) + { + // Regardless of _flowExecutionContext, context will have been flown through this async method, as that's part + // of what async methods do. As such, we're already on whatever ExecutionContext is the right one to invoke + // the completion callback. This method may have even mutated the ExecutionContext, in which case for telemetry + // we need those mutations to be surfaced as part of this callback, so that logging performed here sees those + // mutations (e.g. to the current Activity). + OnCompleted(this); + } } } } @@ -891,11 +928,59 @@ public MultiConnectSocketAsyncEventArgs() : base(unsafeSuppressExecutionContextF public short Version => _mrvtsc.Version; public void Reset() => _mrvtsc.Reset(); - protected override void OnCompleted(SocketAsyncEventArgs e) => _mrvtsc.SetResult(true); + protected override void OnCompleted(SocketAsyncEventArgs e) =>_mrvtsc.SetResult(true); public bool ReachedCoordinationPointFirst() => !Interlocked.Exchange(ref _isCompleted, true); } + private sealed class ParallelMultiConnectSocketState + { + private bool _isCompleted; + private int _count; + private SocketAsyncEventArgs _saea; + + public ParallelMultiConnectSocketState(SocketAsyncEventArgs saea) + { + _saea = saea; + } + public bool Finished() => Interlocked.Exchange(ref _isCompleted, true); + + public void SetResults(Socket? socket, SocketError socketError, int bytesTransferred, SocketFlags flags, Exception? exception) + { + int count = Interlocked.Increment(ref _count); + bool shouldComplete = false; + + if (socketError == SocketError.Success && exception == null) + { + shouldComplete = !Finished(); + if (shouldComplete) + { + _saea._connectSocket = _saea._currentSocket = socket; + _saea.SetResults(SocketError.Success, bytesTransferred, flags); + } + } + else if (count == 2) // We ignore failures on first socket since we have one more pending. + { + shouldComplete = !Finished(); + if (shouldComplete) + { + _saea.SetResults(exception!, 0, SocketFlags.None); + _saea._currentSocket?.UpdateStatusAfterSocketError(_saea._socketError); + } + } + + if (shouldComplete) + { + // If this is the first final result, we need to complete the operation and release underlying SocketAsyncEventArgs + _saea.Complete(); + if (SocketsTelemetry.Log.IsEnabled()) LogBytesTransferEvents(socket?.SocketType, SocketAsyncOperation.Connect, bytesTransferred); + // signal caller we are done. + _saea.OnCompleted(_saea); + } + } + } + + internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags) { SetResults(SocketError.Success, bytesTransferred, flags); diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs index 73c4ee74d18e21..811b6ff2b18bbe 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs @@ -1089,6 +1089,118 @@ public async Task SendTo_DifferentEP_Success(bool ipv4) result = await receiver2.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); Assert.Equal(sendBuffer.Length, result.ReceivedBytes); } + + [ConditionalFact(typeof(DualModeBase), nameof(DualModeBase.LocalhostIsBothIPv4AndIPv6))] + public void Connect_Parallel_Success() + { + using PortBlocker portBlocker = new PortBlocker(() => + { + Socket socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp); + socket.DualMode = false; + socket.BindToAnonymousPort(IPAddress.IPv6Loopback); + return socket; + }); + Socket a = portBlocker.MainSocket; + // the port blocker did not call Socket.Bind so we called bind() but we did not update properties on Socket + Socket b = new Socket(portBlocker.SecondarySocket.SafeHandle); + + a.Listen(1); + b.Listen(1); + Task t1 = a.AcceptAsync(); + Task t2 = b.AcceptAsync(); + + var mres = new ManualResetEventSlim(); + SocketAsyncEventArgs saea = new SocketAsyncEventArgs(); + saea.RemoteEndPoint = new DnsEndPoint("localhost", portBlocker.Port); + saea.Completed += (_, _) => mres.Set(); + if (Socket.ConnectAsync(a.SocketType, a.ProtocolType, saea, ConnectAlgorithm.Parallel)) + { + mres.Wait(TestSettings.PassingTestTimeout); + } + // we should see attemopt to both sockets + Task.WaitAll(new Task[] { t1, t2 }, TestSettings.PassingTestTimeout); + Assert.True(saea.ConnectSocket.Connected); + } + + [ConditionalFact(typeof(DualModeBase), nameof(DualModeBase.LocalhostIsBothIPv4AndIPv6))] + public void Connect_Parallel_Fails() + { + using PortBlocker portBlocker = new PortBlocker(() => + { + Socket socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp); + socket.DualMode = false; + socket.BindToAnonymousPort(IPAddress.IPv6Loopback); + return socket; + }); + Socket a = portBlocker.MainSocket; + // the port blocker did not call Socket.Bind so we called bind() but we did not update properties on Socket + Socket b = new Socket(portBlocker.SecondarySocket.SafeHandle); + + // do NOT a.Listen(1); + // do NOT b.Listen(1); + // Do NOT a.AcceptAsync(); + // Do NOT b.AcceptAsync(); + + var mres = new ManualResetEventSlim(); + SocketAsyncEventArgs saea = new SocketAsyncEventArgs(); + saea.RemoteEndPoint = new DnsEndPoint("localhost", portBlocker.Port); + saea.Completed += (_, _) => mres.Set(); + if (Socket.ConnectAsync(a.SocketType, a.ProtocolType, saea, ConnectAlgorithm.Parallel)) + { + Assert.True(mres.Wait(TestSettings.PassingTestLongTimeout), "Completed did not get called in time"); + } + // we should see attemopt to both sockets + Assert.Null(saea.ConnectSocket); + Assert.NotEqual(SocketError.Success, saea.SocketError); + } + + [ConditionalTheory(typeof(DualModeBase), nameof(DualModeBase.LocalhostIsBothIPv4AndIPv6))] + [InlineData(true)] + [InlineData(false)] + public void Connect_Parallel_FailsOver(bool preferIPv6) + { + using PortBlocker portBlocker = new PortBlocker(() => + { + Socket socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp); + socket.DualMode = false; + socket.BindToAnonymousPort(IPAddress.IPv6Loopback); + return socket; + }); + Socket a = portBlocker.MainSocket; + Socket b = new Socket(portBlocker.SecondarySocket.SafeHandle); + + if (preferIPv6) + { + a.Listen(1); + } + else + { + b.Listen(1); + } + + var mres = new ManualResetEventSlim(); + SocketAsyncEventArgs saea = new SocketAsyncEventArgs(); + saea.RemoteEndPoint = new DnsEndPoint("localhost", portBlocker.Port); + saea.Completed += (_, _) => mres.Set(); + + if (Socket.ConnectAsync(a.SocketType, a.ProtocolType, saea, ConnectAlgorithm.Parallel)) + { + mres.Wait(TestSettings.PassingTestTimeout); + } + // we should see attempt to both sockets + Assert.NotNull(saea.ConnectSocket); + Assert.True(saea.ConnectSocket.Connected); + if (preferIPv6) + { + Assert.Equal(AddressFamily.InterNetworkV6, saea.ConnectSocket.AddressFamily); + Assert.Equal(a.LocalEndPoint, saea.ConnectSocket.RemoteEndPoint); + } + else + { + Assert.Equal(AddressFamily.InterNetwork, saea.ConnectSocket.AddressFamily); + Assert.Equal(b.LocalEndPoint, saea.ConnectSocket.RemoteEndPoint); + } + } } internal static class ConnectExtensions