-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Remove NVPTX::IMAD
opcode, and rely on intruction selection only
#121724
Conversation
32f5cc5
to
66adc32
Compare
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-llvm-selectiondag Author: None (peterbell10) ChangesI noticed that NVPTX will sometimes emit This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified. To fix this, I add some combiner patterns for IMAD. In particular, this PR adds:
Another option might be to remove I found testing this change to be quite tricky as there is no mad intrinsic so we have to write
For the Full diff: https://github.com/llvm/llvm-project/pull/121724.diff 7 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index ff7caec41855fd..3a015c8df2066a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2460,6 +2460,11 @@ class SelectionDAG {
SDNode *FindNodeOrInsertPos(const FoldingSetNodeID &ID, const SDLoc &DL,
void *&InsertPos);
+ SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+ SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, SDVTList VTs,
+ ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+
/// Maps to auto-CSE operations.
std::vector<CondCodeSDNode*> CondCodeNodes;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b2501591c81a3..6d75809cdaf69f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -153,6 +153,13 @@ static cl::opt<bool> EnableVectorFCopySignExtendRound(
"combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
cl::desc(
"Enable merging extends and rounds into FCOPYSIGN on vector types"));
+
+static cl::opt<bool>
+ EnableGenericCombines("combiner-generic-combines", cl::Hidden,
+ cl::init(true),
+ cl::desc("Enable generic DAGCombine patterns. Useful "
+ "for testing target-specific combines."));
+
namespace {
class DAGCombiner {
@@ -251,7 +258,8 @@ namespace {
: DAG(D), TLI(D.getTargetLoweringInfo()),
STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
ForCodeSize = DAG.shouldOptForSize();
- DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
+ DisableGenericCombines = !EnableGenericCombines ||
+ (STI && STI->disableGenericCombines(OptLevel));
MaximumLegalStoreInBits = 0;
// We use the minimum store size here, since that's all we can guarantee
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 10e8ba93359fbd..6a3799e02edd94 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -115,6 +115,10 @@ static cl::opt<unsigned>
MaxSteps("has-predecessor-max-steps", cl::Hidden, cl::init(8192),
cl::desc("DAG combiner limit number of steps when searching DAG "
"for predecessor nodes"));
+static cl::opt<bool> EnableSimplifyNodes(
+ "selectiondag-simplify-nodes", cl::Hidden, cl::init(true),
+ cl::desc("Enable SelectionDAG::getNode simplifications. Useful for testing "
+ "DAG combines."));
static void NewSDValueDbgMsg(SDValue V, StringRef Msg, SelectionDAG *G) {
LLVM_DEBUG(dbgs() << Msg; V.getNode()->dump(G););
@@ -6157,23 +6161,46 @@ static SDValue foldCONCAT_VECTORS(const SDLoc &DL, EVT VT,
}
/// Gets or creates the specified node.
-SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops,
+ const SDNodeFlags Flags) {
SDVTList VTs = getVTList(VT);
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, {});
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
- return SDValue(E, 0);
+ return getNodeImpl(Opcode, DL, VTs, Ops, Flags);
+}
- auto *N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- CSEMap.InsertNode(N, IP);
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL,
+ SDVTList VTs, ArrayRef<SDValue> Ops,
+ const SDNodeFlags Flags) {
+ SDNode *N;
+ // Don't CSE glue-producing nodes
+ if (VTs.VTs[VTs.NumVTs - 1] != MVT::Glue) {
+ FoldingSetNodeID ID;
+ AddNodeIDNode(ID, Opcode, VTs, Ops);
+ void *IP = nullptr;
+ if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+ E->intersectFlagsWith(Flags);
+ return SDValue(E, 0);
+ }
+
+ N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+ createOperands(N, Ops);
+ CSEMap.InsertNode(N, IP);
+ } else {
+ N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+ createOperands(N, Ops);
+ }
+ N->setFlags(Flags);
InsertNode(N);
SDValue V = SDValue(N, 0);
NewSDValueDbgMsg(V, "Creating new node: ", this);
return V;
}
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+ return getNodeImpl(Opcode, DL, VT, {}, SDNodeFlags{});
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1) {
SDNodeFlags Flags;
@@ -6185,6 +6212,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, const SDNodeFlags Flags) {
assert(N1.getOpcode() != ISD::DELETED_NODE && "Operand is DELETED_NODE!");
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
// Constant fold unary operations with a vector integer or float operand.
switch (Opcode) {
@@ -6501,31 +6530,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
- SDNode *N;
- SDVTList VTs = getVTList(VT);
- SDValue Ops[] = {N1};
- if (VT != MVT::Glue) { // Don't CSE glue producing nodes
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- N->setFlags(Flags);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- InsertNode(N);
- SDValue V = SDValue(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
}
static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
@@ -7219,6 +7224,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(N1.getOpcode() != ISD::DELETED_NODE &&
N2.getOpcode() != ISD::DELETED_NODE &&
"Operand is DELETED_NODE!");
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
canonicalizeCommutativeBinop(Opcode, N1, N2);
@@ -7665,32 +7672,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
}
- // Memoize this node if possible.
- SDNode *N;
- SDVTList VTs = getVTList(VT);
- SDValue Ops[] = {N1, N2};
- if (VT != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- N->setFlags(Flags);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- InsertNode(N);
- SDValue V = SDValue(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -7708,6 +7690,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
N2.getOpcode() != ISD::DELETED_NODE &&
N3.getOpcode() != ISD::DELETED_NODE &&
"Operand is DELETED_NODE!");
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
+
// Perform various simplifications.
switch (Opcode) {
case ISD::FMA:
@@ -7862,33 +7847,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
}
-
- // Memoize node if it doesn't produce a glue result.
- SDNode *N;
- SDVTList VTs = getVTList(VT);
- SDValue Ops[] = {N1, N2, N3};
- if (VT != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- N->setFlags(Flags);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- InsertNode(N);
- SDValue V = SDValue(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -10343,6 +10302,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(Op.getOpcode() != ISD::DELETED_NODE &&
"Operand is DELETED_NODE!");
#endif
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, Ops, Flags);
switch (Opcode) {
default: break;
@@ -10411,34 +10372,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
- // Memoize nodes.
- SDNode *N;
- SDVTList VTs = getVTList(VT);
-
- if (VT != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
-
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
-
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- N->setFlags(Flags);
- InsertNode(N);
- SDValue V(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, Ops, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
@@ -10458,6 +10392,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
if (VTList.NumVTs == 1)
return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags);
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
#ifndef NDEBUG
for (const auto &Op : Ops)
@@ -10622,30 +10558,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
#endif
}
- // Memoize the node unless it returns a glue result.
- SDNode *N;
- if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTList, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
- createOperands(N, Ops);
- }
-
- N->setFlags(Flags);
- InsertNode(N);
- SDValue V(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..c4529c9151bc2b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5164,6 +5164,53 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
}
+static SDValue
+PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
+ EVT VT = N0->getValueType(0);
+ SDLoc DL(N);
+ SDNodeFlags Flags = N->getFlags();
+
+ // mad x 1 y => add x y
+ if (N1C && N1C->isOne())
+ return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
+
+ // mad x -1 y => sub y x
+ if (N1C && N1C->isAllOnes()) {
+ Flags.setNoUnsignedWrap(false);
+ return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
+ }
+
+ // mad x 0 y => y
+ if (N1C && N1C->isZero())
+ return N2;
+
+ // mad x y 0 => mul x y
+ if (N2C && N2C->isZero())
+ return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
+
+ // mad c0 c1 x => add x (c0*c1)
+ if (SDValue C =
+ DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
+ return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
+
+ return {};
+}
+
+static SDValue PerformIMADCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
+ if (res)
+ return res;
+
+ return PerformIMADCombineWithOperands(N, N1, N0, N2, DCI);
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5198,6 +5245,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformVSELECTCombine(N, DCI);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
+ case NVPTXISD::IMAD:
+ return PerformIMADCombine(N, DCI);
}
return SDValue();
}
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad-only.ll b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
new file mode 100644
index 00000000000000..fb4bcc39b5a64d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+
+;; mad x 1 y => add y x
+define i32 @test_mad_mul_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_mul_1_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_mad_mul_1_param_1];
+; CHECK-NEXT: add.s32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, 1
+ %add = add i32 %mul, %y
+ ret i32 %add
+}
+
+;; mad x -1 y => sub y x
+define i32 @test_mad_mul_neg_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_neg_1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_mul_neg_1_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_mad_mul_neg_1_param_1];
+; CHECK-NEXT: sub.s32 %r3, %r2, %r1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, -1
+ %add = add i32 %mul, %y
+ ret i32 %add
+}
+
+;; mad x 0 y => y
+define i32 @test_mad_mul_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_0(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_mul_0_param_1];
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, 0
+ %add = add i32 %mul, %y
+ ret i32 %add
+}
+
+;; mad x y 0 => mul x y
+define i32 @test_mad_add_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_add_0(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_add_0_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_mad_add_0_param_1];
+; CHECK-NEXT: mul.lo.s32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, %y
+ %add = add i32 %mul, 0
+ ret i32 %add
+}
+
+;; mad c0 c1 x => add x (c0*c1)
+define i32 @test_mad_fold_mul(i32 %x) {
+; CHECK-LABEL: test_mad_fold_mul(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_mul_param_0];
+; CHECK-NEXT: add.s32 %r2, %r1, 12;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 4, 3
+ %add = add i32 %mul, %x
+ ret i32 %add
+}
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 1b22cfde39725f..7d523a835a1f3f 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -183,3 +183,23 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
%add = add i32 %c, %sel
ret i32 %add
}
+
+;; This case relies on mad x 1 y => add x y, previously we emit:
+;; mad.lo.s32 %r3, %r1, 1, %r2;
+define i32 @test_mad_fold(i32 %x) {
+; CHECK-LABEL: test_mad_fold(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_param_0];
+; CHECK-NEXT: mul.hi.s32 %r2, %r1, -2147221471;
+; CHECK-NEXT: add.s32 %r3, %r1, %r2;
+; CHECK-NEXT: shr.u32 %r4, %r3, 31;
+; CHECK-NEXT: shr.s32 %r5, %r3, 12;
+; CHECK-NEXT: add.s32 %r6, %r5, %r4;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT: ret;
+ %div = sdiv i32 %x, 8191
+ ret i32 %div
+}
diff --git a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
index 27a523b9dd91d2..de19d2983f3435 100644
--- a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
+++ b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
@@ -12,7 +12,7 @@
; CHECK-NOT: __local_depot
; CHECK-32: ld.param.u32 %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
-; CHECK-32-NEXT: mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
+; CHECK-32-NEXT: add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
; CHECK-32-NEXT: and.b32 %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
; CHECK-32-NEXT: alloca.u32 %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
; CHECK-32-NEXT: cvta.local.u32 %r[[ALLOCA]], %r[[ALLOCA]];
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option might be to remove NVPTXISD::IMAD and only combine to mad during selection. This would allow the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.
This would be better without a stronger justification for why an intermediate node is useful
However, it also risks DAGCombiner breaking up the mul-add patterns, which is why I haven't done it that way.
Should have a more concrete justification. Generally combines that would break the pattern are new instances of the pattern to handle (e.g. multiply-to-shift + add)
Perhaps the NVPTX maintainers know the history better here.
That's a very relevant example. multiply-to-shift is a pessimisation if it breaks up a mad, since |
508628e
to
1765171
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
NVPTX::IMAD
opcode, and rely on intruction selection only
Okay, I've reworked this to remove Even today there are only one or two references to IMAD in combine patterns, so I think this is likely fine. The i128 lit test even seems to be doing a better job of generating mad now. |
65b3efc
to
b841af0
Compare
@Artem-B would you mind taking a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this looks like great work, thanks @peterbell10!
The only real question I have is: #121724 (comment).
Adding @AlexMaclean and @Artem-B as reviewers in case I'm unaware of a good reason for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
(ins Int64Regs:$a, i64imm:$b, i64imm:$c), | ||
"mad.lo.s64 \t$dst, $a, $b, $c;", | ||
[(set i64:$dst, (imad i64:$a, imm:$b, imm:$c))]>; | ||
def mul_oneuse : PatFrag<(ops node:$a, node:$b), (mul node:$a, node:$b), [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RISC-V has a generalized form of one-use pattern:
llvm-project/llvm/lib/Target/RISCV/RISCVInstrInfo.td
Lines 1276 to 1283 in 1de3dc7
class binop_oneuse<SDPatternOperator operator> | |
: PatFrag<(ops node:$A, node:$B), | |
(operator node:$A, node:$B), [{ | |
return N->hasOneUse(); | |
}]>; | |
def and_oneuse : binop_oneuse<and>; | |
def mul_oneuse : binop_oneuse<mul>; |
It may be something worth extracting into a common tablegen file. We have quite a few uses of hasOneUse()
in the backends. Could be in a separate patch.
@@ -141,6 +141,7 @@ def hasLDG : Predicate<"Subtarget->hasLDG()">; | |||
def hasLDU : Predicate<"Subtarget->hasLDU()">; | |||
def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">; | |||
def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">; | |||
def hasO1 : Predicate<"TM.getOptLevel() != CodeGenOptLevel::None">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd rename it to something more distinct from 01
. Perhaps hasOptEnabled
I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1, e.g. in https://gcc.godbolt.org/z/45W3Wcnxz This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified. This PR adds: ``` mad x 1 y => add x y mad x -1 y => sub y x mad x 0 y => y mad x y 0 => mul x y mad c0 c1 z => add z (C0 * C1) ``` Another option might be to remove `NVPTXISD::IMAD` and only combine to mad during selection. This would allow the normal DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention. However, it also risks DAGCombiner breaking up the mul-add patterns, which is why I haven't done it that way.
391c8d1
to
b0bd6c9
Compare
Thanks all for reviewing! |
… only (llvm#121724) I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1, e.g. in https://gcc.godbolt.org/z/4j47Y9W4c. This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified. To fix this, I remove `NVPTXISD::IMAD` and only combine to mad during selection. This allows the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.
… only (llvm#121724) I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1, e.g. in https://gcc.godbolt.org/z/4j47Y9W4c. This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified. To fix this, I remove `NVPTXISD::IMAD` and only combine to mad during selection. This allows the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.
I noticed that NVPTX will sometimes emit
mad.lo
to multiply by 1, e.g. in https://gcc.godbolt.org/z/4j47Y9W4c.This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified.
To fix this, I remove
NVPTXISD::IMAD
and only combine to mad during selection. This would allow the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.