Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 8 additions & 141 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {

// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

// Combine wide-vector muls, with extend inputs, to extmul_half.
Expand Down Expand Up @@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}

// Partial MLA reductions.
for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal);
}
}

// As a special case, these operators use the type to mean the type to
Expand Down Expand Up @@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}

bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
return true;

EVT VT = EVT::getEVT(I->getType());
if (VT.getSizeInBits() > 128)
return true;

auto Op1 = I->getOperand(1);

if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
if (Opcode == ISD::MUL) {
if (isa<Instruction>(InputInst->getOperand(0)) &&
isa<Instruction>(InputInst->getOperand(1))) {
// dot only supports signed inputs but also support lowering unsigned.
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
return true;

EVT Op1VT = EVT::getEVT(Op1->getType());
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
((VT.getVectorElementCount() * 2 ==
Op1VT.getVectorElementCount()) ||
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
return false;
}
} else if (ISD::isExtOpcode(Opcode)) {
return false;
}
}
return true;
}

TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
Expand Down Expand Up @@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}

// Try to lower partial.reduce.add to a dot or fallback to a sequence with
// extmul and adds.
SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
return SDValue();

assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
SDLoc DL(N);

SDValue Input = N->getOperand(2);
if (Input->getOpcode() == ISD::MUL) {
SDValue ExtendLHS = Input->getOperand(0);
SDValue ExtendRHS = Input->getOperand(1);
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
"expected widening mul or add");
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
"expected binop to use the same extend for both operands");

SDValue ExtendInLHS = ExtendLHS->getOperand(0);
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
unsigned LowOpc =
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
: WebAssemblyISD::EXTEND_HIGH_U;
SDValue LowLHS;
SDValue LowRHS;
SDValue HighLHS;
SDValue HighRHS;

auto AssignInputs = [&](MVT VT) {
LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
};

if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
if (IsSigned) {
// i32x4.dot_i16x8_s
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
ExtendInLHS, ExtendInRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
}

// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
MVT VT = MVT::v4i32;
AssignInputs(VT);
SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
} else {
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
AssignInputs(MVT::v8i16);
// Lower to a wider tree, using twice the operations compared to above.
if (IsSigned) {
// Use two dots
SDValue DotLHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue DotRHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);

SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
MVT::v4i32, MulLow);
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
MVT::v4i32, MulHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
} else {
// Accumulate the input using extadd_pairwise.
assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
: WebAssemblyISD::EXT_ADD_PAIRWISE_U;
SDValue ExtendIn = Input->getOperand(0);
if (ExtendIn->getValueType(0) == MVT::v8i16) {
SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}

assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
SDValue Add =
DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}

SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
Expand Down Expand Up @@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN: {
if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
return AnyAllCombine;
return performLowerPartialReduction(N, DCI.DAG);
}
case ISD::INTRINSIC_WO_CHAIN:
return performAnyAllCombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI);
}
Expand Down
5 changes: 1 addition & 4 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
/// right decision when generating code for different targets.
const WebAssemblySubtarget *Subtarget;

bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
bool shouldScalarizeBinop(SDValue VecOp) const override;
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
Expand Down Expand Up @@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
LLVMContext &Context,
const Type *RetTy) const override;
LLVMContext &Context, const Type *RetTy) const override;
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
Expand Down
45 changes: 45 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,51 @@ def : Pat<(v2f64 (extloadv2f32 (i64 I64:$addr))),
defm Q15MULR_SAT_S :
SIMDBinary<I16x8, int_wasm_q15mulr_sat_signed, "q15mulr_sat_s", 0x82>;

//===----------------------------------------------------------------------===//
// Partial reductions, using: dot, extmul and extadd_pairwise
//===----------------------------------------------------------------------===//
// MLA: v8i16 -> v4i32
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
(v8i16 V128:$rhs))),
(ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs),
(v8i16 V128:$rhs))),
(ADD_I32x4 (ADD_I32x4 (EXTMUL_LOW_U_I32x4 $lhs, $rhs),
(EXTMUL_HIGH_U_I32x4 $lhs, $rhs)),
$acc)>;
// MLA: v16i8 -> v4i32
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs),
(v16i8 V128:$rhs))),
(ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs),
(extend_low_s_I16x8 $rhs)),
(DOT (extend_high_s_I16x8 $lhs),
(extend_high_s_I16x8 $rhs))),
$acc)>;
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs),
(v16i8 V128:$rhs))),
(ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)),
(extadd_pairwise_u_I32x4 (EXTMUL_HIGH_U_I16x8 $lhs, $rhs))),
$acc)>;

// Accumulate: v8i16 -> v4i32
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
(I16x8.splat (i32 1)))),
(ADD_I32x4 (extadd_pairwise_s_I32x4 $in), $acc)>;

def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
(I16x8.splat (i32 1)))),
(ADD_I32x4 (extadd_pairwise_u_I32x4 $in), $acc)>;

// Accumulate: v16i8 -> v4i32
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$in),
(I8x16.splat (i32 1)))),
(ADD_I32x4 (extadd_pairwise_s_I32x4 (extadd_pairwise_s_I16x8 $in)),
$acc)>;
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$in),
(I8x16.splat (i32 1)))),
(ADD_I32x4 (extadd_pairwise_u_I32x4 (extadd_pairwise_u_I16x8 $in)),
$acc)>;

//===----------------------------------------------------------------------===//
// Relaxed swizzle
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 7 additions & 8 deletions llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ define hidden i32 @accumulate_add_u8_u8(ptr noundef readonly %a, ptr noundef re
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: i32x4.add

entry:
%cmp8.not = icmp eq i32 %N, 0
Expand Down Expand Up @@ -65,11 +65,11 @@ define hidden i32 @accumulate_add_s8_s8(ptr noundef readonly %a, ptr noundef re
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: i32x4.add
entry:
%cmp8.not = icmp eq i32 %N, 0
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
Expand Down Expand Up @@ -108,12 +108,11 @@ define hidden i32 @accumulate_add_s8_u8(ptr noundef readonly %a, ptr noundef re

; MAX-BANDWIDTH: loop
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
entry:
%cmp8.not = icmp eq i32 %N, 0
Expand Down Expand Up @@ -363,10 +362,10 @@ define hidden i32 @accumulate_add_u16_u16(ptr noundef readonly %a, ptr noundef
; MAX-BANDWIDTH: loop
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: i32x4.add
entry:
%cmp8.not = icmp eq i32 %N, 0
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
Expand Down Expand Up @@ -402,10 +401,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly %a, ptr noundef
; MAX-BANDWIDTH: loop
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: i32x4.add
entry:
%cmp8.not = icmp eq i32 %N, 0
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
Expand Down