Skip to content

Commit

Permalink
int8 support + debug
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Feb 7, 2024
1 parent 974677f commit 1e881d1
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/core/src/op/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const

using namespace ov::element;
return IF_TYPE_OF(v1_Multiply_evaluate,
OV_PP_ET_LIST(bf16, f16, f32, f64, i32, i64, u32, u64),
OV_PP_ET_LIST(bf16, f16, f32, f64, i8, i32, i64, u8, u32, u64),
multiply::Evaluate,
inputs[0].get_element_type(),
inputs[0],
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ bool jitIsSupported(const Node* node,
ov::element::f16,
ov::element::f32,
ov::element::i32,
ov::element::u32
ov::element::u32,
ov::element::i8,
ov::element::u8
};

if (!check_precisions(input_precisions, supported_precisions)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,32 @@ void jit_uni_eltwise_generic<isa>::generate() {
}
}

namespace utils {
template <typename T1, typename T2>
void load_vector(const T1& data_lane,
const T2& data_lanes,
const Xbyak_aarch64::XReg &ptr_reg,
const int64_t offset,
const bool broadcast,
jit_generator* h) {
if (broadcast) {
if (offset == 0) {
h->ld1r(data_lane, ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1r(data_lane, ptr(h->X_DEFAULT_ADDR));
}
} else {
if (offset == 0) {
h->ld1(data_lanes, Xbyak_aarch64::ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1(data_lanes, Xbyak_aarch64::ptr(h->X_DEFAULT_ADDR));
}
}
}
} // namespace utils

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
const XReg& ptr_reg,
Expand All @@ -283,16 +309,7 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
const int32_t ptr_offset) {
switch (src_prc) {
case ov::element::f16: {
if (broadcast) {
if (ptr_offset == 0) {
ld1r(data.h, ptr(ptr_reg));
} else {
add_imm(ptr_reg, ptr_reg, ptr_offset, X_DEFAULT_ADDR);
ld1r(data.h, ptr(ptr_reg));
}
} else {
ldr(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr_reg, ptr_offset));
}
utils::load_vector(data.h, data.h4, ptr_reg, ptr_offset, broadcast, this);
break;
}
case ov::element::f32:
Expand All @@ -305,8 +322,13 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
}
break;
}
case ov::element::i8:
case ov::element::u8: {
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this);
break;
}
default: {
IE_THROW(Unexpected) << "src_prc " << src_prc << " is not supported";;
IE_THROW(Unexpected) << "src_prc " << src_prc << " is not supported";
}
}

Expand All @@ -322,10 +344,22 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
scvtf(data.s, data.s);
break;
}
case ov::element::i8: {
sshll(data.h8, data.b8, 0);
sshll(data.s4, data.h4, 0);
scvtf(data.s, data.s);
break;
}
case ov::element::u32: {
ucvtf(data.s, data.s);
break;
}
case ov::element::u8: {
ushll(data.h8, data.b8, 0);
ushll(data.s4, data.h4, 0);
ucvtf(data.s, data.s);
break;
}
default:
IE_THROW(Unexpected) << "src_prc " << src_prc << " is not supported";;
}
Expand Down Expand Up @@ -353,6 +387,24 @@ void jit_uni_eltwise_generic<isa>::load_scalar(const SReg& data,
ldr(data, Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i8: {
ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));

// scalar is loaded, operates with vector
TReg vec(data.getIdx());
sshll(vec.h8, vec.b8, 0);
sshll(vec.s4, vec.h4, 0);
break;
}
case ov::element::u8: {
ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));

