Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Intrinsicify SpanHelpers.IndexOf(char) #22505

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 273 additions & 63 deletions src/System.Private.CoreLib/shared/System/SpanHelpers.Char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;

using Internal.Runtime.CompilerServices;

#if BIT64
using nuint = System.UInt64;
using nint = System.Int64;
#else
using nuint = System.UInt32;
using nint = System.Int32;
#endif

namespace System
Expand Down Expand Up @@ -218,93 +221,243 @@ public static unsafe int IndexOf(ref char searchSpace, char value, int length)
{
Debug.Assert(length >= 0);

fixed (char* pChars = &searchSpace)
{
char* pCh = pChars;
char* pEndCh = pCh + length;
nint offset = 0;
nint lengthToExamine = length;

if (Vector.IsHardwareAccelerated && length >= Vector<ushort>.Count * 2)
if (((int)Unsafe.AsPointer(ref searchSpace) & 1) != 0)
{
// Input isn't char aligned, we won't be able to align it to a Vector
}
else if (Sse2.IsSupported)
{
// Avx2 branch also operates on Sse2 sizes, so check is combined.
// Needs to be double length to allow us to align the data first.
if (length >= Vector128<ushort>.Count * 2)
{
// Figure out how many characters to read sequentially until we are vector aligned
// This is equivalent to:
// unaligned = ((int)pCh % Unsafe.SizeOf<Vector<ushort>>()) / elementsPerByte
// length = (Vector<ushort>.Count - unaligned) % Vector<ushort>.Count
const int elementsPerByte = sizeof(ushort) / sizeof(byte);
int unaligned = ((int)pCh & (Unsafe.SizeOf<Vector<ushort>>() - 1)) / elementsPerByte;
length = (Vector<ushort>.Count - unaligned) & (Vector<ushort>.Count - 1);
lengthToExamine = UnalignedCountVector128(ref searchSpace);
}

SequentialScan:
while (length >= 4)
}
else if (Vector.IsHardwareAccelerated)
{
// Needs to be double length to allow us to align the data first.
if (length >= Vector<ushort>.Count * 2)
{
length -= 4;
lengthToExamine = UnalignedCountVector(ref searchSpace);
}
}

if (pCh[0] == value)
goto Found;
if (pCh[1] == value)
goto Found1;
if (pCh[2] == value)
goto Found2;
if (pCh[3] == value)
goto Found3;
SequentialScan:
// In the non-vector case lengthToExamine is the total length.
// In the vector case lengthToExamine first aligns to Vector,
// then in a second pass after the Vector lengths is the
// remaining data that is shorter than a Vector length.
while (lengthToExamine >= 4)
{
ref char current = ref Add(ref searchSpace, offset);

if (value == current)
goto Found;
if (value == Add(ref current, 1))
goto Found1;
if (value == Add(ref current, 2))
goto Found2;
if (value == Add(ref current, 3))
goto Found3;

offset += 4;
lengthToExamine -= 4;
}

pCh += 4;
}
while (lengthToExamine > 0)
{
if (value == Add(ref searchSpace, offset))
goto Found;

while (length > 0)
offset += 1;
lengthToExamine -= 1;
}

// We get past SequentialScan only if IsHardwareAccelerated or intrinsic .IsSupported is true. However, we still have the redundant check to allow
// the JIT to see that the code is unreachable and eliminate it when the platform does not have hardware accelerated.
if (Avx2.IsSupported)
{
if (offset < length)
{
length--;
Debug.Assert(length - offset >= Vector128<ushort>.Count);
if (((nint)Unsafe.AsPointer(ref Unsafe.Add(ref searchSpace, (IntPtr)offset)) & (nint)(Vector256<byte>.Count - 1)) != 0)
{
// Not currently aligned to Vector256 (is aligned to Vector128); this can cause a problem for searches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment could be clarified to explain the logic here.

Basically, under optimal conditions we are not 32-byte aligned and are instead 16-byte aligned. In the non-optimal case, we could have any alignment, but we definitely have at least 16-bytes available to read, so we shouldn't fault.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm not sure what (in the non-optimal case) prevents an AV for searches with no upper bound (String.wcslen)....

That is, if someone creates a span over unpinned data and with length: int.MaxValue and then they call IndexOf('\0'), the length could be less than 16, and it could be relocated to just before the end of a page, which could then fault when we do the LoadVector128 below, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm not sure what (in the non-optimal case) prevents an AV for searches with no upper bound (String.wcslen)....

String.wcslen(char* ptr) is over pointer data so the input is fixed data and the GC can't relocate it. No guarantees are offered for unfixed data and if you are creating a Span that is larger than data you own you're already in trouble; if there is no null terminator then it will rightfully fault.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should likely have a comment explicitly calling that out then. Specifically that this assumes that the length is either correct or that the data is pinned and that you may AV otherwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved comment

// with no upper bound e.g. String.wcslen. Start with a check on Vector128 to align to Vector256,
// before moving to processing Vector256.

// If the input searchSpan has been fixed or pinned, this ensures we do not fault across memory pages
// while searching for an end of string. Specifically that this assumes that the length is either correct
// or that the data is pinned otherwise it may cause an AccessViolation from crossing a page boundary into an
// unowned page. If the search is unbounded (e.g. null terminator in wcslen) and the search value is not found,
// again this will likely cause an AccessViolation. However, correctly bounded searches will return -1 rather
// than ever causing an AV.

// If the searchSpan has not been fixed or pinned the GC can relocate it during the execution of this
// method, so the alignment only acts as best endeavour. The GC cost is likely to dominate over
// the misalignment that may occur after; to we default to giving the GC a free hand to relocate and
// its up to the caller whether they are operating over fixed data.
Vector128<ushort> values = Vector128.Create((ushort)value);
Vector128<ushort> search = LoadVector128(ref searchSpace, offset);

// Same method as below
int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search).AsByte());
if (matches == 0)
{
// Zero flags set so no matches
offset += Vector128<ushort>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (BitOperations.TrailingZeroCount(matches) / sizeof(char)));
}
}

