Skip to content

Commit

Permalink
Optimize FMA codegen base on the overwritten (#58196)
Browse files Browse the repository at this point in the history
* Optimize FMA codegen base on the overwritten

* Improve function/var names

* Add assertions

* Get use of FMA with TryGetUse

* Decide FMA form with two conditions, OverwrittenOpNum and isContained

* Fix op reg error in codegen

* Decide form using lastUse and isContained in no overwritten case

* Clean up code

* Separate default case overwrittenOpNum==0

* Apply format patch

* Change variable and function names

* Update regOptional for op1 and resolve some other comments

* Optimize FMA codegen base on the overwritten

* Improve function/var names

* Add assertions

* Get use of FMA with TryGetUse

* Decide FMA form with two conditions, OverwrittenOpNum and isContained

* Fix op reg error in codegen

* Decide form using lastUse and isContained in no overwritten case

* Clean up code

* Separate default case overwrittenOpNum==0

* Apply format patch

* Change variable and function names

* Update regOptional for op1 and resolve some other comments

* Change var names

* Fix jit format

* Fix build node error for op1 is regOptional

* Use targetReg instead of GetResultOpNumForFMA in codegen

* Update variable names

* Refactor lsra to solve lastUse status changed caused assertion failure

* Add check to prioritize contained op in lsra

* Update for jit format

* Simplify code

* Resolve comments

* Comment out assert because of lastUse change

* Fix some copiesUpperBits related errors

* Update src/coreclr/jit/lsraxarch.cpp

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>

* Add link to the new issue

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
  • Loading branch information
weilinwa and kunalspathak authored Dec 1, 2021
1 parent e9c6c04 commit 42777cc
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 89 deletions.
47 changes: 47 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21898,6 +21898,53 @@ uint16_t GenTreeLclVarCommon::GetLclOffs() const
}
}

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
//------------------------------------------------------------------------
// GetResultOpNumForFMA: check if the result is written into one of the operands.
// In the case that none of the operand is overwritten, check if any of them is lastUse.
//
// Return Value:
// The operand number overwritten or lastUse. 0 is the default value, where the result is written into
// a destination that is not one of the source operands and there is no last use op.
//
unsigned GenTreeHWIntrinsic::GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
{
// only FMA intrinsic node should call into this function
assert(HWIntrinsicInfo::lookupIsa(gtHWIntrinsicId) == InstructionSet_FMA);
if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR))
{
// For store_lcl_var, check if any op is overwritten

GenTreeLclVarCommon* overwritten = use->AsLclVarCommon();
unsigned overwrittenLclNum = overwritten->GetLclNum();
if (op1->IsLocal() && op1->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
return 1;
}
else if (op2->IsLocal() && op2->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
return 2;
}
else if (op3->IsLocal() && op3->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
return 3;
}
}

// If no overwritten op, check if there is any last use op
// https://github.com/dotnet/runtime/issues/62215

if (op1->OperIs(GT_LCL_VAR) && op1->IsLastUse(0))
return 1;
else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0))
return 2;
else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0))
return 3;

return 0;
}
#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS

#ifdef TARGET_ARM
//------------------------------------------------------------------------
// IsOffsetMisaligned: check if the field needs a special handling on arm.
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5526,6 +5526,7 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
{
return (gtFlags & GTF_SIMDASHW_OP) != 0;
}
unsigned GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3);

NamedIntrinsic GetHWIntrinsicId() const;

Expand Down
97 changes: 56 additions & 41 deletions src/coreclr/jit/hwintrinsiccodegenxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2034,67 +2034,82 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
NamedIntrinsic intrinsicId = node->GetHWIntrinsicId();
var_types baseType = node->GetSimdBaseType();
emitAttr attr = emitActualTypeSize(Compiler::getSIMDTypeForSize(node->GetSimdSize()));
instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); // 213 form
instruction _132form = (instruction)(ins - 1);
instruction _231form = (instruction)(ins + 1);
GenTree* op1 = node->Op(1);
GenTree* op2 = node->Op(2);
GenTree* op3 = node->Op(3);
regNumber targetReg = node->GetRegNum();

regNumber targetReg = node->GetRegNum();

genConsumeMultiOpOperands(node);

regNumber op1Reg;
regNumber op2Reg;
regNumber op1NodeReg = op1->GetRegNum();
regNumber op2NodeReg = op2->GetRegNum();
regNumber op3NodeReg = op3->GetRegNum();

GenTree* emitOp1 = op1;
GenTree* emitOp2 = op2;
GenTree* emitOp3 = op3;

bool isCommutative = false;
const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);

// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
assert(!copiesUpperBits || !op1->isContained());

