Skip to content

Commit

Permalink
[GlobalISel] Handle div-by-pow2
Browse files Browse the repository at this point in the history
This patch adds similar handling of div-by-pow2 as in `SelectionDAG`.
  • Loading branch information
shiltian committed Feb 28, 2024
1 parent 43b7dfc commit 03cfa65
Show file tree
Hide file tree
Showing 4 changed files with 2,409 additions and 14 deletions.
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,16 @@ class CombinerHelper {
bool matchSDivByConst(MachineInstr &MI);
void applySDivByConst(MachineInstr &MI);

/// Given an G_SDIV \p MI expressing a signed divided by a pow2 constant,
/// return expressions that implements it by shifting.
bool matchSDivByPow2(MachineInstr &MI);
void applySDivByPow2(MachineInstr &MI);

/// Given an G_UDIV \p MI expressing an unsigned divided by a pow2 constant,
/// return expressions that implements it by shifting.
bool matchUDivByPow2(MachineInstr &MI);
void applyUDivByPow2(MachineInstr &MI);

// G_UMULH x, (1 << c)) -> x >> (bitwidth - c)
bool matchUMulHToLShr(MachineInstr &MI);
void applyUMulHToLShr(MachineInstr &MI);
Expand Down
21 changes: 17 additions & 4 deletions llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def combine_extracted_vector_load : GICombineRule<
(match (wip_match_opcode G_EXTRACT_VECTOR_ELT):$root,
[{ return Helper.matchCombineExtractedVectorLoad(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def combine_indexed_load_store : GICombineRule<
(defs root:$root, indexed_load_store_matchdata:$matchinfo),
(match (wip_match_opcode G_LOAD, G_SEXTLOAD, G_ZEXTLOAD, G_STORE):$root,
Expand Down Expand Up @@ -1005,7 +1005,20 @@ def sdiv_by_const : GICombineRule<
[{ return Helper.matchSDivByConst(*${root}); }]),
(apply [{ Helper.applySDivByConst(*${root}); }])>;

def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const]>;
def sdiv_by_pow2 : GICombineRule<
(defs root:$root),
(match (wip_match_opcode G_SDIV):$root,
[{ return Helper.matchSDivByPow2(*${root}); }]),
(apply [{ Helper.applySDivByPow2(*${root}); }])>;

def udiv_by_pow2 : GICombineRule<
(defs root:$root),
(match (wip_match_opcode G_UDIV):$root,
[{ return Helper.matchUDivByPow2(*${root}); }]),
(apply [{ Helper.applyUDivByPow2(*${root}); }])>;

def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const,
sdiv_by_pow2, udiv_by_pow2]>;

def reassoc_ptradd : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
Expand Down Expand Up @@ -1325,7 +1338,7 @@ def constant_fold_binops : GICombineGroup<[constant_fold_binop,

def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
extract_vec_elt_combines, combines_for_extload, combine_extracted_vector_load,
undef_combines, identity_combines, phi_combines,
undef_combines, identity_combines, phi_combines,
simplify_add_to_sub, hoist_logic_op_with_same_opcode_hands, shifts_too_big,
reassocs, ptr_add_immed_chain,
shl_ashr_to_sext_inreg, sext_inreg_of_load,
Expand All @@ -1342,7 +1355,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
intdiv_combines, mulh_combines, redundant_neg_operands,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
combine_concat_vector]>;

// A combine group used to for prelegalizer combiners at -O0. The combines in
Expand Down
102 changes: 101 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,7 @@ void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI,
Observer.changedInstr(*BrCond);
}


bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) {
MachineIRBuilder HelperBuilder(MI);
GISelObserverWrapper DummyObserver;
Expand Down Expand Up @@ -5286,6 +5286,106 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
return MIB.buildMul(Ty, Res, Factor);
}

bool CombinerHelper::matchSDivByPow2(MachineInstr &MI) {
assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
if (MI.getFlag(MachineInstr::MIFlag::IsExact))
return false;
auto &SDiv = cast<GenericMachineInstr>(MI);
Register RHS = SDiv.getReg(2);
auto MatchPow2 = [&](const Constant *C) {
if (auto *CI = dyn_cast<ConstantInt>(C))
return CI->getValue().isPowerOf2() || CI->getValue().isNegatedPowerOf2();
return false;
};
return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs */ false);
}

