Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use approved SocketAddress API instead of direct internal access #89841

Merged
merged 4 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ internal static partial class Interop
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Bind")]
internal static unsafe partial Error Bind(SafeHandle socket, ProtocolType socketProtocolType, byte* socketAddress, int socketAddressLen);
internal static partial Error Bind(SafeHandle socket, ProtocolType socketProtocolType, ReadOnlySpan<byte> socketAddress, int socketAddressLen);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Net.Sockets;

Expand All @@ -9,9 +10,13 @@ internal static partial class Interop
internal static partial class Winsock
{
[LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)]
internal static partial SocketError bind(
private static partial SocketError bind(
SafeSocketHandle socketHandle,
byte[] socketAddress,
ReadOnlySpan<byte> socketAddress,
int socketAddressSize);

internal static SocketError bind(
SafeSocketHandle socketHandle,
ReadOnlySpan<byte> socketAddress) => bind(socketHandle, socketAddress, socketAddress.Length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace System.Net.Sockets
{
internal static class IPEndPointExtensions
internal static partial class IPEndPointExtensions
{
public static IPAddress GetIPAddress(ReadOnlySpan<byte> socketAddressBuffer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace System.Net.Sockets
{
internal static class IPEndPointExtensions
internal static partial class IPEndPointExtensions
{
public static Internals.SocketAddress Serialize(EndPoint endpoint)
{
Expand Down
50 changes: 25 additions & 25 deletions src/libraries/Common/src/System/Net/SocketAddress.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class SocketAddress : System.IEquatable<SocketAddress>
internal static readonly int MaxAddressSize = SocketAddressPal.MaxAddressSize;
#pragma warning restore CA1802

internal int InternalSize;
internal byte[] InternalBuffer;
private int _size;
private byte[] _buffer;

private const int MinSize = 2;
private const int DataOffset = 2;
Expand All @@ -39,21 +39,21 @@ public AddressFamily Family
{
get
{
return SocketAddressPal.GetAddressFamily(InternalBuffer);
return SocketAddressPal.GetAddressFamily(_buffer);
}
}

public int Size
{
get
{
return InternalSize;
return _size;
}
set
{
ArgumentOutOfRangeException.ThrowIfGreaterThan(value, InternalBuffer.Length);
ArgumentOutOfRangeException.ThrowIfGreaterThan(value, _buffer.Length);
ArgumentOutOfRangeException.ThrowIfLessThan(value, MinSize);
InternalSize = value;
_size = value;
}
}

Expand All @@ -69,15 +69,15 @@ public byte this[int offset]
{
throw new IndexOutOfRangeException();
}
return InternalBuffer[offset];
return _buffer[offset];
}
set
{
if ((uint)offset >= (uint)Size)
{
throw new IndexOutOfRangeException();
}
InternalBuffer[offset] = value;
_buffer[offset] = value;
}
}

Expand All @@ -97,11 +97,11 @@ public SocketAddress(AddressFamily family, int size)
{
ArgumentOutOfRangeException.ThrowIfLessThan(size, MinSize);

InternalSize = size;
InternalBuffer = new byte[size];
InternalBuffer[0] = (byte)InternalSize;
_size = size;
_buffer = new byte[size];
_buffer[0] = (byte)_size;

SocketAddressPal.SetAddressFamily(InternalBuffer, family);
SocketAddressPal.SetAddressFamily(_buffer, family);
}

internal SocketAddress(IPAddress ipAddress)
Expand All @@ -110,15 +110,15 @@ internal SocketAddress(IPAddress ipAddress)
{

// No Port.
SocketAddressPal.SetPort(InternalBuffer, 0);
SocketAddressPal.SetPort(_buffer, 0);

if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
Span<byte> addressBytes = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
ipAddress.TryWriteBytes(addressBytes, out int bytesWritten);
Debug.Assert(bytesWritten == IPAddressParserStatics.IPv6AddressBytes);

SocketAddressPal.SetIPv6Address(InternalBuffer, addressBytes, (uint)ipAddress.ScopeId);
SocketAddressPal.SetIPv6Address(_buffer, addressBytes, (uint)ipAddress.ScopeId);
}
else
{
Expand All @@ -127,21 +127,21 @@ internal SocketAddress(IPAddress ipAddress)
#pragma warning restore CS0618

Debug.Assert(ipAddress.AddressFamily == AddressFamily.InterNetwork);
SocketAddressPal.SetIPv4Address(InternalBuffer, address);
SocketAddressPal.SetIPv4Address(_buffer, address);
}
}

internal SocketAddress(IPAddress ipaddress, int port)
: this(ipaddress)
{
SocketAddressPal.SetPort(InternalBuffer, unchecked((ushort)port));
SocketAddressPal.SetPort(_buffer, unchecked((ushort)port));
}

internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan<byte> buffer)
{
InternalBuffer = buffer.ToArray();
InternalSize = InternalBuffer.Length;
SocketAddressPal.SetAddressFamily(InternalBuffer, addressFamily);
_buffer = buffer.ToArray();
_size = _buffer.Length;
SocketAddressPal.SetAddressFamily(_buffer, addressFamily);
}

/// <summary>This represents underlying memory that can be passed to native OS calls.</summary>
Expand All @@ -152,7 +152,7 @@ public Memory<byte> Buffer
{
get
{
return new Memory<byte>(InternalBuffer, 0, InternalSize);
return new Memory<byte>(_buffer, 0, _size);
}
}

Expand All @@ -164,14 +164,14 @@ internal IPAddress GetIPAddress()

Span<byte> address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
uint scope;
SocketAddressPal.GetIPv6Address(InternalBuffer, address, out scope);
SocketAddressPal.GetIPv6Address(_buffer, address, out scope);

return new IPAddress(address, (long)scope);
}
else if (Family == AddressFamily.InterNetwork)
{
Debug.Assert(Size >= IPv4AddressSize);
long address = (long)SocketAddressPal.GetIPv4Address(InternalBuffer) & 0x0FFFFFFFF;
long address = (long)SocketAddressPal.GetIPv4Address(_buffer) & 0x0FFFFFFFF;
return new IPAddress(address);
}
else
Expand All @@ -184,7 +184,7 @@ internal IPAddress GetIPAddress()
}
}

internal int GetPort() => (int)SocketAddressPal.GetPort(InternalBuffer);
internal int GetPort() => (int)SocketAddressPal.GetPort(_buffer);

internal IPEndPoint GetIPEndPoint()
{
Expand All @@ -199,7 +199,7 @@ public override bool Equals(object? comparand) =>
public override int GetHashCode()
{
HashCode hash = default;
hash.AddBytes(new ReadOnlySpan<byte>(InternalBuffer, 0, InternalSize));
hash.AddBytes(new ReadOnlySpan<byte>(_buffer, 0, _size));
return hash.ToHashCode();
}

Expand Down Expand Up @@ -234,7 +234,7 @@ public override string ToString()
result[length++] = ':';
result[length++] = '{';

byte[] buffer = InternalBuffer;
byte[] buffer = _buffer;
for (int i = DataOffset; i < Size; i++)
{
if (i > DataOffset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
<Compile Include="$(CommonPath)System\Net\DebugSafeHandleMinusOneIsInvalid.cs"
Link="Common\System\Net\DebugSafeHandleMinusOneIsInvalid.cs" />
<!-- System.Net common -->
<Compile Include="$(CommonPath)System\Net\IPEndPointExtensions.cs"
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
Link="Common\System\Net\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointStatics.cs"
Link="Common\System\Net\IPEndPointStatics.cs" />
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,18 @@ public Socket(SocketInformation socketInformation)
IPEndPoint ep = new IPEndPoint(tempAddress, 0);

Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep);
int size = socketAddress.Size;
wfurt marked this conversation as resolved.
Show resolved Hide resolved
unsafe
{
fixed (byte* bufferPtr = socketAddress.InternalBuffer)
fixed (int* sizePtr = &socketAddress.InternalSize)
fixed (byte* bufferPtr = socketAddress.Buffer.Span)
{
errorCode = SocketPal.GetSockName(_handle, bufferPtr, sizePtr);
errorCode = SocketPal.GetSockName(_handle, bufferPtr, &size);
}
}

if (errorCode == SocketError.Success)
{
socketAddress.Size = size;
_rightEndPoint = ep.Create(socketAddress);
}
else if (errorCode == SocketError.InvalidArgument)
Expand Down
49 changes: 32 additions & 17 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Net.Internals;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -300,22 +301,32 @@ public EndPoint? LocalEndPoint

if (_localEndPoint == null)
{
Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint);
Span<byte> buffer = stackalloc byte[SocketAddress.GetMaximumAddressSize(_addressFamily)];
int size = buffer.Length;
// This may throw ObjectDisposedException.
wfurt marked this conversation as resolved.
Show resolved Hide resolved

unsafe
{
fixed (byte* buffer = socketAddress.InternalBuffer)
fixed (int* bufferSize = &socketAddress.InternalSize)
fixed (byte* ptr = &MemoryMarshal.GetReference(buffer))
{
// This may throw ObjectDisposedException.
SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize);
SocketError errorCode = SocketPal.GetSockName(_handle, ptr, &size);
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
}
}
}
_localEndPoint = _rightEndPoint.Create(socketAddress);

if (_addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6)
{
_localEndPoint = IPEndPointExtensions.CreateIPEndPoint(buffer.Slice(0, size));
}
else
{
SocketAddress socketAddress = new SocketAddress(_rightEndPoint.AddressFamily, size);
buffer.Slice(0, size).CopyTo(socketAddress.Buffer.Span);
_localEndPoint = _rightEndPoint.Create(socketAddress);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}
}

