Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Fix lifetime handling of ReceiveMessageFromAsync buffer on Windows #22012

Merged
merged 3 commits into from
Jul 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,26 @@ private unsafe void InnerStartOperationReceiveMessageFrom()
// WSAMsg also contains a single WSABuffer describing a control buffer.
PinSocketAddressBuffer();

// Create and pin a WSAMessageBuffer if none already.
// Create a WSAMessageBuffer if none exists yet.
if (_wsaMessageBuffer == null)
{
Debug.Assert(!_wsaMessageBufferGCHandle.IsAllocated);
Debug.Assert(_ptrWSAMessageBuffer == IntPtr.Zero);
_wsaMessageBuffer = new byte[sizeof(Interop.Winsock.WSAMsg)];
}

// And ensure the WSAMessageBuffer is appropriately pinned.
if (_ptrWSAMessageBuffer == IntPtr.Zero)
{
Debug.Assert(!_wsaMessageBufferGCHandle.IsAllocated);
_wsaMessageBufferGCHandle = GCHandle.Alloc(_wsaMessageBuffer, GCHandleType.Pinned);
_ptrWSAMessageBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBuffer, 0);
}
else
{
Debug.Assert(_wsaMessageBufferGCHandle.IsAllocated);
Debug.Assert(_ptrWSAMessageBuffer == Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBuffer, 0));
}

