Skip to content

Commit 299d690

Browse files
committed
[RISCV] Custom-legalise 32-bit variable shifts on RV64
The previous DAG combiner-based approach had an issue with infinite loops between the target-dependent and target-independent combiner logic (see PR40333). Although this was worked around in rL351806, the combiner-based approach is still potentially brittle and can fail to select the 32-bit shift variant when profitable to do so, as demonstrated in the pr40333.ll test case. This patch instead introduces target-specific SelectionDAG nodes for SHLW/SRLW/SRAW and custom-lowers variable i32 shifts to them. pr40333.ll is a good example of how this approach can improve codegen. This adds DAG combine that does SimplifyDemandedBits on the operands (only lower 32-bits of first operand and lower 5 bits of second operand are read). This seems better than implementing SimplifyDemandedBitsForTargetNode as there is no guarantee that would be called (and it's not for e.g. the anyext return test cases). Also implements ComputeNumSignBitsForTargetNode. There are codegen changes in atomic-rmw.ll and atomic-cmpxchg.ll but the new instruction sequences are semantically equivalent. Differential Revision: https://reviews.llvm.org/D57085 llvm-svn: 352169
1 parent 3b9a82f commit 299d690

File tree

6 files changed

+366
-335
lines changed

6 files changed

+366
-335
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+86-32
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
8080
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
8181

8282
if (Subtarget.is64Bit()) {
83-
setTargetDAGCombine(ISD::SHL);
84-
setTargetDAGCombine(ISD::SRL);
85-
setTargetDAGCombine(ISD::SRA);
8683
setTargetDAGCombine(ISD::ANY_EXTEND);
84+
setOperationAction(ISD::SHL, MVT::i32, Custom);
85+
setOperationAction(ISD::SRA, MVT::i32, Custom);
86+
setOperationAction(ISD::SRL, MVT::i32, Custom);
8787
}
8888

8989
if (!Subtarget.hasStdExtM()) {
@@ -512,15 +512,52 @@ SDValue RISCVTargetLowering::lowerRETURNADDR(SDValue Op,
512512
return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, XLenVT);
513513
}
514514

515-
// Return true if the given node is a shift with a non-constant shift amount.
516-
static bool isVariableShift(SDValue Val) {
517-
switch (Val.getOpcode()) {
515+
// Returns the opcode of the target-specific SDNode that implements the 32-bit
516+
// form of the given Opcode.
517+
static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
518+
switch (Opcode) {
518519
default:
519-
return false;
520+
llvm_unreachable("Unexpected opcode");
520521
case ISD::SHL:
522+
return RISCVISD::SLLW;
521523
case ISD::SRA:
524+
return RISCVISD::SRAW;
522525
case ISD::SRL:
523-
return Val.getOperand(1).getOpcode() != ISD::Constant;
526+
return RISCVISD::SRLW;
527+
}
528+
}
529+
530+
// Converts the given 32-bit operation to a target-specific SelectionDAG node.
531+
// Because i32 isn't a legal type for RV64, these operations would otherwise
532+
// be promoted to i64, making it difficult to select the SLLW/DIVUW/.../*W
533+
// later one because the fact the operation was originally of type i32 is
534+
// lost.
535+
static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG) {
536+
SDLoc DL(N);
537+
RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
538+
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
539+
SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
540+
SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1);
541+
// ReplaceNodeResults requires we maintain the same type for the return value.
542+
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes);
543+
}
544+
545+
void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
546+
SmallVectorImpl<SDValue> &Results,
547+
SelectionDAG &DAG) const {
548+
SDLoc DL(N);
549+
switch (N->getOpcode()) {
550+
default:
551+
llvm_unreachable("Don't know how to custom type legalize this operation!");
552+
case ISD::SHL:
553+
case ISD::SRA:
554+
case ISD::SRL:
555+
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
556+
"Unexpected custom legalisation");
557+
if (N->getOperand(1).getOpcode() == ISD::Constant)
558+
return;
559+
Results.push_back(customLegalizeToWOp(N, DAG));
560+
break;
524561
}
525562
}
526563

