diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 7998edc9ff40..3dc73574414f 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -409,13 +409,26 @@ private unsafe void InnerStartOperationReceiveMessageFrom() // WSAMsg also contains a single WSABuffer describing a control buffer. PinSocketAddressBuffer(); - // Create and pin a WSAMessageBuffer if none already. + // Create a WSAMessageBuffer if none exists yet. if (_wsaMessageBuffer == null) { + Debug.Assert(!_wsaMessageBufferGCHandle.IsAllocated); + Debug.Assert(_ptrWSAMessageBuffer == IntPtr.Zero); _wsaMessageBuffer = new byte[sizeof(Interop.Winsock.WSAMsg)]; + } + + // And ensure the WSAMessageBuffer is appropriately pinned. + if (_ptrWSAMessageBuffer == IntPtr.Zero) + { + Debug.Assert(!_wsaMessageBufferGCHandle.IsAllocated); _wsaMessageBufferGCHandle = GCHandle.Alloc(_wsaMessageBuffer, GCHandleType.Pinned); _ptrWSAMessageBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBuffer, 0); } + else + { + Debug.Assert(_wsaMessageBufferGCHandle.IsAllocated); + Debug.Assert(_ptrWSAMessageBuffer == Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBuffer, 0)); + } // Create and pin an appropriately sized control buffer if none already IPAddress ipAddress = (_socketAddress.Family == AddressFamily.InterNetworkV6 ? _socketAddress.GetIPAddress() : null); diff --git a/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs b/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs index 248569ebab9e..64b6d193639e 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs @@ -10,96 +10,52 @@ namespace System.Net.Sockets.Tests { public class ReceiveMessageFromAsync { - public void OnCompleted(object sender, SocketAsyncEventArgs args) - { - EventWaitHandle handle = (EventWaitHandle)args.UserToken; - handle.Set(); - } - [OuterLoop] // TODO: Issue #11345 - [Fact] - public void Success_IPv4() + [Theory] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public void ReceiveSentMessages_SocketAsyncEventArgs_Success(bool ipv4, bool changeReceiveBufferEachCall) { - ManualResetEvent completed = new ManualResetEvent(false); + const int DataLength = 1024; + AddressFamily family = ipv4 ? AddressFamily.InterNetwork : AddressFamily.InterNetworkV6; + IPAddress loopback = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; - if (Socket.OSSupportsIPv4) + var completed = new ManualResetEventSlim(false); + using (var sender = new Socket(family, SocketType.Dgram, ProtocolType.Udp)) + using (var receiver = new Socket(family, SocketType.Dgram, ProtocolType.Udp)) { - using (Socket receiver = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) - { - int port = receiver.BindToAnonymousPort(IPAddress.Loopback); - receiver.SetSocketOption(SocketOptionLevel.IP, SocketOptionName.PacketInformation, true); + sender.Bind(new IPEndPoint(loopback, 0)); + receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true); + int port = receiver.BindToAnonymousPort(loopback); - Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); - sender.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + var args = new SocketAsyncEventArgs() { RemoteEndPoint = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0) }; + args.Completed += (s,e) => completed.Set(); + args.SetBuffer(new byte[DataLength], 0, DataLength); + for (int iters = 0; iters < 5; iters++) + { for (int i = 0; i < TestSettings.UDPRedundancy; i++) { - sender.SendTo(new byte[1024], new IPEndPoint(IPAddress.Loopback, port)); + sender.SendTo(new byte[DataLength], new IPEndPoint(loopback, port)); } - SocketAsyncEventArgs args = new SocketAsyncEventArgs(); - args.RemoteEndPoint = new IPEndPoint(IPAddress.Any, 0); - args.SetBuffer(new byte[1024], 0, 1024); - args.Completed += OnCompleted; - args.UserToken = completed; - - bool pending = receiver.ReceiveMessageFromAsync(args); - if (!pending) - { - OnCompleted(null, args); - } - - Assert.True(completed.WaitOne(TestSettings.PassingTestTimeout), "Timeout while waiting for connection"); - - Assert.Equal(1024, args.BytesTransferred); - Assert.Equal(sender.LocalEndPoint, args.RemoteEndPoint); - Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, args.ReceiveMessageFromPacketInfo.Address); - - sender.Dispose(); - } - } - } - - [OuterLoop] // TODO: Issue #11345 - [Fact] - public void Success_IPv6() - { - ManualResetEvent completed = new ManualResetEvent(false); - - if (Socket.OSSupportsIPv6) - { - using (Socket receiver = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp)) - { - int port = receiver.BindToAnonymousPort(IPAddress.IPv6Loopback); - receiver.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true); - - Socket sender = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp); - sender.Bind(new IPEndPoint(IPAddress.IPv6Loopback, 0)); - - for (int i = 0; i < TestSettings.UDPRedundancy; i++) + if (changeReceiveBufferEachCall) { - sender.SendTo(new byte[1024], new IPEndPoint(IPAddress.IPv6Loopback, port)); + args.SetBuffer(new byte[DataLength], 0, DataLength); } - SocketAsyncEventArgs args = new SocketAsyncEventArgs(); - args.RemoteEndPoint = new IPEndPoint(IPAddress.IPv6Any, 0); - args.SetBuffer(new byte[1024], 0, 1024); - args.Completed += OnCompleted; - args.UserToken = completed; - - bool pending = receiver.ReceiveMessageFromAsync(args); - if (!pending) + if (!receiver.ReceiveMessageFromAsync(args)) { - OnCompleted(null, args); + completed.Set(); } + Assert.True(completed.Wait(TestSettings.PassingTestTimeout), "Timeout while waiting for connection"); + completed.Reset(); - Assert.True(completed.WaitOne(TestSettings.PassingTestTimeout), "Timeout while waiting for connection"); - - Assert.Equal(1024, args.BytesTransferred); + Assert.Equal(DataLength, args.BytesTransferred); Assert.Equal(sender.LocalEndPoint, args.RemoteEndPoint); Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, args.ReceiveMessageFromPacketInfo.Address); - - sender.Dispose(); } } } @@ -108,30 +64,33 @@ public void Success_IPv6() [Theory] [InlineData(false)] [InlineData(true)] - public async Task Task_Success(bool ipv4) + public async Task ReceiveSentMessages_Tasks_Success(bool ipv4) { + const int DataLength = 1024; AddressFamily family = ipv4 ? AddressFamily.InterNetwork : AddressFamily.InterNetworkV6; IPAddress loopback = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; - using (Socket receiver = new Socket(family, SocketType.Dgram, ProtocolType.Udp)) - using (Socket sender = new Socket(family, SocketType.Dgram, ProtocolType.Udp)) + using (var receiver = new Socket(family, SocketType.Dgram, ProtocolType.Udp)) + using (var sender = new Socket(family, SocketType.Dgram, ProtocolType.Udp)) { - int port = receiver.BindToAnonymousPort(loopback); - receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true); - sender.Bind(new IPEndPoint(loopback, 0)); + receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true); + int port = receiver.BindToAnonymousPort(loopback); - for (int i = 0; i < TestSettings.UDPRedundancy; i++) + for (int iters = 0; iters < 5; iters++) { - sender.SendTo(new byte[1024], new IPEndPoint(loopback, port)); - } + for (int i = 0; i < TestSettings.UDPRedundancy; i++) + { + sender.SendTo(new byte[DataLength], new IPEndPoint(loopback, port)); + } - SocketReceiveMessageFromResult result = await receiver.ReceiveMessageFromAsync( - new ArraySegment(new byte[1024], 0, 1024), SocketFlags.None, - new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0)); - Assert.Equal(1024, result.ReceivedBytes); - Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint); - Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, result.PacketInformation.Address); + SocketReceiveMessageFromResult result = await receiver.ReceiveMessageFromAsync( + new ArraySegment(new byte[DataLength], 0, DataLength), SocketFlags.None, + new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0)); + Assert.Equal(DataLength, result.ReceivedBytes); + Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint); + Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, result.PacketInformation.Address); + } } } }