Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize FMA codegen base on the overwritten #58196

Merged
merged 40 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ee2c0b6
Optimize FMA codegen base on the overwritten
weilinwa Jul 20, 2021
46d0011
Improve function/var names
weilinwa Aug 27, 2021
cce4bda
Add assertions
weilinwa Aug 27, 2021
b825291
Get use of FMA with TryGetUse
weilinwa Sep 7, 2021
f615e39
Decide FMA form with two conditions, OverwrittenOpNum and isContained
weilinwa Sep 8, 2021
b698036
Fix op reg error in codegen
weilinwa Sep 10, 2021
7d9c0d6
Decide form using lastUse and isContained in no overwritten case
weilinwa Sep 15, 2021
1344d92
Clean up code
weilinwa Sep 18, 2021
029a9b5
Separate default case overwrittenOpNum==0
weilinwa Sep 20, 2021
f2a371f
Apply format patch
weilinwa Sep 29, 2021
9955389
Change variable and function names
weilinwa Oct 1, 2021
7c56653
Update regOptional for op1 and resolve some other comments
weilinwa Oct 5, 2021
1d51caa
Optimize FMA codegen base on the overwritten
weilinwa Jul 20, 2021
091133e
Improve function/var names
weilinwa Aug 27, 2021
9a6ae44
Add assertions
weilinwa Aug 27, 2021
ffcff76
Get use of FMA with TryGetUse
weilinwa Sep 7, 2021
5641f8f
Decide FMA form with two conditions, OverwrittenOpNum and isContained
weilinwa Sep 8, 2021
b7312ac
Fix op reg error in codegen
weilinwa Sep 10, 2021
a325fe3
Decide form using lastUse and isContained in no overwritten case
weilinwa Sep 15, 2021
0f950dd
Clean up code
weilinwa Sep 18, 2021
33a596d
Separate default case overwrittenOpNum==0
weilinwa Sep 20, 2021
5da9368
Apply format patch
weilinwa Sep 29, 2021
c3a9f07
Change variable and function names
weilinwa Oct 1, 2021
9e356aa
Update regOptional for op1 and resolve some other comments
weilinwa Oct 5, 2021
f8159bc
Change var names
weilinwa Oct 13, 2021
18bbe4d
Resolve merge conflicts.
weilinwa Oct 13, 2021
2ca2524
Fix jit format
weilinwa Oct 13, 2021
17bd967
Fix build node error for op1 is regOptional
weilinwa Oct 14, 2021
eed5912
Use targetReg instead of GetResultOpNumForFMA in codegen
weilinwa Oct 28, 2021
43c5034
Update variable names
weilinwa Nov 2, 2021
5ef70a5
Refactor lsra to solve lastUse status changed caused assertion failure
weilinwa Nov 7, 2021
bfa6924
Add check to prioritize contained op in lsra
weilinwa Nov 7, 2021
12f260b
Update for jit format
weilinwa Nov 7, 2021
5ca658e
Simplify code
weilinwa Nov 17, 2021
ec4ef66
Resolve comments
weilinwa Nov 17, 2021
aa93a85
Comment out assert because of lastUse change
weilinwa Nov 19, 2021
c66a018
Fix some copiesUpperBits related errors
weilinwa Nov 22, 2021
ff5a433
Merge branch 'main' into fma_opt
weilinwa Nov 22, 2021
a4657c7
Update src/coreclr/jit/lsraxarch.cpp
weilinwa Nov 30, 2021
75d7a37
Add link to the new issue
weilinwa Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21818,6 +21818,52 @@ uint16_t GenTreeLclVarCommon::GetLclOffs() const
}
}

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
//------------------------------------------------------------------------
// GetResultOpNumForFMA: check if the result is written into one of the operands.
// In the case that none of the operand is overwritten, check if any of them is lastUse.
//
// Return Value:
// The operand number overwritten or lastUse. 0 is the default value, where the result is written into
// a destination that is not one of the source operands and there is no last use op.
//
unsigned GenTreeHWIntrinsic::GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
{
// only FMA intrinsic node should call into this function
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
assert(HWIntrinsicInfo::lookupIsa(gtHWIntrinsicId) == InstructionSet_FMA);
if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR))
{
// For store_lcl_var, check if any op is overwritten

GenTreeLclVarCommon* overwritten = use->AsLclVarCommon();
unsigned overwrittenLclNum = overwritten->GetLclNum();
if (op1->IsLocal() && op1->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
return 1;
}
else if (op2->IsLocal() && op2->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
return 2;
}
else if (op3->IsLocal() && op3->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
return 3;
}
}

