Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Generaize reduction tree matching to all integer reductions #68014

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11108,6 +11108,31 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
}
}

/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
/// which corresponds to it.
static unsigned getVecReduceOpcode(unsigned Opc) {
switch (Opc) {
default:
llvm_unreachable("Unhandled binary to transfrom reduction");
case ISD::ADD:
return ISD::VECREDUCE_ADD;
case ISD::UMAX:
return ISD::VECREDUCE_UMAX;
case ISD::SMAX:
return ISD::VECREDUCE_SMAX;
case ISD::UMIN:
return ISD::VECREDUCE_UMIN;
case ISD::SMIN:
return ISD::VECREDUCE_SMIN;
case ISD::AND:
return ISD::VECREDUCE_AND;
case ISD::OR:
return ISD::VECREDUCE_OR;
case ISD::XOR:
return ISD::VECREDUCE_XOR;
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we intend to have this extra semicolon after the bracket?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 63bbc25


/// Perform two related transforms whose purpose is to incrementally recognize
/// an explode_vector followed by scalar reduction as a vector reduction node.
/// This exists to recover from a deficiency in SLP which can't handle
Expand All @@ -11126,8 +11151,15 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,

const SDLoc DL(N);
const EVT VT = N->getValueType(0);
[[maybe_unused]] const unsigned Opc = N->getOpcode();
assert(Opc == ISD::ADD && "extend this to other reduction types");

// TODO: Handle floating point here.
if (!VT.isInteger())
return SDValue();

const unsigned Opc = N->getOpcode();
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe isOperationLegalOrCustomOrPromote(ReduceOpc, VT)?

"Inconsistent mappings");
const SDValue LHS = N->getOperand(0);
const SDValue RHS = N->getOperand(1);

Expand Down Expand Up @@ -11157,13 +11189,13 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
return DAG.getNode(ReduceOpc, DL, VT, Vec);
}

// Match (binop (reduce (extract_subvector V, 0),
// (extract_vector_elt V, sizeof(SubVec))))
// into a reduction of one more element from the original vector V.
if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
if (LHS.getOpcode() != ReduceOpc)
return SDValue();

SDValue ReduceVec = LHS.getOperand(0);
Expand All @@ -11179,7 +11211,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
return DAG.getNode(ReduceOpc, DL, VT, Vec);
}
}

Expand Down Expand Up @@ -11687,6 +11719,8 @@ static SDValue performANDCombine(SDNode *N,

if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
return V;

if (DCI.isAfterLegalizeDAG())
if (SDValue V = combineDeMorganOfBoolean(N, DAG))
Expand Down Expand Up @@ -11739,6 +11773,8 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,

if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
return V;

if (DCI.isAfterLegalizeDAG())
if (SDValue V = combineDeMorganOfBoolean(N, DAG))
Expand Down Expand Up @@ -11790,6 +11826,9 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,

if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
return V;

// fold (xor (select cond, 0, y), x) ->
// (select cond, x, (xor x, y))
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
Expand Down Expand Up @@ -13995,8 +14034,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SMAX:
case ISD::SMIN:
case ISD::FMAXNUM:
case ISD::FMINNUM:
return combineBinOpToReduce(N, DAG, Subtarget);
case ISD::FMINNUM: {
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
return V;
return SDValue();
}
case ISD::SETCC:
return performSETCCCombine(N, DAG, Subtarget);
case ISD::SIGN_EXTEND_INREG:
Expand Down
Loading