Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[x86] Generate AVX512 fixed-point instructions #7129

Merged
merged 13 commits into from
Oct 31, 2022
54 changes: 51 additions & 3 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Target complete_x86_target(Target t) {
if (t.has_feature(Target::AVX512_Cannonlake) ||
t.has_feature(Target::AVX512_Skylake) ||
t.has_feature(Target::AVX512_KNL)) {
t.set_feature(Target::AVX512);
}
if (t.has_feature(Target::AVX512)) {
t.set_feature(Target::AVX2);
}
if (t.has_feature(Target::AVX2)) {
Expand Down Expand Up @@ -111,6 +114,12 @@ struct x86Intrinsic {

// clang-format off
const x86Intrinsic intrinsic_defs[] = {
// AVX2/SSSE3 LLVM intrinsics for pabs fail in JIT. The integer wrappers
// just call `llvm.abs` (which requires a second argument).
// AVX512BW's pabs instructions aren't directly exposed by LLVM.
{"abs_i8x64", UInt(8, 64), "abs", {Int(8, 64)}, Target::AVX512_Skylake},
{"abs_i16x32", UInt(16, 32), "abs", {Int(16, 32)}, Target::AVX512_Skylake},
{"abs_i32x16", UInt(32, 16), "abs", {Int(32, 16)}, Target::AVX512_Skylake},
{"abs_i8x32", UInt(8, 32), "abs", {Int(8, 32)}, Target::AVX2},
{"abs_i16x16", UInt(16, 16), "abs", {Int(16, 16)}, Target::AVX2},
{"abs_i32x8", UInt(32, 8), "abs", {Int(32, 8)}, Target::AVX2},
Expand All @@ -125,15 +134,19 @@ const x86Intrinsic intrinsic_defs[] = {
{"round_f32x8", Float(32, 8), "round", {Float(32, 8)}, Target::AVX},
{"round_f64x4", Float(64, 4), "round", {Float(64, 4)}, Target::AVX},

{"llvm.sadd.sat.v64i8", Int(8, 64), "saturating_add", {Int(8, 64), Int(8, 64)}, Target::AVX512_Skylake},
{"llvm.sadd.sat.v32i8", Int(8, 32), "saturating_add", {Int(8, 32), Int(8, 32)}, Target::AVX2},
{"llvm.sadd.sat.v16i8", Int(8, 16), "saturating_add", {Int(8, 16), Int(8, 16)}},
{"llvm.sadd.sat.v8i8", Int(8, 8), "saturating_add", {Int(8, 8), Int(8, 8)}},
{"llvm.ssub.sat.v64i8", Int(8, 64), "saturating_sub", {Int(8, 64), Int(8, 64)}, Target::AVX512_Skylake},
{"llvm.ssub.sat.v32i8", Int(8, 32), "saturating_sub", {Int(8, 32), Int(8, 32)}, Target::AVX2},
{"llvm.ssub.sat.v16i8", Int(8, 16), "saturating_sub", {Int(8, 16), Int(8, 16)}},
{"llvm.ssub.sat.v8i8", Int(8, 8), "saturating_sub", {Int(8, 8), Int(8, 8)}},

{"llvm.sadd.sat.v32i16", Int(16, 32), "saturating_add", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.sadd.sat.v16i16", Int(16, 16), "saturating_add", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.sadd.sat.v8i16", Int(16, 8), "saturating_add", {Int(16, 8), Int(16, 8)}},
{"llvm.ssub.sat.v32i16", Int(16, 32), "saturating_sub", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.ssub.sat.v16i16", Int(16, 16), "saturating_sub", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.ssub.sat.v8i16", Int(16, 8), "saturating_sub", {Int(16, 8), Int(16, 8)}},

Expand All @@ -149,13 +162,17 @@ const x86Intrinsic intrinsic_defs[] = {
// Target::AVX instead of Target::AVX2 as the feature flag
// requirement.
// TODO: Just use llvm.*add/*sub.sat, and verify the above comment?
{"llvm.uadd.sat.v64i8", UInt(8, 64), "saturating_add", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake},
{"paddusbx32", UInt(8, 32), "saturating_add", {UInt(8, 32), UInt(8, 32)}, Target::AVX},
{"paddusbx16", UInt(8, 16), "saturating_add", {UInt(8, 16), UInt(8, 16)}},
{"llvm.usub.sat.v64i8", UInt(8, 64), "saturating_sub", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake},
{"psubusbx32", UInt(8, 32), "saturating_sub", {UInt(8, 32), UInt(8, 32)}, Target::AVX},
{"psubusbx16", UInt(8, 16), "saturating_sub", {UInt(8, 16), UInt(8, 16)}},

{"llvm.uadd.sat.v32i16", UInt(16, 32), "saturating_add", {UInt(16, 32), UInt(16, 32)}, Target::AVX512_Skylake},
{"padduswx16", UInt(16, 16), "saturating_add", {UInt(16, 16), UInt(16, 16)}, Target::AVX},
{"padduswx8", UInt(16, 8), "saturating_add", {UInt(16, 8), UInt(16, 8)}},
{"llvm.usub.sat.v32i16", UInt(16, 32), "saturating_sub", {UInt(16, 32), UInt(16, 32)}, Target::AVX512_Skylake},
{"psubuswx16", UInt(16, 16), "saturating_sub", {UInt(16, 16), UInt(16, 16)}, Target::AVX},
{"psubuswx8", UInt(16, 8), "saturating_sub", {UInt(16, 8), UInt(16, 8)}},

Expand All @@ -180,14 +197,15 @@ const x86Intrinsic intrinsic_defs[] = {
{"wmul_pmaddwd_sse2", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}},

// Multiply keep high half
{"llvm.x86.avx512.pmulh.w.512", Int(16, 32), "pmulh", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx2.pmulh.w", Int(16, 16), "pmulh", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.x86.avx512.pmulhu.w.512", UInt(16, 32), "pmulh", {UInt(16, 32), UInt(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx2.pmulhu.w", UInt(16, 16), "pmulh", {UInt(16, 16), UInt(16, 16)}, Target::AVX2},
{"llvm.x86.avx512.pmul.hr.sw.512", Int(16, 32), "pmulhrs", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx2.pmul.hr.sw", Int(16, 16), "pmulhrs", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"saturating_pmulhrswx16", Int(16, 16), "saturating_pmulhrs", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.x86.sse2.pmulh.w", Int(16, 8), "pmulh", {Int(16, 8), Int(16, 8)}},
{"llvm.x86.sse2.pmulhu.w", UInt(16, 8), "pmulh", {UInt(16, 8), UInt(16, 8)}},
{"llvm.x86.ssse3.pmul.hr.sw.128", Int(16, 8), "pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41},
{"saturating_pmulhrswx8", Int(16, 8), "saturating_pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41},

// Convert FP32 to BF16
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_SapphireRapids},
Expand Down Expand Up @@ -582,7 +600,6 @@ void CodeGen_X86::visit(const Call *op) {
static Pattern patterns[] = {
{"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)},
{"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)},
{"saturating_pmulhrs", rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15)},
{"saturating_narrow", i16_sat(wild_i32x_)},
{"saturating_narrow", u16_sat(wild_i32x_)},
{"saturating_narrow", i8_sat(wild_i16x_)},
Expand All @@ -600,6 +617,37 @@ void CodeGen_X86::visit(const Call *op) {
}
}

// Check for saturating_pmulhrs. On x86, pmulhrs is truncating, but it's still faster
// to use pmulhrs than to lower (producing widening multiplication), and have a check
// for the singular overflow case.
static Expr saturating_pmulhrs = rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15);
if (expr_match(saturating_pmulhrs, op, matches)) {
// Rewrite so that we can take advantage of pmulhrs.
internal_assert(matches.size() == 2);
internal_assert(op->type.element_of() == Int(16));
const Expr &a = matches[0];
const Expr &b = matches[1];

Expr pmulhrs = i16(rounding_shift_right(widening_mul(a, b), 15));

Expr i16_min = op->type.min();
Expr i16_max = op->type.max();

// Handle edge case of possible overflow.
// See https://github.com/halide/Halide/pull/7129/files#r1008331426
// On AVX512 (and with enough lanes) we can use a mask register.
if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs);
expr.accept(this);
} else {
Expr mask = select(max(a, b) == i16_min, cast(op->type, -1), cast(op->type, 0));
Expr expr = mask ^ pmulhrs;
expr.accept(this);
}

return;
}

CodeGen_Posix::visit(op);
}

Expand Down
31 changes: 9 additions & 22 deletions src/runtime/x86_avx2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,22 @@ define weak_odr <16 x i16> @packusdwx16(<16 x i32> %arg) nounwind alwaysinline
declare <16 x i16> @llvm.x86.avx2.packusdw(<8 x i32>, <8 x i32>) nounwind readnone

define weak_odr <32 x i8> @abs_i8x32(<32 x i8> %arg) {
%1 = sub <32 x i8> zeroinitializer, %arg
%2 = icmp sgt <32 x i8> %arg, zeroinitializer
%3 = select <32 x i1> %2, <32 x i8> %arg, <32 x i8> %1
ret <32 x i8> %3
%1 = tail call <32 x i8> @llvm.abs.v32i8(<32 x i8> %arg, i1 false)
ret <32 x i8> %1
}
declare <32 x i8> @llvm.abs.v32i8(<32 x i8>, i1) nounwind readnone

define weak_odr <16 x i16> @abs_i16x16(<16 x i16> %arg) {
%1 = sub <16 x i16> zeroinitializer, %arg
%2 = icmp sgt <16 x i16> %arg, zeroinitializer
%3 = select <16 x i1> %2, <16 x i16> %arg, <16 x i16> %1
ret <16 x i16> %3
%1 = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> %arg, i1 false)
ret <16 x i16> %1
}
declare <16 x i16> @llvm.abs.v16i16(<16 x i16>, i1) nounwind readnone

define weak_odr <8 x i32> @abs_i32x8(<8 x i32> %arg) {
%1 = sub <8 x i32> zeroinitializer, %arg
%2 = icmp sgt <8 x i32> %arg, zeroinitializer
%3 = select <8 x i1> %2, <8 x i32> %arg, <8 x i32> %1
ret <8 x i32> %3
}

define weak_odr <16 x i16> @saturating_pmulhrswx16(<16 x i16> %a, <16 x i16> %b) nounwind uwtable readnone alwaysinline {
%1 = tail call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> %a, <16 x i16> %b)
%2 = icmp eq <16 x i16> %a, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%3 = icmp eq <16 x i16> %b, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%4 = and <16 x i1> %2, %3
%5 = select <16 x i1> %4, <16 x i16> <i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767>, <16 x i16> %1
ret <16 x i16> %5
%1 = tail call <8 x i32> @llvm.abs.v8i32(<8 x i32> %arg, i1 false)
ret <8 x i32> %1
}
declare <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16>, <16 x i16>) nounwind readnone
declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1) nounwind readnone

define weak_odr <16 x i16> @hadd_pmadd_u8_avx2(<32 x i8> %a) nounwind alwaysinline {
%1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a, <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
Expand Down
18 changes: 18 additions & 0 deletions src/runtime/x86_avx512.ll
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,21 @@ define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %
ret <4 x i32> %3
}
declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>)

define weak_odr <64 x i8> @abs_i8x64(<64 x i8> %arg) {
%1 = tail call <64 x i8> @llvm.abs.v64i8(<64 x i8> %arg, i1 false)
ret <64 x i8> %1
}
declare <64 x i8> @llvm.abs.v64i8(<64 x i8>, i1) nounwind readnone

define weak_odr <32 x i16> @abs_i16x32(<32 x i16> %arg) {
%1 = tail call <32 x i16> @llvm.abs.v32i16(<32 x i16> %arg, i1 false)
ret <32 x i16> %1
}
declare <32 x i16> @llvm.abs.v32i16(<32 x i16>, i1) nounwind readnone

define weak_odr <16 x i32> @abs_i32x16(<16 x i32> %arg) {
%1 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %arg, i1 false)
ret <16 x i32> %1
}
declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1) nounwind readnone
31 changes: 9 additions & 22 deletions src/runtime/x86_sse41.ll
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,22 @@ define weak_odr <2 x double> @trunc_f64x2(<2 x double> %x) nounwind uwtable read
}

