Skip to content

Commit

Permalink
Specialize by length in single-value SearchValues<string> (#96429)
Browse files Browse the repository at this point in the history
* Specialize by length in single-value SearchValues<string>

* Extra assert

* More comments
  • Loading branch information
MihaZupan authored Jan 3, 2024
1 parent 9459844 commit 7957edc
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Text;

namespace System.Buffers
{
Expand Down Expand Up @@ -61,7 +62,7 @@ public static bool StartsWith<TCaseSensitivity>(ref char matchStart, int lengthR
return false;
}

return TCaseSensitivity.Equals(ref matchStart, candidate);
return TCaseSensitivity.Equals<ValueLength8OrLongerOrUnknown>(ref matchStart, candidate);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -79,13 +80,38 @@ private static bool ScalarEquals<TCaseSensitivity>(ref char matchStart, string c
return true;
}

public interface IValueLength
{
static abstract bool AtLeast4Chars { get; }
static abstract bool AtLeast8CharsOrUnknown { get; }
}

public readonly struct ValueLengthLessThan4 : IValueLength
{
public static bool AtLeast4Chars => false;
public static bool AtLeast8CharsOrUnknown => false;
}

public readonly struct ValueLength4To7 : IValueLength
{
public static bool AtLeast4Chars => true;
public static bool AtLeast8CharsOrUnknown => false;
}

// "Unknown" is currently only used by Teddy when confirming matches.
public readonly struct ValueLength8OrLongerOrUnknown : IValueLength
{
public static bool AtLeast4Chars => true;
public static bool AtLeast8CharsOrUnknown => true;
}

public interface ICaseSensitivity
{
static abstract char TransformInput(char input);
static abstract Vector128<byte> TransformInput(Vector128<byte> input);
static abstract Vector256<byte> TransformInput(Vector256<byte> input);
static abstract Vector512<byte> TransformInput(Vector512<byte> input);
static abstract bool Equals(ref char matchStart, string candidate);
static abstract bool Equals<TValueLength>(ref char matchStart, string candidate) where TValueLength : struct, IValueLength;
}

// Performs no case transformations.
Expand All @@ -104,8 +130,41 @@ public interface ICaseSensitivity
public static Vector512<byte> TransformInput(Vector512<byte> input) => input;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
ScalarEquals<CaseSensitive>(ref matchStart, candidate);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
Debug.Assert(candidate.Length > 1);

ref byte first = ref Unsafe.As<char, byte>(ref matchStart);
ref byte second = ref Unsafe.As<char, byte>(ref candidate.GetRawStringData());
nuint byteLength = (nuint)(uint)candidate.Length * 2;

if (TValueLength.AtLeast8CharsOrUnknown)
{
return SpanHelpers.SequenceEqual(ref first, ref second, byteLength);
}

Debug.Assert(matchStart == candidate[0], "This should only be called after the first character has been checked");

if (TValueLength.AtLeast4Chars)
{
nuint offset = byteLength - sizeof(ulong);
ulong differentBits = Unsafe.ReadUnaligned<ulong>(ref first) - Unsafe.ReadUnaligned<ulong>(ref second);
differentBits |= Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref first, offset)) - Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref second, offset));
return differentBits == 0;
}
else
{
Debug.Assert(candidate.Length is 2 or 3);

// We know that the candidate is 2 or 3 characters long, and that the first character has already been checked.
// We only have to to check the last 2 characters also match.
nuint offset = byteLength - sizeof(uint);

return Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref first, offset))
== Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref second, offset));
}
}
}

// Transforms inputs to their uppercase variants with the assumption that all input characters are ASCII letters.
Expand All @@ -125,8 +184,38 @@ public static bool Equals(ref char matchStart, string candidate) =>
public static Vector512<byte> TransformInput(Vector512<byte> input) => input & Vector512.Create(unchecked((byte)~0x20));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
ScalarEquals<CaseInsensitiveAsciiLetters>(ref matchStart, candidate);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
Debug.Assert(candidate.Length > 1);
Debug.Assert(candidate.ToUpperInvariant() == candidate);

if (TValueLength.AtLeast8CharsOrUnknown)
{
return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length);
}