if (op2->isContained() || op2->isUsedFromSpillTemp())
if (op1->isContained() || op1->isUsedFromSpillTemp())
{
// 132 form: op1 = (op1 * op3) + [op2]

ins = (instruction)(ins - 1);
op1Reg = op1->GetRegNum();
op2Reg = op3->GetRegNum();
op3 = op2;
if (targetReg == op2NodeReg)
{
std::swap(emitOp1, emitOp2);
// op2 = ([op1] * op2) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
ins = _132form;
std::swap(emitOp2, emitOp3);
}
else
{
// targetReg == op3NodeReg or targetReg == ?
// op3 = ([op1] * op2) + op3
// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
ins = _231form;
std::swap(emitOp1, emitOp3);
}
}
else if (op1->isContained() || op1->isUsedFromSpillTemp())
else if (op2->isContained() || op2->isUsedFromSpillTemp())
{
// 231 form: op3 = (op2 * op3) + [op1]

ins = (instruction)(ins + 1);
op1Reg = op3->GetRegNum();
op2Reg = op2->GetRegNum();
op3 = op1;
if (!copiesUpperBits && (targetReg == op3NodeReg))
{
// op3 = (op1 * [op2]) + op3
// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
ins = _231form;
std::swap(emitOp1, emitOp3);
}
else
{
// targetReg == op1NodeReg or targetReg == ?
// op1 = (op1 * [op2]) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
ins = _132form;
}
std::swap(emitOp2, emitOp3);
}
else
{
// 213 form: op1 = (op2 * op1) + [op3]

op1Reg = op1->GetRegNum();
op2Reg = op2->GetRegNum();

isCommutative = !copiesUpperBits;
}

if (isCommutative && (op1Reg != targetReg) && (op2Reg == targetReg))
{
assert(node->isRMWHWIntrinsic(compiler));

// We have "reg2 = (reg1 * reg2) +/- op3" where "reg1 != reg2" on a RMW intrinsic.
//
// For non-commutative intrinsics, we should have ensured that op2 was marked
// delay free in order to prevent it from getting assigned the same register
// as target. However, for commutative intrinsics, we can just swap the operands
// in order to have "reg2 = reg2 op reg1" which will end up producing the right code.

op2Reg = op1Reg;
op1Reg = targetReg;
// targetReg could be op1NodeReg, op2NodeReg, or not equal to any op
// op1 = (op1 * op2) + [op3] or op2 = (op1 * op2) + [op3]
// ? = (op1 * op2) + [op3] or ? = (op1 * op2) + op3
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
if (!copiesUpperBits && (targetReg == op2NodeReg))
{
// op2 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
std::swap(emitOp1, emitOp2);
}
}

genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3);
genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, emitOp1->GetRegNum(), emitOp2->GetRegNum(), emitOp3);
genProduceReg(node);
}

Expand Down
57 changes: 35 additions & 22 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6000,40 +6000,53 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
{
if ((intrinsicId >= NI_FMA_MultiplyAdd) && (intrinsicId <= NI_FMA_MultiplySubtractNegatedScalar))
{
bool supportsRegOptional = false;
bool supportsOp1RegOptional = false;
bool supportsOp2RegOptional = false;
bool supportsOp3RegOptional = false;
unsigned resultOpNum = 0;
LIR::Use use;
GenTree* user = nullptr;

if (BlockRange().TryGetUse(node, &use))
{
user = use.User();
}
resultOpNum = node->GetResultOpNumForFMA(user, op1, op2, op3);

// Prioritize Containable op. Check if any one of the op is containable first.
// Set op regOptional only if none of them is containable.

if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional))
// Prefer to make op3 contained,
if (resultOpNum != 3 && IsContainableHWIntrinsicOp(node, op3, &supportsOp3RegOptional))
{
// 213 form: op1 = (op2 * op1) + [op3]
// result = (op1 * op2) + [op3]
MakeSrcContained(node, op3);
}
else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional))
else if (resultOpNum != 2 && IsContainableHWIntrinsicOp(node, op2, &supportsOp2RegOptional))
{
// 132 form: op1 = (op1 * op3) + [op2]
// result = (op1 * [op2]) + op3
MakeSrcContained(node, op2);
}
else if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional))
else if (resultOpNum != 1 && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId) &&
IsContainableHWIntrinsicOp(node, op1, &supportsOp1RegOptional))
{
// Intrinsics with CopyUpperBits semantics cannot have op1 be contained

if (!HWIntrinsicInfo::CopiesUpperBits(intrinsicId))
{
// 231 form: op3 = (op2 * op3) + [op1]
MakeSrcContained(node, op1);
}
// result = ([op1] * op2) + op3
MakeSrcContained(node, op1);
}
else
else if (supportsOp3RegOptional)
{
assert(supportsRegOptional);

// TODO-XArch-CQ: Technically any one of the three operands can
// be reg-optional. With a limitation on op1 where
// it can only be so if CopyUpperBits is off.
// https://github.com/dotnet/runtime/issues/6358

// 213 form: op1 = (op2 * op1) + op3
assert(resultOpNum != 3);
op3->SetRegOptional();
}
else if (supportsOp2RegOptional)
{
assert(resultOpNum != 2);
op2->SetRegOptional();
}
else if (supportsOp1RegOptional)
{
op1->SetRegOptional();
}
}
else
{
Expand Down
97 changes: 71 additions & 26 deletions src/coreclr/jit/lsraxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2272,48 +2272,93 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree)