define weak_odr <16 x i8> @abs_i8x16(<16 x i8> %x) nounwind uwtable readnone alwaysinline {
%1 = sub <16 x i8> zeroinitializer, %x
%2 = icmp sgt <16 x i8> %x, zeroinitializer
%3 = select <16 x i1> %2, <16 x i8> %x, <16 x i8> %1
ret <16 x i8> %3
%1 = tail call <16 x i8> @llvm.abs.v16i8(<16 x i8> %x, i1 false)
ret <16 x i8> %1
}
declare <16 x i8> @llvm.abs.v16i8(<16 x i8>, i1) nounwind readnone

define weak_odr <8 x i16> @abs_i16x8(<8 x i16> %x) nounwind uwtable readnone alwaysinline {
%1 = sub <8 x i16> zeroinitializer, %x
%2 = icmp sgt <8 x i16> %x, zeroinitializer
%3 = select <8 x i1> %2, <8 x i16> %x, <8 x i16> %1
ret <8 x i16> %3
%1 = tail call <8 x i16> @llvm.abs.v8i16(<8 x i16> %x, i1 false)
ret <8 x i16> %1
}
declare <8 x i16> @llvm.abs.v8i16(<8 x i16>, i1) nounwind readnone

define weak_odr <4 x i32> @abs_i32x4(<4 x i32> %x) nounwind uwtable readnone alwaysinline {
%1 = sub <4 x i32> zeroinitializer, %x
%2 = icmp sgt <4 x i32> %x, zeroinitializer
%3 = select <4 x i1> %2, <4 x i32> %x, <4 x i32> %1
ret <4 x i32> %3
}

