From 42777ccc1fde6e452afd290cfc64eac12e50de05 Mon Sep 17 00:00:00 2001 From: weilinwa Date: Wed, 1 Dec 2021 06:47:13 -0800 Subject: [PATCH] Optimize FMA codegen base on the overwritten (#58196) * Optimize FMA codegen base on the overwritten * Improve function/var names * Add assertions * Get use of FMA with TryGetUse * Decide FMA form with two conditions, OverwrittenOpNum and isContained * Fix op reg error in codegen * Decide form using lastUse and isContained in no overwritten case * Clean up code * Separate default case overwrittenOpNum==0 * Apply format patch * Change variable and function names * Update regOptional for op1 and resolve some other comments * Optimize FMA codegen base on the overwritten * Improve function/var names * Add assertions * Get use of FMA with TryGetUse * Decide FMA form with two conditions, OverwrittenOpNum and isContained * Fix op reg error in codegen * Decide form using lastUse and isContained in no overwritten case * Clean up code * Separate default case overwrittenOpNum==0 * Apply format patch * Change variable and function names * Update regOptional for op1 and resolve some other comments * Change var names * Fix jit format * Fix build node error for op1 is regOptional * Use targetReg instead of GetResultOpNumForFMA in codegen * Update variable names * Refactor lsra to solve lastUse status changed caused assertion failure * Add check to prioritize contained op in lsra * Update for jit format * Simplify code * Resolve comments * Comment out assert because of lastUse change * Fix some copiesUpperBits related errors * Update src/coreclr/jit/lsraxarch.cpp Co-authored-by: Kunal Pathak * Add link to the new issue Co-authored-by: Kunal Pathak --- src/coreclr/jit/gentree.cpp | 47 ++++++++++ src/coreclr/jit/gentree.h | 1 + src/coreclr/jit/hwintrinsiccodegenxarch.cpp | 97 ++++++++++++--------- src/coreclr/jit/lowerxarch.cpp | 57 +++++++----- src/coreclr/jit/lsraxarch.cpp | 97 +++++++++++++++------ 5 files changed, 210 insertions(+), 89 deletions(-) diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index 09e596a774b6e..00c58d70bae1d 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -21898,6 +21898,53 @@ uint16_t GenTreeLclVarCommon::GetLclOffs() const } } +#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS) +//------------------------------------------------------------------------ +// GetResultOpNumForFMA: check if the result is written into one of the operands. +// In the case that none of the operand is overwritten, check if any of them is lastUse. +// +// Return Value: +// The operand number overwritten or lastUse. 0 is the default value, where the result is written into +// a destination that is not one of the source operands and there is no last use op. +// +unsigned GenTreeHWIntrinsic::GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3) +{ + // only FMA intrinsic node should call into this function + assert(HWIntrinsicInfo::lookupIsa(gtHWIntrinsicId) == InstructionSet_FMA); + if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR)) + { + // For store_lcl_var, check if any op is overwritten + + GenTreeLclVarCommon* overwritten = use->AsLclVarCommon(); + unsigned overwrittenLclNum = overwritten->GetLclNum(); + if (op1->IsLocal() && op1->AsLclVarCommon()->GetLclNum() == overwrittenLclNum) + { + return 1; + } + else if (op2->IsLocal() && op2->AsLclVarCommon()->GetLclNum() == overwrittenLclNum) + { + return 2; + } + else if (op3->IsLocal() && op3->AsLclVarCommon()->GetLclNum() == overwrittenLclNum) + { + return 3; + } + } + + // If no overwritten op, check if there is any last use op + // https://github.com/dotnet/runtime/issues/62215 + + if (op1->OperIs(GT_LCL_VAR) && op1->IsLastUse(0)) + return 1; + else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0)) + return 2; + else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0)) + return 3; + + return 0; +} +#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS + #ifdef TARGET_ARM //------------------------------------------------------------------------ // IsOffsetMisaligned: check if the field needs a special handling on arm. diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 553ed29837c83..1a0865073982d 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -5526,6 +5526,7 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic { return (gtFlags & GTF_SIMDASHW_OP) != 0; } + unsigned GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3); NamedIntrinsic GetHWIntrinsicId() const; diff --git a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp index 2e6f018ddaf6c..bb6a6daa815d8 100644 --- a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp +++ b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp @@ -2034,67 +2034,82 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node) NamedIntrinsic intrinsicId = node->GetHWIntrinsicId(); var_types baseType = node->GetSimdBaseType(); emitAttr attr = emitActualTypeSize(Compiler::getSIMDTypeForSize(node->GetSimdSize())); - instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); + instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); // 213 form + instruction _132form = (instruction)(ins - 1); + instruction _231form = (instruction)(ins + 1); GenTree* op1 = node->Op(1); GenTree* op2 = node->Op(2); GenTree* op3 = node->Op(3); - regNumber targetReg = node->GetRegNum(); + + regNumber targetReg = node->GetRegNum(); genConsumeMultiOpOperands(node); - regNumber op1Reg; - regNumber op2Reg; + regNumber op1NodeReg = op1->GetRegNum(); + regNumber op2NodeReg = op2->GetRegNum(); + regNumber op3NodeReg = op3->GetRegNum(); + + GenTree* emitOp1 = op1; + GenTree* emitOp2 = op2; + GenTree* emitOp3 = op3; - bool isCommutative = false; const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId); // Intrinsics with CopyUpperBits semantics cannot have op1 be contained assert(!copiesUpperBits || !op1->isContained()); - if (op2->isContained() || op2->isUsedFromSpillTemp()) + if (op1->isContained() || op1->isUsedFromSpillTemp()) { - // 132 form: op1 = (op1 * op3) + [op2] - - ins = (instruction)(ins - 1); - op1Reg = op1->GetRegNum(); - op2Reg = op3->GetRegNum(); - op3 = op2; + if (targetReg == op2NodeReg) + { + std::swap(emitOp1, emitOp2); + // op2 = ([op1] * op2) + op3 + // 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2 + ins = _132form; + std::swap(emitOp2, emitOp3); + } + else + { + // targetReg == op3NodeReg or targetReg == ? + // op3 = ([op1] * op2) + op3 + // 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1 + ins = _231form; + std::swap(emitOp1, emitOp3); + } } - else if (op1->isContained() || op1->isUsedFromSpillTemp()) + else if (op2->isContained() || op2->isUsedFromSpillTemp()) { - // 231 form: op3 = (op2 * op3) + [op1] - - ins = (instruction)(ins + 1); - op1Reg = op3->GetRegNum(); - op2Reg = op2->GetRegNum(); - op3 = op1; + if (!copiesUpperBits && (targetReg == op3NodeReg)) + { + // op3 = (op1 * [op2]) + op3 + // 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1 + ins = _231form; + std::swap(emitOp1, emitOp3); + } + else + { + // targetReg == op1NodeReg or targetReg == ? + // op1 = (op1 * [op2]) + op3 + // 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2 + ins = _132form; + } + std::swap(emitOp2, emitOp3); } else { - // 213 form: op1 = (op2 * op1) + [op3] - - op1Reg = op1->GetRegNum(); - op2Reg = op2->GetRegNum(); - - isCommutative = !copiesUpperBits; - } - - if (isCommutative && (op1Reg != targetReg) && (op2Reg == targetReg)) - { - assert(node->isRMWHWIntrinsic(compiler)); - - // We have "reg2 = (reg1 * reg2) +/- op3" where "reg1 != reg2" on a RMW intrinsic. - // - // For non-commutative intrinsics, we should have ensured that op2 was marked - // delay free in order to prevent it from getting assigned the same register - // as target. However, for commutative intrinsics, we can just swap the operands - // in order to have "reg2 = reg2 op reg1" which will end up producing the right code. - - op2Reg = op1Reg; - op1Reg = targetReg; + // targetReg could be op1NodeReg, op2NodeReg, or not equal to any op + // op1 = (op1 * op2) + [op3] or op2 = (op1 * op2) + [op3] + // ? = (op1 * op2) + [op3] or ? = (op1 * op2) + op3 + // 213 form: XMM1 = (XMM2 * XMM1) + [XMM3] + if (!copiesUpperBits && (targetReg == op2NodeReg)) + { + // op2 = (op1 * op2) + [op3] + // 213 form: XMM1 = (XMM2 * XMM1) + [XMM3] + std::swap(emitOp1, emitOp2); + } } - genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3); + genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, emitOp1->GetRegNum(), emitOp2->GetRegNum(), emitOp3); genProduceReg(node); } diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index c1b55e992af1f..548a6033d872b 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -6000,40 +6000,53 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node) { if ((intrinsicId >= NI_FMA_MultiplyAdd) && (intrinsicId <= NI_FMA_MultiplySubtractNegatedScalar)) { - bool supportsRegOptional = false; + bool supportsOp1RegOptional = false; + bool supportsOp2RegOptional = false; + bool supportsOp3RegOptional = false; + unsigned resultOpNum = 0; + LIR::Use use; + GenTree* user = nullptr; + + if (BlockRange().TryGetUse(node, &use)) + { + user = use.User(); + } + resultOpNum = node->GetResultOpNumForFMA(user, op1, op2, op3); + + // Prioritize Containable op. Check if any one of the op is containable first. + // Set op regOptional only if none of them is containable. - if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional)) + // Prefer to make op3 contained, + if (resultOpNum != 3 && IsContainableHWIntrinsicOp(node, op3, &supportsOp3RegOptional)) { - // 213 form: op1 = (op2 * op1) + [op3] + // result = (op1 * op2) + [op3] MakeSrcContained(node, op3); } - else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional)) + else if (resultOpNum != 2 && IsContainableHWIntrinsicOp(node, op2, &supportsOp2RegOptional)) { - // 132 form: op1 = (op1 * op3) + [op2] + // result = (op1 * [op2]) + op3 MakeSrcContained(node, op2); } - else if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional)) + else if (resultOpNum != 1 && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId) && + IsContainableHWIntrinsicOp(node, op1, &supportsOp1RegOptional)) { - // Intrinsics with CopyUpperBits semantics cannot have op1 be contained - - if (!HWIntrinsicInfo::CopiesUpperBits(intrinsicId)) - { - // 231 form: op3 = (op2 * op3) + [op1] - MakeSrcContained(node, op1); - } + // result = ([op1] * op2) + op3 + MakeSrcContained(node, op1); } - else + else if (supportsOp3RegOptional) { - assert(supportsRegOptional); - - // TODO-XArch-CQ: Technically any one of the three operands can - // be reg-optional. With a limitation on op1 where - // it can only be so if CopyUpperBits is off. - // https://github.com/dotnet/runtime/issues/6358 - - // 213 form: op1 = (op2 * op1) + op3 + assert(resultOpNum != 3); op3->SetRegOptional(); } + else if (supportsOp2RegOptional) + { + assert(resultOpNum != 2); + op2->SetRegOptional(); + } + else if (supportsOp1RegOptional) + { + op1->SetRegOptional(); + } } else { diff --git a/src/coreclr/jit/lsraxarch.cpp b/src/coreclr/jit/lsraxarch.cpp index 2926f54ce1b1b..ca169600e83f8 100644 --- a/src/coreclr/jit/lsraxarch.cpp +++ b/src/coreclr/jit/lsraxarch.cpp @@ -2272,48 +2272,93 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree) const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId); - // Intrinsics with CopyUpperBits semantics cannot have op1 be contained - assert(!copiesUpperBits || !op1->isContained()); + unsigned resultOpNum = 0; + LIR::Use use; + GenTree* user = nullptr; - if (op2->isContained()) + if (LIR::AsRange(blockSequence[curBBSeqNum]).TryGetUse(intrinsicTree, &use)) { - // 132 form: op1 = (op1 * op3) + [op2] + user = use.User(); + } + resultOpNum = intrinsicTree->GetResultOpNumForFMA(user, op1, op2, op3); - tgtPrefUse = BuildUse(op1); + unsigned containedOpNum = 0; - srcCount += 1; - srcCount += BuildOperandUses(op2); - srcCount += BuildDelayFreeUses(op3, op1); + // containedOpNum remains 0 when no operand is contained or regOptional + if (op1->isContained() || op1->IsRegOptional()) + { + containedOpNum = 1; } - else if (op1->isContained()) + else if (op2->isContained() || op2->IsRegOptional()) { - // 231 form: op3 = (op2 * op3) + [op1] - - tgtPrefUse = BuildUse(op3); - - srcCount += BuildOperandUses(op1); - srcCount += BuildDelayFreeUses(op2, op1); - srcCount += 1; + containedOpNum = 2; } - else + else if (op3->isContained() || op3->IsRegOptional()) { - // 213 form: op1 = (op2 * op1) + [op3] + containedOpNum = 3; + } - tgtPrefUse = BuildUse(op1); - srcCount += 1; + GenTree* emitOp1 = op1; + GenTree* emitOp2 = op2; + GenTree* emitOp3 = op3; - if (copiesUpperBits) + // Intrinsics with CopyUpperBits semantics must have op1 as target + assert(containedOpNum != 1 || !copiesUpperBits); + + if (containedOpNum == 1) + { + // https://github.com/dotnet/runtime/issues/62215 + // resultOpNum might change between lowering and lsra, comment out assertion for now. + // assert(containedOpNum != resultOpNum); + // resultOpNum is 3 or 0: op3/? = ([op1] * op2) + op3 + std::swap(emitOp1, emitOp3); + + if (resultOpNum == 2) { - srcCount += BuildDelayFreeUses(op2, op1); + // op2 = ([op1] * op2) + op3 + std::swap(emitOp2, emitOp3); } - else + } + else if (containedOpNum == 3) + { + // assert(containedOpNum != resultOpNum); + if (resultOpNum == 2 && !copiesUpperBits) { - tgtPrefUse2 = BuildUse(op2); - srcCount += 1; + // op2 = (op1 * op2) + [op3] + std::swap(emitOp1, emitOp2); } + // else: op1/? = (op1 * op2) + [op3] + } + else if (containedOpNum == 2) + { + // assert(containedOpNum != resultOpNum); - srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1); + // op1/? = (op1 * [op2]) + op3 + std::swap(emitOp2, emitOp3); + if (resultOpNum == 3 && !copiesUpperBits) + { + // op3 = (op1 * [op2]) + op3 + std::swap(emitOp1, emitOp2); + } } + else + { + // containedOpNum == 0 + // no extra work when resultOpNum is 0 or 1 + if (resultOpNum == 2) + { + std::swap(emitOp1, emitOp2); + } + else if (resultOpNum == 3) + { + std::swap(emitOp1, emitOp3); + } + } + tgtPrefUse = BuildUse(emitOp1); + + srcCount += 1; + srcCount += BuildDelayFreeUses(emitOp2, emitOp1); + srcCount += emitOp3->isContained() ? BuildOperandUses(emitOp3) : BuildDelayFreeUses(emitOp3, emitOp1); buildUses = false; break;