Skip to content

Commit

Permalink
Implement the last of the approved cross platform hardware intrinsics…
Browse files Browse the repository at this point in the history
…, except shuffle (dotnet#63414)

* Exposing Sum<T> for Vector64/128/256<T>

* Adding support for ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical to Vector<T> and Vector64/128/256<T>

* Adding support for Load, LoadAligned, LoadAlignedNonTemporal, and LoadUnsafe to Vector64/128/256<T>

* Adding support for Store, StoreAligned, StoreAlignedNonTemporal, and StoreUnsafe to Vector64/128/256<T>

* Adding support for ExtractMostSignificantBits to Vector64/128/256<T>

* Adding tests covering Vector64/128/256<T>.Sum

* Adding tests covering Vector64/128/256<T>.ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical

* Moving System.Runtime.InteropServices.UnmanagedType down to System.Runtime so the `unmanaged` constraint can be used

* Adding tests covering Vector64/128/256<T>.Load, LoadAligned, LoadAlignedNonTemporal, and LoadUnsafe

* Fixing a few issues in the source and tests to ensure the right paths are being taken

* Adding tests covering Vector64/128/256<T>.Store, StoreAligned, StoreAlignedNonTemporal, and StoreUnsafe

* Adding tests covering Vector64/128/256<T>.ExtractMostSignificantBits

* Ensure AlignedAlloc is matched by AlignedFree

* Fixing a couple test issues and the handling of Scalar.ExtractMostSignificantBit for nint/nuint

* Applying formatting patch

* Ensure gtNewOperNode uses TYP_INT when dealing with the shiftCount

* Fixing a couple ARM64 node types

* Ensure the shift intrinsics use impPopStack().val on ARM64

* Responding to PR feedback
  • Loading branch information
tannergooding authored Jan 13, 2022
1 parent 8fef95b commit cfe5e98
Show file tree
Hide file tree
Showing 25 changed files with 16,231 additions and 163 deletions.
3 changes: 3 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,9 @@ class Compiler
GenTree* gtNewSimdSqrtNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize, bool isSimdAsHWIntrinsic);

GenTree* gtNewSimdSumNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize, bool isSimdAsHWIntrinsic);

GenTree* gtNewSimdUnOpNode(genTreeOps op,
var_types type,
GenTree* op1,
Expand Down
280 changes: 279 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18074,7 +18074,15 @@ GenTree* Compiler::gtNewSimdBinOpNode(genTreeOps op,
assert(op1->TypeIs(type, simdBaseType, genActualType(simdBaseType)));

assert(op2 != nullptr);
assert(op2->TypeIs(type, simdBaseType, genActualType(simdBaseType)));

if ((op == GT_LSH) || (op == GT_RSH) || (op == GT_RSZ))
{
assert(op2->TypeIs(TYP_INT));
}
else
{
assert(op2->TypeIs(type, simdBaseType, genActualType(simdBaseType)));
}

NamedIntrinsic intrinsic = NI_Illegal;
CORINFO_CLASS_HANDLE clsHnd = gtGetStructHandleForSIMD(type, simdBaseJitType);
Expand Down Expand Up @@ -18201,6 +18209,67 @@ GenTree* Compiler::gtNewSimdBinOpNode(genTreeOps op,
break;
}

case GT_LSH:
case GT_RSH:
case GT_RSZ:
{
assert(!varTypeIsByte(simdBaseType));
assert(!varTypeIsFloating(simdBaseType));
assert((op != GT_RSH) || !varTypeIsUnsigned(simdBaseType));

// "over shifting" is platform specific behavior. We will match the C# behavior
// this requires we mask with (sizeof(T) * 8) - 1 which ensures the shift cannot
// exceed the number of bits available in `T`. This is roughly equivalent to
// x % (sizeof(T) * 8), but that is "more expensive" and only the same for unsigned
// inputs, where-as we have a signed-input and so negative values would differ.

unsigned shiftCountMask = (genTypeSize(simdBaseType) * 8) - 1;

if (op2->IsCnsIntOrI())
{
op2->AsIntCon()->gtIconVal &= shiftCountMask;
}
else
{
op2 = gtNewOperNode(GT_AND, TYP_INT, op2, gtNewIconNode(shiftCountMask));
op2 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op2, NI_SSE2_ConvertScalarToVector128Int32, CORINFO_TYPE_INT,
16, isSimdAsHWIntrinsic);
}

if (simdSize == 32)
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));

