diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 069aab274d3126..e7923ff02de704 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2039,8 +2039,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( return true; EVT VT = EVT::getEVT(I->getType()); - return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && - VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32; + auto Op1 = I->getOperand(1); + EVT Op1VT = EVT::getEVT(Op1->getType()); + if (Op1VT.getVectorElementType() == VT.getVectorElementType() && + (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() || + VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount())) + return false; + return true; } bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { @@ -21784,6 +21789,55 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B); } +SDValue tryLowerPartialReductionToWideAdd(SDNode *N, + const AArch64Subtarget *Subtarget, + SelectionDAG &DAG) { + + assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && + getIntrinsicID(N) == + Intrinsic::experimental_vector_partial_reduce_add && + "Expected a partial reduction node"); + + if (!Subtarget->isSVEorStreamingSVEAvailable()) + return SDValue(); + + SDLoc DL(N); + + auto Acc = N->getOperand(1); + auto ExtInput = N->getOperand(2); + + EVT AccVT = Acc.getValueType(); + EVT AccElemVT = AccVT.getVectorElementType(); + + if (ExtInput.getValueType().getVectorElementType() != AccElemVT) + return SDValue(); + + unsigned ExtInputOpcode = ExtInput->getOpcode(); + if (!ISD::isExtOpcode(ExtInputOpcode)) + return SDValue(); + + auto Input = ExtInput->getOperand(0); + EVT InputVT = Input.getValueType(); + + if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) && + !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) && + !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16)) + return SDValue(); + + bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND; + auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb + : Intrinsic::aarch64_sve_uaddwb; + auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt + : Intrinsic::aarch64_sve_uaddwt; + + auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT); + auto BottomNode = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input); + auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode, + Input); +} + static SDValue performIntrinsicCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -21795,6 +21849,8 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::experimental_vector_partial_reduce_add: { if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) return Dot; + if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG)) + return WideAdd; return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2)); } diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll new file mode 100644 index 00000000000000..1d05649964670d --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll @@ -0,0 +1,141 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s + +define @signed_wide_add_nxv4i32( %acc, %input){ +; CHECK-LABEL: signed_wide_add_nxv4i32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: saddwb z0.d, z0.d, z1.s +; CHECK-NEXT: saddwt z0.d, z0.d, z1.s +; CHECK-NEXT: ret +entry: + %input.wide = sext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %input.wide) + ret %partial.reduce +} + +define @unsigned_wide_add_nxv4i32( %acc, %input){ +; CHECK-LABEL: unsigned_wide_add_nxv4i32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uaddwb z0.d, z0.d, z1.s +; CHECK-NEXT: uaddwt z0.d, z0.d, z1.s +; CHECK-NEXT: ret +entry: + %input.wide = zext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %input.wide) + ret %partial.reduce +} + +define @signed_wide_add_nxv8i16( %acc, %input){ +; CHECK-LABEL: signed_wide_add_nxv8i16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: saddwb z0.s, z0.s, z1.h +; CHECK-NEXT: saddwt z0.s, z0.s, z1.h +; CHECK-NEXT: ret +entry: + %input.wide = sext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32( %acc, %input.wide) + ret %partial.reduce +} + +define @unsigned_wide_add_nxv8i16( %acc, %input){ +; CHECK-LABEL: unsigned_wide_add_nxv8i16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uaddwb z0.s, z0.s, z1.h +; CHECK-NEXT: uaddwt z0.s, z0.s, z1.h +; CHECK-NEXT: ret +entry: + %input.wide = zext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32( %acc, %input.wide) + ret %partial.reduce +} + +define @signed_wide_add_nxv16i8( %acc, %input){ +; CHECK-LABEL: signed_wide_add_nxv16i8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: saddwb z0.h, z0.h, z1.b +; CHECK-NEXT: saddwt z0.h, z0.h, z1.b +; CHECK-NEXT: ret +entry: + %input.wide = sext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16( %acc, %input.wide) + ret %partial.reduce +} + +define @unsigned_wide_add_nxv16i8( %acc, %input){ +; CHECK-LABEL: unsigned_wide_add_nxv16i8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uaddwb z0.h, z0.h, z1.b +; CHECK-NEXT: uaddwt z0.h, z0.h, z1.b +; CHECK-NEXT: ret +entry: + %input.wide = zext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16( %acc, %input.wide) + ret %partial.reduce +} + +define @signed_wide_add_nxv4i16( %acc, %input){ +; CHECK-LABEL: signed_wide_add_nxv4i16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: sxth z1.s, p0/m, z1.s +; CHECK-NEXT: uunpklo z2.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %input.wide = sext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32( %acc, %input.wide) + ret %partial.reduce +} + +define @unsigned_wide_add_nxv4i16( %acc, %input){ +; CHECK-LABEL: unsigned_wide_add_nxv4i16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z1.s, z1.s, #0xffff +; CHECK-NEXT: uunpklo z2.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %input.wide = zext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32( %acc, %input.wide) + ret %partial.reduce +} + +define @signed_wide_add_nxv8i32( %acc, %input){ +; CHECK-LABEL: signed_wide_add_nxv8i32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sunpkhi z4.d, z2.s +; CHECK-NEXT: sunpklo z2.d, z2.s +; CHECK-NEXT: sunpkhi z5.d, z3.s +; CHECK-NEXT: sunpklo z3.d, z3.s +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z4.d +; CHECK-NEXT: add z0.d, z3.d, z0.d +; CHECK-NEXT: add z1.d, z5.d, z1.d +; CHECK-NEXT: ret +entry: + %input.wide = sext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64( %acc, %input.wide) + ret %partial.reduce +} + +define @unsigned_wide_add_nxv8i32( %acc, %input){ +; CHECK-LABEL: unsigned_wide_add_nxv8i32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uunpkhi z4.d, z2.s +; CHECK-NEXT: uunpklo z2.d, z2.s +; CHECK-NEXT: uunpkhi z5.d, z3.s +; CHECK-NEXT: uunpklo z3.d, z3.s +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z4.d +; CHECK-NEXT: add z0.d, z3.d, z0.d +; CHECK-NEXT: add z1.d, z5.d, z1.d +; CHECK-NEXT: ret +entry: + %input.wide = zext %input to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64( %acc, %input.wide) + ret %partial.reduce +}