ref byte first = ref Unsafe.As<char, byte>(ref matchStart);
ref byte second = ref Unsafe.As<char, byte>(ref candidate.GetRawStringData());
nuint byteLength = (nuint)(uint)candidate.Length * 2;

if (TValueLength.AtLeast4Chars)
{
const ulong CaseMask = ~0x20002000200020u;
nuint offset = byteLength - sizeof(ulong);
ulong differentBits = (Unsafe.ReadUnaligned<ulong>(ref first) & CaseMask) - Unsafe.ReadUnaligned<ulong>(ref second);
differentBits |= (Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref second, offset));
return differentBits == 0;
}
else
{
const uint CaseMask = ~0x200020u;
nuint offset = byteLength - sizeof(uint);
uint differentBits = (Unsafe.ReadUnaligned<uint>(ref first) & CaseMask) - Unsafe.ReadUnaligned<uint>(ref second);
differentBits |= (Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref second, offset));
return differentBits == 0;
}
}
}

// Transforms inputs to their uppercase variants with the assumption that all input characters are ASCII.
Expand Down Expand Up @@ -170,8 +259,16 @@ public static Vector512<byte> TransformInput(Vector512<byte> input)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
ScalarEquals<CaseInsensitiveAscii>(ref matchStart, candidate);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
if (TValueLength.AtLeast8CharsOrUnknown)
{
return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length);
}

return ScalarEquals<CaseInsensitiveAscii>(ref matchStart, candidate);
}
}