define weak_odr <8 x i16> @saturating_pmulhrswx8(<8 x i16> %a, <8 x i16> %b) nounwind uwtable readnone alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> %a, <8 x i16> %b)
%2 = icmp eq <8 x i16> %a, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%3 = icmp eq <8 x i16> %b, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%4 = and <8 x i1> %2, %3
%5 = select <8 x i1> %4, <8 x i16> <i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767>, <8 x i16> %1
ret <8 x i16> %5
%1 = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> %x, i1 false)
ret <4 x i32> %1
}
declare <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16>, <8 x i16>) nounwind readnone
declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1) nounwind readnone

define weak_odr <8 x i16> @hadd_pmadd_u8_sse3(<16 x i8> %a) nounwind alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a, <16 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
Expand Down
88 changes: 48 additions & 40 deletions test/correctness/simd_op_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,51 +448,59 @@ class SimdOpCheck : public SimdOpCheckTest {
// AVX 2

if (use_avx2) {
check("vpaddb*ymm", 32, u8_1 + u8_2);
check("vpsubb*ymm", 32, u8_1 - u8_2);
check("vpaddsb*ymm", 32, i8_sat(i16(i8_1) + i16(i8_2)));
check("vpsubsb*ymm", 32, i8_sat(i16(i8_1) - i16(i8_2)));
check("vpaddusb*ymm", 32, u8(min(u16(u8_1) + u16(u8_2), max_u8)));
check("vpsubusb*ymm", 32, u8(max(i16(u8_1) - i16(u8_2), 0)));
check("vpaddw*ymm", 16, u16_1 + u16_2);
check("vpsubw*ymm", 16, u16_1 - u16_2);
check("vpaddsw*ymm", 16, i16_sat(i32(i16_1) + i32(i16_2)));
check("vpsubsw*ymm", 16, i16_sat(i32(i16_1) - i32(i16_2)));
check("vpaddusw*ymm", 16, u16(min(u32(u16_1) + u32(u16_2), max_u16)));
check("vpsubusw*ymm", 16, u16(max(i32(u16_1) - i32(u16_2), 0)));
check("vpaddd*ymm", 8, i32_1 + i32_2);
check("vpsubd*ymm", 8, i32_1 - i32_2);
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) / (256 * 256)));
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) >> cast<unsigned>(16)));
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) >> cast<int>(16)));
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) << cast<int>(-16)));
check("vpmullw*ymm", 16, i16_1 * i16_2);

