Skip to content

Commit 25da9bb

Browse files
authored
[RISCV] Allow swapped operands in reduction formation (#68634)
Very straight forward, but worth landing on it's own in advance of a more complicated generalization.
1 parent aab0626 commit 25da9bb

File tree

2 files changed

+90
-27
lines changed

2 files changed

+90
-27
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11379,16 +11379,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1137911379
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
1138011380
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
1138111381
"Inconsistent mappings");
11382-
const SDValue LHS = N->getOperand(0);
11383-
const SDValue RHS = N->getOperand(1);
11382+
SDValue LHS = N->getOperand(0);
11383+
SDValue RHS = N->getOperand(1);
1138411384

1138511385
if (!LHS.hasOneUse() || !RHS.hasOneUse())
1138611386
return SDValue();
1138711387

11388+
if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
11389+
std::swap(LHS, RHS);
11390+
1138811391
if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
1138911392
!isa<ConstantSDNode>(RHS.getOperand(1)))
1139011393
return SDValue();
1139111394

11395+
uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
1139211396
SDValue SrcVec = RHS.getOperand(0);
1139311397
EVT SrcVecVT = SrcVec.getValueType();
1139411398
assert(SrcVecVT.getVectorElementType() == VT);
@@ -11401,14 +11405,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1140111405
// match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
1140211406
// reduce_op (extract_subvector [2 x VT] from V). This will form the
1140311407
// root of our reduction tree. TODO: We could extend this to any two
11404-
// adjacent constant indices if desired.
11408+
// adjacent aligned constant indices if desired.
1140511409
if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
11406-
LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
11407-
isOneConstant(RHS.getOperand(1))) {
11408-
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
11409-
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11410-
DAG.getVectorIdxConstant(0, DL));
11411-
return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
11410+
LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
11411+
uint64_t LHSIdx =
11412+
cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
11413+
if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
11414+
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
11415+
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11416+
DAG.getVectorIdxConstant(0, DL));
11417+
return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
11418+
}
1141211419
}
1141311420

1141411421
// Match (binop (reduce (extract_subvector V, 0),
@@ -11420,20 +11427,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1142011427
SDValue ReduceVec = LHS.getOperand(0);
1142111428
if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
1142211429
ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
11423-
isNullConstant(ReduceVec.getOperand(1))) {
11424-
uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
11425-
if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
11426-
// For illegal types (e.g. 3xi32), most will be combined again into a
11427-
// wider (hopefully legal) type. If this is a terminal state, we are
11428-
// relying on type legalization here to produce something reasonable
11429-
// and this lowering quality could probably be improved. (TODO)
11430-
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
11431-
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11432-
DAG.getVectorIdxConstant(0, DL));
11433-
auto Flags = ReduceVec->getFlags();
11434-
Flags.intersectWith(N->getFlags());
11435-
return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
11436-
}
11430+
isNullConstant(ReduceVec.getOperand(1)) &&
11431+
ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
11432+
// For illegal types (e.g. 3xi32), most will be combined again into a
11433+
// wider (hopefully legal) type. If this is a terminal state, we are
11434+
// relying on type legalization here to produce something reasonable
11435+
// and this lowering quality could probably be improved. (TODO)
11436+
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
11437+
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
11438+
DAG.getVectorIdxConstant(0, DL));
11439+
auto Flags = ReduceVec->getFlags();
11440+
Flags.intersectWith(N->getFlags());
11441+
return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
1143711442
}
1143811443

1143911444
return SDValue();

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
3434
ret i32 %add2
3535
}
3636

37-
3837
define i32 @reduce_sum_8xi32(<8 x i32> %v) {
3938
; CHECK-LABEL: reduce_sum_8xi32:
4039
; CHECK: # %bb.0:
@@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
449448
ret i32 %add13
450449
}
451450

451+
; Check that we can match with the operand ordered reversed, but the
452+
; reduction order unchanged.
453+
define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
454+
; CHECK-LABEL: reduce_sum_4xi32_op_order:
455+
; CHECK: # %bb.0:
456+
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
457+
; CHECK-NEXT: vmv.s.x v9, zero
458+
; CHECK-NEXT: vredsum.vs v8, v8, v9
459+
; CHECK-NEXT: vmv.x.s a0, v8
460+
; CHECK-NEXT: ret
461+
%e0 = extractelement <4 x i32> %v, i32 0
462+
%e1 = extractelement <4 x i32> %v, i32 1
463+
%e2 = extractelement <4 x i32> %v, i32 2
464+
%e3 = extractelement <4 x i32> %v, i32 3
465+
%add0 = add i32 %e1, %e0
466+
%add1 = add i32 %e2, %add0
467+
%add2 = add i32 %add1, %e3
468+
ret i32 %add2
469+
}
470+
471+
; Negative test - Reduction order isn't compatibile with current
472+
; incremental matching scheme.
473+
define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
474+
; RV32-LABEL: reduce_sum_4xi32_reduce_order:
475+
; RV32: # %bb.0:
476+
; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma
477+
; RV32-NEXT: vmv.x.s a0, v8
478+
; RV32-NEXT: vslidedown.vi v9, v8, 1
479+
; RV32-NEXT: vmv.x.s a1, v9
480+
; RV32-NEXT: vslidedown.vi v9, v8, 2
481+
; RV32-NEXT: vmv.x.s a2, v9
482+
; RV32-NEXT: vslidedown.vi v8, v8, 3
483+
; RV32-NEXT: vmv.x.s a3, v8
484+
; RV32-NEXT: add a1, a1, a2
485+
; RV32-NEXT: add a0, a0, a3
486+
; RV32-NEXT: add a0, a0, a1
487+
; RV32-NEXT: ret
488+
;
489+
; RV64-LABEL: reduce_sum_4xi32_reduce_order:
490+
; RV64: # %bb.0:
491+
; RV64-NEXT: vsetivli zero, 1, e32, m1, ta, ma
492+
; RV64-NEXT: vmv.x.s a0, v8
493+
; RV64-NEXT: vslidedown.vi v9, v8, 1
494+
; RV64-NEXT: vmv.x.s a1, v9
495+
; RV64-NEXT: vslidedown.vi v9, v8, 2
496+
; RV64-NEXT: vmv.x.s a2, v9
497+
; RV64-NEXT: vslidedown.vi v8, v8, 3
498+
; RV64-NEXT: vmv.x.s a3, v8
499+
; RV64-NEXT: add a1, a1, a2
500+
; RV64-NEXT: add a0, a0, a3
501+
; RV64-NEXT: addw a0, a0, a1
502+
; RV64-NEXT: ret
503+
%e0 = extractelement <4 x i32> %v, i32 0
504+
%e1 = extractelement <4 x i32> %v, i32 1
505+
%e2 = extractelement <4 x i32> %v, i32 2
506+
%e3 = extractelement <4 x i32> %v, i32 3
507+
%add0 = add i32 %e1, %e2
508+
%add1 = add i32 %e0, %add0
509+
%add2 = add i32 %add1, %e3
510+
ret i32 %add2
511+
}
512+
452513
;; Most of the cornercases are exercised above, the following just
453514
;; makes sure that other opcodes work as expected.
454515

@@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
923984
}
924985

925986

926-
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
927-
; RV32: {{.*}}
928-
; RV64: {{.*}}

0 commit comments

Comments
 (0)