if (op == GT_LSH)
{
intrinsic = NI_AVX2_ShiftLeftLogical;
}
else if (op == GT_RSH)
{
intrinsic = NI_AVX2_ShiftRightArithmetic;
}
else
{
assert(op == GT_RSZ);
intrinsic = NI_AVX2_ShiftRightLogical;
}
}
else if (op == GT_LSH)
{
intrinsic = NI_SSE2_ShiftLeftLogical;
}
else if (op == GT_RSH)
{
intrinsic = NI_SSE2_ShiftRightArithmetic;
}
else
{
assert(op == GT_RSZ);
intrinsic = NI_SSE2_ShiftRightLogical;
}
break;
}

case GT_MUL:
{
GenTree** broadcastOp = nullptr;
Expand Down Expand Up @@ -18469,6 +18538,98 @@ GenTree* Compiler::gtNewSimdBinOpNode(genTreeOps op,
break;
}

case GT_LSH:
case GT_RSH:
case GT_RSZ:
{
assert(!varTypeIsFloating(simdBaseType));
assert((op != GT_RSH) || !varTypeIsUnsigned(simdBaseType));

// "over shifting" is platform specific behavior. We will match the C# behavior
// this requires we mask with (sizeof(T) * 8) - 1 which ensures the shift cannot
// exceed the number of bits available in `T`. This is roughly equivalent to
// x % (sizeof(T) * 8), but that is "more expensive" and only the same for unsigned
// inputs, where-as we have a signed-input and so negative values would differ.

unsigned shiftCountMask = (genTypeSize(simdBaseType) * 8) - 1;

if (op2->IsCnsIntOrI())
{
op2->AsIntCon()->gtIconVal &= shiftCountMask;

if ((simdSize == 8) && varTypeIsLong(simdBaseType))
{
if (op == GT_LSH)
{
intrinsic = NI_AdvSimd_ShiftLeftLogicalScalar;
}
else if (op == GT_RSH)
{
intrinsic = NI_AdvSimd_ShiftRightArithmeticScalar;
}
else
{
assert(op == GT_RSZ);
intrinsic = NI_AdvSimd_ShiftRightLogicalScalar;
}
}
else if (op == GT_LSH)
{
intrinsic = NI_AdvSimd_ShiftLeftLogical;
}
else if (op == GT_RSH)
{
intrinsic = NI_AdvSimd_ShiftRightArithmetic;
}
else
{
assert(op == GT_RSZ);
intrinsic = NI_AdvSimd_ShiftRightLogical;
}
}
else
{
op2 = gtNewOperNode(GT_AND, TYP_INT, op2, gtNewIconNode(shiftCountMask));

if (op != GT_LSH)
{
op2 = gtNewOperNode(GT_NEG, TYP_INT, op2);
}

op2 = gtNewSimdCreateBroadcastNode(type, op2, simdBaseJitType, simdSize, isSimdAsHWIntrinsic);

if ((simdSize == 8) && varTypeIsLong(simdBaseType))
{
if (op == GT_LSH)
{
intrinsic = NI_AdvSimd_ShiftLogicalScalar;
}
else if (op == GT_RSH)
{
intrinsic = NI_AdvSimd_ShiftArithmeticScalar;
}
else
{
intrinsic = NI_AdvSimd_ShiftLogicalScalar;
}
}
else if (op == GT_LSH)
{
intrinsic = NI_AdvSimd_ShiftLogical;
}
else if (op == GT_RSH)
{
intrinsic = NI_AdvSimd_ShiftArithmetic;
}
else
{
assert(op == GT_RSZ);
intrinsic = NI_AdvSimd_ShiftLogical;
}
}
break;
}

case GT_MUL:
{
assert(!varTypeIsLong(simdBaseType));
Expand Down Expand Up @@ -20596,6 +20757,123 @@ GenTree* Compiler::gtNewSimdSqrtNode(
return gtNewSimdHWIntrinsicNode(type, op1, intrinsic, simdBaseJitType, simdSize, isSimdAsHWIntrinsic);
}

GenTree* Compiler::gtNewSimdSumNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize, bool isSimdAsHWIntrinsic)
{
assert(IsBaselineSimdIsaSupportedDebugOnly());

var_types simdType = getSIMDTypeForSize(simdSize);
assert(varTypeIsSIMD(simdType));

assert(op1 != nullptr);
assert(op1->TypeIs(simdType));

var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
assert(varTypeIsArithmetic(simdBaseType));

NamedIntrinsic intrinsic = NI_Illegal;
GenTree* tmp = nullptr;
CORINFO_CLASS_HANDLE clsHnd = gtGetStructHandleForSIMD(simdType, simdBaseJitType);

#if defined(TARGET_XARCH)
assert(!varTypeIsByte(simdBaseType) && !varTypeIsLong(simdBaseType));

// HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together.
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength);

if (simdSize == 32)
{
// Minus 1 because for the last pass we split the vector to low / high and add them together.
haddCount -= 1;

if (varTypeIsFloating(simdBaseType))
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
intrinsic = NI_AVX_HorizontalAdd;
}
else
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));
intrinsic = NI_AVX2_HorizontalAdd;
}
}
else if (varTypeIsFloating(simdBaseType))
{
assert(compIsaSupportedDebugOnly(InstructionSet_SSE3));
intrinsic = NI_SSE3_HorizontalAdd;
}
else
{
assert(compIsaSupportedDebugOnly(InstructionSet_SSSE3));
intrinsic = NI_SSSE3_HorizontalAdd;
}

