Skip to content

Commit 7957edc

Browse files
authored
Specialize by length in single-value SearchValues<string> (#96429)
* Specialize by length in single-value SearchValues<string> * Extra assert * More comments
1 parent 9459844 commit 7957edc

File tree

5 files changed

+173
-31
lines changed

5 files changed

+173
-31
lines changed

src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/Helpers/StringSearchValuesHelper.cs

+115-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Runtime.CompilerServices;
77
using System.Runtime.InteropServices;
88
using System.Runtime.Intrinsics;
9+
using System.Text;
910

1011
namespace System.Buffers
1112
{
@@ -61,7 +62,7 @@ public static bool StartsWith<TCaseSensitivity>(ref char matchStart, int lengthR
6162
return false;
6263
}
6364

64-
return TCaseSensitivity.Equals(ref matchStart, candidate);
65+
return TCaseSensitivity.Equals<ValueLength8OrLongerOrUnknown>(ref matchStart, candidate);
6566
}
6667

6768
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -79,13 +80,38 @@ private static bool ScalarEquals<TCaseSensitivity>(ref char matchStart, string c
7980
return true;
8081
}
8182

83+
public interface IValueLength
84+
{
85+
static abstract bool AtLeast4Chars { get; }
86+
static abstract bool AtLeast8CharsOrUnknown { get; }
87+
}
88+
89+
public readonly struct ValueLengthLessThan4 : IValueLength
90+
{
91+
public static bool AtLeast4Chars => false;
92+
public static bool AtLeast8CharsOrUnknown => false;
93+
}
94+
95+
public readonly struct ValueLength4To7 : IValueLength
96+
{
97+
public static bool AtLeast4Chars => true;
98+
public static bool AtLeast8CharsOrUnknown => false;
99+
}
100+
101+
// "Unknown" is currently only used by Teddy when confirming matches.
102+
public readonly struct ValueLength8OrLongerOrUnknown : IValueLength
103+
{
104+
public static bool AtLeast4Chars => true;
105+
public static bool AtLeast8CharsOrUnknown => true;
106+
}
107+
82108
public interface ICaseSensitivity
83109
{
84110
static abstract char TransformInput(char input);
85111
static abstract Vector128<byte> TransformInput(Vector128<byte> input);
86112
static abstract Vector256<byte> TransformInput(Vector256<byte> input);
87113
static abstract Vector512<byte> TransformInput(Vector512<byte> input);
88-
static abstract bool Equals(ref char matchStart, string candidate);
114+
static abstract bool Equals<TValueLength>(ref char matchStart, string candidate) where TValueLength : struct, IValueLength;
89115
}
90116

91117
// Performs no case transformations.
@@ -104,8 +130,41 @@ public interface ICaseSensitivity
104130
public static Vector512<byte> TransformInput(Vector512<byte> input) => input;
105131

106132
[MethodImpl(MethodImplOptions.AggressiveInlining)]
107-
public static bool Equals(ref char matchStart, string candidate) =>
108-
ScalarEquals<CaseSensitive>(ref matchStart, candidate);
133+
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
134+
where TValueLength : struct, IValueLength
135+
{
136+
Debug.Assert(candidate.Length > 1);
137+
138+
ref byte first = ref Unsafe.As<char, byte>(ref matchStart);
139+
ref byte second = ref Unsafe.As<char, byte>(ref candidate.GetRawStringData());
140+
nuint byteLength = (nuint)(uint)candidate.Length * 2;
141+
142+
if (TValueLength.AtLeast8CharsOrUnknown)
143+
{
144+
return SpanHelpers.SequenceEqual(ref first, ref second, byteLength);
145+
}
146+
147+
Debug.Assert(matchStart == candidate[0], "This should only be called after the first character has been checked");
148+
149+
if (TValueLength.AtLeast4Chars)
150+
{
151+
nuint offset = byteLength - sizeof(ulong);
152+
ulong differentBits = Unsafe.ReadUnaligned<ulong>(ref first) - Unsafe.ReadUnaligned<ulong>(ref second);
153+
differentBits |= Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref first, offset)) - Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref second, offset));
154+
return differentBits == 0;
155+
}
156+
else
157+
{
158+
Debug.Assert(candidate.Length is 2 or 3);
159+
160+
// We know that the candidate is 2 or 3 characters long, and that the first character has already been checked.
161+
// We only have to to check the last 2 characters also match.
162+
nuint offset = byteLength - sizeof(uint);
163+
164+
return Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref first, offset))
165+
== Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref second, offset));
166+
}
167+
}
109168
}
110169