check("vpmulhrsw*ymm", 16, i16((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));
check("vpmulhrsw*ymm", 16, i16_sat((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));

check("vpcmp*b*ymm", 32, select(u8_1 == u8_2, u8(1), u8(2)));
check("vpcmp*b*ymm", 32, select(u8_1 > u8_2, u8(1), u8(2)));
check("vpcmp*w*ymm", 16, select(u16_1 == u16_2, u16(1), u16(2)));
check("vpcmp*w*ymm", 16, select(u16_1 > u16_2, u16(1), u16(2)));
check("vpcmp*d*ymm", 8, select(u32_1 == u32_2, u32(1), u32(2)));
check("vpcmp*d*ymm", 8, select(u32_1 > u32_2, u32(1), u32(2)));

check("vpavgb*ymm", 32, u8((u16(u8_1) + u16(u8_2) + 1) / 2));
check("vpavgw*ymm", 16, u16((u32(u16_1) + u32(u16_2) + 1) / 2));
check("vpmaxsw*ymm", 16, max(i16_1, i16_2));
check("vpminsw*ymm", 16, min(i16_1, i16_2));
check("vpmaxub*ymm", 32, max(u8_1, u8_2));
check("vpminub*ymm", 32, min(u8_1, u8_2));
auto check_x86_fixed_point = [&](const std::string &suffix, const int m) {
check("vpaddb*" + suffix, 32 * m, u8_1 + u8_2);
check("vpsubb*" + suffix, 32 * m, u8_1 - u8_2);
check("vpaddsb*" + suffix, 32 * m, i8_sat(i16(i8_1) + i16(i8_2)));
check("vpsubsb*" + suffix, 32 * m, i8_sat(i16(i8_1) - i16(i8_2)));
check("vpaddusb*" + suffix, 32 * m, u8(min(u16(u8_1) + u16(u8_2), max_u8)));
check("vpsubusb*" + suffix, 32 * m, u8(max(i16(u8_1) - i16(u8_2), 0)));
check("vpaddw*" + suffix, 16 * m, u16_1 + u16_2);
check("vpsubw*" + suffix, 16 * m, u16_1 - u16_2);
check("vpaddsw*" + suffix, 16 * m, i16_sat(i32(i16_1) + i32(i16_2)));
check("vpsubsw*" + suffix, 16 * m, i16_sat(i32(i16_1) - i32(i16_2)));
check("vpaddusw*" + suffix, 16 * m, u16(min(u32(u16_1) + u32(u16_2), max_u16)));
check("vpsubusw*" + suffix, 16 * m, u16(max(i32(u16_1) - i32(u16_2), 0)));
check("vpaddd*" + suffix, 8 * m, i32_1 + i32_2);
check("vpsubd*" + suffix, 8 * m, i32_1 - i32_2);
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) / (256 * 256)));
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) >> cast<unsigned>(16)));
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) >> cast<int>(16)));
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) << cast<int>(-16)));
check("vpmullw*" + suffix, 16 * m, i16_1 * i16_2);