@@ -545,34 +582,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
545582
switch (N->getOpcode()) {
546583
default:
547584
break;
548-
case ISD::SHL:
549-
case ISD::SRL:
550-
case ISD::SRA: {
551-
assert(Subtarget.getXLen() == 64 && "Combine should be 64-bit only");
552-
if (!DCI.isBeforeLegalize())
553-
break;
554-
SDValue RHS = N->getOperand(1);
555-
if (N->getValueType(0) != MVT::i32 || RHS->getOpcode() == ISD::Constant ||
556-
(RHS->getOpcode() == ISD::AssertZext &&
557-
cast<VTSDNode>(RHS->getOperand(1))->getVT().getSizeInBits() <= 5))
558-
break;
559-
SDValue LHS = N->getOperand(0);
560-
SDLoc DL(N);
561-
SDValue NewRHS =
562-
DAG.getNode(ISD::AssertZext, DL, RHS.getValueType(), RHS,
563-
DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), 5)));
564-
return DCI.CombineTo(
565-
N, DAG.getNode(N->getOpcode(), DL, LHS.getValueType(), LHS, NewRHS));
566-
}
567585
case ISD::ANY_EXTEND: {
568-
// If any-extending an i32 variable-length shift or sdiv/udiv/urem to i64,
569-
// then instead sign-extend in order to increase the chance of being able
570-
// to select the sllw/srlw/sraw/divw/divuw/remuw instructions.
586+
// If any-extending an i32 sdiv/udiv/urem to i64, then instead sign-extend
587+
// in order to increase the chance of being able to select the
588+
// divw/divuw/remuw instructions.
571589
SDValue Src = N->getOperand(0);
572590
if (N->getValueType(0) != MVT::i64 || Src.getValueType() != MVT::i32)
573591
break;
574-
if (!isVariableShift(Src) &&
575-
!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
592+
if (!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
576593
break;
577594
SDLoc DL(N);
578595
// Don't add the new node to the DAGCombiner worklist, in order to avoid
@@ -589,11 +606,42 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
589606
break;
590607
return DCI.CombineTo(N, Op0.getOperand(0), Op0.getOperand(1));
591608
}
609+
case RISCVISD::SLLW:
610+
case RISCVISD::SRAW:
611+
case RISCVISD::SRLW: {
612+
// Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
613+
SDValue LHS = N->getOperand(0);
614+
SDValue RHS = N->getOperand(1);
615+
APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
616+
APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
617+
if ((SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI)) ||
618+
(SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)))
619+
return SDValue();
620+
break;
621+
}
592622
}
593623

594624
return SDValue();
595625
}
596626

627+
unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
628+
SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
629+
unsigned Depth) const {
630+
switch (Op.getOpcode()) {
631+
default:
632+
break;
633+
case RISCVISD::SLLW:
634+
case RISCVISD::SRAW:
635+
case RISCVISD::SRLW:
636+
// TODO: As the result is sign-extended, this is conservatively correct. A
637+
// more precise answer could be calculated for SRAW depending on known
638+
// bits in the shift amount.
639+
return 33;
640+
}
641+
642+
return 1;
643+
}
644+
597645
static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
598646
MachineBasicBlock *BB) {
599647
assert(MI.getOpcode() == RISCV::SplitF64Pseudo && "Unexpected instruction");
@@ -1682,6 +1730,12 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
16821730
return "RISCVISD::SplitF64";
16831731
case RISCVISD::TAIL:
16841732
return "RISCVISD::TAIL";
1733+
case RISCVISD::SLLW:
1734+
return "RISCVISD::SLLW";
1735+
case RISCVISD::SRAW:
1736+
return "RISCVISD::SRAW";
1737+
case RISCVISD::SRLW:
1738+
return "RISCVISD::SRLW";
16851739
}
16861740
return nullptr;
16871741
}

llvm/lib/Target/RISCV/RISCVISelLowering.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ enum NodeType : unsigned {
3131
SELECT_CC,
3232
BuildPairF64,
3333
SplitF64,
34-
TAIL
34+
TAIL,
35+
// RV64I shifts, directly matching the semantics of the named RISC-V
36+
// instructions.
37+
SLLW,
38+
SRAW,
39+
SRLW
3540
};
3641
}
3742

@@ -57,9 +62,16 @@ class RISCVTargetLowering : public TargetLowering {
5762

5863
// Provide custom lowering hooks for some operations.
5964
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
65+
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
66+
SelectionDAG &DAG) const override;
6067

6168
SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
6269

70+
unsigned ComputeNumSignBitsForTargetNode(SDValue Op,
71+
const APInt &DemandedElts,
72+
const SelectionDAG &DAG,
73+
unsigned Depth) const override;
74+
6375
// This method returns the name of a target specific DAG node.
6476
const char *getTargetNodeName(unsigned Opcode) const override;
6577

llvm/lib/Target/RISCV/RISCVInstrInfo.td

+6-34
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def riscv_selectcc : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC,
5151
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
5252
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
5353
SDNPVariadic]>;
54+
def riscv_sllw : SDNode<"RISCVISD::SLLW", SDTIntShiftOp>;
55+
def riscv_sraw : SDNode<"RISCVISD::SRAW", SDTIntShiftOp>;
56+
def riscv_srlw : SDNode<"RISCVISD::SRLW", SDTIntShiftOp>;
5457

5558
//===----------------------------------------------------------------------===//
5659
// Operand and SDNode transformation definitions.
@@ -672,21 +675,9 @@ def sexti32 : PatFrags<(ops node:$src),
672675
def assertzexti32 : PatFrag<(ops node:$src), (assertzext node:$src), [{
673676
return cast<VTSDNode>(N->getOperand(1))->getVT() == MVT::i32;
674677
}]>;
675-
def assertzexti5 : PatFrag<(ops node:$src), (assertzext node:$src), [{
676-
return cast<VTSDNode>(N->getOperand(1))->getVT().getSizeInBits() <= 5;
677-
}]>;
678678
def zexti32 : PatFrags<(ops node:$src),
679679
[(and node:$src, 0xffffffff),
680680
(assertzexti32 node:$src)]>;
681-
// Defines a legal mask for (assertzexti5 (and src, mask)) to be combinable
682-
// with a shiftw operation. The mask mustn't modify the lower 5 bits or the
683-
// upper 32 bits.
684-
def shiftwamt_mask : ImmLeaf<XLenVT, [{
685-
return countTrailingOnes<uint64_t>(Imm) >= 5 && isUInt<32>(Imm);
686-
}]>;
687-
def shiftwamt : PatFrags<(ops node:$src),
688-
[(assertzexti5 (and node:$src, shiftwamt_mask)),
689-
(assertzexti5 node:$src)]>;
690681

691682
/// Immediates
692683

@@ -946,28 +937,9 @@ def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32),
946937
def : Pat<(sra (sext_inreg GPR:$rs1, i32), uimm5:$shamt),
947938
(SRAIW GPR:$rs1, uimm5:$shamt)>;
948939

949-
// For variable-length shifts, we rely on assertzexti5 being inserted during
950-
// lowering (see RISCVTargetLowering::PerformDAGCombine). This enables us to
951-
// guarantee that selecting a 32-bit variable shift is legal (as the variable
952-
// shift is known to be <= 32). We must also be careful not to create
953-
// semantically incorrect patterns. For instance, selecting SRLW for
954-
// (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2)),
955-
// is not guaranteed to be safe, as we don't know whether the upper 32-bits of
956-
// the result are used or not (in the case where rs2=0, this is a
957-
// sign-extension operation).
958-
959-
def : Pat<(sext_inreg (shl GPR:$rs1, (shiftwamt GPR:$rs2)), i32),
960-
(SLLW GPR:$rs1, GPR:$rs2)>;
961-
def : Pat<(zexti32 (shl GPR:$rs1, (shiftwamt GPR:$rs2))),
962-
(SRLI (SLLI (SLLW GPR:$rs1, GPR:$rs2), 32), 32)>;
963-
964-
def : Pat<(sext_inreg (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2)), i32),
965-
(SRLW GPR:$rs1, GPR:$rs2)>;
966-
def : Pat<(zexti32 (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2))),
967-
(SRLI (SLLI (SRLW GPR:$rs1, GPR:$rs2), 32), 32)>;
968-
969-
def : Pat<(sra (sexti32 GPR:$rs1), (shiftwamt GPR:$rs2)),
970-
(SRAW GPR:$rs1, GPR:$rs2)>;
940+
def : PatGprGpr<riscv_sllw, SLLW>;
941+
def : PatGprGpr<riscv_srlw, SRLW>;
942+
def : PatGprGpr<riscv_sraw, SRAW>;
971943

972944
/// Loads
973945

0 commit comments

Comments
 (0)