111170
// Transforms inputs to their uppercase variants with the assumption that all input characters are ASCII letters.
@@ -125,8 +184,38 @@ public static bool Equals(ref char matchStart, string candidate) =>
125184
public static Vector512<byte> TransformInput(Vector512<byte> input) => input & Vector512.Create(unchecked((byte)~0x20));
126185

127186
[MethodImpl(MethodImplOptions.AggressiveInlining)]
128-
public static bool Equals(ref char matchStart, string candidate) =>
129-
ScalarEquals<CaseInsensitiveAsciiLetters>(ref matchStart, candidate);
187+
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
188+
where TValueLength : struct, IValueLength
189+
{
190+
Debug.Assert(candidate.Length > 1);
191+
Debug.Assert(candidate.ToUpperInvariant() == candidate);
192+
193+
if (TValueLength.AtLeast8CharsOrUnknown)
194+
{
195+
return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length);
196+
}
197+
198+
ref byte first = ref Unsafe.As<char, byte>(ref matchStart);
199+
ref byte second = ref Unsafe.As<char, byte>(ref candidate.GetRawStringData());
200+
nuint byteLength = (nuint)(uint)candidate.Length * 2;
201+
202+
if (TValueLength.AtLeast4Chars)
203+
{
204+
const ulong CaseMask = ~0x20002000200020u;
205+
nuint offset = byteLength - sizeof(ulong);
206+
ulong differentBits = (Unsafe.ReadUnaligned<ulong>(ref first) & CaseMask) - Unsafe.ReadUnaligned<ulong>(ref second);
207+
differentBits |= (Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned<ulong>(ref Unsafe.Add(ref second, offset));
208+
return differentBits == 0;
209+
}
210+
else
211+
{
212+
const uint CaseMask = ~0x200020u;
213+
nuint offset = byteLength - sizeof(uint);
214+
uint differentBits = (Unsafe.ReadUnaligned<uint>(ref first) & CaseMask) - Unsafe.ReadUnaligned<uint>(ref second);
215+
differentBits |= (Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned<uint>(ref Unsafe.Add(ref second, offset));
216+
return differentBits == 0;
217+
}
218+
}
130219
}
131220

132221
// Transforms inputs to their uppercase variants with the assumption that all input characters are ASCII.
@@ -170,8 +259,16 @@ public static Vector512<byte> TransformInput(Vector512<byte> input)
170259
}
171260

172261
[MethodImpl(MethodImplOptions.AggressiveInlining)]
173-
public static bool Equals(ref char matchStart, string candidate) =>
174-
ScalarEquals<CaseInsensitiveAscii>(ref matchStart, candidate);
262+
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
263+
where TValueLength : struct, IValueLength
264+
{
265+
if (TValueLength.AtLeast8CharsOrUnknown)
266+
{
267+
return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length);
268+
}
269+
270+
return ScalarEquals<CaseInsensitiveAscii>(ref matchStart, candidate);
271+
}
175272
}
176273

177274
// We can't efficiently map non-ASCII inputs to their Ordinal uppercase variants,
@@ -184,8 +281,16 @@ public static bool Equals(ref char matchStart, string candidate) =>
184281
public static Vector512<byte> TransformInput(Vector512<byte> input) => throw new UnreachableException();
185282

186283
[MethodImpl(MethodImplOptions.AggressiveInlining)]
187-
public static bool Equals(ref char matchStart, string candidate) =>
188-
Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
284+
public static bool Equals<TValueLength>(ref char matchStart, string candidate)
285+
where TValueLength : struct, IValueLength
286+
{
287+
if (TValueLength.AtLeast8CharsOrUnknown)
288+
{
289+
return Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
290+
}
291+
292+
return Ordinal.EqualsIgnoreCase_Scalar(ref matchStart, ref candidate.GetRawStringData(), candidate.Length);
293+
}
189294
}
190295
}
191296
}