if (pCh[0] == value)
goto Found;
lengthToExamine = GetCharVector256SpanLength(offset, length);
if (lengthToExamine > 0)
{
Vector256<ushort> values = Vector256.Create((ushort)value);
do
{
Debug.Assert(lengthToExamine >= Vector256<ushort>.Count);

Vector256<ushort> search = LoadVector256(ref searchSpace, offset);
int matches = Avx2.MoveMask(Avx2.CompareEqual(values, search).AsByte());
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
if (matches == 0)
{
// Zero flags set so no matches
offset += Vector256<ushort>.Count;
lengthToExamine -= Vector256<ushort>.Count;
continue;
}

// Find bitflag offset of first match and add to current offset,
// flags are in bytes so divide for chars
return (int)(offset + (BitOperations.TrailingZeroCount(matches) / sizeof(char)));
} while (lengthToExamine > 0);
}

pCh++;
lengthToExamine = GetCharVector128SpanLength(offset, length);
if (lengthToExamine > 0)
{
Debug.Assert(lengthToExamine >= Vector128<ushort>.Count);

Vector128<ushort> values = Vector128.Create((ushort)value);
Vector128<ushort> search = LoadVector128(ref searchSpace, offset);

// Same method as above
int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search).AsByte());
if (matches == 0)
{
// Zero flags set so no matches
offset += Vector128<ushort>.Count;
// Don't need to change lengthToExamine here as we don't use its current value again.
}
else
{
// Find bitflag offset of first match and add to current offset,
// flags are in bytes so divide for chars
return (int)(offset + (BitOperations.TrailingZeroCount(matches) / sizeof(char)));
}
}

if (offset < length)
{
lengthToExamine = length - offset;
goto SequentialScan;
}
}
}
else if (Sse2.IsSupported)
{
if (offset < length)
{
Debug.Assert(length - offset >= Vector128<ushort>.Count);

// We get past SequentialScan only if IsHardwareAccelerated is true. However, we still have the redundant check to allow
// the JIT to see that the code is unreachable and eliminate it when the platform does not have hardware accelerated.
if (Vector.IsHardwareAccelerated && pCh < pEndCh)
lengthToExamine = GetCharVector128SpanLength(offset, length);
if (lengthToExamine > 0)
{
Vector128<ushort> values = Vector128.Create((ushort)value);
do
{
Debug.Assert(lengthToExamine >= Vector128<ushort>.Count);

Vector128<ushort> search = LoadVector128(ref searchSpace, offset);

// Same method as above
int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search).AsByte());
if (matches == 0)
{
// Zero flags set so no matches
offset += Vector128<ushort>.Count;
lengthToExamine -= Vector128<ushort>.Count;
continue;
}

// Find bitflag offset of first match and add to current offset,
// flags are in bytes so divide for chars
return (int)(offset + (BitOperations.TrailingZeroCount(matches) / sizeof(char)));
} while (lengthToExamine > 0);
}

