Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized GetNonRandomizedHashCode #98838

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,30 +104,34 @@ static T AssertNotNull<T>(T value, [CallerArgumentExpression(nameof(value))] str

// Generates a possible string with a well-known non-randomized hash code:
// - string.GetNonRandomizedHashCode returns 0.
// - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0.
// - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0xe05180a0.
// Provide a different seed to produce a different string.
// Must check OrdinalIgnoreCase hash code to ensure correctness.
string candidate = string.Create(8, currentSeed, static (span, seed) =>
string candidate = string.Create(16, currentSeed, static (span, seed) =>
{
Span<byte> asBytes = MemoryMarshal.AsBytes(span);

uint hash1 = (5381 << 16) + 5381;
uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1;
uint hash3 = BitOperations.RotateLeft(hash1, 5) + hash1;
uint hash4 = BitOperations.RotateLeft(hash1, 5) + hash1;

MemoryMarshal.Write(asBytes, in seed);
MemoryMarshal.Write(asBytes.Slice(4), in hash2); // set hash2 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(12), in hash2); // set hash2 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(8), in hash3); // set hash3 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(4), in hash4); // set hash4 := 0 (for Ordinal)

hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed;
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1);

MemoryMarshal.Write(asBytes.Slice(8), in hash1); // set hash1 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(16), in hash1); // set hash1 := 0 (for Ordinal)
});

int ordinalHashCode = nonRandomizedOrdinal(candidate);
Assert.Equal(0, ordinalHashCode); // ensure has a zero hash code Ordinal

int ordinalIgnoreCaseHashCode = nonRandomizedOrdinalIgnoreCase(candidate);
if (ordinalIgnoreCaseHashCode == 0x24716ca0) // ensure has a zero hash code OrdinalIgnoreCase (might not have one)
if (ordinalIgnoreCaseHashCode == unchecked((int)0xe05180a0)) // ensure has a zero hash code OrdinalIgnoreCase (might not have one)
{
collidingStrings.Add(candidate); // success!
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ private static List<string> GenerateCollidingStrings(int count)
Assert.Equal(0, ordinalHashCode); // ensure has a zero hash code Ordinal

int ordinalIgnoreCaseHashCode = _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(candidate);
if (ordinalIgnoreCaseHashCode == 0x24716ca0) // ensure has a zero hash code OrdinalIgnoreCase (might not have one)
if (ordinalIgnoreCaseHashCode == unchecked((int)0xe05180a0)) // ensure has a zero hash code OrdinalIgnoreCase (might not have one)
{
collidingStrings.Add(candidate); // success!
}
Expand All @@ -291,25 +291,29 @@ private static List<string> GenerateCollidingStrings(int count)

// Generates a possible string with a well-known non-randomized hash code:
// - string.GetNonRandomizedHashCode returns 0.
// - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0.
// - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0xe05180a0.
// Provide a different seed to produce a different string.
// Caller must check OrdinalIgnoreCase hash code to ensure correctness.
static string GenerateCollidingStringCandidate(int seed)
{
return string.Create(8, seed, (span, seed) =>
return string.Create(16, seed, (span, seed) =>
{
Span<byte> asBytes = MemoryMarshal.AsBytes(span);

uint hash1 = (5381 << 16) + 5381;
uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1;
uint hash3 = BitOperations.RotateLeft(hash1, 5) + hash1;
uint hash4 = BitOperations.RotateLeft(hash1, 5) + hash1;

MemoryMarshal.Write(asBytes, in seed);
MemoryMarshal.Write(asBytes.Slice(4), in hash2); // set hash2 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(12), in hash2); // set hash2 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(8), in hash3); // set hash3 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(4), in hash4); // set hash4 := 0 (for Ordinal)

hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed;
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1);

MemoryMarshal.Write(asBytes.Slice(8), in hash1); // set hash1 := 0 (for Ordinal)
MemoryMarshal.Write(asBytes.Slice(16), in hash1); // set hash1 := 0 (for Ordinal)
});
}
}
Expand Down
212 changes: 188 additions & 24 deletions src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Unicode;
using System.Text;
using System.Runtime.Intrinsics;

namespace System
{
Expand Down Expand Up @@ -815,40 +817,107 @@ internal static int GetHashCodeOrdinalIgnoreCase(ReadOnlySpan<char> value)
// or are otherwise mitigated
internal unsafe int GetNonRandomizedHashCode()
{
uint hash1 = (5381 << 16) + 5381;
uint hash2 = hash1;
uint hash3 = hash1;
uint hash4 = hash1;

fixed (char* src = &_firstChar)
{
Debug.Assert(src[this.Length] == '\0', "src[this.Length] == '\\0'");
Debug.Assert(((int)src) % 4 == 0, "Managed string should start at 4 bytes boundary");

uint hash1 = (5381 << 16) + 5381;
uint hash2 = hash1;
Debug.Assert(((int) src) % 4 == 0, "Managed string should start at 4 bytes boundary");

uint* ptr = (uint*)src;
uint* ptr = (uint*) src;
int length = this.Length;

while (length > 2)
if (Vector128.IsHardwareAccelerated && length >= 2 * Vector128<ushort>.Count)
{
Vector128<uint> hashVector = Vector128.Create(hash1);

while (length > 8)
{
Vector128<uint> srcVec = Vector128.Load(ptr);
length -= 8;
hashVector = (hashVector + RotateLeft(hashVector, 5)) ^ srcVec;
ptr += 4;
}

uint hashed1 = hashVector.GetElement(0);
uint hashed2 = hashVector.GetElement(1);
uint hashed3 = hashVector.GetElement(2);
uint hashed4 = hashVector.GetElement(3);

while (length > 4)
Copy link
Member

@jkotas jkotas Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why you have changed this to while loop. I think if was perfectly fine here and more efficient too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original main loop processes the data stream in the way:

while(length > 2)
{
   length -= 4
}

There will be an extra iteration when length is 4n-3, e.g. 7. consuming an extra null terminator, and it also leaves the fact that the trailing string length can only 0/1/2, so it can be processed by the following if statement.

While in the updated case, since the main loop is operating on a larger granularity, the same trick might violate the memory beyond the null terminator. So I used while here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the condition in the main while loop be while (length >= 8) and the check if (length >= 4)?

I understand that the existing code does tricks with the null terminator to save a few instructions. You do not have to match those tricks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for the suggestion, I can try that.

{
uint p0 = ptr[0];
uint p1 = ptr[1];

length -= 4;
hashed3 = (BitOperations.RotateLeft(hashed3, 5) + hashed3) ^ (p0);
hashed4 = (BitOperations.RotateLeft(hashed4, 5) + hashed4) ^ (p1);
ptr += 2;
}

while (length > 0)
{
uint p0 = ptr[0];

length -= 2;
hashed4 = (BitOperations.RotateLeft(hashed4, 5) + hashed4) ^ (p0);
ptr += 1;
}

uint res = (((BitOperations.RotateLeft(hashed1, 5) + hashed1)) ^ hashed3) + 1566083941 * (((BitOperations.RotateLeft(hashed2, 5) + hashed2)) ^ hashed4);
return (int)res;
}


while (length > 8)
{
uint p0 = ptr[0];
uint p1 = ptr[1];
uint p2 = ptr[2];
uint p3 = ptr[3];
length -= 8;
// hashVector = (hashVector + RotateLeft(hashVector, 5)) ^ srcVec;
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (p0);
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (p1);
hash3 = (BitOperations.RotateLeft(hash3, 5) + hash3) ^ (p2);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (p3);
ptr += 4;
}

while (length > 4)
{
uint p0 = ptr[0];
uint p1 = ptr[1];
length -= 4;
// Where length is 4n-1 (e.g. 3,7,11,15,19) this additionally consumes the null terminator
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ ptr[0];
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ ptr[1];
hash3 = (BitOperations.RotateLeft(hash3, 5) + hash3) ^ (p0);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (p1);
ptr += 2;
}

if (length > 0)
while (length > 0)
{
// Where length is 4n-3 (e.g. 1,5,9,13,17) this additionally consumes the null terminator
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ ptr[0];
}
uint p0 = ptr[0];

return (int)(hash1 + (hash2 * 1566083941));
length -= 2;
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (p0);
ptr += 1;
}
}

uint resOnScalarPath = (((BitOperations.RotateLeft(hash1, 5) + hash1)) ^ hash3) + 1566083941 * (((BitOperations.RotateLeft(hash2, 5) + hash2)) ^ hash4);
return (int)resOnScalarPath;
}

internal unsafe int GetNonRandomizedHashCodeOrdinalIgnoreCase()
{
uint hash1 = (5381 << 16) + 5381;
uint hash2 = hash1;
uint hash3 = hash1;
uint hash4 = hash1;

fixed (char* src = &_firstChar)
{
Expand All @@ -863,7 +932,77 @@ internal unsafe int GetNonRandomizedHashCodeOrdinalIgnoreCase()
// be ok because we expect this to be very rare in practice.
const uint NormalizeToLowercase = 0x0020_0020u; // valid both for big-endian and for little-endian

while (length > 2)
if (Vector128.IsHardwareAccelerated && length >= 2 * Vector128<ushort>.Count)
{
Vector128<uint> hashVector = Vector128.Create(hash1);
Vector128<uint> NormalizeToLowercaseVec = Vector128.Create(NormalizeToLowercase);

while (length > 8)
{
Vector128<uint> srcVec = Vector128.Load(ptr);
if (Ascii.VectorContainsNonAsciiChar(srcVec.AsUInt16()))
{
goto NotAscii;
}
length -= 8;
hashVector = (hashVector + RotateLeft(hashVector, 5)) ^ (srcVec | NormalizeToLowercaseVec);
ptr += 4;
}

uint hashed1 = hashVector.GetElement(0);
uint hashed2 = hashVector.GetElement(1);
uint hashed3 = hashVector.GetElement(2);
uint hashed4 = hashVector.GetElement(3);

while (length > 4)
{
uint p0 = ptr[0];
uint p1 = ptr[1];
if (!Utf16Utility.AllCharsInUInt32AreAscii(p0 | p1))
{
goto NotAscii;
}

length -= 4;
hashed3 = (BitOperations.RotateLeft(hashed3, 5) + hashed3) ^ (p0 | NormalizeToLowercase);
hashed4 = (BitOperations.RotateLeft(hashed4, 5) + hashed4) ^ (p1 | NormalizeToLowercase);
ptr += 2;
}
while (length > 0)
{
uint p0 = ptr[0];
if (!Utf16Utility.AllCharsInUInt32AreAscii(p0))
{
goto NotAscii;
}
hashed4 = (BitOperations.RotateLeft(hashed4, 5) + hashed4) ^ (p0 | NormalizeToLowercase);
length -= 2;
ptr += 1;
}

uint res = (((BitOperations.RotateLeft(hashed1, 5) + hashed1)) ^ hashed3) + 1566083941 * (((BitOperations.RotateLeft(hashed2, 5) + hashed2)) ^ hashed4);
return (int)res;
}

while (length > 8)
{
uint p0 = ptr[0];
uint p1 = ptr[1];
uint p2 = ptr[2];
uint p3 = ptr[3];
if (!Utf16Utility.AllCharsInUInt32AreAscii(p0 | p1 | p2 | p3))
{
goto NotAscii;
}
length -= 8;
// hashVector = (hashVector + RotateLeft(hashVector, 5)) ^ srcVec;
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (p0 | NormalizeToLowercase);
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (p1 | NormalizeToLowercase);
hash3 = (BitOperations.RotateLeft(hash3, 5) + hash3) ^ (p2 | NormalizeToLowercase);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (p3 | NormalizeToLowercase);
ptr += 4;
}
while (length > 4)
{
uint p0 = ptr[0];
uint p1 = ptr[1];
Expand All @@ -874,25 +1013,28 @@ internal unsafe int GetNonRandomizedHashCodeOrdinalIgnoreCase()

length -= 4;
// Where length is 4n-1 (e.g. 3,7,11,15,19) this additionally consumes the null terminator
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (p0 | NormalizeToLowercase);
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (p1 | NormalizeToLowercase);
hash3 = (BitOperations.RotateLeft(hash3, 5) + hash3) ^ (p0 | NormalizeToLowercase);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (p1 | NormalizeToLowercase);
ptr += 2;
}

if (length > 0)
while (length > 0)
{
uint p0 = ptr[0];
if (!Utf16Utility.AllCharsInUInt32AreAscii(p0))
{
goto NotAscii;
}

length -= 2;
// Where length is 4n-3 (e.g. 1,5,9,13,17) this additionally consumes the null terminator
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (p0 | NormalizeToLowercase);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (p0 | NormalizeToLowercase);
ptr += 1;
}
}

return (int)(hash1 + (hash2 * 1566083941));
uint resOnScalarPath = (((BitOperations.RotateLeft(hash1, 5) + hash1)) ^ hash3) + 1566083941 * (((BitOperations.RotateLeft(hash2, 5) + hash2)) ^ hash4);
return (int)resOnScalarPath;

NotAscii:
return GetNonRandomizedHashCodeOrdinalIgnoreCaseSlow(this);
Expand All @@ -912,30 +1054,46 @@ static int GetNonRandomizedHashCodeOrdinalIgnoreCaseSlow(string str)
const uint NormalizeToLowercase = 0x0020_0020u;
uint hash1 = (5381 << 16) + 5381;
uint hash2 = hash1;
uint hash3 = hash1;
uint hash4 = hash1;

// Duplicate the main loop, can be removed once JIT gets "Loop Unswitching" optimization
fixed (char* src = scratch)
{
uint* ptr = (uint*)src;
while (length > 2)
while (length > 8)
{
length -= 4;
length -= 8;
// hashVector = (hashVector + RotateLeft(hashVector, 5)) ^ srcVec;
hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (ptr[0] | NormalizeToLowercase);
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[1] | NormalizeToLowercase);
hash3 = (BitOperations.RotateLeft(hash3, 5) + hash3) ^ (ptr[2] | NormalizeToLowercase);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (ptr[3] | NormalizeToLowercase);
ptr += 4;
}

while (length > 4)
{
length -= 4;
hash3 = (BitOperations.RotateLeft(hash3, 5) + hash3) ^ (ptr[0] | NormalizeToLowercase);
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (ptr[1] | NormalizeToLowercase);
ptr += 2;
}

if (length > 0)
while (length > 0)
{
hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[0] | NormalizeToLowercase);
length -= 2;
hash4 = (BitOperations.RotateLeft(hash4, 5) + hash4) ^ (ptr[0] | NormalizeToLowercase);
ptr += 1;
}
}

if (borrowedArr != null)
{
ArrayPool<char>.Shared.Return(borrowedArr);
}
return (int)(hash1 + (hash2 * 1566083941));
uint resOnSlowPath = (((BitOperations.RotateLeft(hash1, 5) + hash1)) ^ hash3) + 1566083941 * (((BitOperations.RotateLeft(hash2, 5) + hash2)) ^ hash4);
return (int)resOnSlowPath;
}
}

Expand Down Expand Up @@ -1066,5 +1224,11 @@ private static CompareOptions GetCompareOptionsFromOrdinalStringComparison(Strin
int ct = (int)comparisonType;
return (CompareOptions)((ct & -ct) << 28); // neg and shl
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static Vector128<uint> RotateLeft(Vector128<uint> src, int control)
{
return Vector128.BitwiseOr(Vector128.ShiftLeft(src, control), Vector128.ShiftRightLogical(src, 32 - control));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ private static bool VectorContainsNonAsciiChar(Vector128<byte> asciiVector)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar(Vector128<ushort> utf16Vector)
internal static bool VectorContainsNonAsciiChar(Vector128<ushort> utf16Vector)
{
// prefer architecture specific intrinsic as they offer better perf
if (Sse2.IsSupported)
Expand Down
Loading