// If no overwritten op, check if there is any last use op

if (op1->OperIs(GT_LCL_VAR) && op1->IsLastUse(0))
return 1;
else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0))
return 2;
else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0))
return 3;

return 0;
}
#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS

#ifdef TARGET_ARM
//------------------------------------------------------------------------
// IsOffsetMisaligned: check if the field needs a special handling on arm.
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5278,6 +5278,7 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
{
return (gtFlags & GTF_SIMDASHW_OP) != 0;
}
unsigned GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3);

#if DEBUGGABLE_GENTREE
GenTreeHWIntrinsic() : GenTreeJitIntrinsic()
Expand Down
85 changes: 56 additions & 29 deletions src/coreclr/jit/hwintrinsiccodegenxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2106,7 +2106,9 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
NamedIntrinsic intrinsicId = node->gtHWIntrinsicId;
var_types baseType = node->GetSimdBaseType();
emitAttr attr = emitActualTypeSize(Compiler::getSIMDTypeForSize(node->GetSimdSize()));
instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); // 213 form
instruction _132form = (instruction)(ins - 1);
instruction _231form = (instruction)(ins + 1);
GenTree* op1 = node->gtGetOp1();
regNumber targetReg = node->GetRegNum();

Expand All @@ -2122,44 +2124,71 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
argList = argList->Rest();
GenTree* op3 = argList->Current();

regNumber op1Reg;
regNumber op2Reg;
regNumber op1NodeReg = op1->GetRegNum();
regNumber op2NodeReg = op2->GetRegNum();
regNumber op3NodeReg = op3->GetRegNum();

GenTree* emitOp1 = op1;
GenTree* emitOp2 = op2;
GenTree* emitOp3 = op3;

bool isCommutative = false;
const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);

// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
assert(!copiesUpperBits || !op1->isContained());

if (op2->isContained() || op2->isUsedFromSpillTemp())
if (op1->isContained() || op1->isUsedFromSpillTemp())
{
// 132 form: op1 = (op1 * op3) + [op2]

ins = (instruction)(ins - 1);
op1Reg = op1->GetRegNum();
op2Reg = op3->GetRegNum();
op3 = op2;
if (targetReg == op2NodeReg)
{
std::swap(emitOp1, emitOp2);
// op2 = ([op1] * op2) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
ins = _132form;
std::swap(emitOp2, emitOp3);
}
else
{
// targetReg == op3NodeReg or targetReg == ?
// op3 = ([op1] * op2) + op3
// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
ins = _231form;
std::swap(emitOp1, emitOp3);
}
}
else if (op1->isContained() || op1->isUsedFromSpillTemp())
else if (op2->isContained() || op2->isUsedFromSpillTemp())
{
// 231 form: op3 = (op2 * op3) + [op1]

ins = (instruction)(ins + 1);
op1Reg = op3->GetRegNum();
op2Reg = op2->GetRegNum();
op3 = op1;
if (targetReg == op3NodeReg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be !copiesUpperBits && (targetReg == op3NodeReg)

Otherwise, copiesUpperBits can be true since op1 is not Contained or UsedFromSpillTemp and therefore swapping emitOp1 isn't correct.

{
// op3 = (op1 * [op2]) + op3
// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
ins = _231form;
std::swap(emitOp1, emitOp3);
}
else
{
// targetReg == op1NodeReg or targetReg == ?
// op1 = (op1 * [op2]) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
ins = _132form;
}
std::swap(emitOp2, emitOp3);
}
else
{
// 213 form: op1 = (op2 * op1) + [op3]

op1Reg = op1->GetRegNum();
op2Reg = op2->GetRegNum();

isCommutative = !copiesUpperBits;
// targetReg could be op1NodeReg, op2NodeReg, or not equal to any op
// op1 = (op1 * op2) + [op3] or op2 = (op1 * op2) + [op3]
// ? = (op1 * op2) + [op3] or ? = (op1 * op2) + op3
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
if (targetReg == op2NodeReg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, I think this needs to be if (!copiesUpperBits && (targetReg == op2NodeReg)) for the same reason.

I think we also don't need the below section doing if (!copiesUpperBits && (emitOp2->GetRegNum() == targetReg)) as it will have already been covered up here.

{
// op2 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
std::swap(emitOp1, emitOp2);
}
}

if (isCommutative && (op1Reg != targetReg) && (op2Reg == targetReg))
if (!copiesUpperBits && (emitOp2->GetRegNum() == targetReg))
{
assert(node->isRMWHWIntrinsic(compiler));

Expand All @@ -2170,11 +2199,9 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
// as target. However, for commutative intrinsics, we can just swap the operands
// in order to have "reg2 = reg2 op reg1" which will end up producing the right code.

op2Reg = op1Reg;
op1Reg = targetReg;
std::swap(emitOp1, emitOp2);
}

genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3);
genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, emitOp1->GetRegNum(), emitOp2->GetRegNum(), emitOp3);
genProduceReg(node);
}

