Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix broken Select with error list on macOS #104915

Merged
merged 12 commits into from
Jul 28, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Select")]
internal static unsafe partial Error Select(Span<int> readFDs, int readFDsLength, Span<int> writeFDs, int writeFDsLength, Span<int> checkError, int checkErrorLength, int timeout, int maxFd, out int triggered);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@
Link="Common\Interop\Unix\System.Native\Interop.ReceiveMessage.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Send.cs"
Link="Common\Interop\Unix\System.Native\Interop.Send.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Select.cs"
Link="Common\Interop\Unix\System.Native\Interop.Select.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SendMessage.cs"
Link="Common\Interop\Unix\System.Native\Interop.SendMessage.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SetSockOpt.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal static partial class SocketPal
public static readonly int MaximumAddressSize = Interop.Sys.GetMaximumAddressSize();
private static readonly bool SupportsDualModeIPv4PacketInfo = GetPlatformSupportsDualModeIPv4PacketInfo();

private static readonly bool SelectOverPollIsBroken = OperatingSystem.IsMacOS() || OperatingSystem.IsIOS() || OperatingSystem.IsTvOS() || OperatingSystem.IsMacCatalyst();

// IovStackThreshold matches Linux's UIO_FASTIOV, which is the number of 'struct iovec'
// that get stackalloced in the Linux kernel.
private const int IovStackThreshold = 8;
Expand Down Expand Up @@ -1782,6 +1784,10 @@ public static unsafe SocketError Select(IList? checkRead, IList? checkWrite, ILi
// by the system. Since poll then expects an array of entries, we try to allocate the array on the stack,
// only falling back to allocating it on the heap if it's deemed too big.

if (SelectOverPollIsBroken)
{
return SelectViaSelect(checkRead, checkWrite, checkError, microseconds);
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved
const int StackThreshold = 80; // arbitrary limit to avoid too much space on stack
if (count < StackThreshold)
{
Expand All @@ -1806,6 +1812,115 @@ public static unsafe SocketError Select(IList? checkRead, IList? checkWrite, ILi
}
}

private static SocketError SelectViaSelect(IList? checkRead, IList? checkWrite, IList? checkError, int microseconds)
{
Span<int> readFDs = checkRead?.Count > 20 ? new int[checkRead.Count] : checkRead?.Count > 0 ? stackalloc int[checkRead.Count] : Span<int>.Empty;
Span<int> writeFDs = checkWrite?.Count > 20 ? new int[checkWrite.Count] : checkWrite?.Count > 0 ? stackalloc int[checkWrite.Count] : Span<int>.Empty;;
Span<int> errorFDs = checkError?.Count > 20 ? new int[checkError.Count] : checkError?.Count > 0 ? stackalloc int[checkError.Count] : Span<int>.Empty;
wfurt marked this conversation as resolved.
Show resolved Hide resolved

int refsAdded = 0;
int maxFd = 0;
try
{
AddDesriptors(ref readFDs, checkRead, ref refsAdded, ref maxFd);
AddDesriptors(ref writeFDs, checkWrite, ref refsAdded, ref maxFd);
AddDesriptors(ref errorFDs, checkError, ref refsAdded, ref maxFd);

int triggered = 0;
Interop.Error err = Interop.Sys.Select(readFDs, readFDs.Length, writeFDs, writeFDs.Length, errorFDs, errorFDs.Length, microseconds, maxFd, out triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
}

if (triggered == 0)
{
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);

checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
}
else
{
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);

FilterSelectList(checkRead, readFDs);
FilterSelectList(checkWrite, writeFDs);
FilterSelectList(checkError, errorFDs);
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}
finally
{
// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}

return (SocketError)0;
}

private static void AddDesriptors(ref Span<int> buffer, IList? socketList, ref int refsAdded, ref int maxFd)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
if (socketList == null || socketList.Count == 0 )
{
return;
}

Debug.Assert(buffer.Length == socketList.Count);
for (int i = 0; i < socketList.Count; i++)
{
Socket? socket = socketList[i] as Socket;
if (socket == null)
{
throw new ArgumentException(SR.Format(SR.net_sockets_select, socket?.GetType().FullName ?? "null", typeof(Socket).FullName), nameof(socketList));
}

if (socket.Handle > maxFd)
{
maxFd = (int)socket.Handle;
}

bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
buffer[i] = (int)socket.InternalSafeHandle.DangerousGetHandle();

refsAdded++;
}
}

private static void FilterSelectList(IList? socketList, Span<int> results)
{
if (socketList == null)
return;

// The Select API requires leaving in the input lists only those sockets that were ready. As such, we need to loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is the same large comment that it's FilterPollList later in this file. Maybe just condense this comment to referring folks to see the comment in FilterPollList?

// through each poll event, and for each that wasn't ready, remove the corresponding Socket from its list. Technically
// this is O(n^2), due to removing from the list requiring shifting down all elements after it. However, this doesn't
// happen with the most common cases. If very few sockets were ready, then as we iterate from the end of the list, each
// removal will typically be O(1) rather than O(n). If most sockets were ready, then we only need to remove a few, in
// which case we're only doing a small number of O(n) shifts. It's only for the intermediate case, where a non-trivial
// number of sockets are ready and a non-trivial number of sockets are not ready that we end up paying the most. We could
// avoid these costs by, for example, allocating a side list that we fill with the sockets that should remain, clearing
// the original list, and then populating the original list with the contents of the side list. That of course has its
// own costs, and so for now we do the "simple" thing. This can be changed in the future as needed.

for (int i = socketList.Count - 1; i >= 0; --i)
{
if (results[i] == 0)
{
socketList.RemoveAt(i);
}
}
}

private static unsafe SocketError SelectViaPoll(
IList? checkRead, int checkReadInitialCount,
IList? checkWrite, int checkWriteInitialCount,
Expand All @@ -1824,13 +1939,15 @@ private static unsafe SocketError SelectViaPoll(
AddToPollArray(events, eventsLength, checkRead, ref offset, Interop.PollEvents.POLLIN | Interop.PollEvents.POLLHUP, ref refsAdded);
AddToPollArray(events, eventsLength, checkWrite, ref offset, Interop.PollEvents.POLLOUT, ref refsAdded);
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.PollEvents.POLLPRI, ref refsAdded);
Debug.Assert(offset == eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");
Debug.Assert(refsAdded == eventsLength, $"Invalid ref adds. refsAdded={refsAdded}, eventsLength={eventsLength}.");

// Console.WriteLine("Fixup ??? {0}", PollNeedsErrorListFixup);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
Debug.Assert(offset <= eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");
Debug.Assert(refsAdded <= eventsLength, $"Invalid ref adds. refsAdded={refsAdded}, eventsLength={eventsLength}.");

// Do the poll
uint triggered = 0;
int milliseconds = microseconds == -1 ? -1 : microseconds / 1000;
Interop.Error err = Interop.Sys.Poll(events, (uint)eventsLength, milliseconds, &triggered);
Interop.Error err = Interop.Sys.Poll(events, (uint)refsAdded, milliseconds, &triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
Expand Down Expand Up @@ -1867,7 +1984,7 @@ private static unsafe SocketError SelectViaPoll(
}
}

private static unsafe void AddToPollArray(Interop.PollEvent* arr, int arrLength, IList? socketList, ref int arrOffset, Interop.PollEvents events, ref int refsAdded)
private static unsafe void AddToPollArray(Interop.PollEvent* arr, int arrLength, IList? socketList, ref int arrOffset, Interop.PollEvents events, ref int refsAdded, int readCount = 0)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
if (socketList == null)
return;
Expand All @@ -1887,6 +2004,29 @@ private static unsafe void AddToPollArray(Interop.PollEvent* arr, int arrLength,
bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
int fd = (int)socket.InternalSafeHandle.DangerousGetHandle();

if (readCount > 0)
{
// some platfoms like macOS do not like if there is duplication between real and error list.
wfurt marked this conversation as resolved.
Show resolved Hide resolved
// To fix that we will search read list and if macthing descriptor exiost we will add events flags
wfurt marked this conversation as resolved.
Show resolved Hide resolved
// instead of adding new entry to error list.
int readIndex = 0;
while (readIndex < readCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this turning a linear operation into an N^2 operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. This is where the assumption comes in that this is used from small number of sockets where it does not matter as much .... and the cost is only on platforms that are currently broken. (and use read and error list together)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GrabYourPitchforks, any concerns?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put up different implementation that avoids the N^2 problem. Please take another look @stephentoub

{
if (arr[readIndex].FileDescriptor == fd)
{
arr[i].Events |= events;
socket.InternalSafeHandle.DangerousRelease();
break;
}
readIndex++;
}
if (readIndex != readCount)
{
continue;
}
}

arr[arrOffset++] = new Interop.PollEvent { Events = events, FileDescriptor = fd };
refsAdded++;
}
Expand Down
107 changes: 103 additions & 4 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.DotNet.XUnitExtensions;
using Xunit;
using Xunit.Abstractions;

Expand All @@ -21,7 +21,7 @@ public SelectTest(ITestOutputHelper output)
}

private const int SmallTimeoutMicroseconds = 10 * 1000;
private const int FailTimeoutMicroseconds = 30 * 1000 * 1000;
internal const int FailTimeoutMicroseconds = 30 * 1000 * 1000;

[SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")]
[Theory]
Expand Down Expand Up @@ -78,6 +78,82 @@ public void Select_ReadWrite_AllReady(int reads, int writes)
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void Select_ReadError_Success(bool dispose)
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

if (dispose)
{
sender.Dispose();
}
else
{
sender.Send(new byte[] { 1 });
}

var readList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, null, errorList, -1);
if (dispose)
{
Assert.True(readList.Count == 1 || errorList.Count == 1);
}
else
{
Assert.Equal(1, readList.Count);
Assert.Equal(0, errorList.Count);
}
}

[Fact]
public void Select_WriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(null, writeList, errorList, -1);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Fact]
public void Select_ReadWriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

sender.Send(new byte[] { 1 });
receiver.Poll(FailTimeoutMicroseconds, SelectMode.SelectRead);
var readList = new List<Socket> { receiver };
var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, writeList, errorList, -1);
Assert.Equal(1, readList.Count);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Theory]
[InlineData(2, 0)]
[InlineData(2, 1)]
Expand Down Expand Up @@ -109,7 +185,6 @@ public void Select_SocketAlreadyClosed_AllSocketsClosableAfterException(int sock
}
}

[SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")]
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/51392", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)]
public void Select_ReadError_NoneReady_ManySockets()
Expand Down Expand Up @@ -245,7 +320,7 @@ public void Poll_ReadReady_LongTimeouts(int microsecondsTimeout)
}
}

private static KeyValuePair<Socket, Socket> CreateConnectedSockets()
internal static KeyValuePair<Socket, Socket> CreateConnectedSockets()
{
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
Expand Down Expand Up @@ -342,5 +417,29 @@ private static void DoAccept(Socket listenSocket, int connectionsToAccept)
}
}
}

[ConditionalFact]
public void Slect_LargeNumber_Succcess()
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
const int MaxSockets = 1025;
KeyValuePair<Socket, Socket>[] socketPairs;
try
{
// we try to shoot for more socket than FD_SETSIZE (that is typically 1024
wfurt marked this conversation as resolved.
Show resolved Hide resolved
socketPairs = Enumerable.Range(0, MaxSockets).Select(_ => SelectTest.CreateConnectedSockets()).ToArray();
}
catch
{
throw new SkipTestException("Unable to open large count number of socket");
}

var readList = new List<Socket>(socketPairs.Select(p => p.Key).ToArray());

// Try to write and read on last sockets
(Socket reader, Socket writer) = socketPairs[MaxSockets - 1];
writer.Send(new byte[1]);
Socket.Select(readList, null, null, SelectTest.FailTimeoutMicroseconds);
Assert.Equal(1, readList.Count);
}
}
}
1 change: 1 addition & 0 deletions src/native/libs/System.Native/entrypoints.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ static const Entry s_sysNative[] =
DllImportEntry(SystemNative_GetGroupName)
DllImportEntry(SystemNative_GetUInt64OSThreadId)
DllImportEntry(SystemNative_TryGetUInt32OSThreadId)
DllImportEntry(SystemNative_Select)
};

EXTERN_C const void* SystemResolveDllImport(const char* name);
Expand Down
Loading
Loading