src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/SingleStringSearchValuesThreeChars.cs

+14-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace System.Buffers
1515
// This implementation uses 3 precomputed anchor points when searching.
1616
// This implementation may also be used for length=2 values, in which case two anchors point at the same position.
1717
// Has an O(i * m) worst-case, with the expected time closer to O(n) for most inputs.
18-
internal sealed class SingleStringSearchValuesThreeChars<TCaseSensitivity> : StringSearchValuesBase
18+
internal sealed class SingleStringSearchValuesThreeChars<TValueLength, TCaseSensitivity> : StringSearchValuesBase
19+
where TValueLength : struct, IValueLength
1920
where TCaseSensitivity : struct, ICaseSensitivity
2021
{
2122
private const ushort CaseConversionMask = unchecked((ushort)~0x20);
@@ -34,6 +35,7 @@ public SingleStringSearchValuesThreeChars(HashSet<string>? uniqueValues, string
3435
{
3536
// We could have more than one entry in 'uniqueValues' if this value is an exact prefix of all the others.
3637
Debug.Assert(value.Length > 1);
38+
Debug.Assert((value.Length >= 8) == TValueLength.AtLeast8CharsOrUnknown);
3739

3840
CharacterFrequencyHelper.GetSingleStringMultiCharacterOffsets(value, IgnoreCase, out int ch2Offset, out int ch3Offset);
3941

@@ -228,7 +230,7 @@ private int IndexOf(ref char searchSpace, int searchSpaceLength)
228230

229231
// CaseInsensitiveUnicode doesn't support single-character transformations, so we skip checking the first character first.
230232
if ((typeof(TCaseSensitivity) == typeof(CaseInsensitiveUnicode) || TCaseSensitivity.TransformInput(cur) == valueHead) &&
231-
TCaseSensitivity.Equals(ref cur, value))
233+
TCaseSensitivity.Equals<TValueLength>(ref cur, value))
232234
{
233235
return (int)i;
234236
}
@@ -325,7 +327,11 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char
325327

326328
ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length);
327329

328-
if (TCaseSensitivity.Equals(ref matchRef, _value))
330+
// If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value.
331+
// With case-sensitive comparisons, we've therefore already confirmed the match, so we can skip doing so here.
332+
// With case-insensitive comparisons, we applied a mask to the input, so while the anchors likely matched, we can't be sure.
333+
if ((typeof(TCaseSensitivity) == typeof(CaseSensitive) && !TValueLength.AtLeast4Chars) ||
334+
TCaseSensitivity.Equals<TValueLength>(ref matchRef, _value))
329335
{
330336
offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2);
331337
return true;
@@ -353,7 +359,11 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char
353359

354360
ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length);
355361

356-
if (TCaseSensitivity.Equals(ref matchRef, _value))
362+
// If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value.
363+
// With case-sensitive comparisons, we've therefore already confirmed the match, so we can skip doing so here.
364+
// With case-insensitive comparisons, we applied a mask to the input, so while the anchors likely matched, we can't be sure.
365+
if ((typeof(TCaseSensitivity) == typeof(CaseSensitive) && !TValueLength.AtLeast4Chars) ||
366+
TCaseSensitivity.Equals<TValueLength>(ref matchRef, _value))
357367
{
358368
offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2);
359369
return true;

src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/StringSearchValues.cs

