Skip to content

Conversation

sparker-arm
Copy link
Contributor

Addresssing issue #160847.

Move away from combining the intrinsic call and instead lower the ISD nodes, using more tablegen for pattern matching.

Move away from combining the intrinsic call and instead lower the ISD
nodes, using more tablegen for pattern matching.
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2025

@llvm/pr-subscribers-backend-webassembly

Author: Sam Parker (sparker-arm)

Changes

Addresssing issue #160847.

Move away from combining the intrinsic call and instead lower the ISD nodes, using more tablegen for pattern matching.


Full diff: https://github.com/llvm/llvm-project/pull/161184.diff

4 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+49-91)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h (+2-4)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td (+9)
  • (modified) llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll (+1-1)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 64b9dc31f75b7..e830def066087 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
 
-    // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
     // Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,18 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
       setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
       setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
     }
+
+    // Partial MLA reductions.
+    // We only have native support with i32x4.dot_i16x8_s, the rest are custom
+    // lowered.
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v8i16,
+                              Legal);
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v8i16,
+                              Custom);
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v16i8,
+                              Custom);
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v16i8,
+                              Custom);
   }
 
   // As a special case, these operators use the type to mean the type to
@@ -416,41 +427,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
   return TargetLowering::getPointerMemTy(DL, AS);
 }
 
-bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
-    const IntrinsicInst *I) const {
-  if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
-    return true;
-
-  EVT VT = EVT::getEVT(I->getType());
-  if (VT.getSizeInBits() > 128)
-    return true;
-
-  auto Op1 = I->getOperand(1);
-
-  if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
-    unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
-    if (Opcode == ISD::MUL) {
-      if (isa<Instruction>(InputInst->getOperand(0)) &&
-          isa<Instruction>(InputInst->getOperand(1))) {
-        // dot only supports signed inputs but also support lowering unsigned.
-        if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
-            cast<Instruction>(InputInst->getOperand(1))->getOpcode())
-          return true;
-
-        EVT Op1VT = EVT::getEVT(Op1->getType());
-        if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
-            ((VT.getVectorElementCount() * 2 ==
-              Op1VT.getVectorElementCount()) ||
-             (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
-          return false;
-      }
-    } else if (ISD::isExtOpcode(Opcode)) {
-      return false;
-    }
-  }
-  return true;
-}
-
 TargetLowering::AtomicExpansionKind
 WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
   // We have wasm instructions for these
@@ -1706,6 +1682,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
     return LowerMUL_LOHI(Op, DAG);
   case ISD::UADDO:
     return LowerUADDO(Op, DAG);
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
+    return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
 
@@ -2113,29 +2092,36 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
                       MachinePointerInfo(SV));
 }
 
-// Try to lower partial.reduce.add to a dot or fallback to a sequence with
-// extmul and adds.
-SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
-  if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
-    return SDValue();
+// We only have native support with i32x4.dot_i16x8_s, so for the unsigned
+// case we can expand to extmul and add. For v16i8 inputs, we can use two dots,
+// for signed, for an expanded tree of extmul adds for unsigned.
+SDValue
+WebAssemblyTargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  assert(Op->getValueType(0) == MVT::v4i32 && "can only support v4i32");
+  SDLoc DL(Op);
 
-  assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
-  SDLoc DL(N);
+  SDValue Acc = Op.getOperand(0);
+  SDValue ExtendInLHS = Op.getOperand(1);
+  SDValue ExtendInRHS = Op.getOperand(2);
+  bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
 
-  SDValue Input = N->getOperand(2);
-  if (Input->getOpcode() == ISD::MUL) {
-    SDValue ExtendLHS = Input->getOperand(0);
-    SDValue ExtendRHS = Input->getOperand(1);
-    assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
-            ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
-           "expected widening mul or add");
-    assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
-           "expected binop to use the same extend for both operands");
-
-    SDValue ExtendInLHS = ExtendLHS->getOperand(0);
-    SDValue ExtendInRHS = ExtendRHS->getOperand(0);
-    bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+  APInt Imm;
+  if (ISD::isConstantSplatVector(ExtendInRHS.getNode(), Imm) && Imm == 1) {
+    // Accumulate the input using extadd_pairwise.
+    unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
+                                    : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
+    if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+      SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendInLHS);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
+    }
+    assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+           "expected v16i8 input types");
+    SDValue Add =
+        DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
+                    DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendInLHS));
+    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
+  } else {
     unsigned LowOpc =
         IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
     unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
@@ -2151,22 +2137,15 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
       HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
       HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
     };
