diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 064fb933572e..45d4e224e277 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -195,6 +195,7 @@ const x86Intrinsic intrinsic_defs[] = { {"packuswbx16", UInt(8, 16), "saturating_narrow", {Int(16, 16)}}, // Widening multiplies that use (v)pmaddwd + {"wmul_pmaddwd_avx512", Int(32, 16), "widening_mul", {Int(16, 16), Int(16, 16)}, Target::AVX512_Skylake}, {"wmul_pmaddwd_avx2", Int(32, 8), "widening_mul", {Int(16, 8), Int(16, 8)}, Target::AVX2}, {"wmul_pmaddwd_sse2", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}}, @@ -221,15 +222,21 @@ const x86Intrinsic intrinsic_defs[] = { {"llvm.x86.ssse3.pmadd.ub.sw.128", Int(16, 8), "saturating_dot_product", {UInt(8, 16), Int(8, 16)}, Target::SSE41}, // Horizontal widening adds using 2-way dot products. - {"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, - {"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, - {"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41}, + {"hadd_pmadd_u8_avx512", UInt(16, 32), "horizontal_widening_add", {UInt(8, 64)}, Target::AVX512_Skylake}, + {"hadd_pmadd_u8_avx512", Int(16, 32), "horizontal_widening_add", {UInt(8, 64)}, Target::AVX512_Skylake}, + {"hadd_pmadd_i8_avx512", Int(16, 32), "horizontal_widening_add", {Int(8, 64)}, Target::AVX512_Skylake}, {"hadd_pmadd_u8_avx2", UInt(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2}, {"hadd_pmadd_u8_avx2", Int(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2}, {"hadd_pmadd_i8_avx2", Int(16, 16), "horizontal_widening_add", {Int(8, 32)}, Target::AVX2}, + {"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, + {"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, + {"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41}, + + {"hadd_pmadd_i16_avx512", Int(32, 16), "horizontal_widening_add", {Int(16, 32)}, Target::AVX512_Skylake}, + {"hadd_pmadd_i16_avx2", Int(32, 8), "horizontal_widening_add", {Int(16, 16)}, Target::AVX2}, + {"hadd_pmadd_i16_sse2", Int(32, 4), "horizontal_widening_add", {Int(16, 8)}}, {"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake}, - {"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Cannonlake}, {"llvm.x86.avx2.pmadd.wd", Int(32, 8), "dot_product", {Int(16, 16), Int(16, 16)}, Target::AVX2}, {"llvm.x86.sse2.pmadd.wd", Int(32, 4), "dot_product", {Int(16, 8), Int(16, 8)}}, @@ -718,12 +725,11 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init {VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, - // One could do a horizontal widening addition with - // other dot_products against a vector of ones. Currently disabled - // because I haven't found other cases where it's clearly better. + // Horizontal widening addition using a dot_product against a vector of ones. {VectorReduce::Add, 2, u16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, {VectorReduce::Add, 2, i16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, {VectorReduce::Add, 2, i16(wild_i8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, + {VectorReduce::Add, 2, i32(wild_i16x_), "horizontal_widening_add", {}, Pattern::SingleArg}, // Sum of absolute differences {VectorReduce::Add, 8, u64(absd(wild_u8x_, wild_u8x_)), "sum_of_absolute_differences", {}}, diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 9700d06e319b..1d375835a12b 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -727,6 +727,7 @@ class FindIntrinsics : public IRMutator { } const int bits = op->type.bits(); + const int lanes = op->type.lanes(); const auto is_x_same_int = op->type.is_int() && is_int(x, bits); const auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); const auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; @@ -781,6 +782,13 @@ class FindIntrinsics : public IRMutator { rewrite(saturating_cast(op->type, rounding_shift_right(widening_mul(x, y), z)), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || + + // Rewrite combinations of deinterleaves into horizontal ops + rewrite(widening_add(slice(x, 0, 2, lanes), slice(x, 1, 2, lanes)), + h_add(cast(op->type.with_lanes(lanes * 2), x), lanes)) || + rewrite(widening_add(slice(x, 1, 2, lanes), slice(x, 0, 2, lanes)), + h_add(cast(op->type.with_lanes(lanes * 2), x), lanes)) || + // We can remove unnecessary widening if we are then performing a saturating narrow. // This is similar to the logic inside `visit_min_or_max`. (((bits <= 32) && diff --git a/src/runtime/x86.ll b/src/runtime/x86.ll index 5e6b5613e9f6..78c96bf5fa5c 100644 --- a/src/runtime/x86.ll +++ b/src/runtime/x86.ll @@ -51,17 +51,6 @@ define weak_odr <8 x i16> @packssdwx8(<8 x i32> %arg) nounwind alwaysinline { ret <8 x i16> %3 } -define weak_odr <8 x i32> @wmul_pmaddwd_avx2(<8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { - %1 = zext <8 x i16> %a to <8 x i32> - %2 = zext <8 x i16> %b to <8 x i32> - %3 = bitcast <8 x i32> %1 to <16 x i16> - %4 = bitcast <8 x i32> %2 to <16 x i16> - %res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %3, <16 x i16> %4) - ret <8 x i32> %res -} - -declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone - define weak_odr <4 x i32> @wmul_pmaddwd_sse2(<4 x i16> %a, <4 x i16> %b) nounwind alwaysinline { %1 = zext <4 x i16> %a to <4 x i32> %2 = zext <4 x i16> %b to <4 x i32> @@ -73,6 +62,11 @@ define weak_odr <4 x i32> @wmul_pmaddwd_sse2(<4 x i16> %a, <4 x i16> %b) nounwin declare <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>) nounwind readnone +define weak_odr <4 x i32> @hadd_pmadd_i16_sse2(<8 x i16> %a) nounwind alwaysinline { + %res = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> ) + ret <4 x i32> %res +} + define weak_odr <4 x float> @sqrt_f32x4(<4 x float> %x) nounwind uwtable readnone alwaysinline { %1 = tail call <4 x float> @llvm.x86.sse.sqrt.ps(<4 x float> %x) nounwind ret <4 x float> %1 diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index f89f6d502e30..3407c03c7029 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -59,3 +59,20 @@ define weak_odr <16 x i16> @hadd_pmadd_i8_avx2(<32 x i8> %a) nounwind alwaysinli ret <16 x i16> %1 } declare <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>) nounwind readnone + +define weak_odr <8 x i32> @wmul_pmaddwd_avx2(<8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = zext <8 x i16> %a to <8 x i32> + %2 = zext <8 x i16> %b to <8 x i32> + %3 = bitcast <8 x i32> %1 to <16 x i16> + %4 = bitcast <8 x i32> %2 to <16 x i16> + %res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %3, <16 x i16> %4) + ret <8 x i32> %res +} + +define weak_odr <8 x i32> @hadd_pmadd_i16_avx2(<16 x i16> %a) nounwind alwaysinline { + %res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a, <16 x i16> ) + ret <8 x i32> %res +} + +declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone + diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 22401897eee2..97e6c680153f 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -156,3 +156,29 @@ define weak_odr <16 x i32> @abs_i32x16(<16 x i32> %arg) { ret <16 x i32> %1 } declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1) nounwind readnone + +define weak_odr <32 x i16> @hadd_pmadd_u8_avx512(<64 x i8> %a) nounwind alwaysinline { + %1 = tail call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> %a, <64 x i8> ) + ret <32 x i16> %1 +} + +define weak_odr <32 x i16> @hadd_pmadd_i8_avx512(<64 x i8> %a) nounwind alwaysinline { + %1 = tail call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> , <64 x i8> %a) + ret <32 x i16> %1 +} +declare <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8>, <64 x i8>) nounwind readnone + +define weak_odr <16 x i32> @wmul_pmaddwd_avx512(<16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = zext <16 x i16> %a to <16 x i32> + %2 = zext <16 x i16> %b to <16 x i32> + %3 = bitcast <16 x i32> %1 to <32 x i16> + %4 = bitcast <16 x i32> %2 to <32 x i16> + %res = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %3, <32 x i16> %4) + ret <16 x i32> %res +} + +define weak_odr <16 x i32> @hadd_pmadd_i16_avx512(<32 x i16> %a) nounwind alwaysinline { + %res = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a, <32 x i16> ) + ret <16 x i32> %res +} +declare <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16>, <32 x i16>) nounwind readnone \ No newline at end of file diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index b79620956a7b..f86134d37630 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -312,6 +312,10 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(check_pmaddubsw, 4 * w, sum(u16(in_u8(2 * x + r2)))); check(check_pmaddubsw, 4 * w, sum(i16(in_u8(2 * x + r2)))); check(check_pmaddubsw, 4 * w, sum(i16(in_i8(2 * x + r2)))); + + check(check_pmaddubsw, 4 * w, u16(in_u8(2 * x)) + in_u8(2 * x + 1)); + check(check_pmaddubsw, 4 * w, i16(in_u8(2 * x)) + in_u8(2 * x + 1)); + check(check_pmaddubsw, 4 * w, i16(in_i8(2 * x)) + in_i8(2 * x + 1)); } } @@ -328,6 +332,8 @@ class SimdOpCheckX86 : public SimdOpCheckTest { RDom r4(0, 4); check(check_pmaddwd, 2 * w, sum(i32(in_i16(x * 4 + r4)) * in_i16(x * 4 + r4 + 32))); + check(check_pmaddwd, 2 * w, i32(in_i16(x * 2)) + in_i16(x * 2 + 1)); + // Also generate for widening_mul check(check_pmaddwd, 2 * w, i32(i16_1) * i32(i16_2)); }