Skip to content

Commit

Permalink
shape agnistic & dynamic support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Feb 6, 2024
1 parent 70478d2 commit 974677f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
}

void jit_emitter::emitter_postamble() const {
for (size_t i = 0; i < preserved_gpr_idxs.size(); ++i) {
const int size = static_cast<int>(preserved_gpr_idxs.size());
for (int i = (size - 1); i >= 0; --i) {
h->ldr(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), post_ptr(h->sp, 16));
}
preserved_gpr_idxs.clear();
Expand Down
7 changes: 5 additions & 2 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2225,7 +2225,8 @@ void Eltwise::initSupportedPrimitiveDescriptors() {

// if dim rank is greater than the maximum possible, we should use the reference execution
#if defined (OPENVINO_ARCH_ARM64)
bool canUseOptimizedImpl = mayiuse(dnnl::impl::cpu::aarch64::asimd) && getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK;
bool canUseOptimizedImpl = mayiuse(dnnl::impl::cpu::aarch64::asimd) && (getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK);
bool canUseOptimizedShapeAgnosticImpl = isDynamicNode() && canUseOptimizedImpl;
#else
bool canUseOptimizedImpl = mayiuse(x64::sse41) && getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK;
// TODO: Add EltwiseLog algorithm support for JIT implementation
Expand Down Expand Up @@ -2316,7 +2317,9 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
canUseOptimizedImpl = false;
}

implType = (useJit && canUseOptimizedImpl) ? EltwiseImplType::optimized : EltwiseImplType::reference;
implType = (useJit && canUseOptimizedImpl) ?
(canUseOptimizedShapeAgnosticImpl ? EltwiseImplType::optimizedShapeAgnostic : EltwiseImplType::optimized) :
EltwiseImplType::reference;
#else
OPENVINO_THROW("Unknow CPU architecture");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,31 @@ void jit_uni_eltwise_generic<isa>::generate() {

// ptrs initializing
if (jep.use_runtime_ptrs) {
IE_THROW(NotImplemented) << "jit_uni_eltwise_generic<isa>::generate: jep.use_runtime_ptrs is not implemented";
for (size_t i = 0; i < jep.inputs_number; i++) {
ldr(start_to_offsets, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_offsets) + i * sizeof(size_t))));
ldr(get_src_reg(i), ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t))));
XReg offset_reg = get_aux_gpr(0); // X_TMP_0;
XReg index_reg = get_aux_gpr(1); // X_TMP_1;
for (int j = 0; j < offset_count; j++) {
ldr(offset_reg, ptr(start_to_offsets, static_cast<int32_t>(j * sizeof(size_t))));
ldr(index_reg, ptr(reg_indexes, static_cast<int32_t>(j * sizeof(size_t))));
madd(get_src_reg(i), offset_reg, index_reg, get_src_reg(i));
}
}

ldr(start_to_offsets, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_offsets))));
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
XReg offset_reg = get_aux_gpr(0); // X_TMP_0;
XReg index_reg = get_aux_gpr(1); // X_TMP_1;
for (int j = 0; j < offset_count; j++) {
ldr(offset_reg, ptr(start_to_offsets, static_cast<int32_t>(j * sizeof(size_t))));
ldr(index_reg, ptr(reg_indexes, static_cast<int32_t>(j * sizeof(size_t))));
madd(reg_dst, offset_reg, index_reg, reg_dst);
}

mov(reg_oc_off, 0);

ldr(reg_work_amount, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, work_amount))));
} else {
auto init_ptrs_with_offsets = [this, offset_count, param2](XReg pointer, const std::vector<size_t>& offsets) {
for (int j = 0; j < offset_count; j++) {
Expand All @@ -70,9 +94,12 @@ void jit_uni_eltwise_generic<isa>::generate() {
init_ptrs_with_offsets(get_src_reg(i), jep.src_offsets[i]);
}

ldr(reg_dst, ptr(param1, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
init_ptrs_with_offsets(reg_dst, jep.dst_offsets);

mov(reg_oc_off, 0);
init_ptrs_with_offsets(reg_oc_off, jep.oc_offsets);

mov(reg_work_amount, jep.work_amount);
}

Expand Down Expand Up @@ -158,7 +185,7 @@ void jit_uni_eltwise_generic<isa>::generate() {
add(reg_dst, reg_dst, jep.dst_prc.size() * loop_step);
sub(reg_work_amount, reg_work_amount, loop_step);
if (jep_.oc_size > 1 && jep_.oc_size != min_src_size)
IE_THROW(NotImplemented) << "jit_uni_eltwise_generic<isa>::generate: reg_oc_off";
add(reg_oc_off, reg_oc_off, loop_step * sizeof(float));

b(AL, unroll_loop_label);
}
Expand Down Expand Up @@ -197,7 +224,7 @@ void jit_uni_eltwise_generic<isa>::generate() {
add(reg_dst, reg_dst, jep.dst_prc.size() * loop_step);
sub(reg_work_amount, reg_work_amount, loop_step);
if (jep_.oc_size > 1)
IE_THROW(NotImplemented) << "jit_uni_eltwise_generic<isa>::generate: reg_oc_off";
add(reg_oc_off, reg_oc_off, loop_step * sizeof(float));

b(AL, main_loop_label);
}
Expand Down Expand Up @@ -233,7 +260,7 @@ void jit_uni_eltwise_generic<isa>::generate() {
add(reg_dst, reg_dst, jep.dst_prc.size() * loop_step);
sub(reg_work_amount, reg_work_amount, loop_step);
if (jep_.oc_size > 1)
IE_THROW(NotImplemented) << "jit_uni_eltwise_generic<isa>::generate: reg_oc_off";
add(reg_oc_off, reg_oc_off, loop_step * sizeof(float));

b(AL, tail_loop_label);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
void generate() override;

private:
const Xbyak_aarch64::XReg X_TMP_0 = x10;
const Xbyak_aarch64::XReg X_TMP_1 = x11;

XReg reg_post_op_ptrs = X_TMP_0;
XReg start_to_offsets = reg_post_op_ptrs;

XReg reg_oc_off = x12;
XReg reg_const_params = abi_param1;
XReg reg_indexes = abi_param2;

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
using TRegS = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TRegS;

Expand All @@ -113,9 +123,9 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// X7 | [not used] | RSP | <stack pointer>
// X8 | [not used] | R8 | src ptr
// X9 | work amount | R9 | src ptr
// X10 | [not used] | R10 | src ptr
// X11 | [not used] | R11 | src ptr
// X12 | [not used] | R12 | src ptr
// X10 | ker temporary| R10 | src ptr
// X11 | ker temporary| R11 | src ptr
// X12 | ker temporary (abi_not_param1) | R12 | src ptr
// X13 | [not used] | R13 | src ptr
// X14 | [not used] | R14 | src ptr
// X15 | dst | R15 | temporary
Expand Down

0 comments on commit 974677f

Please sign in to comment.