From e0c3aca1a2d096ec001af024ef5027fa44723d98 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 12 Feb 2024 15:59:27 -0500 Subject: [PATCH] Vectorize TensorPrimitives.PopCount --- .../netcore/TensorPrimitives.netcore.cs | 267 +++++++++++++----- 1 file changed, 194 insertions(+), 73 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs index dec6446cd7653b..c5027ab02c80a1 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs @@ -12959,14 +12959,10 @@ private static T HorizontalAggregate(Vector128 x) where TAggre x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt32(), Vector128.Create(2, 3, 0, 1)).As()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt32(), Vector128.Create(1, 0, 3, 2)).As()); } - else if (Unsafe.SizeOf() == 8) - { - x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt64(), Vector128.Create(1, 0)).As()); - } else { - Debug.Fail("Should not be reachable"); - throw new NotSupportedException(); + Debug.Assert(Unsafe.SizeOf() == 8); + x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt64(), Vector128.Create(1, 0)).As()); } return x.ToScalar(); @@ -13068,15 +13064,12 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17) (uint)(count * 16)); } - if (Unsafe.SizeOf() == 8) + Debug.Assert(Unsafe.SizeOf() == 8); { return Vector128.LoadUnsafe( ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt64Mask_8x9)), (uint)(count * 8)); } - - Debug.Fail("Shouldn't get here"); - throw new NotSupportedException(); } /// @@ -13107,15 +13100,12 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17) (uint)(count * 16)); } - if (Unsafe.SizeOf() == 8) + Debug.Assert(Unsafe.SizeOf() == 8); { return Vector256.LoadUnsafe( ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt64Mask_8x9)), (uint)(count * 8)); } - - Debug.Fail("Shouldn't get here"); - throw new NotSupportedException(); } /// @@ -13146,15 +13136,12 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x17) (uint)(count * 16)); } - if (Unsafe.SizeOf() == 8) + Debug.Assert(Unsafe.SizeOf() == 8); { return Vector512.LoadUnsafe( ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt64Mask_8x9)), (uint)(count * 8)); } - - Debug.Fail("Shouldn't get here - CreateAlignmentMaskVector512"); - throw new NotSupportedException(); } /// @@ -13185,15 +13172,12 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17) (uint)(count * 16) + 12); // last 4 ints in the row } - if (Unsafe.SizeOf() == 8) + Debug.Assert(Unsafe.SizeOf() == 8); { return Vector128.LoadUnsafe( ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)), (uint)(count * 8) + 6); // last 2 longs in the row } - - Debug.Fail("Shouldn't get here - CreateRemainderMaskVector128"); - throw new NotSupportedException(); } /// @@ -13224,15 +13208,12 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17) (uint)(count * 16) + 8); // last 8 ints in the row } - if (Unsafe.SizeOf() == 8) + Debug.Assert(Unsafe.SizeOf() == 8); { return Vector256.LoadUnsafe( ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)), (uint)(count * 8) + 4); // last 4 longs in the row } - - Debug.Fail("Shouldn't get here - CreateRemainderMaskVector256"); - throw new NotSupportedException(); } /// @@ -13263,15 +13244,12 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x17) (uint)(count * 16)); } - if (Unsafe.SizeOf() == 8) + Debug.Assert(Unsafe.SizeOf() == 8); { return Vector512.LoadUnsafe( ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)), (uint)(count * 8)); } - - Debug.Fail("Shouldn't get here - CreateRemainderMaskVector512"); - throw new NotSupportedException(); } // TODO: The uses of these ApplyScalar methods are all as part of operators when handling edge cases (NaN, Infinity, really large inputs, etc.) @@ -13750,7 +13728,7 @@ private static int IndexOfFinalAggregate(Vector128 resul return resultIndex.As().ToScalar(); } - if (sizeof(T) == 1) + Debug.Assert(sizeof(T) == 1); { // Compare 0,1,2,3,4,5,6,7 with 8,9,10,11,12,13,14,15 tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As(); @@ -13775,8 +13753,6 @@ private static int IndexOfFinalAggregate(Vector128 resul // Return 0 return resultIndex.As().ToScalar(); } - - throw new NotSupportedException(); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -18500,13 +18476,11 @@ public static Vector128 Invoke(Vector128 x) { return Vector128.Ceiling(x.AsSingle()).As(); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); return Vector128.Ceiling(x.AsDouble()).As(); } - - throw new NotSupportedException(); } public static Vector256 Invoke(Vector256 x) @@ -18515,13 +18489,11 @@ public static Vector256 Invoke(Vector256 x) { return Vector256.Ceiling(x.AsSingle()).As(); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); return Vector256.Ceiling(x.AsDouble()).As(); } - - throw new NotSupportedException(); } public static Vector512 Invoke(Vector512 x) @@ -18530,13 +18502,11 @@ public static Vector512 Invoke(Vector512 x) { return Vector512.Ceiling(x.AsSingle()).As(); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); return Vector512.Ceiling(x.AsDouble()).As(); } - - throw new NotSupportedException(); } } @@ -18552,13 +18522,11 @@ public static Vector128 Invoke(Vector128 x) { return Vector128.Floor(x.AsSingle()).As(); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); return Vector128.Floor(x.AsDouble()).As(); } - - throw new NotSupportedException(); } public static Vector256 Invoke(Vector256 x) @@ -18567,13 +18535,11 @@ public static Vector256 Invoke(Vector256 x) { return Vector256.Floor(x.AsSingle()).As(); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); return Vector256.Floor(x.AsDouble()).As(); } - - throw new NotSupportedException(); } public static Vector512 Invoke(Vector512 x) @@ -18582,13 +18548,11 @@ public static Vector512 Invoke(Vector512 x) { return Vector512.Floor(x.AsSingle()).As(); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); return Vector512.Floor(x.AsDouble()).As(); } - - throw new NotSupportedException(); } } @@ -18609,9 +18573,10 @@ public static Vector128 Invoke(Vector128 x) Vector128.Floor(x.AsSingle()).As(), Vector128.Ceiling(x.AsSingle()).As()); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); + if (Sse41.IsSupported) return Sse41.RoundToZero(x.AsDouble()).As(); if (AdvSimd.Arm64.IsSupported) return AdvSimd.Arm64.RoundToZero(x.AsDouble()).As(); @@ -18619,8 +18584,6 @@ public static Vector128 Invoke(Vector128 x) Vector128.Floor(x.AsDouble()).As(), Vector128.Ceiling(x.AsDouble()).As()); } - - throw new NotSupportedException(); } public static Vector256 Invoke(Vector256 x) @@ -18633,17 +18596,16 @@ public static Vector256 Invoke(Vector256 x) Vector256.Floor(x.AsSingle()).As(), Vector256.Ceiling(x.AsSingle()).As()); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); + if (Avx.IsSupported) return Avx.RoundToZero(x.AsDouble()).As(); return Vector256.ConditionalSelect(Vector256.GreaterThanOrEqual(x, Vector256.Zero), Vector256.Floor(x.AsDouble()).As(), Vector256.Ceiling(x.AsDouble()).As()); } - - throw new NotSupportedException(); } public static Vector512 Invoke(Vector512 x) @@ -18656,28 +18618,187 @@ public static Vector512 Invoke(Vector512 x) Vector512.Floor(x.AsSingle()).As(), Vector512.Ceiling(x.AsSingle()).As()); } - - if (typeof(T) == typeof(double)) + else { + Debug.Assert(typeof(T) == typeof(double)); + if (Avx512F.IsSupported) return Avx512F.RoundScale(x.AsDouble(), 0b11).As(); return Vector512.ConditionalSelect(Vector512.GreaterThanOrEqual(x, Vector512.Zero), Vector512.Floor(x.AsDouble()).As(), Vector512.Ceiling(x.AsDouble()).As()); } - - throw new NotSupportedException(); } } /// T.PopCount(x) internal readonly struct PopCountOperator : IUnaryOperator where T : IBinaryInteger { - public static bool Vectorizable => false; // TODO: Vectorize + // TODO https://github.com/dotnet/runtime/issues/96162: Use AVX512 popcount operations when available + + public static bool Vectorizable => + // The fallback approach used for sizeof(T) <= 4 requires 64-bit shifts for + // sizeof(T) == 8, and such shifts aren't accelerated on today's hardware. Alternative + // approaches, such as doing two 32-bit operations and combining them were observed + // to not provide any meaningfuls speedup over scalar. So for now, we don't vectorize + // when sizeof(T) == 8. + sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4; + public static T Invoke(T x) => T.PopCount(x); - public static Vector128 Invoke(Vector128 x) => throw new NotSupportedException(); - public static Vector256 Invoke(Vector256 x) => throw new NotSupportedException(); - public static Vector512 Invoke(Vector512 x) => throw new NotSupportedException(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x) + { + if (sizeof(T) == 1) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.PopCount(x.AsByte()).As(); + } + + if (PackedSimd.IsSupported) + { + return PackedSimd.PopCount(x.AsByte()).As(); + } + + Vector128 c1 = Vector128.Create((byte)0x55); + Vector128 c2 = Vector128.Create((byte)0x33); + Vector128 c3 = Vector128.Create((byte)0x0F); + + // We don't have a per element shuffle for byte on some platforms. + // However, we do currently always have a 16-bit shift available and + // due to how the algorithm works, we don't need to worry about + // any bits that shift into the lower 8-bits from the upper 8-bits. + Vector128 tmp = x.AsByte(); + tmp -= (x.AsUInt16() >> 1).AsByte() & c1; + tmp = (tmp & c2) + ((tmp.AsUInt16() >> 2).AsByte() & c2); + return ((tmp + (tmp.AsUInt16() >> 4).AsByte()) & c3).As(); + } + + if (sizeof(T) == 2) + { + Vector128 c1 = Vector128.Create((ushort)0x5555); + Vector128 c2 = Vector128.Create((ushort)0x3333); + Vector128 c3 = Vector128.Create((ushort)0x0F0F); + Vector128 c4 = Vector128.Create((ushort)0x0101); + + Vector128 tmp = x.AsUInt16(); + tmp -= (tmp >> 1) & c1; + tmp = (tmp & c2) + ((tmp >> 2) & c2); + tmp = (((tmp + (tmp >> 4)) & c3) * c4) >> 8; + return tmp.As(); + } + + Debug.Assert(sizeof(T) == 4); + { + Vector128 c1 = Vector128.Create(0x55555555u); + Vector128 c2 = Vector128.Create(0x33333333u); + Vector128 c3 = Vector128.Create(0x0F0F0F0Fu); + Vector128 c4 = Vector128.Create(0x01010101u); + + Vector128 tmp = x.AsUInt32(); + tmp -= (tmp >> 1) & c1; + tmp = (tmp & c2) + ((tmp >> 2) & c2); + tmp = (((tmp + (tmp >> 4)) & c3) * c4) >> 24; + return tmp.As(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x) + { + if (sizeof(T) == 1) + { + Vector256 c1 = Vector256.Create((byte)0x55); + Vector256 c2 = Vector256.Create((byte)0x33); + Vector256 c3 = Vector256.Create((byte)0x0F); + + // We don't have a per element shuffle for byte on some platforms. + // However, we do currently always have a 16-bit shift available and + // due to how the algorithm works, we don't need to worry about + // any bits that shift into the lower 8-bits from the upper 8-bits. + Vector256 tmp = x.AsByte(); + tmp -= (x.AsUInt16() >> 1).AsByte() & c1; + tmp = (tmp & c2) + ((tmp.AsUInt16() >> 2).AsByte() & c2); + return ((tmp + (tmp.AsUInt16() >> 4).AsByte()) & c3).As(); + } + + if (sizeof(T) == 2) + { + Vector256 c1 = Vector256.Create((ushort)0x5555); + Vector256 c2 = Vector256.Create((ushort)0x3333); + Vector256 c3 = Vector256.Create((ushort)0x0F0F); + Vector256 c4 = Vector256.Create((ushort)0x0101); + + Vector256 tmp = x.AsUInt16(); + tmp -= (tmp >> 1) & c1; + tmp = (tmp & c2) + ((tmp >> 2) & c2); + tmp = (((tmp + (tmp >> 4)) & c3) * c4) >> 8; + return tmp.As(); + } + + Debug.Assert(sizeof(T) == 4); + { + Vector256 c1 = Vector256.Create(0x55555555u); + Vector256 c2 = Vector256.Create(0x33333333u); + Vector256 c3 = Vector256.Create(0x0F0F0F0Fu); + Vector256 c4 = Vector256.Create(0x01010101u); + + Vector256 tmp = x.AsUInt32(); + tmp -= (tmp >> 1) & c1; + tmp = (tmp & c2) + ((tmp >> 2) & c2); + tmp = (((tmp + (tmp >> 4)) & c3) * c4) >> 24; + return tmp.As(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x) + { + if (sizeof(T) == 1) + { + Vector512 c1 = Vector512.Create((byte)0x55); + Vector512 c2 = Vector512.Create((byte)0x33); + Vector512 c3 = Vector512.Create((byte)0x0F); + + // We don't have a per element shuffle for byte on some platforms. + // However, we do currently always have a 16-bit shift available and + // due to how the algorithm works, we don't need to worry about + // any bits that shift into the lower 8-bits from the upper 8-bits. + Vector512 tmp = x.AsByte(); + tmp -= (x.AsUInt16() >> 1).AsByte() & c1; + tmp = (tmp & c2) + ((tmp.AsUInt16() >> 2).AsByte() & c2); + return ((tmp + (tmp.AsUInt16() >> 4).AsByte()) & c3).As(); + } + + if (sizeof(T) == 2) + { + Vector512 c1 = Vector512.Create((ushort)0x5555); + Vector512 c2 = Vector512.Create((ushort)0x3333); + Vector512 c3 = Vector512.Create((ushort)0x0F0F); + Vector512 c4 = Vector512.Create((ushort)0x0101); + + Vector512 tmp = x.AsUInt16(); + tmp -= (tmp >> 1) & c1; + tmp = (tmp & c2) + ((tmp >> 2) & c2); + tmp = (((tmp + (tmp >> 4)) & c3) * c4) >> 8; + return tmp.As(); + } + + Debug.Assert(sizeof(T) == 4); + { + Vector512 c1 = Vector512.Create(0x55555555u); + Vector512 c2 = Vector512.Create(0x33333333u); + Vector512 c3 = Vector512.Create(0x0F0F0F0Fu); + Vector512 c4 = Vector512.Create(0x01010101u); + + Vector512 tmp = x.AsUInt32(); + tmp -= (tmp >> 1) & c1; + tmp = (tmp & c2) + ((tmp >> 2) & c2); + tmp = (((tmp + (tmp >> 4)) & c3) * c4) >> 24; + return tmp.As(); + } + } } /// T.LeadingZeroCount(x)