Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize by length in single-value SearchValues<string> #96429

Merged
merged 3 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Text;

namespace System.Buffers
{
Expand Down Expand Up @@ -61,7 +62,7 @@ public static bool StartsWith<TCaseSensitivity>(ref char matchStart, int lengthR
return false;
}

return TCaseSensitivity.Equals(ref matchStart, candidate);
return TCaseSensitivity.Equals<ValueLength8OrLongerOrUnknown>(ref matchStart, candidate);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -79,13 +80,38 @@ private static bool ScalarEquals<TCaseSensitivity>(ref char matchStart, string c
return true;
}

public interface IValueLength
{
static abstract bool AtLeast4Chars { get; }
static abstract bool AtLeast8CharsOrUnknown { get; }
}

public readonly struct ValueLengthLessThan4 : IValueLength
{
public static bool AtLeast4Chars => false;
public static bool AtLeast8CharsOrUnknown => false;
}

public readonly struct ValueLength4To7 : IValueLength
{
public static bool AtLeast4Chars => true;
public static bool AtLeast8CharsOrUnknown => false;
}

// "Unknown" is currently only used by Teddy when confirming matches.
public readonly struct ValueLength8OrLongerOrUnknown : IValueLength
{
public static bool AtLeast4Chars => true;
public static bool AtLeast8CharsOrUnknown => true;
}

public interface ICaseSensitivity
{
static abstract char TransformInput(char input);
static abstract Vector128<byte> TransformInput(Vector128<byte> input);
static abstract Vector256<byte> TransformInput(Vector256<byte> input);
static abstract Vector512<byte> TransformInput(Vector512<byte> input);
static abstract bool Equals(ref char matchStart, string candidate);
static abstract bool Equals<TValueLength>(ref char matchStart, string candidate) where TValueLength : struct, IValueLength;
}

// Performs no case transformations.
Expand All @@ -104,8 +130,41 @@ public interface ICaseSensitivity
public static Vector512<byte> TransformInput(Vector512<byte> input) => input;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
ScalarEquals<CaseSensitive>(ref matchStart, candidate);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
Debug.Assert(candidate.Length > 1);

ref byte first = ref Unsafe.As<char, byte>(ref matchStart);
ref byte second = ref Unsafe.As<char, byte>(ref candidate.GetRawStringData());
nuint byteLength = (nuint)(uint)candidate.Length * 2;

if (TValueLength.AtLeast8CharsOrUnknown)
{
return SpanHelpers.SequenceEqual(ref first, ref second, byteLength);
}

Debug.Assert(matchStart == candidate[0], "This should only be called after the first character has been checked");

if (TValueLength.AtLeast4Chars)
{
nuint offset = byteLength - sizeof(ulong);
ulong differentBits = Unsafe.ReadUnaligned<ulong>(ref first) - Unsafe.ReadUnaligned<ulong>(ref second);
differentBits |= Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref first, offset)) - Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref second, offset));
return differentBits == 0;
}
else
{
Debug.Assert(candidate.Length is 2 or 3);

// We know that the candidate is 2 or 3 characters long, and that the first character has already been checked.
// We only have to to check the last 2 characters also match.
nuint offset = byteLength - sizeof(uint);

return Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref first, offset))
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved
== Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref second, offset));
}
}
}

// Transforms inputs to their uppercase variants with the assumption that all input characters are ASCII letters.
Expand All @@ -125,8 +184,38 @@ public static bool Equals(ref char matchStart, string candidate) =>
public static Vector512<byte> TransformInput(Vector512<byte> input) => input & Vector512.Create(unchecked((byte)~0x20));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
ScalarEquals<CaseInsensitiveAsciiLetters>(ref matchStart, candidate);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
Debug.Assert(candidate.Length > 1);
Debug.Assert(candidate.ToUpperInvariant() == candidate);

if (TValueLength.AtLeast8CharsOrUnknown)
{
return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length);
}

ref byte first = ref Unsafe.As<char, byte>(ref matchStart);
ref byte second = ref Unsafe.As<char, byte>(ref candidate.GetRawStringData());
nuint byteLength = (nuint)(uint)candidate.Length * 2;

