Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISel] Add ISel supports for SHXADD from Zba extension #67863

Merged
merged 9 commits into from
Oct 18, 2023
109 changes: 109 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -68,6 +69,12 @@ class RISCVInstructionSelector : public InstructionSelector {
ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
ComplexRendererFns selectAddrRegImm(MachineOperand &Root) const;

ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const;
template <unsigned ShAmt>
ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const {
return selectSHXADDOp(Root, ShAmt);
}

// Custom renderers for tablegen
void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI,
int OpIdx) const;
Expand Down Expand Up @@ -122,6 +129,108 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
}

InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
unsigned ShAmt) const {
using namespace llvm::MIPatternMatch;
MachineFunction &MF = *Root.getParent()->getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();

if (!Root.isReg())
return std::nullopt;
Register RootReg = Root.getReg();

const unsigned XLen = STI.getXLen();
APInt Mask, C2;
Register RegY;
std::optional<bool> LeftShift;
// (and (shl y, c2), mask)
if (mi_match(RootReg, MRI,
m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
LeftShift = true;
// (and (lshr y, c2), mask)
else if (mi_match(RootReg, MRI,
m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
LeftShift = false;

if (LeftShift.has_value()) {
if (*LeftShift)
Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
else
Mask &= maskTrailingOnes<uint64_t>(XLen - C2.getLimitedValue());

if (Mask.isShiftedMask()) {
unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();
// Given (and (shl y, c2), mask) in which mask has no leading zeros and
// c3 trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLI, {DstReg}, {RegY})
.addImm(Trailing - C2.getLimitedValue());
MIB.addReg(DstReg);
}}};
}

// Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and
// c3 trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLI, {DstReg}, {RegY})
.addImm(Leading + Trailing);
MIB.addReg(DstReg);
}}};
}
}
}

LeftShift.reset();

// (shl (and y, mask), c2)
if (mi_match(RootReg, MRI,
m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
m_ICst(C2))))
LeftShift = true;
// (lshr (and y, mask), c2)
else if (mi_match(RootReg, MRI,
m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
m_ICst(C2))))
LeftShift = false;

if (LeftShift.has_value() && Mask.isShiftedMask()) {
unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();

// Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
// c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
(Trailing + C2.getLimitedValue()) == ShAmt;
if (!Cond)
// Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
// c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
(Trailing - C2.getLimitedValue()) == ShAmt;

if (Cond) {
Register DstReg = MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
.addImm(Trailing);
MIB.addReg(DstReg);
}}};
}
}

return std::nullopt;
}

InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
// TODO: Need to get the immediate from a G_PTR_ADD. Should this be done in
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/RISCV/RISCVGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def ShiftMaskGI :
GIComplexOperandMatcher<s32, "selectShiftMask">,
GIComplexPatternEquiv<shiftMaskXLen>;

def gi_sh1add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<1>">,
GIComplexPatternEquiv<sh1add_op>;
def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
GIComplexPatternEquiv<sh2add_op>;
def gi_sh3add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<3>">,
GIComplexPatternEquiv<sh3add_op>;

// FIXME: Canonicalize (sub X, C) -> (add X, -C) earlier.
def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
(ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
Expand Down
85 changes: 54 additions & 31 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,39 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
}]>;

// Pattern to exclude simm12 immediates from matching.
// Note: this will be removed once the GISel complex patterns for
// SHXADD_UW is landed.
def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
auto *C = dyn_cast<ConstantSDNode>(N);
return !C || !isInt<12>(C->getSExtValue());
}]>;