const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);

// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
assert(!copiesUpperBits || !op1->isContained());
unsigned resultOpNum = 0;
LIR::Use use;
GenTree* user = nullptr;

if (op2->isContained())
if (LIR::AsRange(blockSequence[curBBSeqNum]).TryGetUse(intrinsicTree, &use))
{
// 132 form: op1 = (op1 * op3) + [op2]
user = use.User();
}
resultOpNum = intrinsicTree->GetResultOpNumForFMA(user, op1, op2, op3);

tgtPrefUse = BuildUse(op1);
unsigned containedOpNum = 0;

srcCount += 1;
srcCount += BuildOperandUses(op2);
srcCount += BuildDelayFreeUses(op3, op1);
// containedOpNum remains 0 when no operand is contained or regOptional
if (op1->isContained() || op1->IsRegOptional())
{
containedOpNum = 1;
}
else if (op1->isContained())
else if (op2->isContained() || op2->IsRegOptional())
{
// 231 form: op3 = (op2 * op3) + [op1]

tgtPrefUse = BuildUse(op3);

srcCount += BuildOperandUses(op1);
srcCount += BuildDelayFreeUses(op2, op1);
srcCount += 1;
containedOpNum = 2;
}
else
else if (op3->isContained() || op3->IsRegOptional())
{
// 213 form: op1 = (op2 * op1) + [op3]
containedOpNum = 3;
}

tgtPrefUse = BuildUse(op1);
srcCount += 1;
GenTree* emitOp1 = op1;
GenTree* emitOp2 = op2;
GenTree* emitOp3 = op3;

if (copiesUpperBits)
// Intrinsics with CopyUpperBits semantics must have op1 as target
assert(containedOpNum != 1 || !copiesUpperBits);

if (containedOpNum == 1)
{
// https://github.com/dotnet/runtime/issues/62215
// resultOpNum might change between lowering and lsra, comment out assertion for now.
// assert(containedOpNum != resultOpNum);
// resultOpNum is 3 or 0: op3/? = ([op1] * op2) + op3
std::swap(emitOp1, emitOp3);

if (resultOpNum == 2)
{
srcCount += BuildDelayFreeUses(op2, op1);
// op2 = ([op1] * op2) + op3
std::swap(emitOp2, emitOp3);
}
else
}
else if (containedOpNum == 3)
{
// assert(containedOpNum != resultOpNum);
if (resultOpNum == 2 && !copiesUpperBits)
{
tgtPrefUse2 = BuildUse(op2);
srcCount += 1;
// op2 = (op1 * op2) + [op3]
std::swap(emitOp1, emitOp2);
}
// else: op1/? = (op1 * op2) + [op3]
}
else if (containedOpNum == 2)
{
// assert(containedOpNum != resultOpNum);

srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1);
// op1/? = (op1 * [op2]) + op3
std::swap(emitOp2, emitOp3);
if (resultOpNum == 3 && !copiesUpperBits)
{
// op3 = (op1 * [op2]) + op3
std::swap(emitOp1, emitOp2);
}
}
else
{
// containedOpNum == 0
// no extra work when resultOpNum is 0 or 1
if (resultOpNum == 2)
{
std::swap(emitOp1, emitOp2);
}
else if (resultOpNum == 3)
{
std::swap(emitOp1, emitOp3);
}
}
tgtPrefUse = BuildUse(emitOp1);

srcCount += 1;
srcCount += BuildDelayFreeUses(emitOp2, emitOp1);
srcCount += emitOp3->isContained() ? BuildOperandUses(emitOp3) : BuildDelayFreeUses(emitOp3, emitOp1);

buildUses = false;
break;
Expand Down

0 comments on commit 42777cc

Please sign in to comment.