Skip to content

Commit

Permalink
AArch64: Implement Vector Masked Compare operations
Browse files Browse the repository at this point in the history
This commit implements Vector Masked Compare evaluators.

Signed-off-by: Akira Saitoh <saiaki@jp.ibm.com>
  • Loading branch information
Akira Saitoh committed Nov 11, 2022
1 parent 29a6de3 commit 8466f36
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 119 deletions.
6 changes: 6 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,12 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I
case TR::vcmple:
case TR::vcmplt:
case TR::vcmpne:
case TR::vmcmpeq:
case TR::vmcmpge:
case TR::vmcmpgt:
case TR::vmcmple:
case TR::vmcmplt:
case TR::vmcmpne:
return true;
case TR::vand:
case TR::vor:
Expand Down
226 changes: 107 additions & 119 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,69 @@ static const TR::InstOpCode::Mnemonic vectorCompareZeroOpCodes[NumVectorCompareO
{ TR::InstOpCode::vcmge16b_zero, TR::InstOpCode::vcmge8h_zero, TR::InstOpCode::vcmge4s_zero, TR::InstOpCode::vcmge2d_zero, TR::InstOpCode::vfcmge4s_zero, TR::InstOpCode::vfcmge2d_zero}, // GE
};

// prototype declaration of vcmpHelper
static TR::Register* vcmpHelper(TR::Node *node, VectorCompareOps compareOp, bool omitNot, bool *flipCompareResult, TR::CodeGenerator *cg);


static
VectorCompareOps getVectorCompareOp(TR::VectorOperation op)
{
switch (op)
{
case TR::vcmpeq:
return VECTOR_COMPARE_EQ;
case TR::vcmpne:
return VECTOR_COMPARE_NE;
case TR::vcmpgt:
return VECTOR_COMPARE_GT;
case TR::vcmpge:
return VECTOR_COMPARE_GE;
case TR::vcmplt:
return VECTOR_COMPARE_LT;
case TR::vcmple:
return VECTOR_COMPARE_LE;
default:
return VECTOR_COMPARE_INVALID;
}
}

/**
* @brief Evaluates a mask node and returns a mask register
*
* @param[in] node: node
* @param[out] flipMask: true if mask value needs to be flipped
* @param[in] cg: CodeGenerator
* @return mask register
*/
static inline
TR::Register *evaluateMaskNode(TR::Node *node, bool &flipMask, TR::CodeGenerator *cg)
{
TR::ILOpCode maskOp = node->getOpCode();

TR::Register *maskReg = NULL;
VectorCompareOps compareOp;
TR::VectorOperation convOp;
if (maskOp.isVectorOpCode() && maskOp.isBooleanCompare()
&& ((compareOp = getVectorCompareOp(maskOp.getVectorOperation())) != VECTOR_COMPARE_INVALID)
&& (node->getReferenceCount() == 1) && (node->getRegister() == NULL))
{
maskReg = vcmpHelper(node, compareOp, true, &flipMask, cg);
}
else if (maskOp.isVectorOpCode() && maskOp.isConversion() && maskOp.isMaskResult()
&& (((convOp = maskOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m))
&& (node->getReferenceCount() == 1) && (node->getRegister() == NULL))
{
flipMask = true;
maskReg = (convOp == TR::s2m) ? toMaskConversionHelper<TR::s2m>(node, true, cg) : toMaskConversionHelper<TR::v2m>(node, true, cg);
}
else
{
maskReg = cg->evaluate(node);
}
TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind");
return maskReg;
}

/**
* @brief A helper function for generating instuction sequence for vector compare operations
*
Expand All @@ -1829,11 +1892,13 @@ vcmpHelper(TR::Node *node, VectorCompareOps compareOp, bool omitNot, bool *flipC
TR::Node *firstChild = node->getFirstChild();
TR::Node *secondChild = node->getSecondChild();
TR::InstOpCode::Mnemonic op;
const bool notAfterCompare = (compareOp == VECTOR_COMPARE_NE);
bool notAfterCompare = (compareOp == VECTOR_COMPARE_NE);
bool recursivelyDecRefCountOnSecondChild;
TR::Register *firstReg = cg->evaluate(firstChild);
TR::Register *targetReg = cg->allocateRegister(TR_VRF);
TR::DataType elemType = firstChild->getDataType().getVectorElementType();
TR::ILOpCode opcode = node->getOpCode();
const bool isMasked = opcode.isVectorMasked();

TR_ASSERT_FATAL_WITH_NODE(node, (elemType >= TR::Int8) && (elemType <= TR::Double) , "unrecognized vector type %s", firstChild->getDataType().toString());

Expand All @@ -1859,9 +1924,38 @@ vcmpHelper(TR::Node *node, VectorCompareOps compareOp, bool omitNot, bool *flipC
}
}

if (isMasked)
{
TR::Node *thirdChild = node->getThirdChild();
bool flipMask = false;
TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg);

/* masked compare returns the bitwise logical conjunction of the comparison result and the mask */
if (notAfterCompare)
{
if (flipMask)
{
/* (not a) and (not b) = not (a or b) */
generateTrg1Src2Instruction(cg, TR::InstOpCode::vorr16b, node, targetReg, targetReg, maskReg);
}
else
{
/* a and (not b) = not ((not a) and b) */
generateTrg1Src2Instruction(cg, TR::InstOpCode::vbic16b, node, targetReg, maskReg, targetReg);
notAfterCompare = false;
}
}
else
{
generateTrg1Src2Instruction(cg, flipMask ? TR::InstOpCode::vbic16b : TR::InstOpCode::vand16b, node, targetReg, targetReg, maskReg);
}

cg->decReferenceCount(thirdChild);
}