if (offset < length)
{
lengthToExamine = length - offset;
goto SequentialScan;
}
}
}
else if (Vector.IsHardwareAccelerated)
{
if (offset < length)
{
// Get the highest multiple of Vector<ushort>.Count that is within the search space.
// That will be how many times we iterate in the loop below.
// This is equivalent to: length = Vector<ushort>.Count * ((int)(pEndCh - pCh) / Vector<ushort>.Count)
length = (int)((pEndCh - pCh) & ~(Vector<ushort>.Count - 1));
Debug.Assert(length - offset >= Vector<ushort>.Count);

// Get comparison Vector
Vector<ushort> vComparison = new Vector<ushort>(value);
lengthToExamine = GetCharVectorSpanLength(offset, length);

while (length > 0)
if (lengthToExamine > 0)
{
// Using Unsafe.Read instead of ReadUnaligned since the search space is pinned and pCh is always vector aligned
Debug.Assert(((int)pCh & (Unsafe.SizeOf<Vector<ushort>>() - 1)) == 0);
Vector<ushort> vMatches = Vector.Equals(vComparison, Unsafe.Read<Vector<ushort>>(pCh));
if (Vector<ushort>.Zero.Equals(vMatches))
Vector<ushort> values = new Vector<ushort>((ushort)value);
do
{
pCh += Vector<ushort>.Count;
length -= Vector<ushort>.Count;
continue;
}
// Find offset of first match
return (int)(pCh - pChars) + LocateFirstFoundChar(vMatches);
Debug.Assert(lengthToExamine >= Vector<ushort>.Count);

var matches = Vector.Equals(values, LoadVector(ref searchSpace, offset));
if (Vector<ushort>.Zero.Equals(matches))
{
offset += Vector<ushort>.Count;
lengthToExamine -= Vector<ushort>.Count;
continue;
}

// Find offset of first match
return (int)(offset + LocateFirstFoundChar(matches));
} while (lengthToExamine > 0);
}

if (pCh < pEndCh)
if (offset < length)
{
length = (int)(pEndCh - pCh);
lengthToExamine = length - offset;
goto SequentialScan;
}
}

return -1;
Found3:
pCh++;
Found2:
pCh++;
Found1:
pCh++;
Found:
return (int)(pCh - pChars);
}
return -1;
Found3:
return (int)(offset + 3);
Found2:
return (int)(offset + 2);
Found1:
return (int)(offset + 1);
Found:
return (int)(offset);
}

[MethodImpl(MethodImplOptions.AggressiveOptimization)]
Expand Down Expand Up @@ -876,5 +1029,62 @@ private static int LocateLastFoundChar(ulong match)
{
return 3 - (BitOperations.LeadingZeroCount(match) >> 4);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref char Add(ref char source, nint elementOffset)
=> ref Unsafe.Add(ref source, (IntPtr)elementOffset);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector<ushort> LoadVector(ref char start, nint offset)
=> Unsafe.ReadUnaligned<Vector<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, (IntPtr)offset)));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector128<ushort> LoadVector128(ref char start, nint offset)
=> Unsafe.ReadUnaligned<Vector128<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, (IntPtr)offset)));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector256<ushort> LoadVector256(ref char start, nint offset)
=> Unsafe.ReadUnaligned<Vector256<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, (IntPtr)offset)));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe UIntPtr LoadUIntPtr(ref char start, nint offset)
=> Unsafe.ReadUnaligned<UIntPtr>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, (IntPtr)offset)));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe nint GetCharVectorSpanLength(nint offset, nint length)
=> ((length - offset) & ~(Vector<ushort>.Count - 1));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe nint GetCharVector128SpanLength(nint offset, nint length)
=> ((length - offset) & ~(Vector128<ushort>.Count - 1));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static nint GetCharVector256SpanLength(nint offset, nint length)
=> ((length - offset) & ~(Vector256<ushort>.Count - 1));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe nint UnalignedCountVector(ref char searchSpace)
{
const int ElementsPerByte = sizeof(ushort) / sizeof(byte);
// Figure out how many characters to read sequentially until we are vector aligned
// This is equivalent to:
// unaligned = ((int)pCh % Unsafe.SizeOf<Vector<ushort>>()) / ElementsPerByte
// length = (Vector<ushort>.Count - unaligned) % Vector<ushort>.Count

// This alignment is only valid if the GC does not relocate; so we use ReadUnaligned to get the data.
// If a GC does occur and alignment is lost, the GC cost will outweigh any gains from alignment so it
// isn't too important to pin to maintain the alignment.
return (nint)(uint)(-(int)Unsafe.AsPointer(ref searchSpace) / ElementsPerByte ) & (Vector<ushort>.Count - 1);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe nint UnalignedCountVector128(ref char searchSpace)
{
const int ElementsPerByte = sizeof(ushort) / sizeof(byte);
// This alignment is only valid if the GC does not relocate; so we use ReadUnaligned to get the data.
// If a GC does occur and alignment is lost, the GC cost will outweigh any gains from alignment so it
// isn't too important to pin to maintain the alignment.
return (nint)(uint)(-(int)Unsafe.AsPointer(ref searchSpace) / ElementsPerByte ) & (Vector128<ushort>.Count - 1);
}
}
}
Loading