Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] simd_builder ver. 2 #12603

Merged
merged 3 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 99 additions & 64 deletions Utilities/JIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,17 @@ asmjit::inline_runtime::~inline_runtime()
asmjit::simd_builder::simd_builder(CodeHolder* ch) noexcept
: native_asm(ch)
{
_init(true);
_init(0);
consts[~v128()] = this->newLabel();
}

asmjit::simd_builder::~simd_builder()
{
}

void asmjit::simd_builder::_init(bool full)
void asmjit::simd_builder::_init(uint new_vsize)
{
if (full && utils::has_avx512_icl())
if ((!new_vsize && utils::has_avx512_icl()) || new_vsize == 64)
{
v0 = x86::zmm0;
v1 = x86::zmm1;
Expand All @@ -377,7 +377,7 @@ void asmjit::simd_builder::_init(bool full)
v5 = x86::zmm5;
vsize = 64;
}
else if (full && utils::has_avx2())
else if ((!new_vsize && utils::has_avx2()) || new_vsize == 32)
{
v0 = x86::ymm0;
v1 = x86::ymm1;
Expand All @@ -395,10 +395,10 @@ void asmjit::simd_builder::_init(bool full)
v3 = x86::xmm3;
v4 = x86::xmm4;
v5 = x86::xmm5;
vsize = 16;
vsize = new_vsize ? new_vsize : 16;
}

if (full && utils::has_avx512())
if (!new_vsize && utils::has_avx512())
{
vmask = -1;
}
Expand Down Expand Up @@ -480,6 +480,10 @@ void asmjit::simd_builder::vec_clobbering_test(u32 esize, const Operand& v, cons
{
this->emit(x86::Inst::kIdVptest, v, rhs);
}
else if (esize == 16 && utils::has_avx())
{
this->emit(x86::Inst::kIdVptest, v, rhs);
}
else if (esize == 16 && utils::has_sse41())
{
this->emit(x86::Inst::kIdPtest, v, rhs);
Expand All @@ -504,6 +508,54 @@ void asmjit::simd_builder::vec_clobbering_test(u32 esize, const Operand& v, cons
}
}

void asmjit::simd_builder::vec_broadcast_gpr(u32 esize, const Operand& v, const x86::Gp& r)
{
if (esize == 2)
{
if (utils::has_avx512())
this->emit(x86::Inst::kIdVpbroadcastw, v, r.r32());
else if (utils::has_avx())
{
this->emit(x86::Inst::kIdVmovd, v, r.r32());
if (utils::has_avx2())
this->emit(x86::Inst::kIdVpbroadcastw, v, v);
else
{
this->emit(x86::Inst::kIdVpunpcklwd, v, v);
this->emit(x86::Inst::kIdVpshufd, v, v, Imm(0));
}
}
else
{
this->emit(x86::Inst::kIdMovd, v, r.r32());
this->emit(x86::Inst::kIdPunpcklwd, v, v);
this->emit(x86::Inst::kIdPshufd, v, v, Imm(0));
}
}
else if (esize == 4)
{
if (utils::has_avx512())
this->emit(x86::Inst::kIdVpbroadcastd, v, r.r32());
else if (utils::has_avx())
{
this->emit(x86::Inst::kIdVmovd, v, r.r32());
if (utils::has_avx2())
this->emit(x86::Inst::kIdVpbroadcastd, v, v);
else
this->emit(x86::Inst::kIdVpshufd, v, v, Imm(0));
}
else
{
this->emit(x86::Inst::kIdMovd, v, r.r32());
this->emit(x86::Inst::kIdPshufd, v, v, Imm(0));
}
}
else
{
fmt::throw_exception("Unimplemented");
}
}

asmjit::x86::Mem asmjit::simd_builder::ptr_scale_for_vec(u32 esize, const x86::Gp& base, const x86::Gp& index)
{
switch (ensure(esize))
Expand Down Expand Up @@ -636,7 +688,7 @@ void asmjit::simd_builder::_vec_binary_op(x86::Inst::Id sse_op, x86::Inst::Id ve
{
if (utils::has_avx())
{
if (vex_op == x86::Inst::kIdNone || this->_extraReg.isReg())
if (evex_op != x86::Inst::kIdNone && (vex_op == x86::Inst::kIdNone || this->_extraReg.isReg() || vsize >= 64))
{
this->evex().emit(evex_op, dst, lhs, rhs);
}
Expand Down Expand Up @@ -694,40 +746,31 @@ void asmjit::simd_builder::vec_umax(u32 esize, const Operand& dst, const Operand
fmt::throw_exception("Unimplemented");
}

void asmjit::simd_builder::vec_umin_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp)
void asmjit::simd_builder::vec_cmp_eq(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
if (!utils::has_sse41())
{
fmt::throw_exception("Unimplemented");
}

ensure(src != tmp);

if (esize == 2)
{
this->emit(utils::has_avx() ? kIdVphminposuw : kIdPhminposuw, x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->emit(utils::has_avx() ? kIdVpextrw : kIdPextrw, dst, x86::Xmm(tmp.id()), Imm(0));
if (vsize == 64)
{
this->evex().emit(kIdVpcmpeqw, x86::k0, lhs, rhs);
this->evex().emit(kIdVpmovm2w, dst, x86::k0);
}
else
{
_vec_binary_op(kIdPcmpeqw, kIdVpcmpeqw, kIdNone, dst, lhs, rhs);
}
}
else if (esize == 4)
{
if (utils::has_avx())
if (vsize == 64)
{
this->vpsrldq(x86::Xmm(tmp.id()), x86::Xmm(src.id()), 8);
this->vpminud(x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->vpsrldq(x86::Xmm(src.id()), x86::Xmm(tmp.id()), 4);
this->vpminud(x86::Xmm(src.id()), x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->vmovd(dst.r32(), x86::Xmm(src.id()));
this->evex().emit(kIdVpcmpeqd, x86::k0, lhs, rhs);
this->evex().emit(kIdVpmovm2d, dst, x86::k0);
}
else
{
this->movdqa(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->psrldq(x86::Xmm(tmp.id()), 8);
this->pminud(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->movdqa(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->psrldq(x86::Xmm(src.id()), 4);
this->pminud(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->movd(dst.r32(), x86::Xmm(src.id()));
_vec_binary_op(kIdPcmpeqw, kIdVpcmpeqw, kIdNone, dst, lhs, rhs);
}
}
else
Expand All @@ -736,50 +779,42 @@ void asmjit::simd_builder::vec_umin_horizontal_i128(u32 esize, const x86::Gp& ds
}
}

void asmjit::simd_builder::vec_umax_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp)
void asmjit::simd_builder::vec_extract_high(u32, const Operand& dst, const Operand& src)
{
using enum x86::Inst::Id;
if (!utils::has_sse41())
{
fmt::throw_exception("Unimplemented");
}

ensure(src != tmp);

if (esize == 2)
{
vec_set_all_ones(x86::Xmm(tmp.id()));
vec_xor(esize, x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->emit(utils::has_avx() ? kIdVphminposuw : kIdPhminposuw, x86::Xmm(tmp.id()), x86::Xmm(tmp.id()));
this->emit(utils::has_avx() ? kIdVpextrw : kIdPextrw, dst, x86::Xmm(tmp.id()), Imm(0));
this->not_(dst.r16());
}
else if (esize == 4)
if (vsize == 32)
this->vextracti32x8(x86::Ymm(dst.id()), x86::Zmm(src.id()), 1);
else if (vsize == 16)
this->vextracti128(x86::Xmm(dst.id()), x86::Ymm(src.id()), 1);
else
{
if (utils::has_avx())
{
this->vpsrldq(x86::Xmm(tmp.id()), x86::Xmm(src.id()), 8);
this->vpmaxud(x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->vpsrldq(x86::Xmm(src.id()), x86::Xmm(tmp.id()), 4);
this->vpmaxud(x86::Xmm(src.id()), x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->vmovd(dst.r32(), x86::Xmm(src.id()));
}
this->vpsrldq(x86::Xmm(dst.id()), x86::Xmm(src.id()), vsize);
else
{
this->movdqa(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->psrldq(x86::Xmm(tmp.id()), 8);
this->pmaxud(x86::Xmm(tmp.id()), x86::Xmm(src.id()));
this->movdqa(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->psrldq(x86::Xmm(src.id()), 4);
this->pmaxud(x86::Xmm(src.id()), x86::Xmm(tmp.id()));
this->movd(dst.r32(), x86::Xmm(src.id()));
this->movdqa(x86::Xmm(dst.id()), x86::Xmm(src.id()));
this->psrldq(x86::Xmm(dst.id()), vsize);
}
}
}

void asmjit::simd_builder::vec_extract_gpr(u32 esize, const x86::Gp& dst, const Operand& src)
{
if (esize == 8 && utils::has_avx())
this->vmovq(dst.r64(), x86::Xmm(src.id()));
else if (esize == 8)
this->movq(dst.r64(), x86::Xmm(src.id()));
else if (esize == 4 && utils::has_avx())
this->vmovd(dst.r32(), x86::Xmm(src.id()));
else if (esize == 4)
this->movd(dst.r32(), x86::Xmm(src.id()));
else if (esize == 2 && utils::has_avx())
this->vpextrw(dst.r32(), x86::Xmm(src.id()), 0);
else if (esize == 2)
this->pextrw(dst.r32(), x86::Xmm(src.id()), 0);
else
{
fmt::throw_exception("Unimplemented");
}
}

#endif /* X86 */

#ifdef LLVM_AVAILABLE
Expand Down
87 changes: 62 additions & 25 deletions Utilities/JIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,13 @@ namespace asmjit

void operator()() noexcept;

void _init(bool full);
void _init(uint new_vsize = 0);
void vec_cleanup_ret();
void vec_set_all_zeros(const Operand& v);
void vec_set_all_ones(const Operand& v);
void vec_set_const(const Operand& v, const v128& value);
void vec_clobbering_test(u32 esize, const Operand& v, const Operand& rhs);
void vec_broadcast_gpr(u32 esize, const Operand& v, const x86::Gp& r);

// return x86::ptr(base, ctr, X, 0) where X is set for esize accordingly
x86::Mem ptr_scale_for_vec(u32 esize, const x86::Gp& base, const x86::Gp& index);
Expand Down Expand Up @@ -260,11 +261,18 @@ namespace asmjit
_vec_binary_op(kIdPor, kIdVpor, kIdVpord, dst, lhs, rhs);
}

void vec_andn(u32, const Operand& dst, const Operand& lhs, const Operand& rhs)
{
using enum x86::Inst::Id;
_vec_binary_op(kIdPandn, kIdVpandn, kIdVpandnd, dst, lhs, rhs);
}

void vec_umin(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs);
void vec_umax(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs);
void vec_cmp_eq(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs);

void vec_umin_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp);
void vec_umax_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp);
void vec_extract_high(u32 esize, const Operand& dst, const Operand& src);
void vec_extract_gpr(u32 esize, const x86::Gp& dst, const Operand& src);

simd_builder& keep_if_not_masked()
{
Expand All @@ -287,7 +295,7 @@ namespace asmjit
return *this;
}

void build_loop(u32 esize, auto reg_ctr, auto reg_cnt, auto&& build, auto&& reduce)
void build_loop(u32 esize, const x86::Gp& reg_ctr, const x86::Gp& reg_cnt, auto&& build, auto&& reduce)
{
ensure((esize & (esize - 1)) == 0);
ensure(esize <= vsize);
Expand All @@ -299,47 +307,76 @@ namespace asmjit
const u32 step = vsize / esize;

this->xor_(reg_ctr.r32(), reg_ctr.r32()); // Reset counter reg
this->sub(reg_cnt, step);
this->cmp(reg_cnt, step);
this->jb(next); // If count < step, skip main loop body
this->align(AlignMode::kCode, 16);
this->bind(body);
this->sub(reg_cnt, step);
build();
this->add(reg_ctr, step);
this->sub(reg_cnt, step);
this->ja(body);
this->cmp(reg_cnt, step);
this->jae(body);
this->bind(next);
if (!vmask)
reduce();
this->add(reg_cnt, step);
this->jz(exit);

if (vmask)
{
// Build single last iteration (masked)
this->test(reg_cnt, reg_cnt);
this->jz(exit);
this->bzhi(reg_cnt, x86::Mem(consts[~u128()], 0), reg_cnt);
this->kmovq(x86::k7, reg_cnt);
vmask = 7;
build();
vmask = -1;
reduce();

// Rollout reduction step
this->bind(exit);
while (true)
{
vsize /= 2;
if (vsize < esize)
break;
this->_init(vsize);
reduce();
}
}
else
{
// Build tail loop (reduced vector width)
Label body = this->newLabel();
this->align(AlignMode::kCode, 16);
this->bind(body);
const uint vsz = vsize / step;
this->_init(false);
vsize = vsz;
build();
this->_init(true);
this->inc(reg_ctr);
this->sub(reg_cnt, 1);
this->ja(body);
// Build unrolled loop tail (reduced vector width)
while (true)
{
vsize /= 2;
if (vsize < esize)
break;

// Shall not clobber flags
this->_init(vsize);
reduce();

if (vsize == esize)
{
// Last "iteration"
this->test(reg_cnt, reg_cnt);
this->jz(exit);
build();
}
else
{
const u32 step = vsize / esize;
Label next = this->newLabel();
this->cmp(reg_cnt, step);
this->jb(next);
build();
this->add(reg_ctr, step);
this->sub(reg_cnt, step);
this->bind(next);
}
}

this->bind(exit);
}

this->bind(exit);
this->_init(0);
}
};

Expand Down
Loading