-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TensorPrimitives.HammingDistance and friends (#103305)
* Add TensorPrimitives.HammingDistance and friends * Address PR feedback
- Loading branch information
1 parent
1ab6888
commit fa987b6
Showing
8 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
202 changes: 202 additions & 0 deletions
202
....Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using System.Collections.Generic; | ||
using System.Diagnostics; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.InteropServices; | ||
using System.Runtime.Intrinsics; | ||
|
||
namespace System.Numerics.Tensors | ||
{ | ||
public static partial class TensorPrimitives | ||
{ | ||
/// <summary>Computes the bitwise Hamming distance between two equal-length tensors of values.</summary> | ||
/// <param name="x">The first tensor, represented as a span.</param> | ||
/// <param name="y">The second tensor, represented as a span.</param> | ||
/// <returns>The number of bits that differ between the two spans.</returns> | ||
/// <exception cref="ArgumentException">Length of <paramref name="x" /> must be same as length of <paramref name="y" />.</exception> | ||
/// <exception cref="ArgumentException"><paramref name="x" /> and <paramref name="y" /> must not be empty.</exception> | ||
public static long HammingBitDistance<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IBinaryInteger<T> | ||
{ | ||
if (x.Length != y.Length) | ||
{ | ||
ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); | ||
} | ||
|
||
long count = 0; | ||
for (int i = 0; i < x.Length; i++) | ||
{ | ||
count += long.CreateTruncating(T.PopCount(x[i] ^ y[i])); | ||
} | ||
|
||
return count; | ||
} | ||
|
||
/// <summary>Computes the Hamming distance between two equal-length tensors of values.</summary> | ||
/// <param name="x">The first tensor, represented as a span.</param> | ||
/// <param name="y">The second tensor, represented as a span.</param> | ||
/// <returns>The number of elements that differ between the two spans.</returns> | ||
/// <exception cref="ArgumentException">Length of <paramref name="x" /> must be same as length of <paramref name="y" />.</exception> | ||
/// <exception cref="ArgumentException"><paramref name="x" /> and <paramref name="y" /> must not be empty.</exception> | ||
/// <remarks> | ||
/// <para> | ||
/// This method computes the number of locations <c>i</c> where <c>!EqualityComparer>T<.Default.Equal(x[i], y[i])</c>. | ||
/// </para> | ||
/// </remarks> | ||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
public static int HammingDistance<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) | ||
{ | ||
if (typeof(T) == typeof(char)) | ||
{ | ||
// Special-case char, as it's reasonable for someone to want to use HammingDistance on strings, | ||
// and we want that accelerated. This can be removed if/when VectorXx<T> supports char. | ||
return CountUnequalElements<ushort>( | ||
MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<T, ushort>(ref MemoryMarshal.GetReference(x)), x.Length), | ||
MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<T, ushort>(ref MemoryMarshal.GetReference(y)), y.Length)); | ||
} | ||
|
||
return CountUnequalElements(x, y); | ||
} | ||
|
||
/// <summary>Counts the number of elements that are pair-wise different between the two spans.</summary> | ||
private static int CountUnequalElements<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) | ||
{ | ||
if (x.Length != y.Length) | ||
{ | ||
ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); | ||
} | ||
|
||
// TODO: This has a very similar structure to CosineSimilarity, which is also open-coded rather than | ||
// using a shared routine plus operator, as we don't have one implemented that exactly fits. We should | ||
// look at refactoring these to share the core logic. | ||
|
||
int count = 0; | ||
if (Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported && x.Length >= Vector128<T>.Count) | ||
{ | ||
if (Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported && x.Length >= Vector256<T>.Count) | ||
{ | ||
if (Vector512.IsHardwareAccelerated && Vector512<T>.IsSupported && x.Length >= Vector512<T>.Count) | ||
{ | ||
ref T xRef = ref MemoryMarshal.GetReference(x); | ||
ref T yRef = ref MemoryMarshal.GetReference(y); | ||
|
||
int oneVectorFromEnd = x.Length - Vector512<T>.Count; | ||
int i = 0; | ||
do | ||
{ | ||
Vector512<T> xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); | ||
Vector512<T> yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); | ||
|
||
count += BitOperations.PopCount((~Vector512.Equals(xVec, yVec)).ExtractMostSignificantBits()); | ||
|
||
i += Vector512<T>.Count; | ||
} | ||
while (i <= oneVectorFromEnd); | ||
|
||
// Process the last vector in the span, masking off elements already processed. | ||
if (i != x.Length) | ||
{ | ||
Vector512<T> xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512<T>.Count)); | ||
Vector512<T> yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512<T>.Count)); | ||
|
||
Vector512<T> remainderMask = CreateRemainderMaskVector512<T>(x.Length - i); | ||
xVec &= remainderMask; | ||
yVec &= remainderMask; | ||
|
||
count += BitOperations.PopCount((~Vector512.Equals(xVec, yVec)).ExtractMostSignificantBits()); | ||
} | ||
} | ||
else | ||
{ | ||
ref T xRef = ref MemoryMarshal.GetReference(x); | ||
ref T yRef = ref MemoryMarshal.GetReference(y); | ||
|
||
// Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. | ||
int oneVectorFromEnd = x.Length - Vector256<T>.Count; | ||
int i = 0; | ||
do | ||
{ | ||
Vector256<T> xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); | ||
Vector256<T> yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); | ||
|
||
count += BitOperations.PopCount((~Vector256.Equals(xVec, yVec)).ExtractMostSignificantBits()); | ||
|
||
i += Vector256<T>.Count; | ||
} | ||
while (i <= oneVectorFromEnd); | ||
|
||
// Process the last vector in the span, masking off elements already processed. | ||
if (i != x.Length) | ||
{ | ||
Vector256<T> xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256<T>.Count)); | ||
Vector256<T> yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256<T>.Count)); | ||
|
||
Vector256<T> remainderMask = CreateRemainderMaskVector256<T>(x.Length - i); | ||
xVec &= remainderMask; | ||
yVec &= remainderMask; | ||
|
||
count += BitOperations.PopCount((~Vector256.Equals(xVec, yVec)).ExtractMostSignificantBits()); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
ref T xRef = ref MemoryMarshal.GetReference(x); | ||
ref T yRef = ref MemoryMarshal.GetReference(y); | ||
|
||
// Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. | ||
int oneVectorFromEnd = x.Length - Vector128<T>.Count; | ||
int i = 0; | ||
do | ||
{ | ||
Vector128<T> xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); | ||
Vector128<T> yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); | ||
|
||
count += BitOperations.PopCount((~Vector128.Equals(xVec, yVec)).ExtractMostSignificantBits()); | ||
|
||
i += Vector128<T>.Count; | ||
} | ||
while (i <= oneVectorFromEnd); | ||
|
||
// Process the last vector in the span, masking off elements already processed. | ||
if (i != x.Length) | ||
{ | ||
Vector128<T> xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128<T>.Count)); | ||
Vector128<T> yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128<T>.Count)); | ||
|
||
Vector128<T> remainderMask = CreateRemainderMaskVector128<T>(x.Length - i); | ||
xVec &= remainderMask; | ||
yVec &= remainderMask; | ||
|
||
count += BitOperations.PopCount((~Vector128.Equals(xVec, yVec)).ExtractMostSignificantBits()); | ||
} | ||
} | ||
} | ||
else if (typeof(T).IsValueType) | ||
{ | ||
for (int i = 0; i < x.Length; i++) | ||
{ | ||
if (!EqualityComparer<T>.Default.Equals(x[i], y[i])) | ||
{ | ||
count++; | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
EqualityComparer<T> comparer = EqualityComparer<T>.Default; | ||
for (int i = 0; i < x.Length; i++) | ||
{ | ||
if (!comparer.Equals(x[i], y[i])) | ||
{ | ||
count++; | ||
} | ||
} | ||
} | ||
|
||
Debug.Assert(count >= 0 && count <= x.Length, $"Expected count to be in the range [0, {x.Length}], got {count}."); | ||
return count; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.