Skip to content

Commit

Permalink
Reduce node fix illegal instruction on sse41
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchen-intel committed Jun 1, 2021
1 parent 115aa14 commit a218de4
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_reduce_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,12 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
}
// reduce
reduce_main_loop();
if (jcp_.reduce_mode == ReduceOr && isa != avx512_common) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
if (jcp_.reduce_mode == ReduceOr && isa != cpu::x64::avx512_common) {
if (isa == cpu::x64::avx2) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
} else if (isa == cpu::x64::sse41) {
cmpneqps(vmm_dst, vmm_zero);
}
uni_vandps(vmm_dst, vmm_dst, vmm_aux);
}
// store
Expand Down Expand Up @@ -361,7 +365,11 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
// reduce
reduce_kernel_scalar(xmm_src, xmm_dst);
if (jcp_.reduce_mode == ReduceOr) {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
if (isa == cpu::x64::sse41) {
cmpneqps(xmm_dst, xmm_zero);
} else {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
}
uni_vandps(xmm_dst, xmm_dst, xmm_aux);
}

Expand Down Expand Up @@ -400,7 +408,11 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene

reduce_kernel_scalar(xmm_src, xmm_dst);
if (jcp_.reduce_mode == ReduceOr) {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
if (isa == cpu::x64::sse41) {
cmpneqps(xmm_dst, xmm_zero);
} else {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
}
uni_vandps(xmm_dst, xmm_dst, xmm_aux);
}

Expand Down Expand Up @@ -448,11 +460,13 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
inline void reduce_kernel(Vmm vmm_src, Vmm vmm_dst) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vcmpps(k_mask, vmm_src, vmm_zero, _cmp_neq_uq);
vblendmps(vmm_src | k_mask, vmm_zero, vmm_aux);
} else {
} else if (isa == cpu::x64::avx2) {
vcmpneqps(vmm_src, vmm_src, vmm_zero);
} else {
cmpneqps(vmm_src, vmm_zero);
}
uni_vandps(vmm_dst, vmm_dst, vmm_src);
break;
Expand Down Expand Up @@ -481,7 +495,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
uni_vaddps(vmm_dst, vmm_dst, vmm_src);
break;
case ReduceOr:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vcmpps(k_mask, vmm_src, vmm_zero, _cmp_neq_uq);
vblendmps(vmm_src | k_mask, vmm_zero, vmm_aux);
}
Expand All @@ -498,7 +512,11 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
inline void reduce_kernel_scalar(Xmm xmm_src, Xmm xmm_dst) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
vcmpneqps(xmm_src, xmm_src, xmm_zero);
if (isa == cpu::x64::sse41) {
cmpneqps(xmm_src, xmm_zero);
} else {
vcmpneqps(xmm_src, xmm_src, xmm_zero);
}
uni_vandps(xmm_dst, xmm_dst, xmm_src);
break;
case ReduceL1:
Expand Down Expand Up @@ -543,11 +561,16 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
}

inline void store_dst_vector() {
if (jcp_.reduce_mode == ReduceOr && isa != avx512_common) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
if (jcp_.reduce_mode == ReduceOr && isa != cpu::x64::avx512_common) {
if (isa == cpu::x64::avx2) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
} else if (isa == cpu::x64::sse41) {
cmpneqps(vmm_dst, vmm_zero);
}
uni_vandps(vmm_dst, vmm_dst, vmm_aux);

if (isa == cpu::x64::sse41) {
vcmpneqps(vmm_dst_aux, vmm_dst_aux, vmm_zero);
cmpneqps(vmm_dst_aux, vmm_zero);
uni_vandps(vmm_dst_aux, vmm_dst_aux, vmm_aux);
}
}
Expand Down Expand Up @@ -628,7 +651,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
vmovdqu16(op, ymm_dst);
break;
case memory::data_type::s8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vmaxps(vmm_dst, vmm_zero, vmm_dst);
vpmovsdb(op, vmm_dst);
} else {
Expand All @@ -643,7 +666,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
}
break;
case memory::data_type::u8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vpmovusdb(op, vmm_dst);
} else {
uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
Expand Down Expand Up @@ -719,34 +742,27 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);

switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::bf16:
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
break;
case memory::data_type::s32:
movss(xmm_aux3, ptr[reg_dst]);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovzxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovsxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
Expand Down Expand Up @@ -1102,7 +1118,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
vmovdqu16(op, ymm_dst);
break;
case memory::data_type::s8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vmaxps(vmm_dst, vmm_zero, vmm_dst);
vpmovsdb(op, vmm_dst);
} else {
Expand All @@ -1117,7 +1133,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
}
break;
case memory::data_type::u8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vpmovusdb(op, vmm_dst);
} else {
uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
Expand Down Expand Up @@ -1249,34 +1265,27 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);

switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::bf16:
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
break;
case memory::data_type::s32:
movss(xmm_aux3, ptr[reg_dst]);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovzxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovsxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
Expand Down

0 comments on commit a218de4

Please sign in to comment.