/*
* If this vector compare node only appears once as a child of masked binary operations,
* the NOT instruction can be omitted because the inlineVectorMaskedBinaryOp generates `bit` instead of `bif`.
* If this vector compare node only appears once as a child of masked operations,
* the NOT instruction can be omitted because the caller method generates `bit` instead of `bif`.
*/
if (!omitNot)
{
Expand Down Expand Up @@ -2616,28 +2710,6 @@ OMR::ARM64::TreeEvaluator::vRegStoreEvaluator(TR::Node *node, TR::CodeGenerator
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

static
VectorCompareOps getVectorCompareOp(TR::VectorOperation op)
{
switch (op)
{
case TR::vcmpeq:
return VECTOR_COMPARE_EQ;
case TR::vcmpne:
return VECTOR_COMPARE_NE;
case TR::vcmpgt:
return VECTOR_COMPARE_GT;
case TR::vcmpge:
return VECTOR_COMPARE_GE;
case TR::vcmplt:
return VECTOR_COMPARE_LT;
case TR::vcmple:
return VECTOR_COMPARE_LE;
default:
return VECTOR_COMPARE_INVALID;
}
}

typedef TR::Register *(*binaryEvaluatorHelper)(TR::Node *node, TR::Register *resReg, TR::Register *lhsRes, TR::Register *rhsReg, TR::CodeGenerator *cg);
/**
* @brief Helper functions for generating instruction sequence for masked binary operations
Expand Down Expand Up @@ -2673,29 +2745,8 @@ inlineVectorMaskedBinaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode
generateTrg1Src2Instruction(cg, op, node, resReg, lhsReg, rhsReg);
}

TR::ILOpCode thirdOp = thirdChild->getOpCode();
bool flipMask = false;
TR::Register *maskReg = NULL;
VectorCompareOps compareOp;
TR::VectorOperation convOp;
if (thirdOp.isVectorOpCode() && thirdOp.isBooleanCompare() && (!thirdOp.isVectorMasked())
&& ((compareOp = getVectorCompareOp(thirdOp.getVectorOperation())) != VECTOR_COMPARE_INVALID)
&& (thirdChild->getReferenceCount() == 1) && (thirdChild->getRegister() == NULL))
{
maskReg = vcmpHelper(thirdChild, compareOp, true, &flipMask, cg);
}
else if (thirdOp.isVectorOpCode() && thirdOp.isConversion() && thirdOp.isMaskResult()
&& (((convOp = thirdOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m))
&& (thirdChild->getReferenceCount() == 1) && (thirdChild->getRegister() == NULL))
{
flipMask = true;
maskReg = (convOp == TR::s2m) ? toMaskConversionHelper<TR::s2m>(thirdChild, true, cg) : toMaskConversionHelper<TR::v2m>(thirdChild, true, cg);
}
else
{
maskReg = cg->evaluate(thirdChild);
}
TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind");
TR::Register *maskReg = evaluateMaskNode(thirdChild, flipMask, cg);

/*
* BIT inserts each bit from the first source if the corresponding bit of the second source is 1.
Expand Down Expand Up @@ -2732,29 +2783,8 @@ inlineVectorMaskedUnaryOp(TR::Node *node, TR::CodeGenerator *cg, TR::InstOpCode:

generateTrg1Src1Instruction(cg, op, node, resReg, srcReg);

TR::ILOpCode secondOp = secondChild->getOpCode();
bool flipMask = false;
TR::Register *maskReg = NULL;
VectorCompareOps compareOp;
TR::VectorOperation convOp;
if (secondOp.isVectorOpCode() && secondOp.isBooleanCompare() && (!secondOp.isVectorMasked())
&& ((compareOp = getVectorCompareOp(secondOp.getVectorOperation())) != VECTOR_COMPARE_INVALID)
&& (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL))
{
maskReg = vcmpHelper(secondChild, compareOp, true, &flipMask, cg);
}
else if (secondOp.isVectorOpCode() && secondOp.isConversion() && secondOp.isMaskResult()
&& (((convOp = secondOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m))
&& (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL))
{
flipMask = true;
maskReg = (convOp == TR::s2m) ? toMaskConversionHelper<TR::s2m>(secondChild, true, cg) : toMaskConversionHelper<TR::v2m>(secondChild, true, cg);
}
else
{
maskReg = cg->evaluate(secondChild);
}
TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind");
TR::Register *maskReg = evaluateMaskNode(secondChild, flipMask, cg);

/*
* BIT inserts each bit from the first source if the corresponding bit of the second source is 1.
Expand Down Expand Up @@ -2857,37 +2887,37 @@ OMR::ARM64::TreeEvaluator::vmandEvaluator(TR::Node *node, TR::CodeGenerator *cg)
TR::Register*
OMR::ARM64::TreeEvaluator::vmcmpeqEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
return vcmpHelper(node, VECTOR_COMPARE_EQ, false, NULL, cg);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmcmpneEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
return vcmpHelper(node, VECTOR_COMPARE_NE, false, NULL, cg);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmcmpgtEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
return vcmpHelper(node, VECTOR_COMPARE_GT, false, NULL, cg);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmcmpgeEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
return vcmpHelper(node, VECTOR_COMPARE_GE, false, NULL, cg);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmcmpltEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
return vcmpHelper(node, VECTOR_COMPARE_LT, false, NULL, cg);
}

TR::Register*
OMR::ARM64::TreeEvaluator::vmcmpleEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
return vcmpHelper(node, VECTOR_COMPARE_LE, false, NULL, cg);
}

TR::Register*
Expand Down Expand Up @@ -2962,29 +2992,8 @@ OMR::ARM64::TreeEvaluator::vmfmaEvaluator(TR::Node *node, TR::CodeGenerator *cg)
}
generateTrg1Src2Instruction(cg, op, node, targetReg, firstReg, secondReg);

TR::ILOpCode fourthOp = fourthChild->getOpCode();
bool flipMask = false;
TR::Register *maskReg = NULL;
VectorCompareOps compareOp;
TR::VectorOperation convOp;
if (fourthOp.isVectorOpCode() && fourthOp.isBooleanCompare() && (!fourthOp.isVectorMasked())
&& ((compareOp = getVectorCompareOp(fourthOp.getVectorOperation())) != VECTOR_COMPARE_INVALID)
&& (fourthChild->getReferenceCount() == 1) && (fourthChild->getRegister() == NULL))
{
maskReg = vcmpHelper(fourthChild, compareOp, true, &flipMask, cg);
}
else if (fourthOp.isVectorOpCode() && fourthOp.isConversion() && fourthOp.isMaskResult()
&& (((convOp = fourthOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m))
&& (fourthChild->getReferenceCount() == 1) && (fourthChild->getRegister() == NULL))
{
flipMask = true;
maskReg = (convOp == TR::s2m) ? toMaskConversionHelper<TR::s2m>(fourthChild, true, cg) : toMaskConversionHelper<TR::v2m>(fourthChild, true, cg);
}
else
{
maskReg = cg->evaluate(fourthChild);
}
TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind");
TR::Register *maskReg = evaluateMaskNode(fourthChild, flipMask, cg);

/*
* BIT inserts each bit from the first source if the corresponding bit of the second source is 1.
Expand Down Expand Up @@ -3224,29 +3233,8 @@ inlineVectorMaskedReductionOp(TR::Node *node, TR::CodeGenerator *cg, TR::DataTyp

TR_ASSERT_FATAL_WITH_NODE(node, sourceReg->getKind() == TR_VRF, "unexpected Register kind");

TR::ILOpCode secondOp = secondChild->getOpCode();
bool flipMask = false;
TR::Register *maskReg = NULL;
VectorCompareOps compareOp;
TR::VectorOperation convOp;
if (secondOp.isVectorOpCode() && secondOp.isBooleanCompare() && (!secondOp.isVectorMasked())
&& ((compareOp = getVectorCompareOp(secondOp.getVectorOperation())) != VECTOR_COMPARE_INVALID)
&& (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL))
{
maskReg = vcmpHelper(secondChild, compareOp, true, &flipMask, cg);
}
else if (secondOp.isVectorOpCode() && secondOp.isConversion() && secondOp.isMaskResult()
&& (((convOp = secondOp.getVectorOperation()) == TR::s2m) || (convOp == TR::v2m))
&& (secondChild->getReferenceCount() == 1) && (secondChild->getRegister() == NULL))
{
flipMask = true;
maskReg = (convOp == TR::s2m) ? toMaskConversionHelper<TR::s2m>(secondChild, true, cg) : toMaskConversionHelper<TR::v2m>(secondChild, true, cg);
}
else
{
maskReg = cg->evaluate(secondChild);
}
TR_ASSERT_FATAL_WITH_NODE(node, maskReg->getKind() == TR_VRF, "unexpected Register kind");
TR::Register *maskReg = evaluateMaskNode(secondChild, flipMask, cg);

TR::Register *tmpReg = cg->allocateRegister(TR_VRF);
/* loads identity vector to tmpReg */
Expand Down

0 comments on commit 8466f36

Please sign in to comment.