+40-16
Original file line numberDiff line numberDiff line change
@@ -309,25 +309,16 @@ private static SearchValues<string> CreateForSingleValue(
309309

310310
if (Vector128.IsHardwareAccelerated && value.Length > 1 && value.Length <= maxLength)
311311
{
312-
if (!ignoreCase)
312+
SearchValues<string>? searchValues = value.Length switch
313313
{
314-
return new SingleStringSearchValuesThreeChars<CaseSensitive>(uniqueValues, value);
315-
}
316-
317-
if (asciiLettersOnly)
318-
{
319-
return new SingleStringSearchValuesThreeChars<CaseInsensitiveAsciiLetters>(uniqueValues, value);
320-
}
314+
< 4 => TryCreateSingleValuesThreeChars<ValueLengthLessThan4>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
315+
< 8 => TryCreateSingleValuesThreeChars<ValueLength4To7>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
316+
_ => TryCreateSingleValuesThreeChars<ValueLength8OrLongerOrUnknown>(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly),
317+
};
321318

322-
if (allAscii)
319+
if (searchValues is not null)
323320
{
324-
return new SingleStringSearchValuesThreeChars<CaseInsensitiveAscii>(uniqueValues, value);
325-
}
326-
327-
// When ignoring casing, all anchor chars we search for must be ASCII.
328-
if (char.IsAscii(value[0]) && value.AsSpan().LastIndexOfAnyInRange((char)0, (char)127) > 0)
329-
{
330-
return new SingleStringSearchValuesThreeChars<CaseInsensitiveUnicode>(uniqueValues, value);
321+
return searchValues;
331322
}
332323
}
333324

@@ -338,6 +329,39 @@ private static SearchValues<string> CreateForSingleValue(
338329
: new SingleStringSearchValuesFallback<SearchValues.FalseConst>(value, uniqueValues);
339330
}
340331

332+
private static SearchValues<string>? TryCreateSingleValuesThreeChars<TValueLength>(
333+
string value,
334+
HashSet<string>? uniqueValues,
335+
bool ignoreCase,
336+
bool allAscii,
337+
bool asciiLettersOnly)
338+
where TValueLength : struct, IValueLength
339+
{
340+
if (!ignoreCase)
341+
{
342+
return new SingleStringSearchValuesThreeChars<TValueLength, CaseSensitive>(uniqueValues, value);
343+
}
344+
345+
if (asciiLettersOnly)
346+
{
347+
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveAsciiLetters>(uniqueValues, value);
348+
}
349+
350+
if (allAscii)
351+
{
352+
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveAscii>(uniqueValues, value);
353+
}
354+
355+
// SingleStringSearchValuesThreeChars doesn't have logic to handle non-ASCII case conversion, so we require that anchor characters are ASCII.
356+
// Right now we're always selecting the first character as one of the anchors, and we need at least two.
357+
if (char.IsAscii(value[0]) && value.AsSpan(1).ContainsAnyInRange((char)0, (char)127))
358+
{
359+
return new SingleStringSearchValuesThreeChars<TValueLength, CaseInsensitiveUnicode>(uniqueValues, value);
360+
}
361+
362+
return null;
363+
}
364+
341365
private static void AnalyzeValues(
342366
ReadOnlySpan<string> values,
343367
ref bool ignoreCase,

src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/StringSearchValuesBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ internal abstract class StringSearchValuesBase : SearchValues<string>
1818
private readonly HashSet<string>? _uniqueValues;
1919

2020
/// <summary>
21-
/// This exists to allow <see cref="SingleStringSearchValuesThreeChars{TCaseSensitivity}"/> to avoid the HashSet allocation.
21+
/// This exists to allow <see cref="SingleStringSearchValuesThreeChars{TValueLength, TCaseSensitivity}"/> to avoid the HashSet allocation.
2222
/// </summary>
2323
protected bool HasUniqueValues => _uniqueValues is not null;
2424

src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs

+3
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ public static bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<char>
189189
=> left.Length == right.Length
190190
&& EqualsIgnoreCase<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(left)), ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(right)), (uint)right.Length);
191191

192+
internal static bool EqualsIgnoreCase(ref char left, ref char right, nuint length) =>
193+
EqualsIgnoreCase<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref left), ref Unsafe.As<char, ushort>(ref right), length);
194+
192195
private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref TRight right, nuint length)
193196
where TLeft : unmanaged, INumberBase<TLeft>
194197
where TRight : unmanaged, INumberBase<TRight>

0 commit comments

Comments
 (0)