// scalar is loaded, operates with vector
TReg vec(data.getIdx());
ushll(vec.h8, vec.b8, 0);
ushll(vec.s4, vec.h4, 0);
break;
}
default: {
IE_THROW(Unexpected) << "dst_prc " << dst_prc << " is not supported";;
}
Expand All @@ -366,11 +418,13 @@ void jit_uni_eltwise_generic<isa>::load_scalar(const SReg& data,
fcvt(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::HReg(data.getIdx()));
break;
}
case ov::element::i32: {
case ov::element::i32:
case ov::element::i8: {
scvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx()));
break;
}
case ov::element::u32: {
case ov::element::u32:
case ov::element::u8: {
ucvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx()));
break;
}
Expand Down Expand Up @@ -406,6 +460,18 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
fcvtnu(data.s, data.s);
break;
}
case ov::element::i8: {
fcvtns(data.s, data.s);
xtn(data.h4, data.s4);
xtn(data.b8, data.h8);
break;
}
case ov::element::u8: {
fcvtnu(data.s, data.s);
xtn(data.h4, data.s4);
xtn(data.b8, data.h8);
break;
}
default: {
IE_THROW(Unexpected) << "src_prc " << src_prc << " is not supported";;
}
Expand All @@ -429,6 +495,11 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i8:
case ov::element::u8: {
str(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
IE_THROW(Unexpected) << "dst_prc " << dst_prc << " is not supported";;
}
Expand Down Expand Up @@ -457,6 +528,20 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,
fcvtnu(data, data);
break;
}
case ov::element::i8: {
TReg vec_data(data.getIdx());
fcvtns(vec_data.s, vec_data.s);
xtn(vec_data.h4, vec_data.s4);
xtn(vec_data.b8, vec_data.h8);
break;
}
case ov::element::u8: {
TReg vec_data(data.getIdx());
fcvtnu(vec_data.s, vec_data.s);
xtn(vec_data.h4, vec_data.s4);
xtn(vec_data.b8, vec_data.h8);
break;
}
default: {
IE_THROW(Unexpected) << "src_prc " << src_prc << " is not supported";;
}
Expand All @@ -480,6 +565,11 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,
str(data, Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i8:
case ov::element::u8: {
str(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
IE_THROW(Unexpected) << "dst_prc " << src_prc << " is not supported";;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,18 @@ ov::Tensor EltwiseLayerCPUTest::generate_eltwise_input(const ov::element::Type&
} else {
switch (type) {
case ov::element::i8:
params = gen_params(INT8_MAX, INT8_MIN);
if (adopt_intervals) {
params = gen_params(11 * 2, -11);
} else {
params = gen_params(INT8_MAX, INT8_MIN);
}
break;
case ov::element::u8:
params = gen_params(UINT8_MAX, 0);
if (adopt_intervals) {
params = gen_params(15, 0);
} else {
params = gen_params(UINT8_MAX, 0);
}
break;
case ov::element::i16:
params = gen_params(INT16_MAX, INT16_MIN);
Expand Down Expand Up @@ -109,7 +117,8 @@ void EltwiseLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetIn
inputs.insert({funcInput.get_node_shared_ptr(), generate_eltwise_input(
funcInput.get_element_type(),
targetInputStaticShapes[i],
(funcInput.get_element_type() == element::i32) || (funcInput.get_element_type() == element::u32))});
(funcInput.get_element_type() == element::i32) || (funcInput.get_element_type() == element::u32) ||
(funcInput.get_element_type() == element::i8) || (funcInput.get_element_type() == element::u8))});
}
}

Expand Down Expand Up @@ -199,7 +208,11 @@ void EltwiseLayerCPUTest::SetUp() {
}
}

auto data_tensor = generate_eltwise_input(netType, shape, (netType == element::i32) || (netType == element::u32));
auto data_tensor = generate_eltwise_input(
netType,
shape,
(netType == element::i32) || (netType == element::u32) ||
(netType == element::i8) || (netType == element::u8));
if ((netType == ElementType::i8) || (netType == ElementType::u8)) {
auto data_ptr = reinterpret_cast<uint8_t*>(data_tensor.data());
std::vector<uint8_t> data(data_ptr, data_ptr + ov::shape_size(shape));
Expand Down Expand Up @@ -272,8 +285,11 @@ const std::vector<ElementType>& netType() {

const std::vector<ElementType>& netTypeJit() {
static const std::vector<ElementType> netType = {
ElementType::f16,
ElementType::i32,
ElementType::f32};
ElementType::f32,
ElementType::i8,
ElementType::u8};
return netType;
}

Expand Down

0 comments on commit 1e881d1

Please sign in to comment.