Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GlobalISel] Fold G_ICMP if possible #86357

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ std::optional<SmallVector<unsigned>>
ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
std::function<unsigned(APInt)> CB);

std::optional<SmallVector<APInt>>
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
const MachineRegisterInfo &MRI);

/// Test if the given value is known to have exactly one bit set. This differs
/// from computeKnownBits in that it doesn't necessarily determine which bit is
/// set.
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
switch (Opc) {
default:
break;
case TargetOpcode::G_ICMP: {
assert(SrcOps.size() == 3 && "Invalid sources");
assert(DstOps.size() == 1 && "Invalid dsts");
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());

if (std::optional<SmallVector<APInt>> Cst =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point of using APInt in this interface. You're discarding the used result type here, and then forcing it into a 1-bit APInt. Either preserve the original result type, or just return a bool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a type check in MachineIRBuilder::buildConstant to make sure their bits are same. The use of APInt is to unify the ConstantFoldICmp interface to make it support both vector and scalar value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a non-APInt version of buildBuildVectorConstant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it would accept ArrayRef<int64_t> I suppose, which still needs to have an explicit construction of that array because we can't have implicit cast from, say ArrayRef<bool> to ArrayRef<int64_t>. That doesn't look too much different from ArrayRef<APInt>.

ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
SrcOps[2].getReg(), *getMRI())) {
if (SrcTy.isVector())
return buildBuildVectorConstant(DstOps[0], *Cst);
return buildConstant(DstOps[0], Cst->front());
}
break;
}
case TargetOpcode::G_ADD:
case TargetOpcode::G_PTR_ADD:
case TargetOpcode::G_AND:
Expand Down
34 changes: 25 additions & 9 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3768,9 +3768,11 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
}
case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
auto [OldValRes, SuccessRes, Addr, CmpVal, NewVal] = MI.getFirst5Regs();
MIRBuilder.buildAtomicCmpXchg(OldValRes, Addr, CmpVal, NewVal,
Register NewOldValRes = MRI.cloneVirtualRegister(OldValRes);
MIRBuilder.buildAtomicCmpXchg(NewOldValRes, Addr, CmpVal, NewVal,
**MI.memoperands_begin());
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, OldValRes, CmpVal);
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, NewOldValRes, CmpVal);
MIRBuilder.buildCopy(OldValRes, NewOldValRes);
MI.eraseFromParent();
return Legalized;
}
Expand All @@ -3789,8 +3791,12 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
case G_UADDO: {
auto [Res, CarryOut, LHS, RHS] = MI.getFirst4Regs();

MIRBuilder.buildAdd(Res, LHS, RHS);
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, Res, RHS);
Register NewRes = MRI.cloneVirtualRegister(Res);

MIRBuilder.buildAdd(NewRes, LHS, RHS);
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, NewRes, RHS);

MIRBuilder.buildCopy(Res, NewRes);

MI.eraseFromParent();
return Legalized;
Expand All @@ -3800,6 +3806,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
const LLT CondTy = MRI.getType(CarryOut);
const LLT Ty = MRI.getType(Res);

Register NewRes = MRI.cloneVirtualRegister(Res);

// Initial add of the two operands.
auto TmpRes = MIRBuilder.buildAdd(Ty, LHS, RHS);

Expand All @@ -3808,15 +3816,18 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {

// Add the sum and the carry.
auto ZExtCarryIn = MIRBuilder.buildZExt(Ty, CarryIn);
MIRBuilder.buildAdd(Res, TmpRes, ZExtCarryIn);
MIRBuilder.buildAdd(NewRes, TmpRes, ZExtCarryIn);

// Second check for carry. We can only carry if the initial sum is all 1s
// and the carry is set, resulting in a new sum of 0.
auto Zero = MIRBuilder.buildConstant(Ty, 0);
auto ResEqZero = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, Res, Zero);
auto ResEqZero =
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, NewRes, Zero);
auto Carry2 = MIRBuilder.buildAnd(CondTy, ResEqZero, CarryIn);
MIRBuilder.buildOr(CarryOut, Carry, Carry2);

