From bddd802db7ebdb4a2c4b98c19a50f7740d598d2b Mon Sep 17 00:00:00 2001 From: Sam Parker Date: Mon, 29 Sep 2025 12:54:29 +0100 Subject: [PATCH 1/2] [WebAssembly] Use partial_reduce_mla ISD nodes Move away from combining the intrinsic call and instead lower the ISD nodes, using more tablegen for pattern matching. --- .../WebAssembly/WebAssemblyISelLowering.cpp | 140 ++++++------------ .../WebAssembly/WebAssemblyISelLowering.h | 6 +- .../WebAssembly/WebAssemblyInstrSIMD.td | 9 ++ .../WebAssembly/partial-reduce-accumulate.ll | 2 +- 4 files changed, 61 insertions(+), 96 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 64b9dc31f75b7..e830def066087 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // SIMD-specific configuration if (Subtarget->hasSIMD128()) { - // Combine partial.reduce.add before legalization gets confused. setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); // Combine wide-vector muls, with extend inputs, to extmul_half. @@ -317,6 +316,18 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom); setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom); } + + // Partial MLA reductions. + // We only have native support with i32x4.dot_i16x8_s, the rest are custom + // lowered. + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v8i16, + Legal); + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v8i16, + Custom); + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v16i8, + Custom); + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v16i8, + Custom); } // As a special case, these operators use the type to mean the type to @@ -416,41 +427,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL, return TargetLowering::getPointerMemTy(DL, AS); } -bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic( - const IntrinsicInst *I) const { - if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add) - return true; - - EVT VT = EVT::getEVT(I->getType()); - if (VT.getSizeInBits() > 128) - return true; - - auto Op1 = I->getOperand(1); - - if (auto *InputInst = dyn_cast(Op1)) { - unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode()); - if (Opcode == ISD::MUL) { - if (isa(InputInst->getOperand(0)) && - isa(InputInst->getOperand(1))) { - // dot only supports signed inputs but also support lowering unsigned. - if (cast(InputInst->getOperand(0))->getOpcode() != - cast(InputInst->getOperand(1))->getOpcode()) - return true; - - EVT Op1VT = EVT::getEVT(Op1->getType()); - if (Op1VT.getVectorElementType() == VT.getVectorElementType() && - ((VT.getVectorElementCount() * 2 == - Op1VT.getVectorElementCount()) || - (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount()))) - return false; - } - } else if (ISD::isExtOpcode(Opcode)) { - return false; - } - } - return true; -} - TargetLowering::AtomicExpansionKind WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { // We have wasm instructions for these @@ -1706,6 +1682,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op, return LowerMUL_LOHI(Op, DAG); case ISD::UADDO: return LowerUADDO(Op, DAG); + case ISD::PARTIAL_REDUCE_SMLA: + case ISD::PARTIAL_REDUCE_UMLA: + return LowerPARTIAL_REDUCE_MLA(Op, DAG); } } @@ -2113,29 +2092,36 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op, MachinePointerInfo(SV)); } -// Try to lower partial.reduce.add to a dot or fallback to a sequence with -// extmul and adds. -SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { - assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN); - if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add) - return SDValue(); +// We only have native support with i32x4.dot_i16x8_s, so for the unsigned +// case we can expand to extmul and add. For v16i8 inputs, we can use two dots, +// for signed, for an expanded tree of extmul adds for unsigned. +SDValue +WebAssemblyTargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, + SelectionDAG &DAG) const { + assert(Op->getValueType(0) == MVT::v4i32 && "can only support v4i32"); + SDLoc DL(Op); - assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32"); - SDLoc DL(N); + SDValue Acc = Op.getOperand(0); + SDValue ExtendInLHS = Op.getOperand(1); + SDValue ExtendInRHS = Op.getOperand(2); + bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA; - SDValue Input = N->getOperand(2); - if (Input->getOpcode() == ISD::MUL) { - SDValue ExtendLHS = Input->getOperand(0); - SDValue ExtendRHS = Input->getOperand(1); - assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) && - ISD::isExtOpcode(ExtendRHS.getOpcode())) && - "expected widening mul or add"); - assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() && - "expected binop to use the same extend for both operands"); - - SDValue ExtendInLHS = ExtendLHS->getOperand(0); - SDValue ExtendInRHS = ExtendRHS->getOperand(0); - bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND; + APInt Imm; + if (ISD::isConstantSplatVector(ExtendInRHS.getNode(), Imm) && Imm == 1) { + // Accumulate the input using extadd_pairwise. + unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S + : WebAssemblyISD::EXT_ADD_PAIRWISE_U; + if (ExtendInLHS->getValueType(0) == MVT::v8i16) { + SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendInLHS); + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); + } + assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && + "expected v16i8 input types"); + SDValue Add = + DAG.getNode(PairwiseOpc, DL, MVT::v4i32, + DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendInLHS)); + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); + } else { unsigned LowOpc = IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U; unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S @@ -2151,22 +2137,15 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS); HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS); }; - if (ExtendInLHS->getValueType(0) == MVT::v8i16) { - if (IsSigned) { - // i32x4.dot_i16x8_s - SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, - ExtendInLHS, ExtendInRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot); - } - + assert(!IsSigned && "expected unsigned"); // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs))) MVT VT = MVT::v4i32; AssignInputs(VT); SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS); SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS); SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh); - return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add); + return DAG.getNode(ISD::ADD, DL, VT, Acc, Add); } else { assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && "expected v16i8 input types"); @@ -2179,7 +2158,7 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { SDValue DotRHS = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS); SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); } SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS); @@ -2190,26 +2169,8 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulHigh); SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); + return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); } - } else { - // Accumulate the input using extadd_pairwise. - assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend"); - bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND; - unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S - : WebAssemblyISD::EXT_ADD_PAIRWISE_U; - SDValue ExtendIn = Input->getOperand(0); - if (ExtendIn->getValueType(0) == MVT::v8i16) { - SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - - assert(ExtendIn->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - SDValue Add = - DAG.getNode(PairwiseOpc, DL, MVT::v4i32, - DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn)); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); } } @@ -3683,11 +3644,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, return performVectorTruncZeroCombine(N, DCI); case ISD::TRUNCATE: return performTruncateCombine(N, DCI); - case ISD::INTRINSIC_WO_CHAIN: { - if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG)) - return AnyAllCombine; - return performLowerPartialReduction(N, DCI.DAG); - } + case ISD::INTRINSIC_WO_CHAIN: + return performAnyAllCombine(N, DCI.DAG); case ISD::MUL: return performMulCombine(N, DCI); } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h index 72401a7a259c0..3ff8346e12a6f 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h @@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering { /// right decision when generating code for different targets. const WebAssemblySubtarget *Subtarget; - bool - shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override; AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override; bool shouldScalarizeBinop(SDValue VecOp) const override; FastISel *createFastISel(FunctionLoweringInfo &FuncInfo, @@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering { bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg, const SmallVectorImpl &Outs, - LLVMContext &Context, - const Type *RetTy) const override; + LLVMContext &Context, const Type *RetTy) const override; SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl &Outs, const SmallVectorImpl &OutVals, const SDLoc &dl, @@ -134,6 +131,7 @@ class WebAssemblyTargetLowering final : public TargetLowering { SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const; SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const; SDValue LowerUADDO(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const; // Custom DAG combine hooks SDValue diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index d8948ad2df037..b5724ecd90155 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1159,6 +1159,9 @@ defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins), 186>; def : Pat<(wasm_dot V128:$lhs, V128:$rhs), (DOT $lhs, $rhs)>; +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs), + (v8i16 V128:$rhs))), + (ADD_I32x4 (DOT $lhs, $rhs), $acc)>; // Extending multiplication: extmul_{low,high}_P, extmul_high def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; @@ -1473,6 +1476,12 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))), (extadd_pairwise_s_I32x4 V128:$in)>; def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))), (extadd_pairwise_s_I16x8 V128:$in)>; +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in), + (I16x8.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_s_I32x4 V128:$in), V128:$acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in), + (I16x8.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_u_I32x4 V128:$in), V128:$acc)>; // f64x2 <-> f32x4 conversions def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; diff --git a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll index 47ea762864cc2..c9e486a3f29b4 100644 --- a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll +++ b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll @@ -402,10 +402,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly %a, ptr noundef ; MAX-BANDWIDTH: loop ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s -; MAX-BANDWIDTH: i32x4.add ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s ; MAX-BANDWIDTH: i32x4.add +; MAX-BANDWIDTH: i32x4.add entry: %cmp8.not = icmp eq i32 %N, 0 br i1 %cmp8.not, label %for.cond.cleanup, label %for.body From 519449d35536b57a75144e4b4c45c213e8e2d203 Mon Sep 17 00:00:00 2001 From: Sam Parker Date: Mon, 29 Sep 2025 16:13:16 +0100 Subject: [PATCH 2/2] Now all done in Tablegen --- .../WebAssembly/WebAssemblyISelLowering.cpp | 99 +------------------ .../WebAssembly/WebAssemblyISelLowering.h | 1 - .../WebAssembly/WebAssemblyInstrSIMD.td | 54 ++++++++-- .../WebAssembly/partial-reduce-accumulate.ll | 13 ++- 4 files changed, 55 insertions(+), 112 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index e830def066087..163bf9ba5b089 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -318,16 +318,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( } // Partial MLA reductions. - // We only have native support with i32x4.dot_i16x8_s, the rest are custom - // lowered. - setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v8i16, - Legal); - setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v8i16, - Custom); - setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v16i8, - Custom); - setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v16i8, - Custom); + for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) { + setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal); + setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal); + } } // As a special case, these operators use the type to mean the type to @@ -1682,9 +1676,6 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op, return LowerMUL_LOHI(Op, DAG); case ISD::UADDO: return LowerUADDO(Op, DAG); - case ISD::PARTIAL_REDUCE_SMLA: - case ISD::PARTIAL_REDUCE_UMLA: - return LowerPARTIAL_REDUCE_MLA(Op, DAG); } } @@ -2092,88 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op, MachinePointerInfo(SV)); } -// We only have native support with i32x4.dot_i16x8_s, so for the unsigned -// case we can expand to extmul and add. For v16i8 inputs, we can use two dots, -// for signed, for an expanded tree of extmul adds for unsigned. -SDValue -WebAssemblyTargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, - SelectionDAG &DAG) const { - assert(Op->getValueType(0) == MVT::v4i32 && "can only support v4i32"); - SDLoc DL(Op); - - SDValue Acc = Op.getOperand(0); - SDValue ExtendInLHS = Op.getOperand(1); - SDValue ExtendInRHS = Op.getOperand(2); - bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA; - - APInt Imm; - if (ISD::isConstantSplatVector(ExtendInRHS.getNode(), Imm) && Imm == 1) { - // Accumulate the input using extadd_pairwise. - unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S - : WebAssemblyISD::EXT_ADD_PAIRWISE_U; - if (ExtendInLHS->getValueType(0) == MVT::v8i16) { - SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendInLHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); - } - assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - SDValue Add = - DAG.getNode(PairwiseOpc, DL, MVT::v4i32, - DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendInLHS)); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); - } else { - unsigned LowOpc = - IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U; - unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S - : WebAssemblyISD::EXTEND_HIGH_U; - SDValue LowLHS; - SDValue LowRHS; - SDValue HighLHS; - SDValue HighRHS; - - auto AssignInputs = [&](MVT VT) { - LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS); - LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS); - HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS); - HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS); - }; - if (ExtendInLHS->getValueType(0) == MVT::v8i16) { - assert(!IsSigned && "expected unsigned"); - // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs))) - MVT VT = MVT::v4i32; - AssignInputs(VT); - SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS); - SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh); - return DAG.getNode(ISD::ADD, DL, VT, Acc, Add); - } else { - assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - AssignInputs(MVT::v8i16); - // Lower to a wider tree, using twice the operations compared to above. - if (IsSigned) { - // Use two dots - SDValue DotLHS = - DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS); - SDValue DotRHS = - DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); - } - - SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS); - SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS); - - SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, - MVT::v4i32, MulLow); - SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, - MVT::v4i32, MulHigh); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add); - } - } -} - SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h index 3ff8346e12a6f..b33a8530310be 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h @@ -131,7 +131,6 @@ class WebAssemblyTargetLowering final : public TargetLowering { SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const; SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const; SDValue LowerUADDO(SDValue Op, SelectionDAG &DAG) const; - SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const; // Custom DAG combine hooks SDValue diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index b5724ecd90155..130602650d34e 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1159,9 +1159,6 @@ defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins), 186>; def : Pat<(wasm_dot V128:$lhs, V128:$rhs), (DOT $lhs, $rhs)>; -def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs), - (v8i16 V128:$rhs))), - (ADD_I32x4 (DOT $lhs, $rhs), $acc)>; // Extending multiplication: extmul_{low,high}_P, extmul_high def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; @@ -1476,12 +1473,6 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))), (extadd_pairwise_s_I32x4 V128:$in)>; def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))), (extadd_pairwise_s_I16x8 V128:$in)>; -def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in), - (I16x8.splat (i32 1)))), - (ADD_I32x4 (extadd_pairwise_s_I32x4 V128:$in), V128:$acc)>; -def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in), - (I16x8.splat (i32 1)))), - (ADD_I32x4 (extadd_pairwise_u_I32x4 V128:$in), V128:$acc)>; // f64x2 <-> f32x4 conversions def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; @@ -1513,6 +1504,51 @@ def : Pat<(v2f64 (extloadv2f32 (i64 I64:$addr))), defm Q15MULR_SAT_S : SIMDBinary; +//===----------------------------------------------------------------------===// +// Partial reductions, using: dot, extmul and extadd_pairwise +//===----------------------------------------------------------------------===// +// MLA: v8i16 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs), + (v8i16 V128:$rhs))), + (ADD_I32x4 (DOT $lhs, $rhs), $acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs), + (v8i16 V128:$rhs))), + (ADD_I32x4 (ADD_I32x4 (EXTMUL_LOW_U_I32x4 $lhs, $rhs), + (EXTMUL_HIGH_U_I32x4 $lhs, $rhs)), + $acc)>; +// MLA: v16i8 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs), + (v16i8 V128:$rhs))), + (ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs), + (extend_low_s_I16x8 $rhs)), + (DOT (extend_high_s_I16x8 $lhs), + (extend_high_s_I16x8 $rhs))), + $acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs), + (v16i8 V128:$rhs))), + (ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)), + (extadd_pairwise_u_I32x4 (EXTMUL_HIGH_U_I16x8 $lhs, $rhs))), + $acc)>; + +// Accumulate: v8i16 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in), + (I16x8.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_s_I32x4 $in), $acc)>; + +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in), + (I16x8.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_u_I32x4 $in), $acc)>; + +// Accumulate: v16i8 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$in), + (I8x16.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_s_I32x4 (extadd_pairwise_s_I16x8 $in)), + $acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$in), + (I8x16.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_u_I32x4 (extadd_pairwise_u_I16x8 $in)), + $acc)>; + //===----------------------------------------------------------------------===// // Relaxed swizzle //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll index c9e486a3f29b4..a599f4653f323 100644 --- a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll +++ b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll @@ -19,11 +19,11 @@ define hidden i32 @accumulate_add_u8_u8(ptr noundef readonly %a, ptr noundef re ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u -; MAX-BANDWIDTH: i32x4.add ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u ; MAX-BANDWIDTH: i32x4.add +; MAX-BANDWIDTH: i32x4.add entry: %cmp8.not = icmp eq i32 %N, 0 @@ -65,11 +65,11 @@ define hidden i32 @accumulate_add_s8_s8(ptr noundef readonly %a, ptr noundef re ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s -; MAX-BANDWIDTH: i32x4.add ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s ; MAX-BANDWIDTH: i32x4.add +; MAX-BANDWIDTH: i32x4.add entry: %cmp8.not = icmp eq i32 %N, 0 br i1 %cmp8.not, label %for.cond.cleanup, label %for.body @@ -108,12 +108,11 @@ define hidden i32 @accumulate_add_s8_u8(ptr noundef readonly %a, ptr noundef re ; MAX-BANDWIDTH: loop ; MAX-BANDWIDTH: v128.load -; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s -; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s -; MAX-BANDWIDTH: i32x4.add -; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u +; MAX-BANDWIDTH: v128.load +; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s +; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s ; MAX-BANDWIDTH: i32x4.add entry: %cmp8.not = icmp eq i32 %N, 0 @@ -363,10 +362,10 @@ define hidden i32 @accumulate_add_u16_u16(ptr noundef readonly %a, ptr noundef ; MAX-BANDWIDTH: loop ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u -; MAX-BANDWIDTH: i32x4.add ; MAX-BANDWIDTH: v128.load ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u ; MAX-BANDWIDTH: i32x4.add +; MAX-BANDWIDTH: i32x4.add entry: %cmp8.not = icmp eq i32 %N, 0 br i1 %cmp8.not, label %for.cond.cleanup, label %for.body