Skip to content

Commit

Permalink
Use pmaddubsw 8-bit horizontal widening adds (Fixes #6859) (#6873)
Browse files Browse the repository at this point in the history
* use pmaddubsw 8-bit horizontal widening adds

* add SSE3 versions too

* add pmaddubsw tests
  • Loading branch information
rootjalex authored Jul 21, 2022
1 parent 967c3bf commit 9a94756
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 25 deletions.
91 changes: 66 additions & 25 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ const x86Intrinsic intrinsic_defs[] = {
{"llvm.x86.avx2.pmadd.ub.sw", Int(16, 16), "saturating_dot_product", {UInt(8, 32), Int(8, 32)}, Target::AVX2},
{"llvm.x86.ssse3.pmadd.ub.sw.128", Int(16, 8), "saturating_dot_product", {UInt(8, 16), Int(8, 16)}, Target::SSE41},

// Horizontal widening adds using 2-way dot products.
{"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41},
{"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41},
{"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41},
{"hadd_pmadd_u8_avx2", UInt(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2},
{"hadd_pmadd_u8_avx2", Int(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2},
{"hadd_pmadd_i8_avx2", Int(16, 16), "horizontal_widening_add", {Int(8, 32)}, Target::AVX2},

{"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Cannonlake},
{"llvm.x86.avx2.pmadd.wd", Int(32, 8), "dot_product", {Int(16, 16), Int(16, 16)}, Target::AVX2},
Expand Down Expand Up @@ -595,6 +603,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
enum {
CombineInit = 1 << 0,
SwapOperands = 1 << 1,
SingleArg = 1 << 2,
};
};
// clang-format off
Expand Down Expand Up @@ -624,8 +633,12 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
{VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit},

// One could do a horizontal widening addition with
// dot_product against a vector of ones. Currently disabled
// because I haven't found case where it's clearly better.
// other dot_products against a vector of ones. Currently disabled
// because I haven't found other cases where it's clearly better.

{VectorReduce::Add, 2, u16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg},
{VectorReduce::Add, 2, i16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg},
{VectorReduce::Add, 2, i16(wild_i8x_), "horizontal_widening_add", {}, Pattern::SingleArg},
};
// clang-format on

Expand All @@ -635,33 +648,61 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
continue;
}
if (expr_match(p.pattern, op->value, matches)) {
Expr a = matches[0];
Expr b = matches[1];
if (p.flags & Pattern::SwapOperands) {
std::swap(a, b);
}
if (p.narrow_type.bits() > 0) {
a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a);
b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b);
}
if (!a.defined() || !b.defined()) {
continue;
}
if (p.flags & Pattern::SingleArg) {
Expr a = matches[0];

if (init.defined() && (p.flags & Pattern::CombineInit)) {
value = call_overloaded_intrin(op->type, p.intrin, {init, a, b});
if (value) {
return;
if (p.narrow_type.bits() > 0) {
a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a);
}
if (!a.defined()) {
continue;
}

if (init.defined() && (p.flags & Pattern::CombineInit)) {
value = call_overloaded_intrin(op->type, p.intrin, {init, a});
if (value) {
return;
}
} else {
value = call_overloaded_intrin(op->type, p.intrin, {a});
if (value) {
if (init.defined()) {
Value *x = value;
Value *y = codegen(init);
value = builder->CreateAdd(x, y);
}
return;
}
}
} else {
value = call_overloaded_intrin(op->type, p.intrin, {a, b});
if (value) {
if (init.defined()) {
Value *x = value;
Value *y = codegen(init);
value = builder->CreateAdd(x, y);
Expr a = matches[0];
Expr b = matches[1];
if (p.flags & Pattern::SwapOperands) {
std::swap(a, b);
}
if (p.narrow_type.bits() > 0) {
a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a);
b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b);
}
if (!a.defined() || !b.defined()) {
continue;
}

if (init.defined() && (p.flags & Pattern::CombineInit)) {
value = call_overloaded_intrin(op->type, p.intrin, {init, a, b});
if (value) {
return;
}
} else {
value = call_overloaded_intrin(op->type, p.intrin, {a, b});
if (value) {
if (init.defined()) {
Value *x = value;
Value *y = codegen(init);
value = builder->CreateAdd(x, y);
}
return;
}
return;
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/runtime/x86_avx2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,14 @@ define weak_odr <16 x i16> @saturating_pmulhrswx16(<16 x i16> %a, <16 x i16> %b)
ret <16 x i16> %5
}
declare <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16>, <16 x i16>) nounwind readnone

define weak_odr <16 x i16> @hadd_pmadd_u8_avx2(<32 x i8> %a) nounwind alwaysinline {
%1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a, <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
ret <16 x i16> %1
}

define weak_odr <16 x i16> @hadd_pmadd_i8_avx2(<32 x i8> %a) nounwind alwaysinline {
%1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <32 x i8> %a)
ret <16 x i16> %1
}
declare <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>) nounwind readnone
11 changes: 11 additions & 0 deletions src/runtime/x86_sse41.ll
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,14 @@ define weak_odr <8 x i16> @saturating_pmulhrswx8(<8 x i16> %a, <8 x i16> %b) nou
ret <8 x i16> %5
}
declare <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16>, <8 x i16>) nounwind readnone

define weak_odr <8 x i16> @hadd_pmadd_u8_sse3(<16 x i8> %a) nounwind alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a, <16 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
ret <8 x i16> %1
}

define weak_odr <8 x i16> @hadd_pmadd_i8_sse3(<16 x i8> %a) nounwind alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <16 x i8> %a)
ret <8 x i16> %1
}
declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind readnone
5 changes: 5 additions & 0 deletions test/correctness/simd_op_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ class SimdOpCheck : public SimdOpCheckTest {
RDom r2(0, 2);
check(check_pmaddubsw, 4 * w, saturating_sum(i16(in_u8(2 * x + r2)) * in_i8(2 * x + r2 + 32)));
check(check_pmaddubsw, 4 * w, saturating_sum(i16(in_i8(2 * x + r2)) * in_u8(2 * x + r2 + 32)));

// uint8 -> uint16 or int16 and int8 -> int16 horizontal widening adds should use pmaddubsw.
check(check_pmaddubsw, 4 * w, sum(u16(in_u8(2 * x + r2))));
check(check_pmaddubsw, 4 * w, sum(i16(in_u8(2 * x + r2))));
check(check_pmaddubsw, 4 * w, sum(i16(in_i8(2 * x + r2))));
}
}

Expand Down

0 comments on commit 9a94756

Please sign in to comment.