diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index 78ed00c13041cf..9edc37b5b242ed 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -32109,6 +32109,15 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const return true; } +//------------------------------------------------------------------------ +// gtFoldExprHWIntrinsic: Attempt to fold a HWIntrinsic +// +// Arguments: +// tree - HWIntrinsic to fold +// +// Return Value: +// folded expression if it could be folded, else the original tree +// #if defined(FEATURE_HW_INTRINSICS) GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) { @@ -32252,7 +32261,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) // We shouldn't find AND_NOT nodes since it should only be produced in lowering assert(oper != GT_AND_NOT); -#if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH) +#ifdef FEATURE_MASKED_HW_INTRINSICS +#ifdef TARGET_XARCH if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper)) { // Comparisons that produce masks lead to more verbose trees than @@ -32370,7 +32380,75 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) } } } -#endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH +#elif defined(TARGET_ARM64) + // Check if the tree can be folded into a mask variant + if (HWIntrinsicInfo::HasAllMaskVariant(tree->GetHWIntrinsicId())) + { + NamedIntrinsic maskVariant = HWIntrinsicInfo::GetMaskVariant(tree->GetHWIntrinsicId()); + + assert(opCount == (size_t)HWIntrinsicInfo::lookupNumArgs(maskVariant)); + + // Check all operands are valid + bool canFold = true; + if (ni == NI_Sve_ConditionalSelect) + { + assert(varTypeIsMask(op1)); + canFold = (op2->OperIsConvertMaskToVector() && op3->OperIsConvertMaskToVector()); + } + else + { + for (size_t i = 1; i <= opCount && canFold; i++) + { + canFold &= tree->Op(i)->OperIsConvertMaskToVector(); + } + } + + if (canFold) + { + // Convert all the operands to masks + for (size_t i = 1; i <= opCount; i++) + { + if (tree->Op(i)->OperIsConvertMaskToVector()) + { + // Replace with op1. + tree->Op(i) = tree->Op(i)->AsHWIntrinsic()->Op(1); + } + else if (tree->Op(i)->IsVectorZero()) + { + // Replace the vector of zeroes with a mask of zeroes. + tree->Op(i) = gtNewSimdFalseMaskByteNode(); + tree->Op(i)->SetMorphed(this); + } + assert(varTypeIsMask(tree->Op(i))); + } + + // Switch to the mask variant + switch (opCount) + { + case 1: + tree->ResetHWIntrinsicId(maskVariant, tree->Op(1)); + break; + case 2: + tree->ResetHWIntrinsicId(maskVariant, tree->Op(1), tree->Op(2)); + break; + case 3: + tree->ResetHWIntrinsicId(maskVariant, this, tree->Op(1), tree->Op(2), tree->Op(3)); + break; + default: + unreached(); + } + + tree->gtType = TYP_MASK; + tree->SetMorphed(this); + tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic(); + tree->SetMorphed(this); + op1 = tree->Op(1); + op2 = nullptr; + op3 = nullptr; + } + } +#endif // TARGET_ARM64 +#endif // FEATURE_MASKED_HW_INTRINSICS switch (ni) { @@ -33559,7 +33637,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) // op2 = op2 & op1 op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon()); - // op3 = op2 & ~op1 + // op3 = op3 & ~op1 op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon()); // op2 = op2 | op3 @@ -33572,8 +33650,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) #if defined(TARGET_ARM64) case NI_Sve_ConditionalSelect: + case NI_Sve_ConditionalSelect_Predicates: { - assert(!varTypeIsMask(retType)); assert(varTypeIsMask(op1)); if (cnsNode != op1) @@ -33602,10 +33680,11 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) if (op2->IsCnsVec() && op3->IsCnsVec()) { + assert(ni == NI_Sve_ConditionalSelect); assert(op2->gtType == TYP_SIMD16); assert(op3->gtType == TYP_SIMD16); - simd16_t op1SimdVal; + simd16_t op1SimdVal = {}; EvaluateSimdCvtMaskToVector(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal); // op2 = op2 & op1 @@ -33614,7 +33693,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) op1SimdVal); op2->AsVecCon()->gtSimd16Val = result; - // op3 = op2 & ~op1 + // op3 = op3 & ~op1 result = {}; EvaluateBinarySimd(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val, op1SimdVal); @@ -33625,6 +33704,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) resultNode = op2; } + else if (op2->IsCnsMsk() && op3->IsCnsMsk()) + { + assert(ni == NI_Sve_ConditionalSelect_Predicates); + + // op2 = op2 & op1 + simdmask_t result = {}; + EvaluateBinaryMask(GT_AND, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal, + op1->AsMskCon()->gtSimdMaskVal); + op2->AsMskCon()->gtSimdMaskVal = result; + + // op3 = op3 & ~op1 + result = {}; + EvaluateBinaryMask(GT_AND_NOT, false, simdBaseType, &result, + op3->AsMskCon()->gtSimdMaskVal, op1->AsMskCon()->gtSimdMaskVal); + op3->AsMskCon()->gtSimdMaskVal = result; + + // op2 = op2 | op3 + result = {}; + EvaluateBinaryMask(GT_OR, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal, + op3->AsMskCon()->gtSimdMaskVal); + op2->AsMskCon()->gtSimdMaskVal = result; + + resultNode = op2; + } break; } #endif // TARGET_ARM64 diff --git a/src/coreclr/jit/simd.h b/src/coreclr/jit/simd.h index f6da9993f90d45..9ade53c5f60cf9 100644 --- a/src/coreclr/jit/simd.h +++ b/src/coreclr/jit/simd.h @@ -1090,6 +1090,12 @@ void EvaluateBinaryMask( break; } + case 16: + { + bitMask = 0x0001000100010001; + break; + } + default: { unreached(); diff --git a/src/tests/JIT/opt/SVE/PredicateInstructions.cs b/src/tests/JIT/opt/SVE/PredicateInstructions.cs index b1336674f1638b..287e90b30e4fe6 100644 --- a/src/tests/JIT/opt/SVE/PredicateInstructions.cs +++ b/src/tests/JIT/opt/SVE/PredicateInstructions.cs @@ -17,110 +17,164 @@ public static void TestPredicateInstructions() { if (Sve.IsSupported) { - ZipLow(); - ZipHigh(); - UnzipOdd(); - UnzipEven(); - TransposeOdd(); - TransposeEven(); - ReverseElement(); - And(); - BitwiseClear(); - Xor(); - Or(); - ConditionalSelect(); + Vector vecsb = Vector.Create(2); + Vector vecs = Vector.Create(2); + Vector vecus = Vector.Create(2); + Vector veci = Vector.Create(3); + Vector vecui = Vector.Create(5); + Vector vecl = Vector.Create(7); + + ZipLowMask(vecs, vecs); + ZipHighMask(vecui, vecui); + UnzipOddMask(vecs, vecs); + UnzipEvenMask(vecsb, vecsb); + TransposeEvenMask(vecl, vecl); + TransposeOddMask(vecs, vecs); + ReverseElementMask(vecs, vecs); + AndMask(vecs, vecs); + BitwiseClearMask(vecs, vecs); + XorMask(veci, veci); + OrMask(vecs, vecs); + ConditionalSelectMask(veci, veci, veci); + + UnzipEvenZipLowMask(vecs, vecs); + TransposeEvenAndMask(vecs, vecs, vecs); + } } + // These should use the predicate variants. + // Sve intrinsics that return masks (Compare) or use mask arguments (CreateBreakAfterMask) are used + // to ensure masks are used. + + [MethodImpl(MethodImplOptions.NoInlining)] - static Vector ZipLow() + static Vector ZipLowMask(Vector a, Vector b) { - return Sve.ZipLow(Vector.Zero, Sve.CreateTrueMaskInt16()); + //ARM64-FULL-LINE: zip1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h + return Sve.ZipLow(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector ZipHigh() + static Vector ZipHighMask(Vector a, Vector b) { - return Sve.ZipHigh(Sve.CreateTrueMaskUInt32(), Sve.CreateTrueMaskUInt32()); + //ARM64-FULL-LINE: zip2 {{p[0-9]+}}.s, {{p[0-9]+}}.s, {{p[0-9]+}}.s + return Sve.CreateBreakAfterMask(Sve.ZipHigh(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateTrueMaskUInt32()); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector UnzipEven() + static Vector UnzipEvenMask(Vector a, Vector b) { - return Sve.UnzipEven(Sve.CreateTrueMaskSByte(), Vector.Zero); + //ARM64-FULL-LINE: uzp1 {{p[0-9]+}}.b, {{p[0-9]+}}.b, {{p[0-9]+}}.b + return Sve.UnzipEven(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector UnzipOdd() + static Vector UnzipOddMask(Vector a, Vector b) { - return Sve.UnzipOdd(Sve.CreateTrueMaskInt16(), Sve.CreateFalseMaskInt16()); + //ARM64-FULL-LINE: uzp2 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h + return Sve.CreateBreakAfterMask(Sve.UnzipOdd(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateTrueMaskInt16()); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector TransposeEven() + static Vector TransposeEvenMask(Vector a, Vector b) { - return Sve.TransposeEven(Sve.CreateFalseMaskInt64(), Sve.CreateTrueMaskInt64()); + //ARM64-FULL-LINE: trn1 {{p[0-9]+}}.d, {{p[0-9]+}}.d, {{p[0-9]+}}.d + return Sve.CreateBreakAfterMask(Sve.TransposeEven(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateFalseMaskInt64()); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector TransposeOdd() + static Vector TransposeOddMask(Vector a, Vector b) { - return Sve.TransposeOdd(Vector.Zero, Sve.CreateTrueMaskInt16()); + //ARM64-FULL-LINE: trn2 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h + return Sve.TransposeOdd(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector ReverseElement() + static Vector ReverseElementMask(Vector a, Vector b) { - return Sve.ReverseElement(Sve.CreateTrueMaskInt16()); + //ARM64-FULL-LINE: rev {{p[0-9]+}}.h, {{p[0-9]+}}.h + return Sve.CreateBreakAfterMask(Sve.ReverseElement(Sve.CompareGreaterThan(a, b)), Sve.CreateFalseMaskInt16()); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector And() + static Vector AndMask(Vector a, Vector b) { - return Sve.ConditionalSelect( - Sve.CreateTrueMaskInt16(), - Sve.And(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()), - Vector.Zero - ); + //ARM64-FULL-LINE: and {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b + return Sve.CreateBreakAfterMask( + Sve.ConditionalSelect( + Sve.CreateTrueMaskInt16(), + Sve.And(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Vector.Zero), + Sve.CreateFalseMaskInt16()); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector BitwiseClear() + static Vector BitwiseClearMask(Vector a, Vector b) { + //ARM64-FULL-LINE: bic {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b return Sve.ConditionalSelect( - Sve.CreateFalseMaskInt16(), - Sve.BitwiseClear(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()), - Vector.Zero - ); + Sve.CreateTrueMaskInt16(), + Sve.BitwiseClear(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Vector.Zero); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector Xor() + static Vector XorMask(Vector a, Vector b) { - return Sve.ConditionalSelect( - Sve.CreateTrueMaskInt32(), - Sve.Xor(Sve.CreateTrueMaskInt32(), Sve.CreateTrueMaskInt32()), - Vector.Zero - ); + //ARM64-FULL-LINE: eor {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b + return Sve.CreateBreakAfterMask( + Sve.ConditionalSelect( + Sve.CreateTrueMaskInt32(), + Sve.Xor(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Vector.Zero), + Sve.CreateFalseMaskInt32()); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector Or() + static Vector OrMask(Vector a, Vector b) { + //ARM64-FULL-LINE: orr {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b return Sve.ConditionalSelect( - Sve.CreateTrueMaskInt16(), - Sve.Or(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()), - Vector.Zero - ); + Sve.CreateTrueMaskInt16(), + Sve.Or(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Vector.Zero); } [MethodImpl(MethodImplOptions.NoInlining)] - static Vector ConditionalSelect() + static Vector ConditionalSelectMask(Vector v, Vector a, Vector b) { - return Sve.ConditionalSelect( - Vector.Zero, - Sve.CreateFalseMaskInt32(), - Sve.CreateTrueMaskInt32() - ); + // Use a passed in vector for the mask to prevent optimising away the select + //ARM64-FULL-LINE: sel {{p[0-9]+}}.b, {{p[0-9]+}}, {{p[0-9]+}}.b, {{p[0-9]+}}.b + return Sve.CreateBreakAfterMask( + Sve.ConditionalSelect(v, Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Sve.CreateFalseMaskInt32()); + } + + // These have multiple uses of the predicate variants + + [MethodImpl(MethodImplOptions.NoInlining)] + static Vector UnzipEvenZipLowMask(Vector a, Vector b) + { + //ARM64-FULL-LINE: zip1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h + //ARM64-FULL-LINE: uzp1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h + return Sve.CreateBreakAfterMask( + Sve.UnzipEven( + Sve.ZipLow(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Sve.CompareLessThan(a, b)), + Sve.CreateTrueMaskInt16()); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static Vector TransposeEvenAndMask(Vector v, Vector a, Vector b) + { + //ARM64-FULL-LINE: and {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b + //ARM64-FULL-LINE: trn1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h + return Sve.TransposeEven( + Sve.CompareGreaterThan(a, b), + Sve.ConditionalSelect( + Sve.CreateTrueMaskInt16(), + Sve.And(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), + Sve.CompareLessThan(a, b))); } -} \ No newline at end of file +}