return _localEndPoint;
Expand All @@ -338,25 +349,30 @@ public EndPoint? RemoteEndPoint
return null;
}

Internals.SocketAddress socketAddress =
_addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6 ?
IPEndPointExtensions.Serialize(_rightEndPoint) :
new Internals.SocketAddress(_addressFamily, SocketPal.MaximumAddressSize); // may be different size than _rightEndPoint.

Span<byte> buffer = stackalloc byte[SocketAddress.GetMaximumAddressSize(_addressFamily)];
int size = buffer.Length;
// This may throw ObjectDisposedException.
SocketError errorCode = SocketPal.GetPeerName(
_handle,
socketAddress.InternalBuffer,
ref socketAddress.InternalSize);

buffer,
ref size);
if (errorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
}

try
{
_remoteEndPoint = _rightEndPoint.Create(socketAddress);
if (_addressFamily == AddressFamily.InterNetwork || _addressFamily == AddressFamily.InterNetworkV6)
{
_remoteEndPoint = IPEndPointExtensions.CreateIPEndPoint(buffer.Slice(0, size));
}
else
{
SocketAddress socketAddress = new SocketAddress(_rightEndPoint.AddressFamily, size);
buffer.Slice(0, size).CopyTo(socketAddress.Buffer.Span);
_remoteEndPoint = _rightEndPoint.Create(socketAddress);
}
}
catch
{
Expand Down Expand Up @@ -765,8 +781,7 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd
SocketError errorCode = SocketPal.Bind(
_handle,
_protocolType,
socketAddress.InternalBuffer,
socketAddress.Size);
socketAddress.Buffer.Span);

