From 5a21c76bbc8581f2a9408af2a718584907a7fb07 Mon Sep 17 00:00:00 2001 From: dmitrygo Date: Wed, 27 May 2020 20:26:22 +0300 Subject: [PATCH] Fixed depthwise injector aux_vec_count for broadcasting case --- src/cpu/jit_uni_depthwise.cpp | 10 +++++----- src/cpu/jit_uni_depthwise.hpp | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/cpu/jit_uni_depthwise.cpp b/src/cpu/jit_uni_depthwise.cpp index 2c9f19a4408..7d9952fca8a 100644 --- a/src/cpu/jit_uni_depthwise.cpp +++ b/src/cpu/jit_uni_depthwise.cpp @@ -55,9 +55,9 @@ struct jit_uni_depthwise_kernel_f32 : public c_compatible { }; template -int jit_uni_depthwise_injector_f32::aux_vecs_count(alg_kind_t depthwise_alg) { +int jit_uni_depthwise_injector_f32::aux_vecs_count(alg_kind_t depthwise_alg, bool is_broadcast) { switch (depthwise_alg) { - case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0; + case alg_kind::depthwise_scale_shift: return isa == sse42 || is_broadcast ? 1 : 0; case alg_kind::depthwise_prelu: return 2; default: assert(!"unsupported depthwise algorithm"); } @@ -66,9 +66,9 @@ int jit_uni_depthwise_injector_f32::aux_vecs_count(alg_kind_t depthwise_alg } template -void jit_uni_depthwise_injector_f32::injector_preamble(size_t start_idx, size_t end_idx) { +void jit_uni_depthwise_injector_f32::injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast) { preserved_vecs_count = 0; - vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32::aux_vecs_count(depthwise_alg); + vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32::aux_vecs_count(depthwise_alg, is_broadcast); for (size_t i = 0; i < vecs_count; i++) { if (preserved_vecs_count >= vecs_to_preserve) @@ -210,7 +210,7 @@ void jit_uni_depthwise_injector_f32::compute_body(size_t start_idx, size_t template void jit_uni_depthwise_injector_f32::compute_vector_range(int start_idx, int end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast) { - injector_preamble(start_idx, end_idx); + injector_preamble(start_idx, end_idx, is_broadcast); compute_body(start_idx_tail, end_idx, p_weights, p_bias, is_broadcast); injector_preamble_tail(start_idx, end_idx); compute_body(start_idx, start_idx_tail, p_weights, p_bias, is_broadcast); diff --git a/src/cpu/jit_uni_depthwise.hpp b/src/cpu/jit_uni_depthwise.hpp index 1ab4ed2f38a..fc78ce99e29 100644 --- a/src/cpu/jit_uni_depthwise.hpp +++ b/src/cpu/jit_uni_depthwise.hpp @@ -65,10 +65,10 @@ struct jit_uni_depthwise_injector_f32 { size_t preserved_vec_idxs[preserved_vecs_max] = {0}; size_t start_idx_tail = 0; - int aux_vecs_count(alg_kind_t elt_alg); + int aux_vecs_count(alg_kind_t elt_alg, bool is_broadcast); void compute_body(size_t start_idx, size_t end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false); - void injector_preamble(size_t start_idx, size_t end_idx); + void injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast = false); void injector_preamble_tail(size_t start_idx, size_t end_idx); void injector_postamble(); void assign_regs();