From 2e977ea84b6fe48f14e2d7e09aa52cc157be77d6 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 14 Jul 2020 16:06:50 +0300 Subject: [PATCH 01/16] Add cache for LocalEndPoint --- .../src/System/Net/Sockets/Socket.Windows.cs | 1 + .../src/System/Net/Sockets/Socket.cs | 64 +++++++++++++++---- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index d20fbea8a1b21..fe47f20624129 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -66,6 +66,7 @@ public Socket(SocketInformation socketInformation) if (errorCode == SocketError.Success) { _rightEndPoint = ep.Create(socketAddress); + _localEndPoint = null; } else if (errorCode == SocketError.InvalidArgument) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 6bddfe34d96ab..788e5e4bb143a 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -25,6 +25,7 @@ public partial class Socket : IDisposable // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the // correct type (IPEndPoint, etc). internal EndPoint? _rightEndPoint; + private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Should be cleared on any _rightEndPoint change internal EndPoint? _remoteEndPoint; // These flags monitor if the socket was ever connected at any time and if it still is. @@ -178,6 +179,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) _rightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); break; } + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change // Try to determine if we're connected, based on querying for a peer, just as we would in RemoteEndPoint, // but ignoring any failures; this is best-effort (RemoteEndPoint also does a catch-all around the Create call). @@ -315,31 +317,36 @@ public EndPoint? LocalEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change _nonBlockingConnectInProgress = false; } - if (_rightEndPoint == null) + if (_localEndPoint == null) { - return null; - } + if (_rightEndPoint == null) + { + return null; + } - Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); + Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); - unsafe - { - fixed (byte* buffer = socketAddress.Buffer) - fixed (int* bufferSize = &socketAddress.InternalSize) + unsafe { - // This may throw ObjectDisposedException. - SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize); - if (errorCode != SocketError.Success) + fixed (byte* buffer = socketAddress.Buffer) + fixed (int* bufferSize = &socketAddress.InternalSize) { - UpdateStatusAfterSocketErrorAndThrowException(errorCode); + // This may throw ObjectDisposedException. + SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize); + if (errorCode != SocketError.Success) + { + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } } } + _localEndPoint = _rightEndPoint.Create(socketAddress); } - return _rightEndPoint.Create(socketAddress); + return _localEndPoint; } } @@ -357,6 +364,7 @@ public EndPoint? RemoteEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change _nonBlockingConnectInProgress = false; } @@ -468,6 +476,7 @@ public bool Connected // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change _nonBlockingConnectInProgress = false; } @@ -830,6 +839,7 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd { // Save a copy of the EndPoint so we can use it for Create(). _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } } @@ -1325,6 +1335,7 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, { // Save a copy of the EndPoint so we can use it for Create(). _rightEndPoint = remoteEP; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer, offset, size); @@ -1562,6 +1573,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla { // Save a copy of the EndPoint so we can use it for Create(). _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } } @@ -1646,6 +1658,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl { // Save a copy of the EndPoint so we can use it for Create(). _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } } @@ -2674,6 +2687,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldLocalEndPoint = _localEndPoint; // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to // avoid a Socket leak in case of error. @@ -2683,6 +2697,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock if (_rightEndPoint == null) { _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } errorCode = SocketPal.SendToAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -2692,6 +2707,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw; } @@ -2701,6 +2717,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock UpdateSendSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); } @@ -3029,6 +3046,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, // Start the ReceiveFrom. EndPoint oldEndPoint = _rightEndPoint; + EndPoint? oldLocalEndPoint = _localEndPoint; // We don't do a CAS demand here because the contents of remoteEP aren't used by // WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress @@ -3048,6 +3066,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, if (_rightEndPoint == null) { _rightEndPoint = remoteEP; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } errorCode = SocketPal.ReceiveMessageFromAsync(this, _handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -3070,6 +3089,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw; } @@ -3079,6 +3099,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); } @@ -3251,6 +3272,7 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint endPointSnapshot, Internals.SocketAddress socketAddress, OriginalAddressOverlappedAsyncResult asyncResult) { EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldLocalEndPoint = _localEndPoint; if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); @@ -3265,6 +3287,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags if (_rightEndPoint == null) { _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } errorCode = SocketPal.ReceiveFromAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -3274,6 +3297,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw; } @@ -3283,6 +3307,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); } @@ -3663,9 +3688,11 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) // Save the old RightEndPoint and prep new RightEndPoint. EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldLocalEndPoint = _localEndPoint; if (_rightEndPoint == null) { _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } // Prepare for the native call. @@ -3689,6 +3716,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) catch { _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; // Clear in-use flag on event args object. e.Complete(); @@ -4005,9 +4033,11 @@ public bool SendToAsync(SocketAsyncEventArgs e) e.StartOperationCommon(this, SocketAsyncOperation.SendTo); EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldLocalEndPoint = _localEndPoint; if (_rightEndPoint == null) { _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } SocketError socketError; @@ -4018,6 +4048,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) catch { _rightEndPoint = null; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change // Clear in-use flag on event args object. e.Complete(); throw; @@ -4026,6 +4057,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) if (!CheckErrorAndUpdateStatus(socketError)) { _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; } return socketError == SocketError.IOPending; @@ -4138,6 +4170,7 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket { // Save a copy of the EndPoint so we can use it for Create(). _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}"); @@ -4516,9 +4549,11 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa } EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldLocalEndPoint = _localEndPoint; if (_rightEndPoint == null) { _rightEndPoint = endPointSnapshot; + _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change } SocketError errorCode; @@ -4530,6 +4565,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa { // _rightEndPoint will always equal oldEndPoint. _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw; } @@ -4546,6 +4582,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa UpdateConnectSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); } @@ -4767,6 +4804,7 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._socketType = _socketType; socket._protocolType = _protocolType; socket._rightEndPoint = _rightEndPoint; + socket._localEndPoint = null; socket._remoteEndPoint = remoteEP; // The socket is connected. From 579f788a54470cdc3f9dbfe9f99e75a317d3f2c9 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 14 Jul 2020 16:34:19 +0300 Subject: [PATCH 02/16] Add RightEndPoint property, clear _localEndPoint in RightEndPoint setter --- .../AcceptOverlappedAsyncResult.Unix.cs | 6 +- .../AcceptOverlappedAsyncResult.Windows.cs | 4 +- .../src/System/Net/Sockets/Socket.Windows.cs | 7 +- .../src/System/Net/Sockets/Socket.cs | 170 +++++++++--------- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 2 +- .../Net/Sockets/SocketAsyncEventArgs.cs | 6 +- 6 files changed, 93 insertions(+), 102 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs index 657544b27c319..77c09deb8f102 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs @@ -32,14 +32,14 @@ public void CompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddre if (errorCode == SocketError.Success) { - Debug.Assert(_listenSocket._rightEndPoint != null); + Debug.Assert(_listenSocket.RightEndPoint != null); - Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket._rightEndPoint); + Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket.RightEndPoint); System.Buffer.BlockCopy(socketAddress, 0, remoteSocketAddress.Buffer, 0, socketAddressLen); _acceptedSocket = _listenSocket.CreateAcceptSocket( SocketPal.CreateSocket(acceptedFileDescriptor), - _listenSocket._rightEndPoint.Create(remoteSocketAddress)); + _listenSocket.RightEndPoint.Create(remoteSocketAddress)); } base.CompletionCallback(0, errorCode); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs index d85b2bfc8c4af..b3d009474d254 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs @@ -25,7 +25,7 @@ internal sealed partial class AcceptOverlappedAsyncResult : BaseOverlappedAsyncR if (NetEventSource.Log.IsEnabled()) LogBuffer(numBytes); // get the endpoint - remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket._rightEndPoint!); + remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket.RightEndPoint!); IntPtr localAddr; int localAddrLength; @@ -86,7 +86,7 @@ internal sealed partial class AcceptOverlappedAsyncResult : BaseOverlappedAsyncR return null; } - return _listenSocket.UpdateAcceptSocket(_acceptSocket!, _listenSocket._rightEndPoint!.Create(remoteSocketAddress!)); + return _listenSocket.UpdateAcceptSocket(_acceptSocket!, _listenSocket.RightEndPoint!.Create(remoteSocketAddress!)); } // SetUnmanagedStructures diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index fe47f20624129..cd466fbda014b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -65,8 +65,7 @@ public Socket(SocketInformation socketInformation) if (errorCode == SocketError.Success) { - _rightEndPoint = ep.Create(socketAddress); - _localEndPoint = null; + RightEndPoint = ep.Create(socketAddress); } else if (errorCode == SocketError.InvalidArgument) { @@ -196,7 +195,7 @@ internal bool DisconnectExBlocking(SafeSocketHandle socketHandle, IntPtr overlap partial void WildcardBindForConnectIfNecessary(AddressFamily addressFamily) { - if (_rightEndPoint != null) + if (RightEndPoint != null) { return; } @@ -351,7 +350,7 @@ private Socket GetOrCreateAcceptSocket(Socket? acceptSocket, bool checkDisconnec { acceptSocket = new Socket(_addressFamily, _socketType, _protocolType); } - else if (acceptSocket._rightEndPoint != null && (!checkDisconnected || !acceptSocket._isDisconnected)) + else if (acceptSocket.RightEndPoint != null && (!checkDisconnected || !acceptSocket._isDisconnected)) { throw new InvalidOperationException(SR.Format(SR.net_sockets_namedmustnotbebound, propertyName)); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 788e5e4bb143a..b693855c2154d 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -22,10 +22,19 @@ public partial class Socket : IDisposable private SafeSocketHandle _handle; - // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the + // RightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the // correct type (IPEndPoint, etc). - internal EndPoint? _rightEndPoint; - private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Should be cleared on any _rightEndPoint change + private EndPoint? _rightEndPoint; + internal EndPoint? RightEndPoint + { + get { return _rightEndPoint; } + private set + { + _rightEndPoint = value; + _localEndPoint = null; + } + } + private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Should be cleared on any RightEndPoint change internal EndPoint? _remoteEndPoint; // These flags monitor if the socket was ever connected at any time and if it still is. @@ -44,7 +53,7 @@ public partial class Socket : IDisposable private bool _nonBlockingConnectInProgress; // Keep track of the kind of endpoint used to do a non-blocking connect, so we can set - // it to _rightEndPoint when we discover we're connected. + // it to RightEndPoint when we discover we're connected. private EndPoint? _nonBlockingConnectRightEndPoint; // These are constants initialized by constructor. @@ -161,7 +170,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) switch (_addressFamily) { case AddressFamily.InterNetwork: - _rightEndPoint = new IPEndPoint( + RightEndPoint = new IPEndPoint( new IPAddress((long)SocketAddressPal.GetIPv4Address(buffer.Slice(0, bufferLength)) & 0x0FFFFFFFF), SocketAddressPal.GetPort(buffer)); break; @@ -169,21 +178,20 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) case AddressFamily.InterNetworkV6: Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; SocketAddressPal.GetIPv6Address(buffer.Slice(0, bufferLength), address, out uint scope); - _rightEndPoint = new IPEndPoint( + RightEndPoint = new IPEndPoint( new IPAddress(address, scope), SocketAddressPal.GetPort(buffer)); break; case AddressFamily.Unix: socketAddress = new Internals.SocketAddress(_addressFamily, buffer.Slice(0, bufferLength)); - _rightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); + RightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); break; } - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change // Try to determine if we're connected, based on querying for a peer, just as we would in RemoteEndPoint, // but ignoring any failures; this is best-effort (RemoteEndPoint also does a catch-all around the Create call). - if (_rightEndPoint != null) + if (RightEndPoint != null) { try { @@ -316,19 +324,18 @@ public EndPoint? LocalEndPoint { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - _rightEndPoint = _nonBlockingConnectRightEndPoint; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = _nonBlockingConnectRightEndPoint; _nonBlockingConnectInProgress = false; } if (_localEndPoint == null) { - if (_rightEndPoint == null) + if (RightEndPoint == null) { return null; } - Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); + Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(RightEndPoint); unsafe { @@ -343,7 +350,7 @@ public EndPoint? LocalEndPoint } } } - _localEndPoint = _rightEndPoint.Create(socketAddress); + _localEndPoint = RightEndPoint.Create(socketAddress); } return _localEndPoint; @@ -363,20 +370,19 @@ public EndPoint? RemoteEndPoint { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - _rightEndPoint = _nonBlockingConnectRightEndPoint; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = _nonBlockingConnectRightEndPoint; _nonBlockingConnectInProgress = false; } - if (_rightEndPoint == null || !_isConnected) + if (RightEndPoint == null || !_isConnected) { return null; } Internals.SocketAddress socketAddress = _addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6 ? - IPEndPointExtensions.Serialize(_rightEndPoint) : - new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size than _rightEndPoint. + IPEndPointExtensions.Serialize(RightEndPoint) : + new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size than RightEndPoint. // This may throw ObjectDisposedException. SocketError errorCode = SocketPal.GetPeerName( @@ -391,7 +397,7 @@ public EndPoint? RemoteEndPoint try { - _remoteEndPoint = _rightEndPoint.Create(socketAddress); + _remoteEndPoint = RightEndPoint.Create(socketAddress); } catch { @@ -475,8 +481,7 @@ public bool Connected { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - _rightEndPoint = _nonBlockingConnectRightEndPoint; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = _nonBlockingConnectRightEndPoint; _nonBlockingConnectInProgress = false; } @@ -515,7 +520,7 @@ public bool IsBound { get { - return (_rightEndPoint != null); + return (RightEndPoint != null); } } @@ -835,11 +840,10 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd UpdateStatusAfterSocketErrorAndThrowException(errorCode); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } } @@ -1072,7 +1076,7 @@ public Socket Accept() ThrowIfDisposed(); - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -1092,7 +1096,7 @@ public Socket Accept() Internals.SocketAddress socketAddress = _addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6 ? - IPEndPointExtensions.Serialize(_rightEndPoint) : + IPEndPointExtensions.Serialize(RightEndPoint) : new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size. // This may throw ObjectDisposedException. @@ -1113,7 +1117,7 @@ public Socket Accept() Debug.Assert(!acceptedSocketHandle.IsInvalid); - Socket socket = CreateAcceptSocket(acceptedSocketHandle, _rightEndPoint.Create(socketAddress)); + Socket socket = CreateAcceptSocket(acceptedSocketHandle, RightEndPoint.Create(socketAddress)); if (NetEventSource.Log.IsEnabled()) NetEventSource.Accepted(socket, socket.RemoteEndPoint!, socket.LocalEndPoint); return socket; } @@ -1331,11 +1335,10 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, UpdateStatusAfterSocketErrorAndThrowException(errorCode); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - _rightEndPoint = remoteEP; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = remoteEP; } if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer, offset, size); @@ -1529,7 +1532,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla { throw new ArgumentOutOfRangeException(nameof(size)); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -1569,11 +1572,10 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla catch { } - if (_rightEndPoint == null) + if (RightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } } @@ -1609,7 +1611,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl { throw new ArgumentOutOfRangeException(nameof(size)); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -1654,11 +1656,10 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl catch { } - if (_rightEndPoint == null) + if (RightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } } @@ -2078,7 +2079,7 @@ private bool CanUseConnectEx(EndPoint remoteEP) // Unix sockets are not supported by ConnectEx. return (_socketType == SocketType.Stream) && - (_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)) && + (RightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)) && (remoteEP.AddressFamily != AddressFamily.Unix); } @@ -2686,7 +2687,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock { if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); - EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldEndPoint = RightEndPoint; EndPoint? oldLocalEndPoint = _localEndPoint; // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to @@ -2694,10 +2695,9 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock SocketError errorCode = SocketError.SocketError; try { - if (_rightEndPoint == null) + if (RightEndPoint == null) { - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } errorCode = SocketPal.SendToAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -2706,7 +2706,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock } catch (ObjectDisposedException) { - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw; } @@ -2716,7 +2716,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock { UpdateSendSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); @@ -3033,7 +3033,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, { throw new ArgumentOutOfRangeException(nameof(size)); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3045,7 +3045,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, asyncResult.StartPostingAsyncOp(false); // Start the ReceiveFrom. - EndPoint oldEndPoint = _rightEndPoint; + EndPoint oldEndPoint = RightEndPoint; EndPoint? oldLocalEndPoint = _localEndPoint; // We don't do a CAS demand here because the contents of remoteEP aren't used by @@ -3063,10 +3063,9 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, SetReceivingPacketInformation(); - if (_rightEndPoint == null) + if (RightEndPoint == null) { - _rightEndPoint = remoteEP; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = remoteEP; } errorCode = SocketPal.ReceiveMessageFromAsync(this, _handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -3088,7 +3087,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, } catch (ObjectDisposedException) { - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw; } @@ -3098,7 +3097,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); @@ -3233,7 +3232,7 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket { throw new ArgumentOutOfRangeException(nameof(size)); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3271,7 +3270,7 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint endPointSnapshot, Internals.SocketAddress socketAddress, OriginalAddressOverlappedAsyncResult asyncResult) { - EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldEndPoint = RightEndPoint; EndPoint? oldLocalEndPoint = _localEndPoint; if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); @@ -3284,10 +3283,9 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags // Save a copy of the original EndPoint in the asyncResult. asyncResult.SocketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); - if (_rightEndPoint == null) + if (RightEndPoint == null) { - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } errorCode = SocketPal.ReceiveFromAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -3296,7 +3294,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags } catch (ObjectDisposedException) { - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw; } @@ -3306,7 +3304,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); @@ -3448,7 +3446,7 @@ public IAsyncResult BeginAccept(Socket? acceptSocket, int receiveSize, AsyncCall private void DoBeginAccept(Socket? acceptSocket, int receiveSize, AcceptOverlappedAsyncResult asyncResult) { - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3463,7 +3461,7 @@ private void DoBeginAccept(Socket? acceptSocket, int receiveSize, AcceptOverlapp if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"AcceptSocket:{acceptSocket}"); - int socketAddressSize = GetAddressSize(_rightEndPoint); + int socketAddressSize = GetAddressSize(RightEndPoint); SocketError errorCode = SocketPal.AcceptAsync(this, _handle, acceptHandle, receiveSize, socketAddressSize, asyncResult); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"AcceptAsync returns:{errorCode} {asyncResult}"); @@ -3580,7 +3578,7 @@ public bool AcceptAsync(SocketAsyncEventArgs e) { throw new ArgumentException(SR.net_multibuffernotsupported, nameof(e)); } - if (_rightEndPoint == null) + if (RightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3687,12 +3685,11 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily); // Save the old RightEndPoint and prep new RightEndPoint. - EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldEndPoint = RightEndPoint; EndPoint? oldLocalEndPoint = _localEndPoint; - if (_rightEndPoint == null) + if (RightEndPoint == null) { - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } // Prepare for the native call. @@ -3715,7 +3712,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) } catch { - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; // Clear in-use flag on event args object. @@ -4032,12 +4029,11 @@ public bool SendToAsync(SocketAsyncEventArgs e) // Prepare for and make the native call. e.StartOperationCommon(this, SocketAsyncOperation.SendTo); - EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldEndPoint = RightEndPoint; EndPoint? oldLocalEndPoint = _localEndPoint; - if (_rightEndPoint == null) + if (RightEndPoint == null) { - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } SocketError socketError; @@ -4047,8 +4043,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) } catch { - _rightEndPoint = null; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = null; // Clear in-use flag on event args object. e.Complete(); throw; @@ -4056,7 +4051,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) if (!CheckErrorAndUpdateStatus(socketError)) { - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; } @@ -4166,11 +4161,10 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket throw socketException; } - if (_rightEndPoint == null) + if (RightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}"); @@ -4333,7 +4327,7 @@ internal void SetReceivingPacketInformation() { // DualMode: When bound to IPv6Any you must enable both socket options. // When bound to an IPv4 mapped IPv6 address you must enable the IPv4 socket option. - IPEndPoint? ipEndPoint = _rightEndPoint as IPEndPoint; + IPEndPoint? ipEndPoint = RightEndPoint as IPEndPoint; IPAddress? boundAddress = (ipEndPoint != null ? ipEndPoint.Address : null); Debug.Assert(boundAddress != null, "Not Bound"); if (_addressFamily == AddressFamily.InterNetwork) @@ -4548,12 +4542,11 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa asyncResult.StartPostingAsyncOp(false); } - EndPoint? oldEndPoint = _rightEndPoint; + EndPoint? oldEndPoint = RightEndPoint; EndPoint? oldLocalEndPoint = _localEndPoint; - if (_rightEndPoint == null) + if (RightEndPoint == null) { - _rightEndPoint = endPointSnapshot; - _localEndPoint = null; // clear cached value, if any, after _rightEndPoint change + RightEndPoint = endPointSnapshot; } SocketError errorCode; @@ -4563,8 +4556,8 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa } catch { - // _rightEndPoint will always equal oldEndPoint. - _rightEndPoint = oldEndPoint; + // RightEndPoint will always equal oldEndPoint. + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw; } @@ -4581,7 +4574,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa { UpdateConnectSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. - _rightEndPoint = oldEndPoint; + RightEndPoint = oldEndPoint; _localEndPoint = oldLocalEndPoint; throw new SocketException((int)errorCode); @@ -4803,8 +4796,7 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._addressFamily = _addressFamily; socket._socketType = _socketType; socket._protocolType = _protocolType; - socket._rightEndPoint = _rightEndPoint; - socket._localEndPoint = null; + socket.RightEndPoint = RightEndPoint; socket._remoteEndPoint = remoteEP; // The socket is connected. diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 09d688d9e0c70..f4e7c1c9381b2 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -351,7 +351,7 @@ private SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAd System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.Buffer, 0, _acceptAddressBufferCount); _acceptSocket = _currentSocket!.CreateAcceptSocket( SocketPal.CreateSocket(_acceptedFileDescriptor), - _currentSocket._rightEndPoint!.Create(remoteSocketAddress)); + _currentSocket.RightEndPoint!.Create(remoteSocketAddress)); return SocketError.Success; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index 118371df52f16..6ad1e78432078 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -523,7 +523,7 @@ internal void StartOperationAccept() // AcceptEx needs a single buffer that's the size of two native sockaddr buffers with 16 // extra bytes each. It can also take additional buffer space in front of those special // sockaddr structures that can be filled in with initial data coming in on a connection. - _acceptAddressBufferCount = 2 * (Socket.GetAddressSize(_currentSocket!._rightEndPoint!) + 16); + _acceptAddressBufferCount = 2 * (Socket.GetAddressSize(_currentSocket!.RightEndPoint!) + 16); // If our caller specified a buffer (willing to get received data with the Accept) then // it needs to be large enough for the two special sockaddr buffers that AcceptEx requires. @@ -681,13 +681,13 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { case SocketAsyncOperation.Accept: // Get the endpoint. - Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_currentSocket!._rightEndPoint!); + Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_currentSocket!.RightEndPoint!); socketError = FinishOperationAccept(remoteSocketAddress); if (socketError == SocketError.Success) { - _acceptSocket = _currentSocket.UpdateAcceptSocket(_acceptSocket!, _currentSocket._rightEndPoint!.Create(remoteSocketAddress)); + _acceptSocket = _currentSocket.UpdateAcceptSocket(_acceptSocket!, _currentSocket.RightEndPoint!.Create(remoteSocketAddress)); if (NetEventSource.Log.IsEnabled()) { From f0a5ec96d5f76095f8badf6baffaaf16a2c7d31a Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 23 Jul 2020 18:11:03 +0300 Subject: [PATCH 03/16] Roll back property changes --- .../AcceptOverlappedAsyncResult.Unix.cs | 6 +- .../AcceptOverlappedAsyncResult.Windows.cs | 4 +- .../src/System/Net/Sockets/Socket.Windows.cs | 6 +- .../src/System/Net/Sockets/Socket.cs | 183 ++++++++---------- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 2 +- .../Net/Sockets/SocketAsyncEventArgs.cs | 6 +- 6 files changed, 91 insertions(+), 116 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs index 77c09deb8f102..657544b27c319 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Unix.cs @@ -32,14 +32,14 @@ public void CompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddre if (errorCode == SocketError.Success) { - Debug.Assert(_listenSocket.RightEndPoint != null); + Debug.Assert(_listenSocket._rightEndPoint != null); - Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket.RightEndPoint); + Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket._rightEndPoint); System.Buffer.BlockCopy(socketAddress, 0, remoteSocketAddress.Buffer, 0, socketAddressLen); _acceptedSocket = _listenSocket.CreateAcceptSocket( SocketPal.CreateSocket(acceptedFileDescriptor), - _listenSocket.RightEndPoint.Create(remoteSocketAddress)); + _listenSocket._rightEndPoint.Create(remoteSocketAddress)); } base.CompletionCallback(0, errorCode); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs index b3d009474d254..d85b2bfc8c4af 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/AcceptOverlappedAsyncResult.Windows.cs @@ -25,7 +25,7 @@ internal sealed partial class AcceptOverlappedAsyncResult : BaseOverlappedAsyncR if (NetEventSource.Log.IsEnabled()) LogBuffer(numBytes); // get the endpoint - remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket.RightEndPoint!); + remoteSocketAddress = IPEndPointExtensions.Serialize(_listenSocket._rightEndPoint!); IntPtr localAddr; int localAddrLength; @@ -86,7 +86,7 @@ internal sealed partial class AcceptOverlappedAsyncResult : BaseOverlappedAsyncR return null; } - return _listenSocket.UpdateAcceptSocket(_acceptSocket!, _listenSocket.RightEndPoint!.Create(remoteSocketAddress!)); + return _listenSocket.UpdateAcceptSocket(_acceptSocket!, _listenSocket._rightEndPoint!.Create(remoteSocketAddress!)); } // SetUnmanagedStructures diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index 28396fc6ba356..d9c64a5471278 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -66,7 +66,7 @@ public Socket(SocketInformation socketInformation) if (errorCode == SocketError.Success) { - RightEndPoint = ep.Create(socketAddress); + _rightEndPoint = ep.Create(socketAddress); } else if (errorCode == SocketError.InvalidArgument) { @@ -197,7 +197,7 @@ internal bool DisconnectExBlocking(SafeSocketHandle socketHandle, IntPtr overlap partial void WildcardBindForConnectIfNecessary(AddressFamily addressFamily) { - if (RightEndPoint != null) + if (_rightEndPoint != null) { return; } @@ -352,7 +352,7 @@ private Socket GetOrCreateAcceptSocket(Socket? acceptSocket, bool checkDisconnec { acceptSocket = new Socket(_addressFamily, _socketType, _protocolType); } - else if (acceptSocket.RightEndPoint != null && (!checkDisconnected || !acceptSocket._isDisconnected)) + else if (acceptSocket._rightEndPoint != null && (!checkDisconnected || !acceptSocket._isDisconnected)) { throw new InvalidOperationException(SR.Format(SR.net_sockets_namedmustnotbebound, propertyName)); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 8e0140ffee499..fae55d3061fb5 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -24,19 +24,10 @@ public partial class Socket : IDisposable private SafeSocketHandle _handle; - // RightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the + // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the // correct type (IPEndPoint, etc). - private EndPoint? _rightEndPoint; - internal EndPoint? RightEndPoint - { - get { return _rightEndPoint; } - private set - { - _rightEndPoint = value; - _localEndPoint = null; - } - } - private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Should be cleared on any RightEndPoint change + internal EndPoint? _rightEndPoint; + private EndPoint? _localEndPoint; // Cached LocalEndPoint value internal EndPoint? _remoteEndPoint; // These flags monitor if the socket was ever connected at any time and if it still is. @@ -55,8 +46,8 @@ private set private bool _nonBlockingConnectInProgress; // Keep track of the kind of endpoint used to do a non-blocking connect, so we can set - // it to RightEndPoint when we discover we're connected. - private EndPoint? _nonBlockingConnectRightEndPoint; + // it to _rightEndPoint when we discover we're connected. + private EndPoint? _nonBlockingConnect_rightEndPoint; // These are constants initialized by constructor. private AddressFamily _addressFamily; @@ -172,7 +163,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) switch (_addressFamily) { case AddressFamily.InterNetwork: - RightEndPoint = new IPEndPoint( + _rightEndPoint = new IPEndPoint( new IPAddress((long)SocketAddressPal.GetIPv4Address(buffer.Slice(0, bufferLength)) & 0x0FFFFFFFF), SocketAddressPal.GetPort(buffer)); break; @@ -180,20 +171,20 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) case AddressFamily.InterNetworkV6: Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; SocketAddressPal.GetIPv6Address(buffer.Slice(0, bufferLength), address, out uint scope); - RightEndPoint = new IPEndPoint( + _rightEndPoint = new IPEndPoint( new IPAddress(address, scope), SocketAddressPal.GetPort(buffer)); break; case AddressFamily.Unix: socketAddress = new Internals.SocketAddress(_addressFamily, buffer.Slice(0, bufferLength)); - RightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); + _rightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); break; } // Try to determine if we're connected, based on querying for a peer, just as we would in RemoteEndPoint, // but ignoring any failures; this is best-effort (RemoteEndPoint also does a catch-all around the Create call). - if (RightEndPoint != null) + if (_rightEndPoint != null) { try { @@ -326,18 +317,18 @@ public EndPoint? LocalEndPoint { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - RightEndPoint = _nonBlockingConnectRightEndPoint; + _rightEndPoint = _nonBlockingConnect_rightEndPoint; _nonBlockingConnectInProgress = false; } - if (_localEndPoint == null) + if (_rightEndPoint == null) { - if (RightEndPoint == null) - { - return null; - } + return null; + } - Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(RightEndPoint); + if (_localEndPoint == null) + { + Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); unsafe { @@ -352,7 +343,7 @@ public EndPoint? LocalEndPoint } } } - _localEndPoint = RightEndPoint.Create(socketAddress); + _localEndPoint = _rightEndPoint.Create(socketAddress); } return _localEndPoint; @@ -372,19 +363,19 @@ public EndPoint? RemoteEndPoint { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - RightEndPoint = _nonBlockingConnectRightEndPoint; + _rightEndPoint = _nonBlockingConnect_rightEndPoint; _nonBlockingConnectInProgress = false; } - if (RightEndPoint == null || !_isConnected) + if (_rightEndPoint == null || !_isConnected) { return null; } Internals.SocketAddress socketAddress = _addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6 ? - IPEndPointExtensions.Serialize(RightEndPoint) : - new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size than RightEndPoint. + IPEndPointExtensions.Serialize(_rightEndPoint) : + new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size than _rightEndPoint. // This may throw ObjectDisposedException. SocketError errorCode = SocketPal.GetPeerName( @@ -399,7 +390,7 @@ public EndPoint? RemoteEndPoint try { - _remoteEndPoint = RightEndPoint.Create(socketAddress); + _remoteEndPoint = _rightEndPoint.Create(socketAddress); } catch { @@ -483,7 +474,7 @@ public bool Connected { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - RightEndPoint = _nonBlockingConnectRightEndPoint; + _rightEndPoint = _nonBlockingConnect_rightEndPoint; _nonBlockingConnectInProgress = false; } @@ -522,7 +513,7 @@ public bool IsBound { get { - return (RightEndPoint != null); + return (_rightEndPoint != null); } } @@ -842,10 +833,10 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd UpdateStatusAfterSocketErrorAndThrowException(errorCode); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } } @@ -899,7 +890,7 @@ public void Connect(EndPoint remoteEP) if (!Blocking) { - _nonBlockingConnectRightEndPoint = remoteEP; + _nonBlockingConnect_rightEndPoint = remoteEP; _nonBlockingConnectInProgress = true; } @@ -1078,7 +1069,7 @@ public Socket Accept() ThrowIfDisposed(); - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -1098,7 +1089,7 @@ public Socket Accept() Internals.SocketAddress socketAddress = _addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6 ? - IPEndPointExtensions.Serialize(RightEndPoint) : + IPEndPointExtensions.Serialize(_rightEndPoint) : new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size. // This may throw ObjectDisposedException. @@ -1119,7 +1110,7 @@ public Socket Accept() Debug.Assert(!acceptedSocketHandle.IsInvalid); - Socket socket = CreateAcceptSocket(acceptedSocketHandle, RightEndPoint.Create(socketAddress)); + Socket socket = CreateAcceptSocket(acceptedSocketHandle, _rightEndPoint.Create(socketAddress)); if (NetEventSource.Log.IsEnabled()) NetEventSource.Accepted(socket, socket.RemoteEndPoint!, socket.LocalEndPoint); return socket; } @@ -1337,10 +1328,10 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, UpdateStatusAfterSocketErrorAndThrowException(errorCode); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - RightEndPoint = remoteEP; + _rightEndPoint = remoteEP; } if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer, offset, size); @@ -1534,7 +1525,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla { throw new ArgumentOutOfRangeException(nameof(size)); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -1574,10 +1565,10 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla catch { } - if (RightEndPoint == null) + if (_rightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } } @@ -1613,7 +1604,7 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl { throw new ArgumentOutOfRangeException(nameof(size)); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -1658,10 +1649,10 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl catch { } - if (RightEndPoint == null) + if (_rightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } } @@ -2082,7 +2073,7 @@ private bool CanUseConnectEx(EndPoint remoteEP) // Unix sockets are not supported by ConnectEx. return (_socketType == SocketType.Stream) && - (RightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)) && + (_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)) && (remoteEP.AddressFamily != AddressFamily.Unix); } @@ -2696,17 +2687,16 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock { if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); - EndPoint? oldEndPoint = RightEndPoint; - EndPoint? oldLocalEndPoint = _localEndPoint; + EndPoint? oldEndPoint = _rightEndPoint; // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to // avoid a Socket leak in case of error. SocketError errorCode = SocketError.SocketError; try { - if (RightEndPoint == null) + if (_rightEndPoint == null) { - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } errorCode = SocketPal.SendToAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -2715,8 +2705,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock } catch (ObjectDisposedException) { - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw; } @@ -2725,8 +2714,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock { UpdateSendSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw new SocketException((int)errorCode); } @@ -3042,7 +3030,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, { throw new ArgumentOutOfRangeException(nameof(size)); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3054,8 +3042,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, asyncResult.StartPostingAsyncOp(false); // Start the ReceiveFrom. - EndPoint oldEndPoint = RightEndPoint; - EndPoint? oldLocalEndPoint = _localEndPoint; + EndPoint oldEndPoint = _rightEndPoint; // We don't do a CAS demand here because the contents of remoteEP aren't used by // WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress @@ -3072,9 +3059,9 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, SetReceivingPacketInformation(); - if (RightEndPoint == null) + if (_rightEndPoint == null) { - RightEndPoint = remoteEP; + _rightEndPoint = remoteEP; } errorCode = SocketPal.ReceiveMessageFromAsync(this, _handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -3096,8 +3083,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, } catch (ObjectDisposedException) { - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw; } @@ -3106,8 +3092,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw new SocketException((int)errorCode); } @@ -3241,7 +3226,7 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket { throw new ArgumentOutOfRangeException(nameof(size)); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3279,8 +3264,7 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint endPointSnapshot, Internals.SocketAddress socketAddress, OriginalAddressOverlappedAsyncResult asyncResult) { - EndPoint? oldEndPoint = RightEndPoint; - EndPoint? oldLocalEndPoint = _localEndPoint; + EndPoint? oldEndPoint = _rightEndPoint; if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); @@ -3292,9 +3276,9 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags // Save a copy of the original EndPoint in the asyncResult. asyncResult.SocketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); - if (RightEndPoint == null) + if (_rightEndPoint == null) { - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } errorCode = SocketPal.ReceiveFromAsync(_handle, buffer, offset, size, socketFlags, socketAddress, asyncResult); @@ -3303,8 +3287,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags } catch (ObjectDisposedException) { - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw; } @@ -3313,8 +3296,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw new SocketException((int)errorCode); } @@ -3455,7 +3437,7 @@ public IAsyncResult BeginAccept(Socket? acceptSocket, int receiveSize, AsyncCall private void DoBeginAccept(Socket? acceptSocket, int receiveSize, AcceptOverlappedAsyncResult asyncResult) { - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3470,7 +3452,7 @@ private void DoBeginAccept(Socket? acceptSocket, int receiveSize, AcceptOverlapp if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"AcceptSocket:{acceptSocket}"); - int socketAddressSize = GetAddressSize(RightEndPoint); + int socketAddressSize = GetAddressSize(_rightEndPoint); SocketError errorCode = SocketPal.AcceptAsync(this, _handle, acceptHandle, receiveSize, socketAddressSize, asyncResult); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"AcceptAsync returns:{errorCode} {asyncResult}"); @@ -3587,7 +3569,7 @@ public bool AcceptAsync(SocketAsyncEventArgs e) { throw new ArgumentException(SR.net_multibuffernotsupported, nameof(e)); } - if (RightEndPoint == null) + if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); } @@ -3693,12 +3675,11 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily); - // Save the old RightEndPoint and prep new RightEndPoint. - EndPoint? oldEndPoint = RightEndPoint; - EndPoint? oldLocalEndPoint = _localEndPoint; - if (RightEndPoint == null) + // Save the old _rightEndPoint and prep new _rightEndPoint. + EndPoint? oldEndPoint = _rightEndPoint; + if (_rightEndPoint == null) { - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } // Prepare for the native call. @@ -3721,8 +3702,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) } catch { - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; // Clear in-use flag on event args object. e.Complete(); @@ -4038,11 +4018,10 @@ public bool SendToAsync(SocketAsyncEventArgs e) // Prepare for and make the native call. e.StartOperationCommon(this, SocketAsyncOperation.SendTo); - EndPoint? oldEndPoint = RightEndPoint; - EndPoint? oldLocalEndPoint = _localEndPoint; - if (RightEndPoint == null) + EndPoint? oldEndPoint = _rightEndPoint; + if (_rightEndPoint == null) { - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } SocketError socketError; @@ -4052,7 +4031,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) } catch { - RightEndPoint = null; + _rightEndPoint = null; // Clear in-use flag on event args object. e.Complete(); throw; @@ -4060,8 +4039,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) if (!CheckErrorAndUpdateStatus(socketError)) { - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; } return socketError == SocketError.IOPending; @@ -4176,10 +4154,10 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.ConnectStop(); - if (RightEndPoint == null) + if (_rightEndPoint == null) { // Save a copy of the EndPoint so we can use it for Create(). - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}"); @@ -4342,7 +4320,7 @@ internal void SetReceivingPacketInformation() { // DualMode: When bound to IPv6Any you must enable both socket options. // When bound to an IPv4 mapped IPv6 address you must enable the IPv4 socket option. - IPEndPoint? ipEndPoint = RightEndPoint as IPEndPoint; + IPEndPoint? ipEndPoint = _rightEndPoint as IPEndPoint; IPAddress? boundAddress = (ipEndPoint != null ? ipEndPoint.Address : null); Debug.Assert(boundAddress != null, "Not Bound"); if (_addressFamily == AddressFamily.InterNetwork) @@ -4557,11 +4535,10 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa asyncResult.StartPostingAsyncOp(false); } - EndPoint? oldEndPoint = RightEndPoint; - EndPoint? oldLocalEndPoint = _localEndPoint; - if (RightEndPoint == null) + EndPoint? oldEndPoint = _rightEndPoint; + if (_rightEndPoint == null) { - RightEndPoint = endPointSnapshot; + _rightEndPoint = endPointSnapshot; } SocketError errorCode; @@ -4571,9 +4548,8 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa } catch { - // RightEndPoint will always equal oldEndPoint. - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + // _rightEndPoint will always equal oldEndPoint. + _rightEndPoint = oldEndPoint; throw; } @@ -4589,8 +4565,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa { UpdateConnectSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. - RightEndPoint = oldEndPoint; - _localEndPoint = oldLocalEndPoint; + _rightEndPoint = oldEndPoint; throw new SocketException((int)errorCode); } @@ -4811,7 +4786,7 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._addressFamily = _addressFamily; socket._socketType = _socketType; socket._protocolType = _protocolType; - socket.RightEndPoint = RightEndPoint; + socket._rightEndPoint = _rightEndPoint; socket._remoteEndPoint = remoteEP; // The socket is connected. diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index f4e7c1c9381b2..09d688d9e0c70 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -351,7 +351,7 @@ private SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAd System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.Buffer, 0, _acceptAddressBufferCount); _acceptSocket = _currentSocket!.CreateAcceptSocket( SocketPal.CreateSocket(_acceptedFileDescriptor), - _currentSocket.RightEndPoint!.Create(remoteSocketAddress)); + _currentSocket._rightEndPoint!.Create(remoteSocketAddress)); return SocketError.Success; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index 46efa4fedbdfd..213cb30909b99 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -524,7 +524,7 @@ internal void StartOperationAccept() // AcceptEx needs a single buffer that's the size of two native sockaddr buffers with 16 // extra bytes each. It can also take additional buffer space in front of those special // sockaddr structures that can be filled in with initial data coming in on a connection. - _acceptAddressBufferCount = 2 * (Socket.GetAddressSize(_currentSocket!.RightEndPoint!) + 16); + _acceptAddressBufferCount = 2 * (Socket.GetAddressSize(_currentSocket!._rightEndPoint!) + 16); // If our caller specified a buffer (willing to get received data with the Accept) then // it needs to be large enough for the two special sockaddr buffers that AcceptEx requires. @@ -692,13 +692,13 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { case SocketAsyncOperation.Accept: // Get the endpoint. - Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_currentSocket!.RightEndPoint!); + Internals.SocketAddress remoteSocketAddress = IPEndPointExtensions.Serialize(_currentSocket!._rightEndPoint!); socketError = FinishOperationAccept(remoteSocketAddress); if (socketError == SocketError.Success) { - _acceptSocket = _currentSocket.UpdateAcceptSocket(_acceptSocket!, _currentSocket.RightEndPoint!.Create(remoteSocketAddress)); + _acceptSocket = _currentSocket.UpdateAcceptSocket(_acceptSocket!, _currentSocket._rightEndPoint!.Create(remoteSocketAddress)); if (NetEventSource.Log.IsEnabled()) { From 999d59690b78df209eed1445ec5a91d0b64ceb3b Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 23 Jul 2020 18:18:31 +0300 Subject: [PATCH 04/16] Fix _nonBlockingConnectRightEndPoint --- .../src/System/Net/Sockets/Socket.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index fae55d3061fb5..60305b4e4723b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -47,7 +47,7 @@ public partial class Socket : IDisposable // Keep track of the kind of endpoint used to do a non-blocking connect, so we can set // it to _rightEndPoint when we discover we're connected. - private EndPoint? _nonBlockingConnect_rightEndPoint; + private EndPoint? _nonBlockingConnectRightEndPoint; // These are constants initialized by constructor. private AddressFamily _addressFamily; @@ -317,7 +317,7 @@ public EndPoint? LocalEndPoint { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - _rightEndPoint = _nonBlockingConnect_rightEndPoint; + _rightEndPoint = _nonBlockingConnectRightEndPoint; _nonBlockingConnectInProgress = false; } @@ -363,7 +363,7 @@ public EndPoint? RemoteEndPoint { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - _rightEndPoint = _nonBlockingConnect_rightEndPoint; + _rightEndPoint = _nonBlockingConnectRightEndPoint; _nonBlockingConnectInProgress = false; } @@ -474,7 +474,7 @@ public bool Connected { // Update the state if we've become connected after a non-blocking connect. _isConnected = true; - _rightEndPoint = _nonBlockingConnect_rightEndPoint; + _rightEndPoint = _nonBlockingConnectRightEndPoint; _nonBlockingConnectInProgress = false; } @@ -890,7 +890,7 @@ public void Connect(EndPoint remoteEP) if (!Blocking) { - _nonBlockingConnect_rightEndPoint = remoteEP; + _nonBlockingConnectRightEndPoint = remoteEP; _nonBlockingConnectInProgress = true; } @@ -3675,7 +3675,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily); - // Save the old _rightEndPoint and prep new _rightEndPoint. + // Save the old RightEndPoint and prep new RightEndPoint. EndPoint? oldEndPoint = _rightEndPoint; if (_rightEndPoint == null) { From d715a3114abf354dbb2ba591027db2d76ace27e9 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 4 Aug 2020 16:09:13 +0300 Subject: [PATCH 05/16] Add clear on error --- .../src/System/Net/Sockets/Socket.cs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 89881379f5df5..b79ec393773aa 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -27,7 +27,7 @@ public partial class Socket : IDisposable // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the // correct type (IPEndPoint, etc). internal EndPoint? _rightEndPoint; - private EndPoint? _localEndPoint; // Cached LocalEndPoint value + private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Will clear on error and on disconnect internal EndPoint? _remoteEndPoint; // These flags monitor if the socket was ever connected at any time and if it still is. @@ -2308,6 +2308,7 @@ private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult { SetToDisconnected(); _remoteEndPoint = null; + _localEndPoint = null; } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}"); @@ -2337,6 +2338,7 @@ public void Disconnect(bool reuseSocket) SetToDisconnected(); _remoteEndPoint = null; + _localEndPoint = null; } // Routine Description: @@ -2765,6 +2767,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -2774,6 +2777,7 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock UpdateSendSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -3153,6 +3157,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -3162,6 +3167,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -3362,6 +3368,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags catch (ObjectDisposedException) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -3371,6 +3378,7 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -3793,6 +3801,7 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket) catch { _rightEndPoint = oldEndPoint; + _localEndPoint = null; // Clear in-use flag on event args object. e.Complete(); @@ -4122,6 +4131,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) catch { _rightEndPoint = null; + _localEndPoint = null; // Clear in-use flag on event args object. e.Complete(); throw; @@ -4130,6 +4140,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) if (!CheckErrorAndUpdateStatus(socketError)) { _rightEndPoint = oldEndPoint; + _localEndPoint = null; } return socketError == SocketError.IOPending; @@ -4640,6 +4651,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa { // _rightEndPoint will always equal oldEndPoint. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw; } @@ -4656,6 +4668,7 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa UpdateConnectSocketErrorForDisposed(ref errorCode); // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; + _localEndPoint = null; throw new SocketException((int)errorCode); } @@ -4878,6 +4891,7 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._protocolType = _protocolType; socket._rightEndPoint = _rightEndPoint; socket._remoteEndPoint = remoteEP; + socket._localEndPoint = null; // The socket is connected. socket.SetToConnected(); From c3eb4ec65e87dc31b3f9f193c521ad6f24bed676 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Wed, 5 Aug 2020 20:32:00 +0300 Subject: [PATCH 06/16] Add clear on connect and tests --- .../src/System/Net/Sockets/Socket.cs | 8 +- .../FunctionalTests/LocalEndPointTest.cs | 140 ++++++++++++++++++ .../System.Net.Sockets.Tests.csproj | 1 + 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index b79ec393773aa..18b88de30759c 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -27,7 +27,7 @@ public partial class Socket : IDisposable // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the // correct type (IPEndPoint, etc). internal EndPoint? _rightEndPoint; - private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Will clear on error and on disconnect + private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Will clear on connect, error and disconnect internal EndPoint? _remoteEndPoint; // These flags monitor if the socket was ever connected at any time and if it still is. @@ -216,6 +216,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) } _isConnected = true; + _localEndPoint = null; break; case SocketError.InvalidArgument: @@ -225,6 +226,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) // whether we're actually connected or not, err on the side of saying // we're connected. _isConnected = true; + _localEndPoint = null; break; } } @@ -318,6 +320,7 @@ public EndPoint? LocalEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + _localEndPoint = null; _nonBlockingConnectInProgress = false; } @@ -364,6 +367,7 @@ public EndPoint? RemoteEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + _localEndPoint = null; _nonBlockingConnectInProgress = false; } @@ -475,6 +479,7 @@ public bool Connected // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; + _localEndPoint = null; _nonBlockingConnectInProgress = false; } @@ -4924,6 +4929,7 @@ internal void SetToConnected() // some point in time update the perf counter as well. _isConnected = true; _isDisconnected = false; + _localEndPoint = null; if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected"); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs new file mode 100644 index 0000000000000..c97808d39b44b --- /dev/null +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -0,0 +1,140 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Sockets.Tests +{ + public class LocalEndPointTest + { + [Fact] + public void UdpSocket_BoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + { + server.Bind(new IPEndPoint(IPAddress.Any, 0)); + + Assert.Null(client.LocalEndPoint); + + client.Bind(new IPEndPoint(IPAddress.Any, 0)); + + IPEndPoint localEPAfterBind = (IPEndPoint) client.LocalEndPoint; + Assert.Equal(IPAddress.Any, localEPAfterBind.Address); + int portAfterBind = localEPAfterBind.Port; + + var sendToEP = new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port); + + client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); + + Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(portAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + + byte[] buf = new byte[3]; + EndPoint receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + server.ReceiveFrom(buf, ref receiveFromEP); + + Assert.Equal(new byte[] { 1, 2, 3 }, buf); + Assert.Equal(portAfterBind, ((IPEndPoint)receiveFromEP).Port); + + IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); + sendToResult.AsyncWaitHandle.WaitOne(); + client.EndSendTo(sendToResult); + + Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(portAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + + buf = new byte[3]; + receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + server.ReceiveFrom(buf, ref receiveFromEP); + + Assert.Equal(new byte[] { 4, 5, 6 }, buf); + Assert.Equal(portAfterBind, ((IPEndPoint)receiveFromEP).Port); + } + } + + [Fact] + public void UdpSocket_NotBound_LocalEPBecomesWildcardAddressOnSendTo() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + { + server.Bind(new IPEndPoint(IPAddress.Any, 0)); + + Assert.Null(client.LocalEndPoint); + + var sendToEP = new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port); + + client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); + + Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + + byte[] buf = new byte[3]; + EndPoint receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + server.ReceiveFrom(buf, ref receiveFromEP); + + Assert.Equal(new byte[] { 1, 2, 3 }, buf); + + IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); + sendToResult.AsyncWaitHandle.WaitOne(); + client.EndSendTo(sendToResult); + + Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + + buf = new byte[3]; + receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + server.ReceiveFrom(buf, ref receiveFromEP); + + Assert.Equal(new byte[] { 4, 5, 6 }, buf); + } + } + + [Fact] + public async Task TcpSocket_BoundToWildcardAddress_LocalEPChangeToSpecificOnConnnect() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + server.Bind(new IPEndPoint(IPAddress.Any, 0)); + client.Bind(new IPEndPoint(IPAddress.Any, 0)); + + IPEndPoint localEPAfterBind = (IPEndPoint)client.LocalEndPoint; + Assert.Equal(IPAddress.Any, localEPAfterBind.Address); + int portAfterBind = localEPAfterBind.Port; + + server.Listen(); + Task acceptTask = server.AcceptAsync(); + + client.Connect(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port)); + + Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(portAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + } + } + + [Fact] + public async Task TcpSocket_NotBound_LocalEPChangeToSpecificOnConnnect() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + server.Bind(new IPEndPoint(IPAddress.Any, 0)); + server.Listen(); + Task acceptTask = server.AcceptAsync(); + + Assert.Null(client.LocalEndPoint); + + client.Connect(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port)); + + Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + } + } + } +} diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index ef710218ea2ad..c31b56acf40ec 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -23,6 +23,7 @@ + From 324e06062959cca5272735d74e2e7b6a53dea90e Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Wed, 5 Aug 2020 21:29:16 +0300 Subject: [PATCH 07/16] Add caching test --- .../tests/FunctionalTests/LocalEndPointTest.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index c97808d39b44b..bca436fb4b089 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -136,5 +136,19 @@ public async Task TcpSocket_NotBound_LocalEPChangeToSpecificOnConnnect() Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); } } + + [Fact] + public void LocalEndPoint_IsCached() + { + using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Bind(new IPEndPoint(IPAddress.Any, 0)); + + EndPoint localEndPointCall1 = socket.LocalEndPoint; + EndPoint localEndPointCall2 = socket.LocalEndPoint; + + Assert.Same(localEndPointCall1, localEndPointCall2); + } + } } } From 283f11b5794e08819c6c60bbe53f5dbe1cada5eb Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 11 Aug 2020 15:57:29 +0300 Subject: [PATCH 08/16] Use BindToAnonymousPort --- .../FunctionalTests/LocalEndPointTest.cs | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index bca436fb4b089..30c4af46ca723 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -14,43 +14,41 @@ public void UdpSocket_BoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) { - server.Bind(new IPEndPoint(IPAddress.Any, 0)); + int serverPort = server.BindToAnonymousPort(IPAddress.Any); Assert.Null(client.LocalEndPoint); - client.Bind(new IPEndPoint(IPAddress.Any, 0)); + int clientPortAfterBind = client.BindToAnonymousPort(IPAddress.Any); - IPEndPoint localEPAfterBind = (IPEndPoint) client.LocalEndPoint; - Assert.Equal(IPAddress.Any, localEPAfterBind.Address); - int portAfterBind = localEPAfterBind.Port; + Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); - var sendToEP = new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port); + var sendToEP = new IPEndPoint(IPAddress.Loopback, serverPort); client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); - Assert.Equal(portAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + Assert.Equal(clientPortAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); byte[] buf = new byte[3]; EndPoint receiveFromEP = new IPEndPoint(IPAddress.Any, 0); server.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 1, 2, 3 }, buf); - Assert.Equal(portAfterBind, ((IPEndPoint)receiveFromEP).Port); + Assert.Equal(clientPortAfterBind, ((IPEndPoint)receiveFromEP).Port); IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); sendToResult.AsyncWaitHandle.WaitOne(); client.EndSendTo(sendToResult); Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); - Assert.Equal(portAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + Assert.Equal(clientPortAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); buf = new byte[3]; receiveFromEP = new IPEndPoint(IPAddress.Any, 0); server.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 4, 5, 6 }, buf); - Assert.Equal(portAfterBind, ((IPEndPoint)receiveFromEP).Port); + Assert.Equal(clientPortAfterBind, ((IPEndPoint)receiveFromEP).Port); } } @@ -60,11 +58,11 @@ public void UdpSocket_NotBound_LocalEPBecomesWildcardAddressOnSendTo() using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) { - server.Bind(new IPEndPoint(IPAddress.Any, 0)); + int serverPort = server.BindToAnonymousPort(IPAddress.Any); Assert.Null(client.LocalEndPoint); - var sendToEP = new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port); + var sendToEP = new IPEndPoint(IPAddress.Loopback, serverPort); client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); @@ -96,20 +94,18 @@ public async Task TcpSocket_BoundToWildcardAddress_LocalEPChangeToSpecificOnConn using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { - server.Bind(new IPEndPoint(IPAddress.Any, 0)); - client.Bind(new IPEndPoint(IPAddress.Any, 0)); + int serverPort = server.BindToAnonymousPort(IPAddress.Any); + int clientPortAfterBind = client.BindToAnonymousPort(IPAddress.Any); - IPEndPoint localEPAfterBind = (IPEndPoint)client.LocalEndPoint; - Assert.Equal(IPAddress.Any, localEPAfterBind.Address); - int portAfterBind = localEPAfterBind.Port; + Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); server.Listen(); Task acceptTask = server.AcceptAsync(); - client.Connect(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port)); + client.Connect(new IPEndPoint(IPAddress.Loopback, serverPort)); Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); - Assert.Equal(portAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + Assert.Equal(clientPortAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); @@ -122,13 +118,13 @@ public async Task TcpSocket_NotBound_LocalEPChangeToSpecificOnConnnect() using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { - server.Bind(new IPEndPoint(IPAddress.Any, 0)); + int serverPort = server.BindToAnonymousPort(IPAddress.Any); server.Listen(); Task acceptTask = server.AcceptAsync(); Assert.Null(client.LocalEndPoint); - client.Connect(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)server.LocalEndPoint).Port)); + client.Connect(new IPEndPoint(IPAddress.Loopback, serverPort); Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); @@ -142,7 +138,7 @@ public void LocalEndPoint_IsCached() { using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { - socket.Bind(new IPEndPoint(IPAddress.Any, 0)); + socket.BindToAnonymousPort(IPAddress.Any); EndPoint localEndPointCall1 = socket.LocalEndPoint; EndPoint localEndPointCall2 = socket.LocalEndPoint; From 1ff3e1413d1890703ffba1e2baea4b752234d80a Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 11 Aug 2020 17:19:15 +0300 Subject: [PATCH 09/16] Fix typo --- .../tests/FunctionalTests/LocalEndPointTest.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index 30c4af46ca723..f67e1adeea434 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -124,7 +124,7 @@ public async Task TcpSocket_NotBound_LocalEPChangeToSpecificOnConnnect() Assert.Null(client.LocalEndPoint); - client.Connect(new IPEndPoint(IPAddress.Loopback, serverPort); + client.Connect(new IPEndPoint(IPAddress.Loopback, serverPort)); Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); From ec8e54e792966bf2dd6b8c27e9b7729b3374ce82 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 18 Aug 2020 21:56:00 +0300 Subject: [PATCH 10/16] Assign _localEndPoint from the listener --- .../System.Net.Sockets/src/System/Net/Sockets/Socket.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 370fc71f14d43..e89be393ce91e 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -4871,7 +4871,7 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._protocolType = _protocolType; socket._rightEndPoint = _rightEndPoint; socket._remoteEndPoint = remoteEP; - socket._localEndPoint = null; + socket._localEndPoint = _localEndPoint; // The socket is connected. socket.SetToConnected(); From 3d9385b9d3f1600328d0d8053fcfb9bb61c08b32 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 24 Aug 2020 12:37:28 +0300 Subject: [PATCH 11/16] tmp commit to check on mac --- .../src/System/Net/Sockets/Socket.cs | 47 ++++- .../FunctionalTests/LocalEndPointTest.cs | 195 ++++++++++++++---- 2 files changed, 193 insertions(+), 49 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index e89be393ce91e..1316de8734d13 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -216,7 +216,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) } _isConnected = true; - _localEndPoint = null; + HandleLocalEndPointOnConnect(); break; case SocketError.InvalidArgument: @@ -226,7 +226,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) // whether we're actually connected or not, err on the side of saying // we're connected. _isConnected = true; - _localEndPoint = null; + HandleLocalEndPointOnConnect(); break; } } @@ -320,7 +320,7 @@ public EndPoint? LocalEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; - _localEndPoint = null; + HandleLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -367,7 +367,7 @@ public EndPoint? RemoteEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; - _localEndPoint = null; + HandleLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -479,7 +479,7 @@ public bool Connected // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; - _localEndPoint = null; + HandleLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -4904,10 +4904,45 @@ internal void SetToConnected() // some point in time update the perf counter as well. _isConnected = true; _isDisconnected = false; - _localEndPoint = null; + HandleLocalEndPointOnConnect(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected"); } + private void HandleLocalEndPointOnConnect() + { + //_localEndPoint = null; + if (_localEndPoint == null) + { + return; + } + + if (_localEndPoint is IPEndPoint ipLocalEndpoint) + { + // If a client socket was bound to a wildcard address, then the `connect` system call + // will assign a specific address to the client socket's local endpoint instead of a + // wildcard address. In that case we should clear the cached wildcard local endpoint. + + // If a listener socket was bound to a wildcard address, then the `accept` system call + // will assign a specific address the accept socket's local endpoint instead of a + // wildcard address. In that case we should clear the accept socket's cached wildcard + // local endpoint copied from listener. + + if (IsWildcardAddress(ipLocalEndpoint.Address)) + { + _localEndPoint = null; + } + } + } + + private bool IsWildcardAddress(IPAddress address) + { + return address.ToString() == IPAddress.Any.ToString() + || address.ToString() == IPAddress.IPv6Any.ToString() + || address.ToString() == IPAddress.Any.MapToIPv6().ToString(); + + //return address == IPAddress.Any || address == IPAddress.IPv6Any || address == IPAddress.Any.MapToIPv6(); + } + internal void SetToDisconnected() { if (!_isConnected) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index f67e1adeea434..74e133df973f6 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -6,82 +6,104 @@ namespace System.Net.Sockets.Tests { - public class LocalEndPointTest + public abstract class LocalEndPointTest { + protected abstract bool IPv6 { get; } + + private IPAddress Wildcard => IPv6 ? IPAddress.IPv6Any : IPAddress.Any; + + private IPAddress Loopback => IPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback; + [Fact] - public void UdpSocket_BoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() + public void UdpSocket_ClientBoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() { - using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket server = CreateUdpSocket()) + using (Socket client = CreateUdpSocket()) { - int serverPort = server.BindToAnonymousPort(IPAddress.Any); + int serverPort = server.BindToAnonymousPort(Wildcard); Assert.Null(client.LocalEndPoint); - int clientPortAfterBind = client.BindToAnonymousPort(IPAddress.Any); + int clientPortAfterBind = client.BindToAnonymousPort(Wildcard); - Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard before sendto - var sendToEP = new IPEndPoint(IPAddress.Loopback, serverPort); + var sendToEP = new IPEndPoint(Loopback, serverPort); client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); - Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); - Assert.Equal(clientPortAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // stays as wildcard after sendto + Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); byte[] buf = new byte[3]; - EndPoint receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); server.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 1, 2, 3 }, buf); + Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address Assert.Equal(clientPortAfterBind, ((IPEndPoint)receiveFromEP).Port); IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); sendToResult.AsyncWaitHandle.WaitOne(); client.EndSendTo(sendToResult); - Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); - Assert.Equal(clientPortAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // stays as wildcard after async WSASendTo + Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); buf = new byte[3]; - receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + receiveFromEP = new IPEndPoint(Wildcard, 0); server.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 4, 5, 6 }, buf); + Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address Assert.Equal(clientPortAfterBind, ((IPEndPoint)receiveFromEP).Port); } } [Fact] - public void UdpSocket_NotBound_LocalEPBecomesWildcardAddressOnSendTo() + public void UdpSocket_ClientNotBound_LocalEPBecomesWildcardOnSendTo() { - using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket server = CreateUdpSocket()) + using (Socket client = CreateUdpSocket()) { - int serverPort = server.BindToAnonymousPort(IPAddress.Any); + int serverPort = server.BindToAnonymousPort(Wildcard); - Assert.Null(client.LocalEndPoint); + Assert.Null(client.LocalEndPoint); // null before sendto - var sendToEP = new IPEndPoint(IPAddress.Loopback, serverPort); + var sendToEP = new IPEndPoint(Loopback, serverPort); client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); - Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard after sendto byte[] buf = new byte[3]; - EndPoint receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); server.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 1, 2, 3 }, buf); + } + } + + [Fact] + public void UdpSocket_ClientNotBound_LocalEPBecomesWildcardOnAsyncSendTo() + { + using (Socket server = CreateUdpSocket()) + using (Socket client = CreateUdpSocket()) + { + int serverPort = server.BindToAnonymousPort(Wildcard); + + Assert.Null(client.LocalEndPoint); // null before async WSASendTo + + var sendToEP = new IPEndPoint(Loopback, serverPort); IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); sendToResult.AsyncWaitHandle.WaitOne(); client.EndSendTo(sendToResult); - Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard after async WSASendTo - buf = new byte[3]; - receiveFromEP = new IPEndPoint(IPAddress.Any, 0); + byte[] buf = new byte[3]; + EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); server.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 4, 5, 6 }, buf); @@ -89,23 +111,23 @@ public void UdpSocket_NotBound_LocalEPBecomesWildcardAddressOnSendTo() } [Fact] - public async Task TcpSocket_BoundToWildcardAddress_LocalEPChangeToSpecificOnConnnect() + public async Task TcpSocket_ClientBoundToWildcardAddress_LocalEPChangeToSpecificOnConnnect() { - using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) { - int serverPort = server.BindToAnonymousPort(IPAddress.Any); - int clientPortAfterBind = client.BindToAnonymousPort(IPAddress.Any); + int serverPort = server.BindToAnonymousPort(Wildcard); + int clientPortAfterBind = client.BindToAnonymousPort(Wildcard); - Assert.Equal(IPAddress.Any, ((IPEndPoint)client.LocalEndPoint).Address); + Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard before connect server.Listen(); Task acceptTask = server.AcceptAsync(); - client.Connect(new IPEndPoint(IPAddress.Loopback, serverPort)); + client.Connect(new IPEndPoint(Loopback, serverPort)); - Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); - Assert.Equal(clientPortAfterBind, ((IPEndPoint)client.LocalEndPoint).Port); + Assert.Equal(Loopback, GetLocalEPAddress(client)); // specific after connect + Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); @@ -113,32 +135,79 @@ public async Task TcpSocket_BoundToWildcardAddress_LocalEPChangeToSpecificOnConn } [Fact] - public async Task TcpSocket_NotBound_LocalEPChangeToSpecificOnConnnect() + public async Task TcpSocket_ClientNotBound_LocalEPChangeToSpecificOnConnnect() { - using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) { - int serverPort = server.BindToAnonymousPort(IPAddress.Any); + int serverPort = server.BindToAnonymousPort(Loopback); server.Listen(); Task acceptTask = server.AcceptAsync(); - Assert.Null(client.LocalEndPoint); + Assert.Null(client.LocalEndPoint); // null before connect + + client.Connect(new IPEndPoint(Loopback, serverPort)); - client.Connect(new IPEndPoint(IPAddress.Loopback, serverPort)); + Assert.Equal(Loopback, GetLocalEPAddress(client)); // specific after connect - Assert.Equal(IPAddress.Loopback, ((IPEndPoint)client.LocalEndPoint).Address); + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + } + } + + [Fact] + public async Task TcpSocket_ServerBoundToWildcardAddress_AcceptSocketLocalEPIsSpecific() + { + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) + { + int serverPort = server.BindToAnonymousPort(Wildcard); + + Assert.Equal(Wildcard, GetLocalEPAddress(server)); // server -> wildcard before accept + + server.Listen(); + Task acceptTask = server.AcceptAsync(); + + client.Connect(new IPEndPoint(Loopback, serverPort)); Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + + Assert.Equal(Wildcard, GetLocalEPAddress(server)); // server -> stays as wildcard + Assert.Equal(Loopback, GetLocalEPAddress(accept)); // accept -> specific + Assert.Equal(serverPort, GetLocalEPPort(accept)); + } + } + + [Fact] + public async Task TcpSocket_ServerBoundToSpecificAddress_AcceptSocketLocalEPIsSame() + { + using (Socket server = CreateTcpSocket()) + using (Socket client = CreateTcpSocket()) + { + int serverPort = server.BindToAnonymousPort(Loopback); + + Assert.Equal(Loopback, GetLocalEPAddress(server)); // server -> specific before accept + + server.Listen(); + Task acceptTask = server.AcceptAsync(); + + client.Connect(new IPEndPoint(Loopback, serverPort)); + + Socket accept = await acceptTask; + Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + + Assert.Equal(GetLocalEPAddress(server), GetLocalEPAddress(accept)); // accept -> same address + Assert.Equal(serverPort, GetLocalEPPort(accept)); } } [Fact] public void LocalEndPoint_IsCached() { - using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (Socket socket = CreateTcpSocket()) { - socket.BindToAnonymousPort(IPAddress.Any); + socket.BindToAnonymousPort(Loopback); EndPoint localEndPointCall1 = socket.LocalEndPoint; EndPoint localEndPointCall2 = socket.LocalEndPoint; @@ -146,5 +215,45 @@ public void LocalEndPoint_IsCached() Assert.Same(localEndPointCall1, localEndPointCall2); } } + + private Socket CreateUdpSocket() + { + return new Socket( + IPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, + SocketType.Dgram, + ProtocolType.Udp + ); + } + + private Socket CreateTcpSocket() + { + return new Socket( + IPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, + SocketType.Stream, + ProtocolType.Tcp + ); + } + + private IPAddress GetLocalEPAddress(Socket socket) + { + return ((IPEndPoint)socket.LocalEndPoint).Address; + } + + private int GetLocalEPPort(Socket socket) + { + return ((IPEndPoint)socket.LocalEndPoint).Port; + } + } + + [Trait("IPv4", "true")] + public class LocalEndPointIPv4Test : LocalEndPointTest + { + protected override bool IPv6 => false; + } + + [Trait("IPv6", "true")] + public class LocalEndPointIPv6Test : LocalEndPointTest + { + protected override bool IPv6 => true; } } From 4e183edb49ab3bb7174af364c4f21d1bbd7067d9 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 24 Aug 2020 14:20:49 +0300 Subject: [PATCH 12/16] tmp --- .../System.Net.Sockets/src/System/Net/Sockets/Socket.cs | 7 ++++--- .../tests/FunctionalTests/LocalEndPointTest.cs | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 1316de8734d13..054770820ba81 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -329,7 +329,7 @@ public EndPoint? LocalEndPoint return null; } - if (_localEndPoint == null) + //if (_localEndPoint == null) { Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); @@ -4936,9 +4936,10 @@ private void HandleLocalEndPointOnConnect() private bool IsWildcardAddress(IPAddress address) { - return address.ToString() == IPAddress.Any.ToString() + return true; + /*return address.ToString() == IPAddress.Any.ToString() || address.ToString() == IPAddress.IPv6Any.ToString() - || address.ToString() == IPAddress.Any.MapToIPv6().ToString(); + || address.ToString() == IPAddress.Any.MapToIPv6().ToString();*/ //return address == IPAddress.Any || address == IPAddress.IPv6Any || address == IPAddress.Any.MapToIPv6(); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index 74e133df973f6..c262c62384eb2 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -202,7 +202,7 @@ public async Task TcpSocket_ServerBoundToSpecificAddress_AcceptSocketLocalEPIsSa } } - [Fact] + //[Fact] public void LocalEndPoint_IsCached() { using (Socket socket = CreateTcpSocket()) From eeb057daa7b7a14231fe0f3f64f60852fabd85d7 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 25 Aug 2020 17:13:54 +0300 Subject: [PATCH 13/16] Fix _localEndPoint clearing on mac --- .../src/System/Net/Sockets/Socket.cs | 74 ++++++++++--------- .../FunctionalTests/LocalEndPointTest.cs | 3 +- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 054770820ba81..d9f425715d205 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -22,14 +22,22 @@ public partial class Socket : IDisposable { internal const int DefaultCloseTimeout = -1; // NOTE: changing this default is a breaking change. + private static readonly HashSet s_wildcardAddresses = new HashSet + { + IPAddress.Any, IPAddress.IPv6Any, IPAddress.Any.MapToIPv6() + }; + private SafeSocketHandle _handle; // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the // correct type (IPEndPoint, etc). internal EndPoint? _rightEndPoint; - private EndPoint? _localEndPoint; // Cached LocalEndPoint value. Will clear on connect, error and disconnect internal EndPoint? _remoteEndPoint; + // Cached LocalEndPoint value. Cleared on disconnect and error. Cached wildcard addresses are + // also cleared on connect and accept. + private EndPoint? _localEndPoint; + // These flags monitor if the socket was ever connected at any time and if it still is. private bool _isConnected; private bool _isDisconnected; @@ -216,7 +224,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) } _isConnected = true; - HandleLocalEndPointOnConnect(); + UpdateLocalEndPointOnConnect(); break; case SocketError.InvalidArgument: @@ -226,7 +234,7 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) // whether we're actually connected or not, err on the side of saying // we're connected. _isConnected = true; - HandleLocalEndPointOnConnect(); + UpdateLocalEndPointOnConnect(); break; } } @@ -320,7 +328,7 @@ public EndPoint? LocalEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; - HandleLocalEndPointOnConnect(); + UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -329,7 +337,7 @@ public EndPoint? LocalEndPoint return null; } - //if (_localEndPoint == null) + if (_localEndPoint == null) { Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); @@ -367,7 +375,7 @@ public EndPoint? RemoteEndPoint // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; - HandleLocalEndPointOnConnect(); + UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -479,7 +487,7 @@ public bool Connected // Update the state if we've become connected after a non-blocking connect. _isConnected = true; _rightEndPoint = _nonBlockingConnectRightEndPoint; - HandleLocalEndPointOnConnect(); + UpdateLocalEndPointOnConnect(); _nonBlockingConnectInProgress = false; } @@ -4871,7 +4879,12 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP) socket._protocolType = _protocolType; socket._rightEndPoint = _rightEndPoint; socket._remoteEndPoint = remoteEP; - socket._localEndPoint = _localEndPoint; + + // If the listener socket was bound to a wildcard address, then the `accept` system call + // will assign a specific address to the accept socket's local endpoint instead of a + // wildcard address. In that case we should not copy listener's wildcard local endpoint. + + socket._localEndPoint = !IsWildcardEndPoint(_localEndPoint) ? _localEndPoint : null; // The socket is connected. socket.SetToConnected(); @@ -4904,44 +4917,35 @@ internal void SetToConnected() // some point in time update the perf counter as well. _isConnected = true; _isDisconnected = false; - HandleLocalEndPointOnConnect(); + UpdateLocalEndPointOnConnect(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected"); } - private void HandleLocalEndPointOnConnect() + private void UpdateLocalEndPointOnConnect() { - //_localEndPoint = null; - if (_localEndPoint == null) - { - return; - } + // If the client socket was bound to a wildcard address, then the `connect` system call + // will assign a specific address to the client socket's local endpoint instead of a + // wildcard address. In that case we should clear the cached wildcard local endpoint. - if (_localEndPoint is IPEndPoint ipLocalEndpoint) + if (IsWildcardEndPoint(_localEndPoint)) { - // If a client socket was bound to a wildcard address, then the `connect` system call - // will assign a specific address to the client socket's local endpoint instead of a - // wildcard address. In that case we should clear the cached wildcard local endpoint. - - // If a listener socket was bound to a wildcard address, then the `accept` system call - // will assign a specific address the accept socket's local endpoint instead of a - // wildcard address. In that case we should clear the accept socket's cached wildcard - // local endpoint copied from listener. - - if (IsWildcardAddress(ipLocalEndpoint.Address)) - { - _localEndPoint = null; - } + _localEndPoint = null; } } - private bool IsWildcardAddress(IPAddress address) + private bool IsWildcardEndPoint(EndPoint? endPoint) { - return true; - /*return address.ToString() == IPAddress.Any.ToString() - || address.ToString() == IPAddress.IPv6Any.ToString() - || address.ToString() == IPAddress.Any.MapToIPv6().ToString();*/ + if (endPoint == null) + { + return false; + } - //return address == IPAddress.Any || address == IPAddress.IPv6Any || address == IPAddress.Any.MapToIPv6(); + if (endPoint is IPEndPoint ipEndpoint) + { + return s_wildcardAddresses.Contains(ipEndpoint.Address); + } + + return false; } internal void SetToDisconnected() diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index c262c62384eb2..582b315e69553 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -172,6 +172,7 @@ public async Task TcpSocket_ServerBoundToWildcardAddress_AcceptSocketLocalEPIsSp Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); + Assert.Equal(accept.LocalEndPoint, client.RemoteEndPoint); Assert.Equal(Wildcard, GetLocalEPAddress(server)); // server -> stays as wildcard Assert.Equal(Loopback, GetLocalEPAddress(accept)); // accept -> specific @@ -202,7 +203,7 @@ public async Task TcpSocket_ServerBoundToSpecificAddress_AcceptSocketLocalEPIsSa } } - //[Fact] + [Fact] public void LocalEndPoint_IsCached() { using (Socket socket = CreateTcpSocket()) From 85de50d903dc7a71a03b0ce0d17d8c2c658ff380 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Wed, 26 Aug 2020 12:57:25 +0300 Subject: [PATCH 14/16] Inline wildcard addresses --- .../System.Net.Sockets/src/System/Net/Sockets/Socket.cs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index d9f425715d205..565b9eb067760 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -22,11 +22,6 @@ public partial class Socket : IDisposable { internal const int DefaultCloseTimeout = -1; // NOTE: changing this default is a breaking change. - private static readonly HashSet s_wildcardAddresses = new HashSet - { - IPAddress.Any, IPAddress.IPv6Any, IPAddress.Any.MapToIPv6() - }; - private SafeSocketHandle _handle; // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the @@ -4942,7 +4937,8 @@ private bool IsWildcardEndPoint(EndPoint? endPoint) if (endPoint is IPEndPoint ipEndpoint) { - return s_wildcardAddresses.Contains(ipEndpoint.Address); + IPAddress address = ipEndpoint.Address; + return IPAddress.Any.Equals(address) || IPAddress.IPv6Any.Equals(address) || IPAddress.Any.MapToIPv6().Equals(address); } return false; From cbaed88282f467fc67badaca1fbdde92b08cee47 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 27 Aug 2020 17:08:59 +0300 Subject: [PATCH 15/16] PR fixes --- .../src/System/Net/Sockets/Socket.Windows.cs | 2 +- .../System.Net.Sockets/src/System/Net/Sockets/Socket.cs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index a8cb83dc651f8..20639f78c97e2 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -232,7 +232,7 @@ partial void WildcardBindForConnectIfNecessary(AddressFamily addressFamily) switch (addressFamily) { case AddressFamily.InterNetwork: - address = IsDualMode ? IPAddress.Any.MapToIPv6() : IPAddress.Any; + address = IsDualMode ? s_IPAddressAnyMapToIPv6 : IPAddress.Any; break; case AddressFamily.InterNetworkV6: diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 565b9eb067760..479ba522ace6b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -22,6 +22,8 @@ public partial class Socket : IDisposable { internal const int DefaultCloseTimeout = -1; // NOTE: changing this default is a breaking change. + private static readonly IPAddress s_IPAddressAnyMapToIPv6 = IPAddress.Any.MapToIPv6(); + private SafeSocketHandle _handle; // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the @@ -219,7 +221,6 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) } _isConnected = true; - UpdateLocalEndPointOnConnect(); break; case SocketError.InvalidArgument: @@ -229,7 +230,6 @@ private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) // whether we're actually connected or not, err on the side of saying // we're connected. _isConnected = true; - UpdateLocalEndPointOnConnect(); break; } } @@ -4938,7 +4938,7 @@ private bool IsWildcardEndPoint(EndPoint? endPoint) if (endPoint is IPEndPoint ipEndpoint) { IPAddress address = ipEndpoint.Address; - return IPAddress.Any.Equals(address) || IPAddress.IPv6Any.Equals(address) || IPAddress.Any.MapToIPv6().Equals(address); + return IPAddress.Any.Equals(address) || IPAddress.IPv6Any.Equals(address) || s_IPAddressAnyMapToIPv6.Equals(address); } return false; From d0fa7599c69fb94b1346c96f59527d98e565e31b Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 1 Sep 2020 18:20:13 +0100 Subject: [PATCH 16/16] Rewrite tests with SocketTestHelperBase --- .../FunctionalTests/LocalEndPointTest.cs | 188 ++++++++++-------- 1 file changed, 106 insertions(+), 82 deletions(-) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs index 582b315e69553..ac2ce8c37e8d0 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/LocalEndPointTest.cs @@ -3,10 +3,11 @@ using System.Threading.Tasks; using Xunit; +using Xunit.Abstractions; namespace System.Net.Sockets.Tests { - public abstract class LocalEndPointTest + public abstract class LocalEndPointTest : SocketTestHelperBase where T : SocketHelperBase, new() { protected abstract bool IPv6 { get; } @@ -14,104 +15,66 @@ public abstract class LocalEndPointTest private IPAddress Loopback => IPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback; + public LocalEndPointTest(ITestOutputHelper output) : base(output) { } + [Fact] - public void UdpSocket_ClientBoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() + public async Task UdpSocket_WhenBoundToWildcardAddress_LocalEPDoesNotChangeOnSendTo() { - using (Socket server = CreateUdpSocket()) - using (Socket client = CreateUdpSocket()) + using (Socket receiver = CreateUdpSocket()) + using (Socket sender = CreateUdpSocket()) { - int serverPort = server.BindToAnonymousPort(Wildcard); + int receiverPort = receiver.BindToAnonymousPort(Wildcard); - Assert.Null(client.LocalEndPoint); + Assert.Null(sender.LocalEndPoint); - int clientPortAfterBind = client.BindToAnonymousPort(Wildcard); + int senderPortAfterBind = sender.BindToAnonymousPort(Wildcard); - Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard before sendto + Assert.Equal(Wildcard, GetLocalEPAddress(sender)); // wildcard before sendto - var sendToEP = new IPEndPoint(Loopback, serverPort); + var sendToEP = new IPEndPoint(Loopback, receiverPort); - client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); + await SendToAsync(sender, new byte[] { 1, 2, 3 }, sendToEP); - Assert.Equal(Wildcard, GetLocalEPAddress(client)); // stays as wildcard after sendto - Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); + Assert.Equal(Wildcard, GetLocalEPAddress(sender)); // stays as wildcard after sendto + Assert.Equal(senderPortAfterBind, GetLocalEPPort(sender)); byte[] buf = new byte[3]; EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); - server.ReceiveFrom(buf, ref receiveFromEP); + receiver.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 1, 2, 3 }, buf); Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address - Assert.Equal(clientPortAfterBind, ((IPEndPoint)receiveFromEP).Port); - - IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); - sendToResult.AsyncWaitHandle.WaitOne(); - client.EndSendTo(sendToResult); - - Assert.Equal(Wildcard, GetLocalEPAddress(client)); // stays as wildcard after async WSASendTo - Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); - - buf = new byte[3]; - receiveFromEP = new IPEndPoint(Wildcard, 0); - server.ReceiveFrom(buf, ref receiveFromEP); - - Assert.Equal(new byte[] { 4, 5, 6 }, buf); - Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address - Assert.Equal(clientPortAfterBind, ((IPEndPoint)receiveFromEP).Port); + Assert.Equal(senderPortAfterBind, ((IPEndPoint)receiveFromEP).Port); } } [Fact] - public void UdpSocket_ClientNotBound_LocalEPBecomesWildcardOnSendTo() + public async Task UdpSocket_WhenNotBound_LocalEPChangeToWildcardOnSendTo() { - using (Socket server = CreateUdpSocket()) - using (Socket client = CreateUdpSocket()) + using (Socket receiver = CreateUdpSocket()) + using (Socket sender = CreateUdpSocket()) { - int serverPort = server.BindToAnonymousPort(Wildcard); + int receiverPort = receiver.BindToAnonymousPort(Wildcard); - Assert.Null(client.LocalEndPoint); // null before sendto + Assert.Null(sender.LocalEndPoint); // null before sendto - var sendToEP = new IPEndPoint(Loopback, serverPort); + var sendToEP = new IPEndPoint(Loopback, receiverPort); - client.SendTo(new byte[] { 1, 2, 3 }, sendToEP); + await SendToAsync(sender, new byte[] { 1, 2, 3 }, sendToEP); - Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard after sendto + Assert.Equal(Wildcard, GetLocalEPAddress(sender)); // changes to wildcard after sendto byte[] buf = new byte[3]; EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); - server.ReceiveFrom(buf, ref receiveFromEP); + receiver.ReceiveFrom(buf, ref receiveFromEP); Assert.Equal(new byte[] { 1, 2, 3 }, buf); + Assert.Equal(Loopback, ((IPEndPoint)receiveFromEP).Address); // received from specific address } } [Fact] - public void UdpSocket_ClientNotBound_LocalEPBecomesWildcardOnAsyncSendTo() - { - using (Socket server = CreateUdpSocket()) - using (Socket client = CreateUdpSocket()) - { - int serverPort = server.BindToAnonymousPort(Wildcard); - - Assert.Null(client.LocalEndPoint); // null before async WSASendTo - - var sendToEP = new IPEndPoint(Loopback, serverPort); - - IAsyncResult sendToResult = client.BeginSendTo(new byte[] { 4, 5, 6 }, 0, 3, SocketFlags.None, sendToEP, null, null); - sendToResult.AsyncWaitHandle.WaitOne(); - client.EndSendTo(sendToResult); - - Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard after async WSASendTo - - byte[] buf = new byte[3]; - EndPoint receiveFromEP = new IPEndPoint(Wildcard, 0); - server.ReceiveFrom(buf, ref receiveFromEP); - - Assert.Equal(new byte[] { 4, 5, 6 }, buf); - } - } - - [Fact] - public async Task TcpSocket_ClientBoundToWildcardAddress_LocalEPChangeToSpecificOnConnnect() + public async Task TcpClientSocket_WhenBoundToWildcardAddress_LocalEPChangeToSpecificOnConnnect() { using (Socket server = CreateTcpSocket()) using (Socket client = CreateTcpSocket()) @@ -122,11 +85,11 @@ public async Task TcpSocket_ClientBoundToWildcardAddress_LocalEPChangeToSpecific Assert.Equal(Wildcard, GetLocalEPAddress(client)); // wildcard before connect server.Listen(); - Task acceptTask = server.AcceptAsync(); + Task acceptTask = AcceptAsync(server); - client.Connect(new IPEndPoint(Loopback, serverPort)); + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); - Assert.Equal(Loopback, GetLocalEPAddress(client)); // specific after connect + Assert.Equal(Loopback, GetLocalEPAddress(client)); // changes to specific after connect Assert.Equal(clientPortAfterBind, GetLocalEPPort(client)); Socket accept = await acceptTask; @@ -135,20 +98,20 @@ public async Task TcpSocket_ClientBoundToWildcardAddress_LocalEPChangeToSpecific } [Fact] - public async Task TcpSocket_ClientNotBound_LocalEPChangeToSpecificOnConnnect() + public async Task TcpClientSocket_WhenNotBound_LocalEPChangeToSpecificOnConnnect() { using (Socket server = CreateTcpSocket()) using (Socket client = CreateTcpSocket()) { int serverPort = server.BindToAnonymousPort(Loopback); server.Listen(); - Task acceptTask = server.AcceptAsync(); + Task acceptTask = AcceptAsync(server); Assert.Null(client.LocalEndPoint); // null before connect - client.Connect(new IPEndPoint(Loopback, serverPort)); + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); - Assert.Equal(Loopback, GetLocalEPAddress(client)); // specific after connect + Assert.Equal(Loopback, GetLocalEPAddress(client)); // changes to specific after connect Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); @@ -156,7 +119,7 @@ public async Task TcpSocket_ClientNotBound_LocalEPChangeToSpecificOnConnnect() } [Fact] - public async Task TcpSocket_ServerBoundToWildcardAddress_AcceptSocketLocalEPIsSpecific() + public async Task TcpAcceptSocket_WhenServerBoundToWildcardAddress_LocalEPIsSpecific() { using (Socket server = CreateTcpSocket()) using (Socket client = CreateTcpSocket()) @@ -166,9 +129,9 @@ public async Task TcpSocket_ServerBoundToWildcardAddress_AcceptSocketLocalEPIsSp Assert.Equal(Wildcard, GetLocalEPAddress(server)); // server -> wildcard before accept server.Listen(); - Task acceptTask = server.AcceptAsync(); + Task acceptTask = AcceptAsync(server); - client.Connect(new IPEndPoint(Loopback, serverPort)); + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); @@ -181,7 +144,7 @@ public async Task TcpSocket_ServerBoundToWildcardAddress_AcceptSocketLocalEPIsSp } [Fact] - public async Task TcpSocket_ServerBoundToSpecificAddress_AcceptSocketLocalEPIsSame() + public async Task TcpAcceptSocket_WhenServerBoundToSpecificAddress_LocalEPIsSame() { using (Socket server = CreateTcpSocket()) using (Socket client = CreateTcpSocket()) @@ -191,9 +154,9 @@ public async Task TcpSocket_ServerBoundToSpecificAddress_AcceptSocketLocalEPIsSa Assert.Equal(Loopback, GetLocalEPAddress(server)); // server -> specific before accept server.Listen(); - Task acceptTask = server.AcceptAsync(); + Task acceptTask = AcceptAsync(server); - client.Connect(new IPEndPoint(Loopback, serverPort)); + await ConnectAsync(client, new IPEndPoint(Loopback, serverPort)); Socket accept = await acceptTask; Assert.Equal(accept.RemoteEndPoint, client.LocalEndPoint); @@ -245,16 +208,77 @@ private int GetLocalEPPort(Socket socket) return ((IPEndPoint)socket.LocalEndPoint).Port; } } + public abstract class LocalEndPointTestIPv4 : LocalEndPointTest where T : SocketHelperBase, new() + { + protected override bool IPv6 => false; + + public LocalEndPointTestIPv4(ITestOutputHelper output) : base(output) { } + } + + public abstract class LocalEndPointTestIPv6 : LocalEndPointTest where T : SocketHelperBase, new() + { + protected override bool IPv6 => true; + + public LocalEndPointTestIPv6(ITestOutputHelper output) : base(output) { } + } [Trait("IPv4", "true")] - public class LocalEndPointIPv4Test : LocalEndPointTest + public sealed class LocalEndPointTestIPv4Sync : LocalEndPointTestIPv4 { - protected override bool IPv6 => false; + public LocalEndPointTestIPv4Sync(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4SyncForceNonBlocking : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4SyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Apm : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Apm(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Task : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Task(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv4", "true")] + public sealed class LocalEndPointTestIPv4Eap : LocalEndPointTestIPv4 + { + public LocalEndPointTestIPv4Eap(ITestOutputHelper output) : base(output) { } } [Trait("IPv6", "true")] - public class LocalEndPointIPv6Test : LocalEndPointTest + public sealed class LocalEndPointTestIPv6Sync : LocalEndPointTestIPv6 { - protected override bool IPv6 => true; + public LocalEndPointTestIPv6Sync(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6SyncForceNonBlocking : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6SyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Apm : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Apm(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Task : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Task(ITestOutputHelper output) : base(output) { } + } + + [Trait("IPv6", "true")] + public sealed class LocalEndPointTestIPv6Eap : LocalEndPointTestIPv6 + { + public LocalEndPointTestIPv6Eap(ITestOutputHelper output) : base(output) { } } }