Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/cpu/jit_uni_depthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ struct jit_uni_depthwise_kernel_f32 : public c_compatible {
};

template <cpu_isa_t isa>
int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg, bool is_broadcast) {
switch (depthwise_alg) {
case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
case alg_kind::depthwise_scale_shift: return isa == sse42 || is_broadcast ? 1 : 0;
case alg_kind::depthwise_prelu: return 2;
default: assert(!"unsupported depthwise algorithm");
}
Expand All @@ -66,9 +66,9 @@ int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg
}

template <cpu_isa_t isa>
void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast) {
preserved_vecs_count = 0;
vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg, is_broadcast);

for (size_t i = 0; i < vecs_count; i++) {
if (preserved_vecs_count >= vecs_to_preserve)
Expand Down Expand Up @@ -210,7 +210,7 @@ void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t
template <cpu_isa_t isa>
void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast) {
injector_preamble(start_idx, end_idx);
injector_preamble(start_idx, end_idx, is_broadcast);
compute_body(start_idx_tail, end_idx, p_weights, p_bias, is_broadcast);
injector_preamble_tail(start_idx, end_idx);
compute_body(start_idx, start_idx_tail, p_weights, p_bias, is_broadcast);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/jit_uni_depthwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ struct jit_uni_depthwise_injector_f32 {
size_t preserved_vec_idxs[preserved_vecs_max] = {0};
size_t start_idx_tail = 0;

int aux_vecs_count(alg_kind_t elt_alg);
int aux_vecs_count(alg_kind_t elt_alg, bool is_broadcast);

void compute_body(size_t start_idx, size_t end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false);
void injector_preamble(size_t start_idx, size_t end_idx);
void injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast = false);
void injector_preamble_tail(size_t start_idx, size_t end_idx);
void injector_postamble();
void assign_regs();
Expand Down