check("vpmulhrsw*" + suffix, 16 * m, i16((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));
check("vpmulhrsw*" + suffix, 16 * m, i16_sat((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));

check("vpcmp*b*" + suffix, 32 * m, select(u8_1 == u8_2, u8(1), u8(2)));
check("vpcmp*b*" + suffix, 32 * m, select(u8_1 > u8_2, u8(1), u8(2)));
check("vpcmp*w*" + suffix, 16 * m, select(u16_1 == u16_2, u16(1), u16(2)));
check("vpcmp*w*" + suffix, 16 * m, select(u16_1 > u16_2, u16(1), u16(2)));
check("vpcmp*d*" + suffix, 8 * m, select(u32_1 == u32_2, u32(1), u32(2)));
check("vpcmp*d*" + suffix, 8 * m, select(u32_1 > u32_2, u32(1), u32(2)));

check("vpavgb*" + suffix, 32 * m, u8((u16(u8_1) + u16(u8_2) + 1) / 2));
check("vpavgw*" + suffix, 16 * m, u16((u32(u16_1) + u32(u16_2) + 1) / 2));
check("vpmaxsw*" + suffix, 16 * m, max(i16_1, i16_2));
check("vpminsw*" + suffix, 16 * m, min(i16_1, i16_2));
check("vpmaxub*" + suffix, 32 * m, max(u8_1, u8_2));
check("vpminub*" + suffix, 32 * m, min(u8_1, u8_2));

check("vpabsb*" + suffix, 32 * m, abs(i8_1));
check("vpabsw*" + suffix, 16 * m, abs(i16_1));
check("vpabsd*" + suffix, 8 * m, abs(i32_1));
};

check_x86_fixed_point("ymm", 1);

if (use_avx512) {
check_x86_fixed_point("zmm", 2);
}

check(use_avx512 ? "vpaddq*zmm" : "vpaddq*ymm", 8, i64_1 + i64_2);
check(use_avx512 ? "vpsubq*zmm" : "vpsubq*ymm", 8, i64_1 - i64_2);
check(use_avx512 ? "vpmullq" : "vpmuludq*ymm", 8, u64_1 * u64_2);

check("vpabsb*ymm", 32, abs(i8_1));
check("vpabsw*ymm", 16, abs(i16_1));
check("vpabsd*ymm", 8, abs(i32_1));

// llvm doesn't distinguish between signed and unsigned multiplies
// check("vpmuldq", 8, i64(i32_1) * i64(i32_2));
if (!use_avx512) {
Expand Down