// We can't efficiently map non-ASCII inputs to their Ordinal uppercase variants,
Expand All @@ -184,8 +281,16 @@ public static bool Equals(ref char matchStart, string candidate) =>
public static Vector512<byte> TransformInput(Vector512<byte> input) => throw new UnreachableException();

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool Equals(ref char matchStart, string candidate) =>
Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
where TValueLength : struct, IValueLength
{
if (TValueLength.AtLeast8CharsOrUnknown)
{
return Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
}

return Ordinal.EqualsIgnoreCase_Scalar(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace System.Buffers
// This implementation uses 3 precomputed anchor points when searching.
// This implementation may also be used for length=2 values, in which case two anchors point at the same position.
// Has an O(i * m) worst-case, with the expected time closer to O(n) for most inputs.
internal sealed class SingleStringSearchValuesThreeChars<TCaseSensitivity> : StringSearchValuesBase
internal sealed class SingleStringSearchValuesThreeChars<TValueLength, TCaseSensitivity> : StringSearchValuesBase
where TValueLength : struct, IValueLength
where TCaseSensitivity : struct, ICaseSensitivity
{
private const ushort CaseConversionMask = unchecked((ushort)~0x20);
Expand All @@ -34,6 +35,7 @@ public SingleStringSearchValuesThreeChars(HashSet<string>? uniqueValues, string
{
// We could have more than one entry in 'uniqueValues' if this value is an exact prefix of all the others.
Debug.Assert(value.Length > 1);
Debug.Assert((value.Length >= 8) == TValueLength.AtLeast8CharsOrUnknown);

CharacterFrequencyHelper.GetSingleStringMultiCharacterOffsets(value, IgnoreCase, out int ch2Offset, out int ch3Offset);

Expand Down Expand Up @@ -228,7 +230,7 @@ private int IndexOf(ref char searchSpace, int searchSpaceLength)

// CaseInsensitiveUnicode doesn't support single-character transformations, so we skip checking the first character first.
if ((typeof(TCaseSensitivity) == typeof(CaseInsensitiveUnicode) || TCaseSensitivity.TransformInput(cur) == valueHead) &&
TCaseSensitivity.Equals(ref cur, value))
TCaseSensitivity.Equals<TValueLength>(ref cur, value))
{
return (int)i;
}
Expand Down Expand Up @@ -325,7 +327,11 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char

ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length);

if (TCaseSensitivity.Equals(ref matchRef, _value))
// If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value.
// With case-sensitive comparisons, we've therefore already confirmed the match, so we can skip doing so here.
// With case-insensitive comparisons, we applied a mask to the input, so while the anchors likely matched, we can't be sure.
if ((typeof(TCaseSensitivity) == typeof(CaseSensitive) && !TValueLength.AtLeast4Chars) ||
TCaseSensitivity.Equals<TValueLength>(ref matchRef, _value))
{
offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2);
return true;
Expand Down Expand Up @@ -353,7 +359,11 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char

ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length);

if (TCaseSensitivity.Equals(ref matchRef, _value))
// If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value.
// With case-sensitive comparisons, we've therefore already confirmed the match, so we can skip doing so here.
// With case-insensitive comparisons, we applied a mask to the input, so while the anchors likely matched, we can't be sure.
if ((typeof(TCaseSensitivity) == typeof(CaseSensitive) && !TValueLength.AtLeast4Chars) ||
TCaseSensitivity.Equals<TValueLength>(ref matchRef, _value))
{
offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,25 +309,16 @@ private static SearchValues<string> CreateForSingleValue(

if (Vector128.IsHardwareAccelerated && value.Length > 1 && value.Length <= maxLength)
{
if (!ignoreCase)
SearchValues<string>? searchValues = value.Length switch
{
return new SingleStringSearchValuesThreeChars<CaseSensitive>(uniqueValues, value);
}

if (asciiLettersOnly)
{
return new SingleStringSearchValuesThreeChars<CaseInsensitiveAsciiLetters>(uniqueValues, value);
}
< 4 => TryCreateSingleValuesThreeChars<ValueLengthLessThan4>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
< 8 => TryCreateSingleValuesThreeChars<ValueLength4To7>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
_ => TryCreateSingleValuesThreeChars<ValueLength8OrLongerOrUnknown>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
};

if (allAscii)
if (searchValues is not null)
{
return new SingleStringSearchValuesThreeChars<CaseInsensitiveAscii>(uniqueValues, value);
}

// When ignoring casing, all anchor chars we search for must be ASCII.
if (char.IsAscii(value[0]) && value.AsSpan().LastIndexOfAnyInRange((char)0, (char)127) > 0)
{
return new SingleStringSearchValuesThreeChars<CaseInsensitiveUnicode>(uniqueValues, value);
return searchValues;
}
}

Expand All @@ -338,6 +329,39 @@ private static SearchValues<string> CreateForSingleValue(
: new SingleStringSearchValuesFallback<SearchValues.FalseConst>(value, uniqueValues);
}

private static SearchValues<string>? TryCreateSingleValuesThreeChars<TValueLength>(
string value,
HashSet<string>? uniqueValues,
bool ignoreCase,
bool allAscii,
bool asciiLettersOnly)
where TValueLength : struct, IValueLength
{
if (!ignoreCase)
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseSensitive>(uniqueValues, value);
}

if (asciiLettersOnly)
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveAsciiLetters>(uniqueValues, value);
}

if (allAscii)
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveAscii>(uniqueValues, value);
}

// SingleStringSearchValuesThreeChars doesn't have logic to handle non-ASCII case conversion, so we require that anchor characters are ASCII.
// Right now we're always selecting the first character as one of the anchors, and we need at least two.
if (char.IsAscii(value[0]) && value.AsSpan(1).ContainsAnyInRange((char)0, (char)127))
{
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveUnicode>(uniqueValues, value);
}

return null;
}

private static void AnalyzeValues(
ReadOnlySpan<string> values,
ref bool ignoreCase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal abstract class StringSearchValuesBase : SearchValues<string>
private readonly HashSet<string>? _uniqueValues;

/// <summary>
/// This exists to allow <see cref="SingleStringSearchValuesThreeChars{TCaseSensitivity}"/> to avoid the HashSet allocation.
/// This exists to allow <see cref="SingleStringSearchValuesThreeChars{TValueLength, TCaseSensitivity}"/> to avoid the HashSet allocation.
/// </summary>
protected bool HasUniqueValues => _uniqueValues is not null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ public static bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<char>
=> left.Length == right.Length
&& EqualsIgnoreCase<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(left)), ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(right)), (uint)right.Length);

internal static bool EqualsIgnoreCase(ref char left, ref char right, nuint length) =>
EqualsIgnoreCase<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref left), ref Unsafe.As<char, ushort>(ref right), length);

private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref TRight right, nuint length)
where TLeft : unmanaged, INumberBase<TLeft>
where TRight : unmanaged, INumberBase<TRight>
Expand Down

0 comments on commit 7957edc

Please sign in to comment.