void CombinerHelper::applySDivByPow2(MachineInstr &MI) {
assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
auto &SDiv = cast<GenericMachineInstr>(MI);
Register Dst = SDiv.getReg(0);
Register LHS = SDiv.getReg(1);
Register RHS = SDiv.getReg(2);
LLT Ty = MRI.getType(Dst);
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);

Builder.setInstrAndDebugLoc(MI);

auto RHSC = getIConstantVRegValWithLookThrough(RHS, MRI);
assert(RHSC.has_value() && "RHS must be a constant");
auto RHSCV = RHSC->Value;
auto Zero = Builder.buildConstant(Ty, 0);

// Special case: (sdiv X, 1) -> X
if (RHSCV.isOne()) {
replaceSingleDefInstWithReg(MI, LHS);
return;
}
// Special Case: (sdiv X, -1) -> 0-X
if (RHSCV.isAllOnes()) {
auto Sub = Builder.buildSub(Ty, Zero, LHS);
replaceSingleDefInstWithReg(MI, Sub->getOperand(0).getReg());
return;
}

unsigned Bitwidth = Ty.getScalarSizeInBits();
unsigned TrailingZeros = RHSCV.countTrailingZeros();
auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
auto Inexact = Builder.buildConstant(ShiftAmtTy, Bitwidth - TrailingZeros);
auto Sign = Builder.buildAShr(
Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
// Add (LHS < 0) ? abs2 - 1 : 0;
auto Srl = Builder.buildShl(Ty, Sign, Inexact);
auto Add = Builder.buildAdd(Ty, LHS, Srl);
auto Sra = Builder.buildAShr(Ty, Add, C1);

// If dividing by a positive value, we're done. Otherwise, the result must
// be negated.
auto Res = RHSCV.isNegative() ? Builder.buildSub(Ty, Zero, Sra) : Sra;
replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
}

bool CombinerHelper::matchUDivByPow2(MachineInstr &MI) {
assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
if (MI.getFlag(MachineInstr::MIFlag::IsExact))
return false;
auto &UDiv = cast<GenericMachineInstr>(MI);
Register RHS = UDiv.getReg(2);
auto MatchPow2 = [&](const Constant *C) {
if (auto *CI = dyn_cast<ConstantInt>(C))
return CI->getValue().isPowerOf2();
return false;
};
return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs */ false);
}

void CombinerHelper::applyUDivByPow2(MachineInstr &MI) {
assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected SDIV");
auto &UDiv = cast<GenericMachineInstr>(MI);
Register Dst = UDiv.getReg(0);
Register LHS = UDiv.getReg(1);
Register RHS = UDiv.getReg(2);
LLT Ty = MRI.getType(Dst);
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);

Builder.setInstrAndDebugLoc(MI);

auto RHSC = getIConstantVRegValWithLookThrough(RHS, MRI);
assert(RHSC.has_value() && "RHS must be a constant");
auto RHSCV = RHSC->Value;

// Special case: (udiv X, 1) -> X
if (RHSCV.isOne()) {
replaceSingleDefInstWithReg(MI, LHS);
return;
}

unsigned TrailingZeros = RHSCV.countTrailingZeros();
auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
auto Res = Builder.buildLShr(Ty, LHS, C1);
replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
}

bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
assert(MI.getOpcode() == TargetOpcode::G_UMULH);
Register RHS = MI.getOperand(2).getReg();
Expand Down
Loading

0 comments on commit 03cfa65

Please sign in to comment.