for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL, nullptr DEBUGARG("Clone op1 for vector sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, intrinsic, simdBaseJitType, simdSize);
}

if (simdSize == 32)
{
intrinsic = (simdBaseType == TYP_FLOAT) ? NI_SSE_Add : NI_SSE2_Add;

op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL, nullptr DEBUGARG("Clone op1 for vector sum"));
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode(0x01, TYP_INT), NI_AVX_ExtractVector128,
simdBaseJitType, simdSize);

tmp = gtNewSimdAsHWIntrinsicNode(simdType, tmp, NI_Vector256_GetLower, simdBaseJitType, simdSize);
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, tmp, intrinsic, simdBaseJitType, 16);
}

return gtNewSimdAsHWIntrinsicNode(type, op1, NI_Vector128_ToScalar, simdBaseJitType, simdSize);
#elif defined(TARGET_ARM64)
switch (simdBaseType)
{
case TYP_BYTE:
case TYP_UBYTE:
case TYP_SHORT:
case TYP_USHORT:
case TYP_INT:
case TYP_UINT:
{
tmp = gtNewSimdAsHWIntrinsicNode(simdType, op1, NI_AdvSimd_Arm64_AddAcross, simdBaseJitType, simdSize);
return gtNewSimdAsHWIntrinsicNode(type, tmp, NI_Vector64_ToScalar, simdBaseJitType, 8);
}
case TYP_FLOAT:
{
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength);

for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for vector sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, NI_AdvSimd_Arm64_AddPairwise, simdBaseJitType,
simdSize);
}

return gtNewSimdAsHWIntrinsicNode(type, op1, NI_Vector128_ToScalar, simdBaseJitType, simdSize);
}
case TYP_DOUBLE:
case TYP_LONG:
case TYP_ULONG:
{
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD8, op1, NI_AdvSimd_Arm64_AddPairwiseScalar, simdBaseJitType,
simdSize);
return gtNewSimdAsHWIntrinsicNode(type, op1, NI_Vector64_ToScalar, simdBaseJitType, 8);
}
default:
{
unreached();
}
}
#else
#error Unsupported platform
#endif // !TARGET_XARCH && !TARGET_ARM64
}

GenTree* Compiler::gtNewSimdUnOpNode(genTreeOps op,
var_types type,
GenTree* op1,
Expand Down
Loading

0 comments on commit cfe5e98

Please sign in to comment.