// Create and pin an appropriately sized control buffer if none already
IPAddress ipAddress = (_socketAddress.Family == AddressFamily.InterNetworkV6 ? _socketAddress.GetIPAddress() : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,96 +10,52 @@ namespace System.Net.Sockets.Tests
{
public class ReceiveMessageFromAsync
{
public void OnCompleted(object sender, SocketAsyncEventArgs args)
{
EventWaitHandle handle = (EventWaitHandle)args.UserToken;
handle.Set();
}

[OuterLoop] // TODO: Issue #11345
[Fact]
public void Success_IPv4()
[Theory]
[InlineData(false, false)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public void ReceiveSentMessages_SocketAsyncEventArgs_Success(bool ipv4, bool changeReceiveBufferEachCall)
{
ManualResetEvent completed = new ManualResetEvent(false);
const int DataLength = 1024;
AddressFamily family = ipv4 ? AddressFamily.InterNetwork : AddressFamily.InterNetworkV6;
IPAddress loopback = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;

if (Socket.OSSupportsIPv4)
var completed = new ManualResetEventSlim(false);
using (var sender = new Socket(family, SocketType.Dgram, ProtocolType.Udp))
using (var receiver = new Socket(family, SocketType.Dgram, ProtocolType.Udp))
{
using (Socket receiver = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
{
int port = receiver.BindToAnonymousPort(IPAddress.Loopback);
receiver.SetSocketOption(SocketOptionLevel.IP, SocketOptionName.PacketInformation, true);
sender.Bind(new IPEndPoint(loopback, 0));
receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true);
int port = receiver.BindToAnonymousPort(loopback);

Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
sender.Bind(new IPEndPoint(IPAddress.Loopback, 0));
var args = new SocketAsyncEventArgs() { RemoteEndPoint = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0) };
args.Completed += (s,e) => completed.Set();
args.SetBuffer(new byte[DataLength], 0, DataLength);

for (int iters = 0; iters < 5; iters++)
{
for (int i = 0; i < TestSettings.UDPRedundancy; i++)
{
sender.SendTo(new byte[1024], new IPEndPoint(IPAddress.Loopback, port));
sender.SendTo(new byte[DataLength], new IPEndPoint(loopback, port));
}

SocketAsyncEventArgs args = new SocketAsyncEventArgs();
args.RemoteEndPoint = new IPEndPoint(IPAddress.Any, 0);
args.SetBuffer(new byte[1024], 0, 1024);
args.Completed += OnCompleted;
args.UserToken = completed;

bool pending = receiver.ReceiveMessageFromAsync(args);
if (!pending)
{
OnCompleted(null, args);
}

Assert.True(completed.WaitOne(TestSettings.PassingTestTimeout), "Timeout while waiting for connection");

Assert.Equal(1024, args.BytesTransferred);
Assert.Equal(sender.LocalEndPoint, args.RemoteEndPoint);
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, args.ReceiveMessageFromPacketInfo.Address);

sender.Dispose();
}
}
}

[OuterLoop] // TODO: Issue #11345
[Fact]
public void Success_IPv6()
{
ManualResetEvent completed = new ManualResetEvent(false);

if (Socket.OSSupportsIPv6)
{
using (Socket receiver = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp))
{
int port = receiver.BindToAnonymousPort(IPAddress.IPv6Loopback);
receiver.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true);

Socket sender = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp);
sender.Bind(new IPEndPoint(IPAddress.IPv6Loopback, 0));

for (int i = 0; i < TestSettings.UDPRedundancy; i++)
if (changeReceiveBufferEachCall)
{
sender.SendTo(new byte[1024], new IPEndPoint(IPAddress.IPv6Loopback, port));
args.SetBuffer(new byte[DataLength], 0, DataLength);
}

SocketAsyncEventArgs args = new SocketAsyncEventArgs();
args.RemoteEndPoint = new IPEndPoint(IPAddress.IPv6Any, 0);
args.SetBuffer(new byte[1024], 0, 1024);
args.Completed += OnCompleted;
args.UserToken = completed;

bool pending = receiver.ReceiveMessageFromAsync(args);
if (!pending)
if (!receiver.ReceiveMessageFromAsync(args))
{
OnCompleted(null, args);
completed.Set();
}
Assert.True(completed.Wait(TestSettings.PassingTestTimeout), "Timeout while waiting for connection");
completed.Reset();

Assert.True(completed.WaitOne(TestSettings.PassingTestTimeout), "Timeout while waiting for connection");

Assert.Equal(1024, args.BytesTransferred);
Assert.Equal(DataLength, args.BytesTransferred);
Assert.Equal(sender.LocalEndPoint, args.RemoteEndPoint);
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, args.ReceiveMessageFromPacketInfo.Address);

sender.Dispose();
}
}
}
Expand All @@ -108,30 +64,33 @@ public void Success_IPv6()
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task Task_Success(bool ipv4)
public async Task ReceiveSentMessages_Tasks_Success(bool ipv4)
{
const int DataLength = 1024;
AddressFamily family = ipv4 ? AddressFamily.InterNetwork : AddressFamily.InterNetworkV6;
IPAddress loopback = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback;

using (Socket receiver = new Socket(family, SocketType.Dgram, ProtocolType.Udp))
using (Socket sender = new Socket(family, SocketType.Dgram, ProtocolType.Udp))
using (var receiver = new Socket(family, SocketType.Dgram, ProtocolType.Udp))
using (var sender = new Socket(family, SocketType.Dgram, ProtocolType.Udp))
{
int port = receiver.BindToAnonymousPort(loopback);
receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true);

sender.Bind(new IPEndPoint(loopback, 0));
receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true);
int port = receiver.BindToAnonymousPort(loopback);

for (int i = 0; i < TestSettings.UDPRedundancy; i++)
for (int iters = 0; iters < 5; iters++)
{
sender.SendTo(new byte[1024], new IPEndPoint(loopback, port));
}
for (int i = 0; i < TestSettings.UDPRedundancy; i++)
{
sender.SendTo(new byte[DataLength], new IPEndPoint(loopback, port));
}

SocketReceiveMessageFromResult result = await receiver.ReceiveMessageFromAsync(
new ArraySegment<byte>(new byte[1024], 0, 1024), SocketFlags.None,
new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0));
Assert.Equal(1024, result.ReceivedBytes);
Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint);
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, result.PacketInformation.Address);
SocketReceiveMessageFromResult result = await receiver.ReceiveMessageFromAsync(
new ArraySegment<byte>(new byte[DataLength], 0, DataLength), SocketFlags.None,
new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0));
Assert.Equal(DataLength, result.ReceivedBytes);
Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint);
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, result.PacketInformation.Address);
}
}
}
}
Expand Down