if (TValueLength.AtLeast4Chars)
{
const ulong CaseMask = ~0x20002000200020u;
nuint offset = byteLength - sizeof(ulong);
ulong differentBits = (Unsafe.ReadUnaligned<ulong>(ref first) & CaseMask) - Unsafe.ReadUnaligned<ulong>(ref second);
differentBits |= (Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref second, offset));
return differentBits == 0;
}
else
{
const uint CaseMask = ~0x200020u;
nuint offset = byteLength - sizeof(uint);
uint differentBits = (Unsafe.ReadUnaligned<uint>(ref first) & CaseMask) - Unsafe.ReadUnaligned<uint>(ref second);
differentBits |= (Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref second, offset));
return differentBits == 0;
}
}
}

// Transforms inputs to their uppercase variants with the assumption that all input characters are ASCII.
Expand Down Expand Up @@ -170,8 +259,16 @@ public static Vector512<byte> TransformInput(Vector512<byte> input)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
ScalarEquals<CaseInsensitiveAscii>(ref matchStart, candidate);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
if (TValueLength.AtLeast8CharsOrUnknown)
{
return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length);
}

return ScalarEquals<CaseInsensitiveAscii>(ref matchStart, candidate);
}
}

// We can't efficiently map non-ASCII inputs to their Ordinal uppercase variants,
Expand All @@ -184,8 +281,16 @@ public static bool Equals(ref char matchStart, string candidate) =>
public static Vector512<byte> TransformInput(Vector512<byte> input) => throw new UnreachableException();

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
if (TValueLength.AtLeast8CharsOrUnknown)
{
return Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
}

return Ordinal.EqualsIgnoreCase_Scalar(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace System.Buffers
// This implementation uses 3 precomputed anchor points when searching.
// This implementation may also be used for length=2 values, in which case two anchors point at the same position.
// Has an O(i * m) worst-case, with the expected time closer to O(n) for most inputs.
internal sealed class SingleStringSearchValuesThreeChars<TCaseSensitivity> : StringSearchValuesBase
internal sealed class SingleStringSearchValuesThreeChars<TValueLength, TCaseSensitivity> : StringSearchValuesBase
where TValueLength : struct, IValueLength
where TCaseSensitivity : struct, ICaseSensitivity
{
private const ushort CaseConversionMask = unchecked((ushort)~0x20);
Expand All @@ -34,6 +35,7 @@ public SingleStringSearchValuesThreeChars(HashSet<string>? uniqueValues, string
{
// We could have more than one entry in 'uniqueValues' if this value is an exact prefix of all the others.
Debug.Assert(value.Length > 1);
Debug.Assert((value.Length >= 8) == TValueLength.AtLeast8CharsOrUnknown);

CharacterFrequencyHelper.GetSingleStringMultiCharacterOffsets(value, IgnoreCase, out int ch2Offset, out int ch3Offset);

Expand Down Expand Up @@ -228,7 +230,7 @@ private int IndexOf(ref char searchSpace, int searchSpaceLength)

// CaseInsensitiveUnicode doesn't support single-character transformations, so we skip checking the first character first.
if ((typeof(TCaseSensitivity) == typeof(CaseInsensitiveUnicode) || TCaseSensitivity.TransformInput(cur) == valueHead) &&
TCaseSensitivity.Equals(ref cur, value))
TCaseSensitivity.Equals<TValueLength>(ref cur, value))
{
return (int)i;
}
Expand Down Expand Up @@ -325,7 +327,11 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char

ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length);

if (TCaseSensitivity.Equals(ref matchRef, _value))
// If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value.
// With case-sensitive comparisons, we've therefore already confirmed the match, so we can skip doing so here.
// With case-insensitive comparisons, we applied a mask to the input, so while the anchors likely matched, we can't be sure.
if ((typeof(TCaseSensitivity) == typeof(CaseSensitive) && !TValueLength.AtLeast4Chars) ||
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved
TCaseSensitivity.Equals<TValueLength>(ref matchRef, _value))
{
offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2);
return true;
Expand Down Expand Up @@ -353,7 +359,11 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char

ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length);

if (TCaseSensitivity.Equals(ref matchRef, _value))
// If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value.
// With case-sensitive comparisons, we've therefore already confirmed the match, so we can skip doing so here.
// With case-insensitive comparisons, we applied a mask to the input, so while the anchors likely matched, we can't be sure.
if ((typeof(TCaseSensitivity) == typeof(CaseSensitive) && !TValueLength.AtLeast4Chars) ||
TCaseSensitivity.Equals<TValueLength>(ref matchRef, _value))
{
offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,25 +309,16 @@ private static SearchValues<string> CreateForSingleValue(

if (Vector128.IsHardwareAccelerated && value.Length > 1 && value.Length <= maxLength)
{
if (!ignoreCase)
SearchValues<string>? searchValues = value.Length switch
{
return new SingleStringSearchValuesThreeChars<CaseSensitive>(uniqueValues, value);
}

if (asciiLettersOnly)
{
return new SingleStringSearchValuesThreeChars<CaseInsensitiveAsciiLetters>(uniqueValues, value);
}
< 4 => TryCreateSingleValuesThreeChars<ValueLengthLessThan4>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
< 8 => TryCreateSingleValuesThreeChars<ValueLength4To7>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
_ => TryCreateSingleValuesThreeChars<ValueLength8OrLongerOrUnknown>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
};

if (allAscii)
if (searchValues is not null)
{
return new SingleStringSearchValuesThreeChars<CaseInsensitiveAscii>(uniqueValues, value);
}

// When ignoring casing, all anchor chars we search for must be ASCII.
if (char.IsAscii(value[0]) && value.AsSpan().LastIndexOfAnyInRange((char)0, (char)127) > 0)
{
return new SingleStringSearchValuesThreeChars<CaseInsensitiveUnicode>(uniqueValues, value);
return searchValues;
}
}

Expand All @@ -338,6 +329,39 @@ private static SearchValues<string> CreateForSingleValue(
: new SingleStringSearchValuesFallback<SearchValues.FalseConst>(value, uniqueValues);
}

private static SearchValues<string>? TryCreateSingleValuesThreeChars<TValueLength>(
string value,
HashSet<string>? uniqueValues,
bool ignoreCase,
bool allAscii,
bool asciiLettersOnly)
where TValueLength : struct, IValueLength
{
if (!ignoreCase)
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseSensitive>(uniqueValues, value);
}

if (asciiLettersOnly)
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveAsciiLetters>(uniqueValues, value);
}

if (allAscii)
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveAscii>(uniqueValues, value);
}

// SingleStringSearchValuesThreeChars doesn't have logic to handle non-ASCII case conversion, so we require that anchor characters are ASCII.
// Right now we're always selecting the first character as one of the anchors, and we need at least two.
if (char.IsAscii(value[0]) && value.AsSpan(1).ContainsAnyInRange((char)0, (char)127))
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveUnicode>(uniqueValues, value);
}

return null;
}

private static void AnalyzeValues(
ReadOnlySpan<string> values,
ref bool ignoreCase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal abstract class StringSearchValuesBase : SearchValues<string>
private readonly HashSet<string>? _uniqueValues;

/// <summary>
/// This exists to allow <see cref="SingleStringSearchValuesThreeChars{TCaseSensitivity}"/> to avoid the HashSet allocation.
/// This exists to allow <see cref="SingleStringSearchValuesThreeChars{TValueLength, TCaseSensitivity}"/> to avoid the HashSet allocation.
/// </summary>
protected bool HasUniqueValues => _uniqueValues is not null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ public static bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<char>
=> left.Length == right.Length
&& EqualsIgnoreCase<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(left)), ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(right)), (uint)right.Length);

internal static bool EqualsIgnoreCase(ref char left, ref char right, nuint length) =>
EqualsIgnoreCase<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref left), ref Unsafe.As<char, ushort>(ref right), length);

private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref TRight right, nuint length)
where TLeft : unmanaged, INumberBase<TLeft>
where TRight : unmanaged, INumberBase<TRight>
Expand Down