diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 6be3fa71479be..b0fc99f6eff86 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -11363,16 +11363,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, const unsigned ReduceOpc = getVecReduceOpcode(Opc); assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) && "Inconsistent mappings"); - const SDValue LHS = N->getOperand(0); - const SDValue RHS = N->getOperand(1); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); if (!LHS.hasOneUse() || !RHS.hasOneUse()) return SDValue(); + if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + std::swap(LHS, RHS); + if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || !isa(RHS.getOperand(1))) return SDValue(); + uint64_t RHSIdx = cast(RHS.getOperand(1))->getLimitedValue(); SDValue SrcVec = RHS.getOperand(0); EVT SrcVecVT = SrcVec.getValueType(); assert(SrcVecVT.getVectorElementType() == VT); @@ -11385,14 +11389,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to // reduce_op (extract_subvector [2 x VT] from V). This will form the // root of our reduction tree. TODO: We could extend this to any two - // adjacent constant indices if desired. + // adjacent aligned constant indices if desired. if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) && - isOneConstant(RHS.getOperand(1))) { - EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2); - SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, - DAG.getVectorIdxConstant(0, DL)); - return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags()); + LHS.getOperand(0) == SrcVec && isa(LHS.getOperand(1))) { + uint64_t LHSIdx = + cast(LHS.getOperand(1))->getLimitedValue(); + if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) { + EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2); + SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, + DAG.getVectorIdxConstant(0, DL)); + return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags()); + } } // Match (binop (reduce (extract_subvector V, 0), @@ -11404,20 +11411,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, SDValue ReduceVec = LHS.getOperand(0); if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) && - isNullConstant(ReduceVec.getOperand(1))) { - uint64_t Idx = cast(RHS.getOperand(1))->getLimitedValue(); - if (ReduceVec.getValueType().getVectorNumElements() == Idx) { - // For illegal types (e.g. 3xi32), most will be combined again into a - // wider (hopefully legal) type. If this is a terminal state, we are - // relying on type legalization here to produce something reasonable - // and this lowering quality could probably be improved. (TODO) - EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1); - SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, - DAG.getVectorIdxConstant(0, DL)); - auto Flags = ReduceVec->getFlags(); - Flags.intersectWith(N->getFlags()); - return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags); - } + isNullConstant(ReduceVec.getOperand(1)) && + ReduceVec.getValueType().getVectorNumElements() == RHSIdx) { + // For illegal types (e.g. 3xi32), most will be combined again into a + // wider (hopefully legal) type. If this is a terminal state, we are + // relying on type legalization here to produce something reasonable + // and this lowering quality could probably be improved. (TODO) + EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1); + SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, + DAG.getVectorIdxConstant(0, DL)); + auto Flags = ReduceVec->getFlags(); + Flags.intersectWith(N->getFlags()); + return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags); } return SDValue(); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll index 76df097a76971..fd4a54b468f15 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll @@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) { ret i32 %add2 } - define i32 @reduce_sum_8xi32(<8 x i32> %v) { ; CHECK-LABEL: reduce_sum_8xi32: ; CHECK: # %bb.0: @@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) { ret i32 %add13 } +; Check that we can match with the operand ordered reversed, but the +; reduction order unchanged. +define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) { +; CHECK-LABEL: reduce_sum_4xi32_op_order: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; CHECK-NEXT: vmv.s.x v9, zero +; CHECK-NEXT: vredsum.vs v8, v8, v9 +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: ret + %e0 = extractelement <4 x i32> %v, i32 0 + %e1 = extractelement <4 x i32> %v, i32 1 + %e2 = extractelement <4 x i32> %v, i32 2 + %e3 = extractelement <4 x i32> %v, i32 3 + %add0 = add i32 %e1, %e0 + %add1 = add i32 %e2, %add0 + %add2 = add i32 %add1, %e3 + ret i32 %add2 +} + +; Negative test - Reduction order isn't compatibile with current +; incremental matching scheme. +define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) { +; RV32-LABEL: reduce_sum_4xi32_reduce_order: +; RV32: # %bb.0: +; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma +; RV32-NEXT: vmv.x.s a0, v8 +; RV32-NEXT: vslidedown.vi v9, v8, 1 +; RV32-NEXT: vmv.x.s a1, v9 +; RV32-NEXT: vslidedown.vi v9, v8, 2 +; RV32-NEXT: vmv.x.s a2, v9 +; RV32-NEXT: vslidedown.vi v8, v8, 3 +; RV32-NEXT: vmv.x.s a3, v8 +; RV32-NEXT: add a1, a1, a2 +; RV32-NEXT: add a0, a0, a3 +; RV32-NEXT: add a0, a0, a1 +; RV32-NEXT: ret +; +; RV64-LABEL: reduce_sum_4xi32_reduce_order: +; RV64: # %bb.0: +; RV64-NEXT: vsetivli zero, 1, e32, m1, ta, ma +; RV64-NEXT: vmv.x.s a0, v8 +; RV64-NEXT: vslidedown.vi v9, v8, 1 +; RV64-NEXT: vmv.x.s a1, v9 +; RV64-NEXT: vslidedown.vi v9, v8, 2 +; RV64-NEXT: vmv.x.s a2, v9 +; RV64-NEXT: vslidedown.vi v8, v8, 3 +; RV64-NEXT: vmv.x.s a3, v8 +; RV64-NEXT: add a1, a1, a2 +; RV64-NEXT: add a0, a0, a3 +; RV64-NEXT: addw a0, a0, a1 +; RV64-NEXT: ret + %e0 = extractelement <4 x i32> %v, i32 0 + %e1 = extractelement <4 x i32> %v, i32 1 + %e2 = extractelement <4 x i32> %v, i32 2 + %e3 = extractelement <4 x i32> %v, i32 3 + %add0 = add i32 %e1, %e2 + %add1 = add i32 %e0, %add0 + %add2 = add i32 %add1, %e3 + ret i32 %add2 +} + ;; Most of the cornercases are exercised above, the following just ;; makes sure that other opcodes work as expected. @@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) { } -;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: -; RV32: {{.*}} -; RV64: {{.*}}