-
     if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
-      if (IsSigned) {
-        // i32x4.dot_i16x8_s
-        SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
-                                  ExtendInLHS, ExtendInRHS);
-        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
-      }
-
+      assert(!IsSigned && "expected unsigned");
       // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
       MVT VT = MVT::v4i32;
       AssignInputs(VT);
       SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
       SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
-      return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
+      return DAG.getNode(ISD::ADD, DL, VT, Acc, Add);
     } else {
       assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
              "expected v16i8 input types");
@@ -2179,7 +2158,7 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
         SDValue DotRHS =
             DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
         SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
-        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
       }
 
       SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
@@ -2190,26 +2169,8 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
       SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
                                     MVT::v4i32, MulHigh);
       SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
     }
-  } else {
-    // Accumulate the input using extadd_pairwise.
-    assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
-    bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
-    unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
-                                    : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
-    SDValue ExtendIn = Input->getOperand(0);
-    if (ExtendIn->getValueType(0) == MVT::v8i16) {
-      SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
-    }
-
-    assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
-           "expected v16i8 input types");
-    SDValue Add =
-        DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
-                    DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
-    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
   }
 }
 
@@ -3683,11 +3644,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
     return performVectorTruncZeroCombine(N, DCI);
   case ISD::TRUNCATE:
     return performTruncateCombine(N, DCI);
-  case ISD::INTRINSIC_WO_CHAIN: {
-    if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
-      return AnyAllCombine;
-    return performLowerPartialReduction(N, DCI.DAG);
-  }
+  case ISD::INTRINSIC_WO_CHAIN:
+    return performAnyAllCombine(N, DCI.DAG);
   case ISD::MUL:
     return performMulCombine(N, DCI);
   }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 72401a7a259c0..3ff8346e12a6f 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   /// right decision when generating code for different targets.
   const WebAssemblySubtarget *Subtarget;
 
-  bool
-  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
   AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
   bool shouldScalarizeBinop(SDValue VecOp) const override;
   FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
@@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
                       bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
-                      LLVMContext &Context,
-                      const Type *RetTy) const override;
+                      LLVMContext &Context, const Type *RetTy) const override;
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
                       const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
@@ -134,6 +131,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
   SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const;
   SDValue LowerUADDO(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
 
   // Custom DAG combine hooks
   SDValue
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index d8948ad2df037..b5724ecd90155 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1159,6 +1159,9 @@ defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
                   186>;
 def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
           (DOT $lhs, $rhs)>;
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+                                                         (v8i16 V128:$rhs))),
+          (ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
 
 // Extending multiplication: extmul_{low,high}_P, extmul_high
 def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
@@ -1473,6 +1476,12 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
           (extadd_pairwise_s_I32x4 V128:$in)>;
 def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
           (extadd_pairwise_s_I16x8 V128:$in)>;
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
+                                                         (I16x8.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_s_I32x4 V128:$in), V128:$acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
+                                                         (I16x8.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_u_I32x4 V128:$in), V128:$acc)>;
 
 // f64x2 <-> f32x4 conversions
 def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
index 47ea762864cc2..c9e486a3f29b4 100644
--- a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
+++ b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
@@ -402,10 +402,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly  %a, ptr noundef
 ; MAX-BANDWIDTH: loop
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
-; MAX-BANDWIDTH: i32x4.add
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
 ; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
 entry:
   %cmp8.not = icmp eq i32 %N, 0
   br i1 %cmp8.not, label %for.cond.cleanup, label %for.body

@sparker-arm sparker-arm requested a review from dschuff September 29, 2025 12:03
@sparker-arm sparker-arm marked this pull request as draft September 29, 2025 14:26
@sparker-arm
Copy link
Contributor Author

I'm going to try more tablegen patterns.

@sparker-arm sparker-arm marked this pull request as ready for review September 29, 2025 15:15
Copy link
Member

@dschuff dschuff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@sparker-arm sparker-arm merged commit 156e9b4 into llvm:main Sep 30, 2025
11 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Addresssing issue llvm#160847.
 
Move away from combining the intrinsic call and instead lower the ISD
nodes, using tablegen for pattern matching.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants