Skip to content

Commit

Permalink
Socket: don't assign right endpoint until the connect is successful. (#…
Browse files Browse the repository at this point in the history
…53581)

* Socket: don't assign right endpoint until the connect is successful.

'Right endpoint' must match the address family of the Socket or
we can't serialize the LocalEndPoint and RemoteEndPoint.

When multiple connect attempts are made against a DualMode Socket with
both IPv4 and IPv6 addresses, a failed attempt must not set 'right
endpoint'.

* SocketTaskExtensionsTest.EnsureMethodsAreCallable: update expected exceptions

* PR feedback

* EnsureMethodsAreCallable: move ReceiveFromAsync before ConnectAsync to avoid wildcard bind on Windows that leads to a different exception
  • Loading branch information
tmds authored Jun 5, 2021
1 parent bc44b09 commit 0d31ddb
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 31 deletions.
47 changes: 18 additions & 29 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ public partial class Socket : IDisposable
// to poll for the real state until we're done connecting.
private bool _nonBlockingConnectInProgress;

// Keep track of the kind of endpoint used to do a non-blocking connect, so we can set
// it to _rightEndPoint when we discover we're connected.
private EndPoint? _nonBlockingConnectRightEndPoint;
// Keep track of the kind of endpoint used to do a connect, so we can set
// it to _rightEndPoint when we're connected.
private EndPoint? _pendingConnectRightEndPoint;

// These are constants initialized by constructor.
private AddressFamily _addressFamily;
Expand Down Expand Up @@ -285,11 +285,8 @@ public EndPoint? LocalEndPoint

if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
SetToConnected();
}

if (_rightEndPoint == null)
Expand Down Expand Up @@ -332,11 +329,9 @@ public EndPoint? RemoteEndPoint
{
if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
// Update the state if we've become connected after a non-blocking connect.
SetToConnected();
}

if (_rightEndPoint == null || !_isConnected)
Expand Down Expand Up @@ -439,11 +434,9 @@ public bool Connected

if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
// Update the state if we've become connected after a non-blocking connect.
SetToConnected();
}

return _isConnected;
Expand Down Expand Up @@ -856,12 +849,8 @@ public void Connect(EndPoint remoteEP)
ValidateForMultiConnect(isMultiEndpoint: false);

Internals.SocketAddress socketAddress = Serialize(ref remoteEP);

if (!Blocking)
{
_nonBlockingConnectRightEndPoint = remoteEP;
_nonBlockingConnectInProgress = true;
}
_pendingConnectRightEndPoint = remoteEP;
_nonBlockingConnectInProgress = !Blocking;

DoConnect(remoteEP, socketAddress);
}
Expand Down Expand Up @@ -2768,13 +2757,11 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
}

e._socketAddress = Serialize(ref endPointSnapshot);
_pendingConnectRightEndPoint = endPointSnapshot;
_nonBlockingConnectInProgress = false;

WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);

// Save the old RightEndPoint and prep new RightEndPoint.
EndPoint? oldEndPoint = _rightEndPoint;
_rightEndPoint ??= endPointSnapshot;

if (SocketsTelemetry.Log.IsEnabled())
{
SocketsTelemetry.Log.ConnectStart(e._socketAddress!);
Expand All @@ -2801,7 +2788,6 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaCan
SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
}

_rightEndPoint = oldEndPoint;
_localEndPoint = null;

// Clear in-use flag on event args object.
Expand Down Expand Up @@ -3217,12 +3203,11 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket

if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);

// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= endPointSnapshot;

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}");

// Update state and performance counters.
_pendingConnectRightEndPoint = endPointSnapshot;
_nonBlockingConnectInProgress = false;
SetToConnected();
if (NetEventSource.Log.IsEnabled()) NetEventSource.Connected(this, LocalEndPoint, RemoteEndPoint);
}
Expand Down Expand Up @@ -3659,10 +3644,14 @@ internal void SetToConnected()
return;
}

Debug.Assert(_nonBlockingConnectInProgress == false);

// Update the status: this socket was indeed connected at
// some point in time update the perf counter as well.
_isConnected = true;
_isDisconnected = false;
_rightEndPoint ??= _pendingConnectRightEndPoint;
_pendingConnectRightEndPoint = null;
UpdateLocalEndPointOnConnect();
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected");
}
Expand Down
54 changes: 54 additions & 0 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,60 @@ public async Task Connect_MultipleIPAddresses_Success(IPAddress listenAt)
}
}

[Fact]
public async Task Connect_DualMode_MultiAddressFamilyConnect_RetrievedEndPoints_Success()
{
if (!SupportsMultiConnect)
return;

int port;
using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
{
Assert.True(client.DualMode);

Task connectTask = MultiConnectAsync(client, new IPAddress[] { IPAddress.IPv6Loopback, IPAddress.Loopback }, port);
await connectTask;

var localEndPoint = client.LocalEndPoint as IPEndPoint;
Assert.NotNull(localEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);

var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
Assert.NotNull(remoteEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
}
}

[Fact]
public async Task Connect_DualMode_DnsConnect_RetrievedEndPoints_Success()
{
var localhostAddresses = Dns.GetHostAddresses("localhost");
if (Array.IndexOf(localhostAddresses, IPAddress.Loopback) == -1 ||
Array.IndexOf(localhostAddresses, IPAddress.IPv6Loopback) == -1)
{
return;
}

int port;
using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
{
Assert.True(client.DualMode);

Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port));
await connectTask;

var localEndPoint = client.LocalEndPoint as IPEndPoint;
Assert.NotNull(localEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);

var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
Assert.NotNull(remoteEndPoint);
Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
}
}

[Fact]
public async Task Connect_OnConnectedSocket_Fails()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public async Task EnsureMethodsAreCallable()
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.AcceptAsync(s));
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.AcceptAsync(s, null));

await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));

await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint, CancellationToken.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint.Address, badEndPoint.Port));
Expand All @@ -35,8 +38,6 @@ public async Task EnsureMethodsAreCallable()
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, buffer.AsMemory(), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>[] { new ArraySegment<byte>(buffer) }, SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));

await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, buffer.AsMemory(), SocketFlags.None));
Expand Down

0 comments on commit 0d31ddb

Please sign in to comment.