From 06323f61a74f5c447acae70912f378d13fa7df23 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Tue, 16 Jul 2024 07:33:34 -0700 Subject: [PATCH] Ensure we check the right types and handle some edge cases --- src/coreclr/jit/hwintrinsicxarch.cpp | 1 + src/coreclr/jit/morph.cpp | 40 ++++++++++--------------- src/coreclr/jit/valuenum.cpp | 44 +++++++++++++++------------- 3 files changed, 40 insertions(+), 45 deletions(-) diff --git a/src/coreclr/jit/hwintrinsicxarch.cpp b/src/coreclr/jit/hwintrinsicxarch.cpp index 90ac95ac52a3bc..59b35027d71bbd 100644 --- a/src/coreclr/jit/hwintrinsicxarch.cpp +++ b/src/coreclr/jit/hwintrinsicxarch.cpp @@ -882,6 +882,7 @@ NamedIntrinsic HWIntrinsicInfo::lookupEvexMaskId(NamedIntrinsic intrinsic) case NI_SSE2_CompareLessThan: case NI_SSE42_CompareLessThan: case NI_AVX_CompareLessThan: + case NI_AVX2_CompareLessThan: { return NI_EVEX_CompareLessThanMask; } diff --git a/src/coreclr/jit/morph.cpp b/src/coreclr/jit/morph.cpp index a716b56ba0438b..9d2b493a67e5cf 100644 --- a/src/coreclr/jit/morph.cpp +++ b/src/coreclr/jit/morph.cpp @@ -9939,15 +9939,19 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) GenTreeHWIntrinsic* op1Intrin = op1->AsHWIntrinsic(); - if (!isCndSel) + if (isCndSel) { - // CndSel knows how to handle mismatched mask sizes, but not all consumers can - - if (genTypeSize(op1Intrin->GetSimdBaseType()) != genTypeSize(simdBaseType)) + if (op1->OperIsConvertMaskToVector()) { + // We're already in the "correct" shape for lowering to take advantage break; } } + else if (genTypeSize(op1Intrin->GetSimdBaseType()) != genTypeSize(simdBaseType)) + { + // CndSel knows how to handle mismatched mask sizes, but not all consumers can + break; + } if (!canUseEvexEncoding()) { @@ -9956,8 +9960,7 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) // We have something expecting a mask and have a case where we could be producing a mask directly - NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId(); - + NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId(); NamedIntrinsic evexIntrinId = HWIntrinsicInfo::lookupEvexMaskId(op1IntrinId); if (evexIntrinId != NI_Illegal) @@ -9967,18 +9970,7 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) op1Intrin->ChangeHWIntrinsicId(evexIntrinId); op1Intrin->gtType = TYP_MASK; -#ifdef DEBUG - // We want to remorph the nodes introduced below, so clear the flag - - auto resetMorphedFlag = [](GenTree** slot, fgWalkData* data) -> fgWalkResult { - (*slot)->gtDebugFlags &= ~GTF_DEBUG_NODE_MORPHED; - return WALK_CONTINUE; - }; - - fgWalkTreePost(&op1, resetMorphedFlag); -#endif // DEBUG - - switch (op1IntrinId) + switch (evexIntrinId) { case NI_EVEX_AndMask: case NI_EVEX_AndNotMask: @@ -9986,22 +9978,21 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) case NI_EVEX_XorMask: { // There's a few special nodes which are allowed to combine masks - // and so we handle these by inserting a CvtVectorToMask over each - // operand and remorphing, which will get us the optimized sequence + // and so we handle these explicitly by introducing the vector to + // mask conversions required and folding, which will get us to an + // optimized sequence where relevant cvtNode = op1Intrin->Op(1); cvtNode = gtNewSimdCvtVectorToMaskNode(TYP_MASK, cvtNode, simdBaseJitType, simdSize); - cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic()); + cvtNode = gtFoldExpr(cvtNode); op1Intrin->Op(1) = cvtNode; cvtNode = op1Intrin->Op(2); cvtNode = gtNewSimdCvtVectorToMaskNode(TYP_MASK, cvtNode, simdBaseJitType, simdSize); - cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic()); + cvtNode = gtFoldExpr(cvtNode); op1Intrin->Op(2) = cvtNode; - - op1 = fgMorphHWIntrinsic(op1Intrin); break; } @@ -10015,7 +10006,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) { // This will allow lowering to emit a vblendm and potentially do embedded masking cvtNode = gtNewSimdCvtMaskToVectorNode(retType, op1, simdBaseJitType, simdSize); - cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic()); node->Op(1) = cvtNode; return node; diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index 3fbfec141dd497..c45ad191c5baf9 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -8268,7 +8268,9 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, if (varTypeIsFloating(baseType)) { // Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types - if (VNIsVectorNaN(type, baseType, cnsVN)) + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (VNIsVectorNaN(simdType, baseType, cnsVN)) { return VNZeroForType(type); } @@ -8278,22 +8280,22 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, case GT_GT: { - ValueNum zeroVN = VNZeroForType(type); + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); if (varTypeIsUnsigned(baseType)) { // Handle `(0 > x) == false` for unsigned types. - if ((cnsVN == arg0VN) && (cnsVN == zeroVN)) + if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType))) { - return zeroVN; + return VNZeroForType(type); } } else if (varTypeIsFloating(baseType)) { // Handle `(x > NaN) == false` and `(NaN > x) == false` for floating-point types - if (VNIsVectorNaN(type, baseType, cnsVN)) + if (VNIsVectorNaN(simdType, baseType, cnsVN)) { - return zeroVN; + return VNZeroForType(type); } } break; @@ -8301,12 +8303,12 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, case GT_GE: { - ValueNum zeroVN = VNZeroForType(type); + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); if (varTypeIsUnsigned(baseType)) { // Handle `x >= 0 == true` for unsigned types. - if ((cnsVN == arg1VN) && (cnsVN == zeroVN)) + if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType))) { return VNAllBitsForType(type); } @@ -8314,9 +8316,9 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, else if (varTypeIsFloating(baseType)) { // Handle `(x >= NaN) == false` and `(NaN >= x) == false` for floating-point types - if (VNIsVectorNaN(type, baseType, cnsVN)) + if (VNIsVectorNaN(simdType, baseType, cnsVN)) { - return zeroVN; + return VNZeroForType(type); } } break; @@ -8324,22 +8326,22 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, case GT_LT: { - ValueNum zeroVN = VNZeroForType(type); + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); if (varTypeIsUnsigned(baseType)) { // Handle `x < 0 == false` for unsigned types. - if ((cnsVN == arg1VN) && (cnsVN == zeroVN)) + if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType))) { - return zeroVN; + return VNZeroForType(type); } } else if (varTypeIsFloating(baseType)) { // Handle `(x < NaN) == false` and `(NaN < x) == false` for floating-point types - if (VNIsVectorNaN(type, baseType, cnsVN)) + if (VNIsVectorNaN(simdType, baseType, cnsVN)) { - return zeroVN; + return VNZeroForType(type); } } break; @@ -8347,12 +8349,12 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, case GT_LE: { - ValueNum zeroVN = VNZeroForType(type); + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); if (varTypeIsUnsigned(baseType)) { // Handle `0 <= x == true` for unsigned types. - if ((cnsVN == arg0VN) && (cnsVN == zeroVN)) + if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType))) { return VNAllBitsForType(type); } @@ -8360,9 +8362,9 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, else if (varTypeIsFloating(baseType)) { // Handle `(x <= NaN) == false` and `(NaN <= x) == false` for floating-point types - if (VNIsVectorNaN(type, baseType, cnsVN)) + if (VNIsVectorNaN(simdType, baseType, cnsVN)) { - return zeroVN; + return VNZeroForType(type); } } break; @@ -8417,10 +8419,12 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, case GT_NE: { + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + if (varTypeIsFloating(baseType)) { // Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types - if (VNIsVectorNaN(type, baseType, cnsVN)) + if (VNIsVectorNaN(simdType, baseType, cnsVN)) { return VNAllBitsForType(type); }