MIRBuilder.buildCopy(Res, NewRes);

MI.eraseFromParent();
return Legalized;
}
Expand Down Expand Up @@ -7671,10 +7682,12 @@ LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
LLT Ty = Dst0Ty;
LLT BoolTy = Dst1Ty;

Register NewDst0 = MRI.cloneVirtualRegister(Dst0);

if (IsAdd)
MIRBuilder.buildAdd(Dst0, LHS, RHS);
MIRBuilder.buildAdd(NewDst0, LHS, RHS);
else
MIRBuilder.buildSub(Dst0, LHS, RHS);
MIRBuilder.buildSub(NewDst0, LHS, RHS);

// TODO: If SADDSAT/SSUBSAT is legal, compare results to detect overflow.

Expand All @@ -7687,12 +7700,15 @@ LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
// (LHS) if and only if the other operand (RHS) is (non-zero) positive,
// otherwise there will be overflow.
auto ResultLowerThanLHS =
MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, Dst0, LHS);
shiltian marked this conversation as resolved.
Show resolved Hide resolved
MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, NewDst0, LHS);
auto ConditionRHS = MIRBuilder.buildICmp(
IsAdd ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGT, BoolTy, RHS, Zero);

MIRBuilder.buildXor(Dst1, ConditionRHS, ResultLowerThanLHS);

MIRBuilder.buildCopy(Dst0, NewDst0);
MI.eraseFromParent();

return Legalized;
}

Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,74 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
return std::nullopt;
}

std::optional<SmallVector<APInt>>
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
const MachineRegisterInfo &MRI) {
LLT Ty = MRI.getType(Op1);
if (Ty != MRI.getType(Op2))
return std::nullopt;

auto TryFoldScalar = [&MRI, Pred](Register LHS,
Register RHS) -> std::optional<APInt> {
auto LHSCst = getIConstantVRegVal(LHS, MRI);
shiltian marked this conversation as resolved.
Show resolved Hide resolved
auto RHSCst = getIConstantVRegVal(RHS, MRI);
if (!LHSCst || !RHSCst)
return std::nullopt;

switch (Pred) {
case CmpInst::Predicate::ICMP_EQ:
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
case CmpInst::Predicate::ICMP_NE:
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
case CmpInst::Predicate::ICMP_UGT:
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
case CmpInst::Predicate::ICMP_UGE:
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
case CmpInst::Predicate::ICMP_ULT:
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
case CmpInst::Predicate::ICMP_ULE:
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
case CmpInst::Predicate::ICMP_SGT:
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
case CmpInst::Predicate::ICMP_SGE:
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
case CmpInst::Predicate::ICMP_SLT:
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
case CmpInst::Predicate::ICMP_SLE:
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
default:
return std::nullopt;
}
};

SmallVector<APInt> FoldedICmps;

if (Ty.isVector()) {
// Try to constant fold each element.
auto *BV1 = getOpcodeDef<GBuildVector>(Op1, MRI);
auto *BV2 = getOpcodeDef<GBuildVector>(Op2, MRI);
if (!BV1 || !BV2)
return std::nullopt;
assert(BV1->getNumSources() == BV2->getNumSources() && "Invalid vectors");
for (unsigned I = 0; I < BV1->getNumSources(); ++I) {
if (auto MaybeFold =
TryFoldScalar(BV1->getSourceReg(I), BV2->getSourceReg(I))) {
FoldedICmps.emplace_back(*MaybeFold);
continue;
}
return std::nullopt;
}
return FoldedICmps;
}

if (auto MaybeCst = TryFoldScalar(Op1, Op2)) {
FoldedICmps.emplace_back(*MaybeCst);
return FoldedICmps;
}

return std::nullopt;
}

bool llvm::isKnownToBeAPowerOfTwo(Register Reg, const MachineRegisterInfo &MRI,
GISelKnownBits *KB) {
std::optional<DefinitionAndSourceRegister> DefSrcReg =
Expand Down
Loading
Loading