From 8466f366d487816535cfd258ffb1ca3e44096f7e Mon Sep 17 00:00:00 2001 From: Akira Saitoh Date: Wed, 9 Nov 2022 22:53:02 +0900 Subject: [PATCH] AArch64: Implement Vector Masked Compare operations This commit implements Vector Masked Compare evaluators. Signed-off-by: Akira Saitoh --- compiler/aarch64/codegen/OMRCodeGenerator.cpp | 6 + compiler/aarch64/codegen/OMRTreeEvaluator.cpp | 226 +++++++++--------- 2 files changed, 113 insertions(+), 119 deletions(-) diff --git a/compiler/aarch64/codegen/OMRCodeGenerator.cpp b/compiler/aarch64/codegen/OMRCodeGenerator.cpp index 36c1210f618..cbf4979b44e 100644 --- a/compiler/aarch64/codegen/OMRCodeGenerator.cpp +++ b/compiler/aarch64/codegen/OMRCodeGenerator.cpp @@ -672,6 +672,12 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I case TR::vcmple: case TR::vcmplt: case TR::vcmpne: + case TR::vmcmpeq: + case TR::vmcmpge: + case TR::vmcmpgt: + case TR::vmcmple: + case TR::vmcmplt: + case TR::vmcmpne: return true; case TR::vand: case TR::vor: diff --git a/compiler/aarch64/codegen/OMRTreeEvaluator.cpp b/compiler/aarch64/codegen/OMRTreeEvaluator.cpp index 7bc06fdc1c2..93c3ce52456 100644 --- a/compiler/aarch64/codegen/OMRTreeEvaluator.cpp +++ b/compiler/aarch64/codegen/OMRTreeEvaluator.cpp @@ -1813,6 +1813,69 @@ static const TR::InstOpCode::Mnemonic vectorCompareZeroOpCodes[NumVectorCompareO { TR::InstOpCode::vcmge16b_zero, TR::InstOpCode::vcmge8h_zero, TR::InstOpCode::vcmge4s_zero, TR::InstOpCode::vcmge2d_zero, TR::InstOpCode::vfcmge4s_zero, TR::InstOpCode::vfcmge2d_zero}, // GE }; +// prototype declaration of vcmpHelper +static TR::Register* vcmpHelper(TR::Node *node, VectorCompareOps compareOp, bool omitNot, bool *flipCompareResult, TR::CodeGenerator *cg); + + +static +VectorCompareOps getVectorCompareOp(TR::VectorOperation op) + { + switch (op) + { + case TR::vcmpeq: + return VECTOR_COMPARE_EQ; + case TR::vcmpne: + return VECTOR_COMPARE_NE; + case TR::vcmpgt: + return VECTOR_COMPARE_GT; + case TR::vcmpge: + return VECTOR_COMPARE_GE; + case TR::vcmplt: + return VECTOR_COMPARE_LT; + case TR::vcmple: + return VECTOR_COMPARE_LE; + default: + return VECTOR_COMPARE_INVALID; + } + } + +/** + * @brief Evaluates a mask node and returns a mask register + * + * @param[in] node: node + * @param[out] flipMask: true if mask value needs to be flipped + * @param[in] cg: CodeGenerator + * @return mask register + */ +static inline +TR::Register *evaluateMaskNode(TR::Node *node, bool &flipMask, TR::CodeGenerator *cg) + { + TR::ILOpCode maskOp = node->getOpCode(); + + TR::Register *maskReg = NULL; + VectorCompareOps compareOp; + TR::VectorOperation convOp; + if (maskOp.isVectorOpCode() && maskOp.isBooleanCompare() + && ((compareOp = getVectorCompareOp(maskOp.getVectorOperation())) != VECTOR_COMPARE_INVALID) + && (node->getReferenceCount() == 1) && (node->getRegister() == NULL)) + { + maskReg = vcmpHelper(node, compareOp, true, &flipMask, cg); + } + else if (maskOp.isVectorOpCode() && maskOp.isConversion() && maskOp.isMaskResult() + && (((convOp = maskOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m)) + && (node->getReferenceCount() == 1) && (node->getRegister() == NULL)) + { + flipMask = true; + maskReg = (convOp == TR::s2m) ? toMaskConversionHelper(node, true, cg) : toMaskConversionHelper(node, true, cg); + } + else + { + maskReg = cg->evaluate(node); + } + TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind"); + return maskReg; + } + /** * @brief A helper function for generating instuction sequence for vector compare operations * @@ -1829,11 +1892,13 @@ vcmpHelper(TR::Node *node, VectorCompareOps compareOp, bool omitNot, bool *flipC TR::Node *firstChild = node->getFirstChild(); TR::Node *secondChild = node->getSecondChild(); TR::InstOpCode::Mnemonic op; - const bool notAfterCompare = (compareOp == VECTOR_COMPARE_NE); + bool notAfterCompare = (compareOp == VECTOR_COMPARE_NE); bool recursivelyDecRefCountOnSecondChild; TR::Register *firstReg = cg->evaluate(firstChild); TR::Register *targetReg = cg->allocateRegister(TR_VRF); TR::DataType elemType = firstChild->getDataType().getVectorElementType(); + TR::ILOpCode opcode = node->getOpCode(); + const bool isMasked = opcode.isVectorMasked(); TR_ASSERT_FATAL_WITH_NODE(node, (elemType >= TR::Int8) && (elemType <= TR::Double) , "unrecognized vector type %s", firstChild->getDataType().toString()); @@ -1859,9 +1924,38 @@ vcmpHelper(TR::Node *node, VectorCompareOps compareOp, bool omitNot, bool *flipC } } + if (isMasked) + { + TR::Node *thirdChild = node->getThirdChild(); + bool flipMask = false; + TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg); + + /* masked compare returns the bitwise logical conjunction of the comparison result and the mask */ + if (notAfterCompare) + { + if (flipMask) + { + /* (not a) and (not b) = not (a or b) */ + generateTrg1Src2Instruction(cg, TR::InstOpCode::vorr16b, node, targetReg, targetReg, maskReg); + } + else + { + /* a and (not b) = not ((not a) and b) */ + generateTrg1Src2Instruction(cg, TR::InstOpCode::vbic16b, node, targetReg, maskReg, targetReg); + notAfterCompare = false; + } + } + else + { + generateTrg1Src2Instruction(cg, flipMask ? TR::InstOpCode::vbic16b : TR::InstOpCode::vand16b, node, targetReg, targetReg, maskReg); + } + + cg->decReferenceCount(thirdChild); + } + /* - * If this vector compare node only appears once as a child of masked binary operations, - * the NOT instruction can be omitted because the inlineVectorMaskedBinaryOp generates `bit` instead of `bif`. + * If this vector compare node only appears once as a child of masked operations, + * the NOT instruction can be omitted because the caller method generates `bit` instead of `bif`. */ if (!omitNot) { @@ -2616,28 +2710,6 @@ OMR::ARM64::TreeEvaluator::vRegStoreEvaluator(TR::Node *node, TR::CodeGenerator return TR::TreeEvaluator::unImpOpEvaluator(node, cg); } -static -VectorCompareOps getVectorCompareOp(TR::VectorOperation op) - { - switch (op) - { - case TR::vcmpeq: - return VECTOR_COMPARE_EQ; - case TR::vcmpne: - return VECTOR_COMPARE_NE; - case TR::vcmpgt: - return VECTOR_COMPARE_GT; - case TR::vcmpge: - return VECTOR_COMPARE_GE; - case TR::vcmplt: - return VECTOR_COMPARE_LT; - case TR::vcmple: - return VECTOR_COMPARE_LE; - default: - return VECTOR_COMPARE_INVALID; - } - } - typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg); /** * @brief Helper functions for generating instruction sequence for masked binary operations @@ -2673,29 +2745,8 @@ inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode generateTrg1Src2Instruction(cg, op, node, resReg, lhsReg, rhsReg); } - TR::ILOpCode thirdOp = thirdChild->getOpCode(); bool flipMask = false; - TR::Register *maskReg = NULL; - VectorCompareOps compareOp; - TR::VectorOperation convOp; - if (thirdOp.isVectorOpCode() && thirdOp.isBooleanCompare() && (!thirdOp.isVectorMasked()) - && ((compareOp = getVectorCompareOp(thirdOp.getVectorOperation())) != VECTOR_COMPARE_INVALID) - && (thirdChild->getReferenceCount() == 1) && (thirdChild->getRegister() == NULL)) - { - maskReg = vcmpHelper(thirdChild, compareOp, true, &flipMask, cg); - } - else if (thirdOp.isVectorOpCode() && thirdOp.isConversion() && thirdOp.isMaskResult() - && (((convOp = thirdOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m)) - && (thirdChild->getReferenceCount() == 1) && (thirdChild->getRegister() == NULL)) - { - flipMask = true; - maskReg = (convOp == TR::s2m) ? toMaskConversionHelper(thirdChild, true, cg) : toMaskConversionHelper(thirdChild, true, cg); - } - else - { - maskReg = cg->evaluate(thirdChild); - } - TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind"); + TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg); /* * BIT inserts each bit from the first source if the corresponding bit of the second source is 1. @@ -2732,29 +2783,8 @@ inlineVectorMaskedUnaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode: generateTrg1Src1Instruction(cg, op, node, resReg, srcReg); - TR::ILOpCode secondOp = secondChild->getOpCode(); bool flipMask = false; - TR::Register *maskReg = NULL; - VectorCompareOps compareOp; - TR::VectorOperation convOp; - if (secondOp.isVectorOpCode() && secondOp.isBooleanCompare() && (!secondOp.isVectorMasked()) - && ((compareOp = getVectorCompareOp(secondOp.getVectorOperation())) != VECTOR_COMPARE_INVALID) - && (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL)) - { - maskReg = vcmpHelper(secondChild, compareOp, true, &flipMask, cg); - } - else if (secondOp.isVectorOpCode() && secondOp.isConversion() && secondOp.isMaskResult() - && (((convOp = secondOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m)) - && (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL)) - { - flipMask = true; - maskReg = (convOp == TR::s2m) ? toMaskConversionHelper(secondChild, true, cg) : toMaskConversionHelper(secondChild, true, cg); - } - else - { - maskReg = cg->evaluate(secondChild); - } - TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind"); + TR::Register *maskReg = evaluateMaskNode(secondChild, flipMask, cg); /* * BIT inserts each bit from the first source if the corresponding bit of the second source is 1. @@ -2857,37 +2887,37 @@ OMR::ARM64::TreeEvaluator::vmandEvaluator(TR::Node *node, TR::CodeGenerator *cg) TR::Register* OMR::ARM64::TreeEvaluator::vmcmpeqEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + return vcmpHelper(node, VECTOR_COMPARE_EQ, false, NULL, cg); } TR::Register* OMR::ARM64::TreeEvaluator::vmcmpneEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + return vcmpHelper(node, VECTOR_COMPARE_NE, false, NULL, cg); } TR::Register* OMR::ARM64::TreeEvaluator::vmcmpgtEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + return vcmpHelper(node, VECTOR_COMPARE_GT, false, NULL, cg); } TR::Register* OMR::ARM64::TreeEvaluator::vmcmpgeEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + return vcmpHelper(node, VECTOR_COMPARE_GE, false, NULL, cg); } TR::Register* OMR::ARM64::TreeEvaluator::vmcmpltEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + return vcmpHelper(node, VECTOR_COMPARE_LT, false, NULL, cg); } TR::Register* OMR::ARM64::TreeEvaluator::vmcmpleEvaluator(TR::Node *node, TR::CodeGenerator *cg) { - return TR::TreeEvaluator::unImpOpEvaluator(node, cg); + return vcmpHelper(node, VECTOR_COMPARE_LE, false, NULL, cg); } TR::Register* @@ -2962,29 +2992,8 @@ OMR::ARM64::TreeEvaluator::vmfmaEvaluator(TR::Node *node, TR::CodeGenerator *cg) } generateTrg1Src2Instruction(cg, op, node, targetReg, firstReg, secondReg); - TR::ILOpCode fourthOp = fourthChild->getOpCode(); bool flipMask = false; - TR::Register *maskReg = NULL; - VectorCompareOps compareOp; - TR::VectorOperation convOp; - if (fourthOp.isVectorOpCode() && fourthOp.isBooleanCompare() && (!fourthOp.isVectorMasked()) - && ((compareOp = getVectorCompareOp(fourthOp.getVectorOperation())) != VECTOR_COMPARE_INVALID) - && (fourthChild->getReferenceCount() == 1) && (fourthChild->getRegister() == NULL)) - { - maskReg = vcmpHelper(fourthChild, compareOp, true, &flipMask, cg); - } - else if (fourthOp.isVectorOpCode() && fourthOp.isConversion() && fourthOp.isMaskResult() - && (((convOp = fourthOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m)) - && (fourthChild->getReferenceCount() == 1) && (fourthChild->getRegister() == NULL)) - { - flipMask = true; - maskReg = (convOp == TR::s2m) ? toMaskConversionHelper(fourthChild, true, cg) : toMaskConversionHelper(fourthChild, true, cg); - } - else - { - maskReg = cg->evaluate(fourthChild); - } - TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind"); + TR::Register *maskReg = evaluateMaskNode(fourthChild, flipMask, cg); /* * BIT inserts each bit from the first source if the corresponding bit of the second source is 1. @@ -3224,29 +3233,8 @@ inlineVectorMaskedReductionOp(TR::Node *node, TR::CodeGenerator *cg, TR::DataTyp TR_ASSERT_FATAL_WITH_NODE(node, sourceReg->getKind() == TR_VRF, "unexpected Register kind"); - TR::ILOpCode secondOp = secondChild->getOpCode(); bool flipMask = false; - TR::Register *maskReg = NULL; - VectorCompareOps compareOp; - TR::VectorOperation convOp; - if (secondOp.isVectorOpCode() && secondOp.isBooleanCompare() && (!secondOp.isVectorMasked()) - && ((compareOp = getVectorCompareOp(secondOp.getVectorOperation())) != VECTOR_COMPARE_INVALID) - && (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL)) - { - maskReg = vcmpHelper(secondChild, compareOp, true, &flipMask, cg); - } - else if (secondOp.isVectorOpCode() && secondOp.isConversion() && secondOp.isMaskResult() - && (((convOp = secondOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m)) - && (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL)) - { - flipMask = true; - maskReg = (convOp == TR::s2m) ? toMaskConversionHelper(secondChild, true, cg) : toMaskConversionHelper(secondChild, true, cg); - } - else - { - maskReg = cg->evaluate(secondChild); - } - TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind"); + TR::Register *maskReg = evaluateMaskNode(secondChild, flipMask, cg); TR::Register *tmpReg = cg->allocateRegister(TR_VRF); /* loads identity vector to tmpReg */