Expand Down
57 changes: 35 additions & 22 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6335,40 +6335,53 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
{
if ((intrinsicId >= NI_FMA_MultiplyAdd) && (intrinsicId <= NI_FMA_MultiplySubtractNegatedScalar))
{
bool supportsRegOptional = false;
bool supportsOp1RegOptional = false;
bool supportsOp2RegOptional = false;
bool supportsOp3RegOptional = false;
unsigned resultOpNum = 0;
LIR::Use use;
GenTree* user = nullptr;

if (BlockRange().TryGetUse(node, &use))
{
user = use.User();
}
resultOpNum = node->GetResultOpNumForFMA(user, op1, op2, op3);

// Prioritize Containable op. Check if any one of the op is containable first.
// Set op regOptional only if none of them is containable.

if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional))
// Prefer to make op3 contained,
if (resultOpNum != 3 && IsContainableHWIntrinsicOp(node, op3, &supportsOp3RegOptional))
{
// 213 form: op1 = (op2 * op1) + [op3]
// result = (op1 * op2) + [op3]
MakeSrcContained(node, op3);
}
else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional))
else if (resultOpNum != 2 && IsContainableHWIntrinsicOp(node, op2, &supportsOp2RegOptional))
{
// 132 form: op1 = (op1 * op3) + [op2]
// result = (op1 * [op2]) + op3
MakeSrcContained(node, op2);
}
else if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional))
else if (resultOpNum != 1 && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId) &&
IsContainableHWIntrinsicOp(node, op1, &supportsOp1RegOptional))
{
// Intrinsics with CopyUpperBits semantics cannot have op1 be contained

if (!HWIntrinsicInfo::CopiesUpperBits(intrinsicId))
{
// 231 form: op3 = (op2 * op3) + [op1]
MakeSrcContained(node, op1);
}
// result = ([op1] * op2) + op3
MakeSrcContained(node, op1);
}
else
else if (supportsOp3RegOptional)
{
assert(supportsRegOptional);

// TODO-XArch-CQ: Technically any one of the three operands can
// be reg-optional. With a limitation on op1 where
// it can only be so if CopyUpperBits is off.
// https://github.com/dotnet/runtime/issues/6358

// 213 form: op1 = (op2 * op1) + op3
assert(resultOpNum != 3);
op3->SetRegOptional();
}
else if (supportsOp2RegOptional)
{
assert(resultOpNum != 2);
op2->SetRegOptional();
}
else if (supportsOp1RegOptional)
{
op1->SetRegOptional();
}
}
else
{
Expand Down
96 changes: 70 additions & 26 deletions src/coreclr/jit/lsraxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,48 +2328,92 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree)

const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);

// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
assert(!copiesUpperBits || !op1->isContained());
unsigned resultOpNum = 0;
LIR::Use use;
GenTree* user = nullptr;

