diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index a8cb83dc651f8..20639f78c97e2 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -232,7 +232,7 @@ partial void WildcardBindForConnectIfNecessary(AddressFamily addressFamily) switch (addressFamily) { case AddressFamily.InterNetwork: - address = IsDualMode ? IPAddress.Any.MapToIPv6() : IPAddress.Any; + address = IsDualMode ? s_IPAddressAnyMapToIPv6 : IPAddress.Any; break; case AddressFamily.InterNetworkV6: diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 9fbc3c4c9a28b..479ba522ace6b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -22,6 +22,8 @@ public partial class Socket : IDisposable { internal const int DefaultCloseTimeout = -1; // NOTE: changing this default is a breaking change. + private static readonly IPAddress s_IPAddressAnyMapToIPv6 = IPAddress.Any.MapToIPv6(); + private SafeSocketHandle _handle; // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the @@ -29,6 +31,10 @@ public partial class Socket : IDisposable internal EndPoint? _rightEndPoint; internal EndPoint? _remoteEndPoint; + // Cached LocalEndPoint value. Cleared on disconnect and error. Cached wildcard addresses are + // also cleared on connect and accept. + private EndPoint? _localEndPoint; + // These flags monitor if the socket was ever connected at any time and if it still is. private bool _isConnected; private bool _isDisconnected; @@ -317,6 +323,7 @@ public EndPoint? LocalEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -325,23 +332,27 @@ public EndPoint? LocalEndPoint return null; } - Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); - - unsafe + if (_localEndPoint == null) { - fixed (byte* buffer = socketAddress.Buffer) - fixed (int* bufferSize = &socketAddress.InternalSize) + Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); + + unsafe { - // This may throw ObjectDisposedException. - SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize); - if (errorCode != SocketError.Success) + fixed (byte* buffer = socketAddress.Buffer) + fixed (int* bufferSize = &socketAddress.InternalSize) { - UpdateStatusAfterSocketErrorAndThrowException(errorCode); + // This may throw ObjectDisposedException. + SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize); + if (errorCode != SocketError.Success) + { + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } } } + _localEndPoint = _rightEndPoint.Create(socketAddress); } - return _rightEndPoint.Create(socketAddress); + return _localEndPoint; } } @@ -359,6 +370,7 @@ public EndPoint? RemoteEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -470,6 +482,7 @@ public bool Connected // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -2303,6 +2316,7 @@ private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult { SetToDisconnected(); _remoteEndPoint = null; + _localEndPoint = null; } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}"); @@ -2332,6 +2346,7 @@ public void Disconnect(bool reuseSocket) SetToDisconnected(); _remoteEndPoint = null; + _localEndPoint = null; } // Routine Description: @@ -2760,6 +2775,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -2769,6 +2785,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock UpdateSendSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -3148,6 +3165,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -3157,6 +3175,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -3357,6 +3376,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -3366,6 +3386,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -3762,6 +3783,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) catch { _rightEndPoint = oldEndPoint; + _localEndPoint = null; // Clear in-use flag on event args object. e.Complete(); @@ -4091,6 +4113,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) catch { _rightEndPoint = null; + _localEndPoint = null; // Clear in-use flag on event args object. e.Complete(); throw; @@ -4099,6 +4122,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) if (!CheckErrorAndUpdateStatus(socketError)) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; } return socketError == SocketError.IOPending; @@ -4610,6 +4634,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa { // _rightEndPoint will always equal oldEndPoint. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -4626,6 +4651,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa UpdateConnectSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -4849,6 +4875,12 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._rightEndPoint = _rightEndPoint; socket._remoteEndPoint = remoteEP; + // If the listener socket was bound to a wildcard address, then the `accept` system call + // will assign a specific address to the accept socket's local endpoint instead of a + // wildcard address. In that case we should not copy listener's wildcard local endpoint. + + socket._localEndPoint = !IsWildcardEndPoint(_localEndPoint) ? _localEndPoint : null; + // The socket is connected. socket.SetToConnected(); @@ -4880,9 +4912,38 @@ internal void SetToConnected() // some point in time update the perf counter as well. _isConnected = true; _isDisconnected = false; + UpdateLocalEndPointOnConnect(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected"); } + private void UpdateLocalEndPointOnConnect() + { + // If the client socket was bound to a wildcard address, then the `connect` system call + // will assign a specific address to the client socket's local endpoint instead of a + // wildcard address. In that case we should clear the cached wildcard local endpoint. + + if (IsWildcardEndPoint(_localEndPoint)) + { + _localEndPoint = null; + } + } + + private bool IsWildcardEndPoint(EndPoint? endPoint) + { + if (endPoint == null) + { + return false; + } + + if (endPoint is IPEndPoint ipEndpoint) + { + IPAddress address = ipEndpoint.Address; + return IPAddress.Any.Equals(address) || IPAddress.IPv6Any.Equals(address) || s_IPAddressAnyMapToIPv6.Equals(address); + } + + return false; + } + internal void SetToDisconnected() { if (!_isConnected) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs new file mode 100644 index 0000000000000..ac2ce8c37e8d0 --- /dev/null +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -0,0 +1,284 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.Sockets.Tests +{ + public abstract class LocalEndPointTest : SocketTestHelperBase where T : SocketHelperBase, new() + { + protected abstract bool IPv6 { get; } + + private IPAddress Wildcard => IPv6 ? IPAddress.IPv6Any : IPAddress.Any; + + private IPAddress Loopback => IPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback; + + public LocalEndPointTest(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task UdpSocket_WhenBoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() + { + using (Socket receiver = CreateUdpSocket()) + using (Socket sender = CreateUdpSocket()) + { + int receiverPort = receiver.BindToAnonymousPort(Wildcard); + + Assert.Null(sender.LocalEndPoint); + + int senderPortAfterBind = sender.BindToAnonymousPort(Wildcard); + + Assert.Equal(Wildcard, GetLocalEPAddress(sender)); // wildcard before sendto + + var sendToEP = new IPEndPoint(Loopback, receiverPort); + + await SendToAsync(sender, new byte[] { 1, 2, 3 }, sendToEP); + + Assert.Equal(Wildcard, GetLocalEPAddress(sender)); // stays as wildcard after sendto + Assert.Equal(senderPortAfterBind, GetLocalEPPort(sender)); + + byte[] buf = new byte[3]; + EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); + receiver.ReceiveFrom(buf, ref receiveFromEP); + + Assert.Equal(new byte[] { 1, 2, 3 }, buf); + Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address + Assert.Equal(senderPortAfterBind, ((IPEndPoint)receiveFromEP).Port); + } + } + + [Fact] + public async Task UdpSocket_WhenNotBound_LocalEPChangeToWildcardOnSendTo() + { + using (Socket receiver = CreateUdpSocket()) + using (Socket sender = CreateUdpSocket()) + { + int receiverPort = receiver.BindToAnonymousPort(Wildcard); + + Assert.Null(sender.LocalEndPoint); // null before sendto + + var sendToEP = new IPEndPoint(Loopback, receiverPort); + + await SendToAsync(sender, new byte[] { 1, 2, 3 }, sendToEP); + + Assert.Equal(Wildcard, GetLocalEPAddress(sender)); // changes to wildcard after sendto + + byte[] buf = new byte[3]; + EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); + receiver.ReceiveFrom(buf, ref receiveFromEP); + + Assert.Equal(new byte[] { 1, 2, 3 }, buf); + Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address + } + } + + [Fact] + public async Task TcpClientSocket_WhenBoundToWildcardAddress_LocalEPChangeToSpecificOnConnnect() + { + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) + { + int serverPort = server.BindToAnonymousPort(Wildcard); + int clientPortAfterBind = client.BindToAnonymousPort(Wildcard); + + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard before connect + + server.Listen(); + Task acceptTask = AcceptAsync(server); + + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); + + Assert.Equal(Loopback, GetLocalEPAddress(client)); // changes to specific after connect + Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + } + } + + [Fact] + public async Task TcpClientSocket_WhenNotBound_LocalEPChangeToSpecificOnConnnect() + { + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) + { + int serverPort = server.BindToAnonymousPort(Loopback); + server.Listen(); + Task acceptTask = AcceptAsync(server); + + Assert.Null(client.LocalEndPoint); // null before connect + + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); + + Assert.Equal(Loopback, GetLocalEPAddress(client)); // changes to specific after connect + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + } + } + + [Fact] + public async Task TcpAcceptSocket_WhenServerBoundToWildcardAddress_LocalEPIsSpecific() + { + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) + { + int serverPort = server.BindToAnonymousPort(Wildcard); + + Assert.Equal(Wildcard, GetLocalEPAddress(server)); // server -> wildcard before accept + + server.Listen(); + Task acceptTask = AcceptAsync(server); + + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + Assert.Equal(accept.LocalEndPoint, client.RemoteEndPoint); + + Assert.Equal(Wildcard, GetLocalEPAddress(server)); // server -> stays as wildcard + Assert.Equal(Loopback, GetLocalEPAddress(accept)); // accept -> specific + Assert.Equal(serverPort, GetLocalEPPort(accept)); + } + } + + [Fact] + public async Task TcpAcceptSocket_WhenServerBoundToSpecificAddress_LocalEPIsSame() + { + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) + { + int serverPort = server.BindToAnonymousPort(Loopback); + + Assert.Equal(Loopback, GetLocalEPAddress(server)); // server -> specific before accept + + server.Listen(); + Task acceptTask = AcceptAsync(server); + + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + + Assert.Equal(GetLocalEPAddress(server), GetLocalEPAddress(accept)); // accept -> same address + Assert.Equal(serverPort, GetLocalEPPort(accept)); + } + } + + [Fact] + public void LocalEndPoint_IsCached() + { + using (Socket socket = CreateTcpSocket()) + { + socket.BindToAnonymousPort(Loopback); + + EndPoint localEndPointCall1 = socket.LocalEndPoint; + EndPoint localEndPointCall2 = socket.LocalEndPoint; + + Assert.Same(localEndPointCall1, localEndPointCall2); + } + } + + private Socket CreateUdpSocket() + { + return new Socket( + IPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, + SocketType.Dgram, + ProtocolType.Udp + ); + } + + private Socket CreateTcpSocket() + { + return new Socket( + IPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, + SocketType.Stream, + ProtocolType.Tcp + ); + } + + private IPAddress GetLocalEPAddress(Socket socket) + { + return ((IPEndPoint)socket.LocalEndPoint).Address; + } + + private int GetLocalEPPort(Socket socket) + { + return ((IPEndPoint)socket.LocalEndPoint).Port; + } + } + public abstract class LocalEndPointTestIPv4 : LocalEndPointTest where T : SocketHelperBase, new() + { + protected override bool IPv6 => false; + + public LocalEndPointTestIPv4(ITestOutputHelper output) : base(output) { } + } + + public abstract class LocalEndPointTestIPv6 : LocalEndPointTest where T : SocketHelperBase, new() + { + protected override bool IPv6 => true; + + public LocalEndPointTestIPv6(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Sync : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Sync(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4SyncForceNonBlocking : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4SyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Apm : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Apm(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Task : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Task(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Eap : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Eap(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Sync : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Sync(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6SyncForceNonBlocking : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6SyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Apm : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Apm(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Task : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Task(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Eap : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Eap(ITestOutputHelper output) : base(output) { } + } +} diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index 43662fa98f8df..1e51b4f1099b3 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -23,6 +23,7 @@ +