Skip to content

Commit

Permalink
Use pmaddubsw for non-RDom horizontal widening adds (halide#7440)
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams authored and ardier committed Mar 3, 2024
1 parent f7719f4 commit 08a4226
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 18 deletions.
20 changes: 13 additions & 7 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ const x86Intrinsic intrinsic_defs[] = {
{"packuswbx16", UInt(8, 16), "saturating_narrow", {Int(16, 16)}},

// Widening multiplies that use (v)pmaddwd
{"wmul_pmaddwd_avx512", Int(32, 16), "widening_mul", {Int(16, 16), Int(16, 16)}, Target::AVX512_Skylake},
{"wmul_pmaddwd_avx2", Int(32, 8), "widening_mul", {Int(16, 8), Int(16, 8)}, Target::AVX2},
{"wmul_pmaddwd_sse2", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}},

Expand All @@ -221,15 +222,21 @@ const x86Intrinsic intrinsic_defs[] = {
{"llvm.x86.ssse3.pmadd.ub.sw.128", Int(16, 8), "saturating_dot_product", {UInt(8, 16), Int(8, 16)}, Target::SSE41},

// Horizontal widening adds using 2-way dot products.
{"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41},
{"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41},
{"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41},
{"hadd_pmadd_u8_avx512", UInt(16, 32), "horizontal_widening_add", {UInt(8, 64)}, Target::AVX512_Skylake},
{"hadd_pmadd_u8_avx512", Int(16, 32), "horizontal_widening_add", {UInt(8, 64)}, Target::AVX512_Skylake},
{"hadd_pmadd_i8_avx512", Int(16, 32), "horizontal_widening_add", {Int(8, 64)}, Target::AVX512_Skylake},
{"hadd_pmadd_u8_avx2", UInt(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2},
{"hadd_pmadd_u8_avx2", Int(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2},
{"hadd_pmadd_i8_avx2", Int(16, 16), "horizontal_widening_add", {Int(8, 32)}, Target::AVX2},
{"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41},
{"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41},
{"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41},

{"hadd_pmadd_i16_avx512", Int(32, 16), "horizontal_widening_add", {Int(16, 32)}, Target::AVX512_Skylake},
{"hadd_pmadd_i16_avx2", Int(32, 8), "horizontal_widening_add", {Int(16, 16)}, Target::AVX2},
{"hadd_pmadd_i16_sse2", Int(32, 4), "horizontal_widening_add", {Int(16, 8)}},

{"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Cannonlake},
{"llvm.x86.avx2.pmadd.wd", Int(32, 8), "dot_product", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.x86.sse2.pmadd.wd", Int(32, 4), "dot_product", {Int(16, 8), Int(16, 8)}},

Expand Down Expand Up @@ -718,12 +725,11 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init

{VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit},

// One could do a horizontal widening addition with
// other dot_products against a vector of ones. Currently disabled
// because I haven't found other cases where it's clearly better.
// Horizontal widening addition using a dot_product against a vector of ones.
{VectorReduce::Add, 2, u16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg},
{VectorReduce::Add, 2, i16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg},
{VectorReduce::Add, 2, i16(wild_i8x_), "horizontal_widening_add", {}, Pattern::SingleArg},
{VectorReduce::Add, 2, i32(wild_i16x_), "horizontal_widening_add", {}, Pattern::SingleArg},

// Sum of absolute differences
{VectorReduce::Add, 8, u64(absd(wild_u8x_, wild_u8x_)), "sum_of_absolute_differences", {}},
Expand Down
8 changes: 8 additions & 0 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ class FindIntrinsics : public IRMutator {
}

const int bits = op->type.bits();
const int lanes = op->type.lanes();
const auto is_x_same_int = op->type.is_int() && is_int(x, bits);
const auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
const auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
Expand Down Expand Up @@ -781,6 +782,13 @@ class FindIntrinsics : public IRMutator {
rewrite(saturating_cast(op->type, rounding_shift_right(widening_mul(x, y), z)),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) ||

// Rewrite combinations of deinterleaves into horizontal ops
rewrite(widening_add(slice(x, 0, 2, lanes), slice(x, 1, 2, lanes)),
h_add(cast(op->type.with_lanes(lanes * 2), x), lanes)) ||
rewrite(widening_add(slice(x, 1, 2, lanes), slice(x, 0, 2, lanes)),
h_add(cast(op->type.with_lanes(lanes * 2), x), lanes)) ||

// We can remove unnecessary widening if we are then performing a saturating narrow.
// This is similar to the logic inside `visit_min_or_max`.
(((bits <= 32) &&
Expand Down
16 changes: 5 additions & 11 deletions src/runtime/x86.ll
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,6 @@ define weak_odr <8 x i16> @packssdwx8(<8 x i32> %arg) nounwind alwaysinline {
ret <8 x i16> %3
}

define weak_odr <8 x i32> @wmul_pmaddwd_avx2(<8 x i16> %a, <8 x i16> %b) nounwind alwaysinline {
%1 = zext <8 x i16> %a to <8 x i32>
%2 = zext <8 x i16> %b to <8 x i32>
%3 = bitcast <8 x i32> %1 to <16 x i16>
%4 = bitcast <8 x i32> %2 to <16 x i16>
%res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %3, <16 x i16> %4)
ret <8 x i32> %res
}

declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone

define weak_odr <4 x i32> @wmul_pmaddwd_sse2(<4 x i16> %a, <4 x i16> %b) nounwind alwaysinline {
%1 = zext <4 x i16> %a to <4 x i32>
%2 = zext <4 x i16> %b to <4 x i32>
Expand All @@ -73,6 +62,11 @@ define weak_odr <4 x i32> @wmul_pmaddwd_sse2(<4 x i16> %a, <4 x i16> %b) nounwin

declare <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>) nounwind readnone

define weak_odr <4 x i32> @hadd_pmadd_i16_sse2(<8 x i16> %a) nounwind alwaysinline {
%res = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>)
ret <4 x i32> %res
}

define weak_odr <4 x float> @sqrt_f32x4(<4 x float> %x) nounwind uwtable readnone alwaysinline {
%1 = tail call <4 x float> @llvm.x86.sse.sqrt.ps(<4 x float> %x) nounwind
ret <4 x float> %1
Expand Down
17 changes: 17 additions & 0 deletions src/runtime/x86_avx2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,20 @@ define weak_odr <16 x i16> @hadd_pmadd_i8_avx2(<32 x i8> %a) nounwind alwaysinli
ret <16 x i16> %1
}
declare <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>) nounwind readnone

define weak_odr <8 x i32> @wmul_pmaddwd_avx2(<8 x i16> %a, <8 x i16> %b) nounwind alwaysinline {
%1 = zext <8 x i16> %a to <8 x i32>
%2 = zext <8 x i16> %b to <8 x i32>
%3 = bitcast <8 x i32> %1 to <16 x i16>
%4 = bitcast <8 x i32> %2 to <16 x i16>
%res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %3, <16 x i16> %4)
ret <8 x i32> %res
}

define weak_odr <8 x i32> @hadd_pmadd_i16_avx2(<16 x i16> %a) nounwind alwaysinline {
%res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a, <16 x i16> <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>)
ret <8 x i32> %res
}

declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone

26 changes: 26 additions & 0 deletions src/runtime/x86_avx512.ll
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,29 @@ define weak_odr <16 x i32> @abs_i32x16(<16 x i32> %arg) {
ret <16 x i32> %1
}
declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1) nounwind readnone

define weak_odr <32 x i16> @hadd_pmadd_u8_avx512(<64 x i8> %a) nounwind alwaysinline {
%1 = tail call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> %a, <64 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
ret <32 x i16> %1
}

define weak_odr <32 x i16> @hadd_pmadd_i8_avx512(<64 x i8> %a) nounwind alwaysinline {
%1 = tail call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <64 x i8> %a)
ret <32 x i16> %1
}
declare <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8>, <64 x i8>) nounwind readnone

define weak_odr <16 x i32> @wmul_pmaddwd_avx512(<16 x i16> %a, <16 x i16> %b) nounwind alwaysinline {
%1 = zext <16 x i16> %a to <16 x i32>
%2 = zext <16 x i16> %b to <16 x i32>
%3 = bitcast <16 x i32> %1 to <32 x i16>
%4 = bitcast <16 x i32> %2 to <32 x i16>
%res = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %3, <32 x i16> %4)
ret <16 x i32> %res
}

define weak_odr <16 x i32> @hadd_pmadd_i16_avx512(<32 x i16> %a) nounwind alwaysinline {
%res = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a, <32 x i16> <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>)
ret <16 x i32> %res
}
declare <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16>, <32 x i16>) nounwind readnone
6 changes: 6 additions & 0 deletions test/correctness/simd_op_check_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ class SimdOpCheckX86 : public SimdOpCheckTest {
check(check_pmaddubsw, 4 * w, sum(u16(in_u8(2 * x + r2))));
check(check_pmaddubsw, 4 * w, sum(i16(in_u8(2 * x + r2))));
check(check_pmaddubsw, 4 * w, sum(i16(in_i8(2 * x + r2))));

check(check_pmaddubsw, 4 * w, u16(in_u8(2 * x)) + in_u8(2 * x + 1));
check(check_pmaddubsw, 4 * w, i16(in_u8(2 * x)) + in_u8(2 * x + 1));
check(check_pmaddubsw, 4 * w, i16(in_i8(2 * x)) + in_i8(2 * x + 1));
}
}

Expand All @@ -328,6 +332,8 @@ class SimdOpCheckX86 : public SimdOpCheckTest {
RDom r4(0, 4);
check(check_pmaddwd, 2 * w, sum(i32(in_i16(x * 4 + r4)) * in_i16(x * 4 + r4 + 32)));

check(check_pmaddwd, 2 * w, i32(in_i16(x * 2)) + in_i16(x * 2 + 1));

// Also generate for widening_mul
check(check_pmaddwd, 2 * w, i32(i16_1) * i32(i16_2));
}
Expand Down

0 comments on commit 08a4226

Please sign in to comment.