@@ -11379,16 +11379,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1137911379 const unsigned ReduceOpc = getVecReduceOpcode(Opc);
1138011380 assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
1138111381 "Inconsistent mappings");
11382- const SDValue LHS = N->getOperand(0);
11383- const SDValue RHS = N->getOperand(1);
11382+ SDValue LHS = N->getOperand(0);
11383+ SDValue RHS = N->getOperand(1);
1138411384
1138511385 if (!LHS.hasOneUse() || !RHS.hasOneUse())
1138611386 return SDValue();
1138711387
11388+ if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
11389+ std::swap(LHS, RHS);
11390+
1138811391 if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
1138911392 !isa<ConstantSDNode>(RHS.getOperand(1)))
1139011393 return SDValue();
1139111394
11395+ uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
1139211396 SDValue SrcVec = RHS.getOperand(0);
1139311397 EVT SrcVecVT = SrcVec.getValueType();
1139411398 assert(SrcVecVT.getVectorElementType() == VT);
@@ -11401,14 +11405,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1140111405 // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
1140211406 // reduce_op (extract_subvector [2 x VT] from V). This will form the
1140311407 // root of our reduction tree. TODO: We could extend this to any two
11404- // adjacent constant indices if desired.
11408+ // adjacent aligned constant indices if desired.
1140511409 if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
11406- LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
11407- isOneConstant(RHS.getOperand(1))) {
11408- EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
11409- SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11410- DAG.getVectorIdxConstant(0, DL));
11411- return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
11410+ LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
11411+ uint64_t LHSIdx =
11412+ cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
11413+ if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
11414+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
11415+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11416+ DAG.getVectorIdxConstant(0, DL));
11417+ return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
11418+ }
1141211419 }
1141311420
1141411421 // Match (binop (reduce (extract_subvector V, 0),
@@ -11420,20 +11427,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1142011427 SDValue ReduceVec = LHS.getOperand(0);
1142111428 if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
1142211429 ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
11423- isNullConstant(ReduceVec.getOperand(1))) {
11424- uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
11425- if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
11426- // For illegal types (e.g. 3xi32), most will be combined again into a
11427- // wider (hopefully legal) type. If this is a terminal state, we are
11428- // relying on type legalization here to produce something reasonable
11429- // and this lowering quality could probably be improved. (TODO)
11430- EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
11431- SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11432- DAG.getVectorIdxConstant(0, DL));
11433- auto Flags = ReduceVec->getFlags();
11434- Flags.intersectWith(N->getFlags());
11435- return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
11436- }
11430+ isNullConstant(ReduceVec.getOperand(1)) &&
11431+ ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
11432+ // For illegal types (e.g. 3xi32), most will be combined again into a
11433+ // wider (hopefully legal) type. If this is a terminal state, we are
11434+ // relying on type legalization here to produce something reasonable
11435+ // and this lowering quality could probably be improved. (TODO)
11436+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
11437+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11438+ DAG.getVectorIdxConstant(0, DL));
11439+ auto Flags = ReduceVec->getFlags();
11440+ Flags.intersectWith(N->getFlags());
11441+ return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
1143711442 }
1143811443
1143911444 return SDValue();
0 commit comments