// Throw an appropriate SocketException if the native call fails.
if (errorCode != SocketError.Success)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ internal void LogBuffer(int size)

private SocketError FinishOperationAccept(Internals.SocketAddress remoteSocketAddress)
{
System.Buffer.BlockCopy(_acceptBuffer!, 0, remoteSocketAddress.InternalBuffer, 0, _acceptAddressBufferCount);
new ReadOnlySpan<byte>(_acceptBuffer!, 0, _acceptAddressBufferCount).CopyTo(remoteSocketAddress.Buffer.Span);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
remoteSocketAddress.Size = _acceptAddressBufferCount;

Socket acceptedSocket = _currentSocket!.CreateAcceptSocket(
SocketPal.CreateSocket(_acceptedFileDescriptor),
_currentSocket._rightEndPoint!.Create(remoteSocketAddress));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
{
bool success = socket.ConnectEx(
handle,
_socketAddress!.InternalBuffer.AsSpan(),
_socketAddress!.Buffer.Span,
(IntPtr)(bufferPtr + _offset),
_count,
out int bytesTransferred,
Expand Down Expand Up @@ -762,7 +762,7 @@ internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handl
1,
out int bytesTransferred,
_socketFlags,
_socketAddress!.InternalBuffer.AsSpan(),
_socketAddress!.Buffer.Span,
overlapped,
IntPtr.Zero);

Expand All @@ -789,7 +789,7 @@ internal unsafe SocketError DoOperationSendToMultiBuffer(SafeSocketHandle handle
_bufferListInternal!.Count,
out int bytesTransferred,
_socketFlags,
_socketAddress!.InternalBuffer.AsSpan(),
_socketAddress!.Buffer.Span,
overlapped,
IntPtr.Zero);

Expand Down Expand Up @@ -1058,10 +1058,11 @@ private unsafe SocketError FinishOperationAccept(Internals.SocketAddress remoteS
out localAddr,
out localAddrLength,
out remoteAddr,
out remoteSocketAddress.InternalSize
out int size
);

Marshal.Copy(remoteAddr, remoteSocketAddress.InternalBuffer, 0, remoteSocketAddress.Size);
new ReadOnlySpan<byte>((void*)remoteAddr, size).CopyTo(remoteSocketAddress.Buffer.Span);
remoteSocketAddress.Size = size;
}

socketError = Interop.Winsock.setsockopt(
Expand Down
Loading
Loading