Skip to content

Commit

Permalink
[AArch64][SVE] Add codegen support for partial reduction lowering to …
Browse files Browse the repository at this point in the history
…wide add instructions (#114406)

For partial reductions in the situation of the number of elements
being halved, a pair of wide add instructions can be used.
  • Loading branch information
JamesChesterman authored Nov 12, 2024
1 parent e05d91b commit c3c2e1e
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 2 deletions.
60 changes: 58 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2039,8 +2039,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;

EVT VT = EVT::getEVT(I->getType());
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
auto Op1 = I->getOperand(1);
EVT Op1VT = EVT::getEVT(Op1->getType());
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
return false;
return true;
}

bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
Expand Down Expand Up @@ -21784,6 +21789,55 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}

SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {

assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
getIntrinsicID(N) ==
Intrinsic::experimental_vector_partial_reduce_add &&
"Expected a partial reduction node");

if (!Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();

SDLoc DL(N);

auto Acc = N->getOperand(1);
auto ExtInput = N->getOperand(2);

EVT AccVT = Acc.getValueType();
EVT AccElemVT = AccVT.getVectorElementType();

if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
return SDValue();

unsigned ExtInputOpcode = ExtInput->getOpcode();
if (!ISD::isExtOpcode(ExtInputOpcode))
return SDValue();

auto Input = ExtInput->getOperand(0);
EVT InputVT = Input.getValueType();

if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();

bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
: Intrinsic::aarch64_sve_uaddwb;
auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
: Intrinsic::aarch64_sve_uaddwt;

auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
auto BottomNode =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
Input);
}

static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -21795,6 +21849,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
}
Expand Down
141 changes: 141 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s

define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-LABEL: signed_wide_add_nxv4i32:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: saddwb z0.d, z0.d, z1.s
; CHECK-NEXT: saddwt z0.d, z0.d, z1.s
; CHECK-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
ret <vscale x 2 x i64> %partial.reduce
}

define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-LABEL: unsigned_wide_add_nxv4i32:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uaddwb z0.d, z0.d, z1.s
; CHECK-NEXT: uaddwt z0.d, z0.d, z1.s
; CHECK-NEXT: ret
entry:
%input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
ret <vscale x 2 x i64> %partial.reduce
}

define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
; CHECK-LABEL: signed_wide_add_nxv8i16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: saddwb z0.s, z0.s, z1.h
; CHECK-NEXT: saddwt z0.s, z0.s, z1.h
; CHECK-NEXT: ret
entry:
%input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
; CHECK-LABEL: unsigned_wide_add_nxv8i16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uaddwb z0.s, z0.s, z1.h
; CHECK-NEXT: uaddwt z0.s, z0.s, z1.h
; CHECK-NEXT: ret
entry:
%input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
; CHECK-LABEL: signed_wide_add_nxv16i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: saddwb z0.h, z0.h, z1.b
; CHECK-NEXT: saddwt z0.h, z0.h, z1.b
; CHECK-NEXT: ret
entry:
%input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
ret <vscale x 8 x i16> %partial.reduce
}

define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
; CHECK-LABEL: unsigned_wide_add_nxv16i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uaddwb z0.h, z0.h, z1.b
; CHECK-NEXT: uaddwt z0.h, z0.h, z1.b
; CHECK-NEXT: ret
entry:
%input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
ret <vscale x 8 x i16> %partial.reduce
}

define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
; CHECK-LABEL: signed_wide_add_nxv4i16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: sxth z1.s, p0/m, z1.s
; CHECK-NEXT: uunpklo z2.d, z1.s
; CHECK-NEXT: uunpkhi z1.d, z1.s
; CHECK-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i16> %input to <vscale x 4 x i32>
%partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
ret <vscale x 2 x i32> %partial.reduce
}

define <vscale x 2 x i32> @unsigned_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
; CHECK-LABEL: unsigned_wide_add_nxv4i16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEXT: uunpklo z2.d, z1.s
; CHECK-NEXT: uunpkhi z1.d, z1.s
; CHECK-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
entry:
%input.wide = zext <vscale x 4 x i16> %input to <vscale x 4 x i32>
%partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
ret <vscale x 2 x i32> %partial.reduce
}

define <vscale x 4 x i64> @signed_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
; CHECK-LABEL: signed_wide_add_nxv8i32:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sunpkhi z4.d, z2.s
; CHECK-NEXT: sunpklo z2.d, z2.s
; CHECK-NEXT: sunpkhi z5.d, z3.s
; CHECK-NEXT: sunpklo z3.d, z3.s
; CHECK-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEXT: add z1.d, z1.d, z4.d
; CHECK-NEXT: add z0.d, z3.d, z0.d
; CHECK-NEXT: add z1.d, z5.d, z1.d
; CHECK-NEXT: ret
entry:
%input.wide = sext <vscale x 8 x i32> %input to <vscale x 8 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
ret <vscale x 4 x i64> %partial.reduce
}

define <vscale x 4 x i64> @unsigned_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
; CHECK-LABEL: unsigned_wide_add_nxv8i32:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uunpkhi z4.d, z2.s
; CHECK-NEXT: uunpklo z2.d, z2.s
; CHECK-NEXT: uunpkhi z5.d, z3.s
; CHECK-NEXT: uunpklo z3.d, z3.s
; CHECK-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEXT: add z1.d, z1.d, z4.d
; CHECK-NEXT: add z0.d, z3.d, z0.d
; CHECK-NEXT: add z1.d, z5.d, z1.d
; CHECK-NEXT: ret
entry:
%input.wide = zext <vscale x 8 x i32> %input to <vscale x 8 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
ret <vscale x 4 x i64> %partial.reduce
}

0 comments on commit c3c2e1e

Please sign in to comment.