// GISel currently doesn't support PatFrag for leaf nodes, so `non_imm12`
// cannot be directly supported in GISel. To reuse patterns between the two
// ISels, we instead create PatFrag on operators that use `non_imm12`.
class binop_with_non_imm12<SDPatternOperator binop>
: PatFrag<(ops node:$x, node:$y), (binop node:$x, node:$y), [{
auto *C = dyn_cast<ConstantSDNode>(Operands[1]);
return !C || !isInt<12>(C->getSExtValue());
}]> {
let PredicateCodeUsesOperands = 1;
let GISelPredicateCode = [{
const MachineOperand &ImmOp = *Operands[1];
const MachineFunction &MF = *MI.getParent()->getParent();
const MachineRegisterInfo &MRI = MF.getRegInfo();

if (ImmOp.isReg() && ImmOp.getReg())
if (auto Val = getIConstantVRegValWithLookThrough(ImmOp.getReg(), MRI)) {
// We do NOT want immediates that fit in 12 bits.
return !isInt<12>(Val->Value.getSExtValue());
}

return true;
}];
}
def add_non_imm12 : binop_with_non_imm12<add>;
def or_is_add_non_imm12 : binop_with_non_imm12<or_is_add>;

def Shifted32OnesMask : PatLeaf<(imm), [{
uint64_t Imm = N->getZExtValue();
if (!isShiftedMask_64(Imm))
Expand Down Expand Up @@ -647,20 +675,17 @@ let Predicates = [HasStdExtZbb, IsRV64] in
def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>;

let Predicates = [HasStdExtZba] in {
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), non_imm12:$rs2),
(SH1ADD GPR:$rs1, GPR:$rs2)>;
def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), non_imm12:$rs2),
(SH2ADD GPR:$rs1, GPR:$rs2)>;
def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), non_imm12:$rs2),
(SH3ADD GPR:$rs1, GPR:$rs2)>;

// More complex cases use a ComplexPattern.
def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2),
(SH1ADD sh1add_op:$rs1, GPR:$rs2)>;
def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
(SH2ADD sh2add_op:$rs1, GPR:$rs2)>;
def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
(SH3ADD sh3add_op:$rs1, GPR:$rs2)>;
foreach i = {1,2,3} in {
defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
def : Pat<(XLenVT (add_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)),
(shxadd GPR:$rs1, GPR:$rs2)>;

defvar pat = !cast<ComplexPattern>("sh"#i#"add_op");
// More complex cases use a ComplexPattern.
def : Pat<(XLenVT (add_non_imm12 pat:$rs1, GPR:$rs2)),
(shxadd pat:$rs1, GPR:$rs2)>;
}

def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
(SH1ADD (SH1ADD GPR:$rs1, GPR:$rs1), GPR:$rs2)>;
Expand Down Expand Up @@ -730,26 +755,24 @@ def : Pat<(i64 (shl (and GPR:$rs1, 0xFFFFFFFF), uimm5:$shamt)),
def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)),
(SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask),
Shifted32OnesMask:$mask)>;

def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFF), GPR:$rs2)),
(ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>;

def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (or_is_add_non_imm12 (and GPR:$rs1, 0xFFFFFFFF), GPR:$rs2)),
(ADD_UW GPR:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), non_imm12:$rs2)),
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), non_imm12:$rs2)),
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), non_imm12:$rs2)),
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
foreach i = {1,2,3} in {
defvar shxadd_uw = !cast<Instruction>("SH"#i#"ADD_UW");
def : Pat<(i64 (add_non_imm12 (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 i)), (XLenVT GPR:$rs2))),
(shxadd_uw GPR:$rs1, GPR:$rs2)>;
}

def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (XLenVT GPR:$rs2))),
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (XLenVT GPR:$rs2))),
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (XLenVT GPR:$rs2))),
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>;

// More complex cases use a ComplexPattern.
Expand All @@ -760,19 +783,19 @@ def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
(SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFE), (XLenVT GPR:$rs2))),
(SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFC), (XLenVT GPR:$rs2))),
(SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFF8), (XLenVT GPR:$rs2))),
(SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>;

// Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x1FFFFFFFE), (XLenVT GPR:$rs2))),
(SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x3FFFFFFFC), (XLenVT GPR:$rs2))),
(SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8), (XLenVT GPR:$rs2))),
(SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>;

def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
Expand Down
Loading