if (op2->isContained())
if (LIR::AsRange(blockSequence[curBBSeqNum]).TryGetUse(intrinsicTree, &use))
{
// 132 form: op1 = (op1 * op3) + [op2]
user = use.User();
}
resultOpNum = intrinsicTree->GetResultOpNumForFMA(user, op1, op2, op3);

tgtPrefUse = BuildUse(op1);
unsigned containedOpNum = 0;

srcCount += 1;
srcCount += BuildOperandUses(op2);
srcCount += BuildDelayFreeUses(op3, op1);
// containedOpNum remains 0 when no op is contianed or regOptional
weilinwa marked this conversation as resolved.
Show resolved Hide resolved
if (op1->isContained() || op1->IsRegOptional())
{
containedOpNum = 1;
}
else if (op1->isContained())
else if (op2->isContained() || op2->IsRegOptional())
{
// 231 form: op3 = (op2 * op3) + [op1]

tgtPrefUse = BuildUse(op3);

srcCount += BuildOperandUses(op1);
srcCount += BuildDelayFreeUses(op2, op1);
srcCount += 1;
containedOpNum = 2;
}
else
else if (op3->isContained() || op3->IsRegOptional())
{
// 213 form: op1 = (op2 * op1) + [op3]
containedOpNum = 3;
}

tgtPrefUse = BuildUse(op1);
srcCount += 1;
GenTree* emitOp1 = op1;
GenTree* emitOp2 = op2;
GenTree* emitOp3 = op3;

if (copiesUpperBits)
// Intrinsics with CopyUpperBits semantics must have op1 as target
assert(containedOpNum != 1 || !copiesUpperBits);

if (containedOpNum == 1)
{
// resultOpNum might change between lowering and lsra, comment out assertion for now.
// assert(containedOpNum != resultOpNum);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to uncomment this assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion cannot be uncommented because the last use value could change after lowering step. I left them here for follow up work if necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please create a issue for it and add the link to the issue in the comment here?

// resultOpNum is 3 or 0: op3/? = ([op1] * op2) + op3
std::swap(emitOp1, emitOp3);

if (resultOpNum == 2)
{
srcCount += BuildDelayFreeUses(op2, op1);
// op2 = ([op1] * op2) + op3
std::swap(emitOp2, emitOp3);
}
else
}
else if (containedOpNum == 3)
{
// assert(containedOpNum != resultOpNum);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here?

if (resultOpNum == 2 && !copiesUpperBits)
{
tgtPrefUse2 = BuildUse(op2);
srcCount += 1;
// op2 = (op1 * op2) + [op3]
std::swap(emitOp1, emitOp2);
}
// else: op1/? = (op1 * op2) + [op3]
}
else if (containedOpNum == 2)
{
// assert(containedOpNum != resultOpNum);

srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1);
// op1/? = (op1 * [op2]) + op3
std::swap(emitOp2, emitOp3);
if (resultOpNum == 3 && !copiesUpperBits)
Copy link
Member

@tannergooding tannergooding Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just capturing a comment, I don't think we need to do anything in this PR.

I think the logic around copiesUpperBits could be simplified a bit so we don't need these extra checks everywhere. That is, if copiesUpperBits is true, then resultOpNum doesn't matter if its not 1 so maybe we should be forcing resultOpNum to be 0 in that case (that is if copiesUpperBits == true and resultOpNum != 1, then treat it as 0, because no matter what we do, op1 cannot be swapped or moved about and op2/op3 will be delay free or contained).

{
// op3 = (op1 * [op2]) + op3
std::swap(emitOp1, emitOp2);
}
}
else
{
// containedOpNum == 0
// no extra work when resultOpNum is 0 or 1
if (resultOpNum == 2)
{
std::swap(emitOp1, emitOp2);
}
else if (resultOpNum == 3)
{
std::swap(emitOp1, emitOp3);
}
}
tgtPrefUse = BuildUse(emitOp1);

srcCount += 1;
srcCount += BuildDelayFreeUses(emitOp2, emitOp1);
srcCount += emitOp3->isContained() ? BuildOperandUses(emitOp3) : BuildDelayFreeUses(emitOp3, emitOp1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot smaller and easier to follow now 🎉


buildUses = false;
break;
Expand Down