Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ enum algorithm {
eltwise_clamp = mkldnn_eltwise_clamp,
eltwise_not = mkldnn_eltwise_not,
eltwise_swish = mkldnn_eltwise_swish,
eltwise_mish = mkldnn_eltwise_mish,
depthwise_scale_shift = mkldnn_depthwise_scale_shift,
depthwise_prelu = mkldnn_depthwise_prelu,
lrn_across_channels = mkldnn_lrn_across_channels,
Expand Down
2 changes: 2 additions & 0 deletions include/mkldnn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@ typedef enum {
mkldnn_eltwise_not = 0xef,
/** Eltwise: swish */
mkldnn_eltwise_swish = 0xff,
/** Eltwise: mish */
mkldnn_eltwise_mish = 0x1f0,
/** Max pooling */
mkldnn_pooling_max = 0x1ff,
/** Average pooling include padding */
Expand Down
1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace alg_kind {
const alg_kind_t eltwise_clamp = mkldnn_eltwise_clamp;
const alg_kind_t eltwise_not = mkldnn_eltwise_not;
const alg_kind_t eltwise_swish = mkldnn_eltwise_swish;
const alg_kind_t eltwise_mish = mkldnn_eltwise_mish;
const alg_kind_t depthwise_scale_shift = mkldnn_depthwise_scale_shift;
const alg_kind_t depthwise_prelu = mkldnn_depthwise_prelu;
const alg_kind_t pooling_max = mkldnn_pooling_max;
Expand Down
3 changes: 2 additions & 1 deletion src/common/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
&& one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish)
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish,
eltwise_mish)
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
if (!args_ok) return invalid_arguments;

Expand Down
7 changes: 7 additions & 0 deletions src/common/math_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ inline U swish_bwd(T dd, T s, A alpha) {
return dd * (v + s * alpha * v * (1 - v));
}

template <typename T,
typename U = typename utils::remove_reference<T>::type>
inline U mish_fwd(T s) {
float v = ::log1pf(::expf((float)s));
return (U)(s * ::tanhf(v));
}

template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U scale_shift_fwd(T s_val, A w_val, A b_val) {
Expand Down
1 change: 1 addition & 0 deletions src/common/mkldnn_debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
if (v == mkldnn_eltwise_clamp) return "eltwise_clamp";
if (v == mkldnn_eltwise_not) return "eltwise_not";
if (v == mkldnn_eltwise_swish) return "eltwise_swish";
if (v == mkldnn_eltwise_mish) return "eltwise_mish";
if (v == mkldnn_pooling_max) return "pooling_max";
if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
Expand Down
3 changes: 2 additions & 1 deletion src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish);
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish,
eltwise_mish);
if (!known_alg)
return invalid_arguments;

Expand Down
177 changes: 175 additions & 2 deletions src/cpu/jit_uni_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,117 @@ void jit_uni_eltwise_injector_f32<isa>::swish_compute_vector(
h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
}

template <cpu_isa_t isa>
void jit_uni_eltwise_injector_f32<isa>::mish_compute_vector(
const Vmm &vmm_src) {
// Save src data on stack for later usage
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);

// soft_relu - ln(1+exp(x))
// duplicate src
h->uni_vmovups(vmm_aux2, vmm_src);

h->uni_vminps(vmm_src, vmm_src, table_val(25));
h->uni_vmaxps(vmm_src, vmm_src, table_val(26));
h->uni_vmovups(vmm_aux1, vmm_src);
// calculate exp(x)
// fx = x * log2ef + 0.5
h->uni_vmulps(vmm_src, vmm_src, table_val(2));
h->uni_vaddps(vmm_src, vmm_src, table_val(1));

// tmp = floorf(fx)
h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);

// keep fx for further computations
h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
// calculation fx * ln2
h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
// x = x - fx * ln2
h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
// y = p5
h->uni_vmovups(vmm_aux3, table_val(9));
// y = y * x + p4
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(8));
// y = y * x + p3
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(7));
// y = y * x + p2
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(6));
// y = y * x + p1
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
// y = y * x + p0
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(5));

// compute 2^(-n)
if (isa == avx512_common) {
h->vmulps(vmm_aux1, vmm_src, table_val(27));
h->vcvtps2dq(vmm_aux1, vmm_aux1);
} else {
h->uni_vcvtps2dq(vmm_aux1, vmm_src);
h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(27));
}

h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
// calculate ln(1 + y)
h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
// x = y; y is free; keep x for further computations
h->uni_vmovups(vmm_src, vmm_aux3);
// frexp()
h->uni_vpsrld(vmm_src, vmm_src, 23);
h->uni_vcvtdq2ps(vmm_src, vmm_src);
// got n. where n is x = 2^n * y. y = 0.5 .. 1
h->uni_vsubps(vmm_src, vmm_src, table_val(28));

h->uni_vandps(vmm_aux3, vmm_aux3, table_val(29));
// got y. (mantisa) 0.5 < y < 1
h->uni_vorps(vmm_aux3, vmm_aux3, table_val(30));
// y = y - 1
h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
// y = p8
h->uni_vmovups(vmm_aux1, table_val(39));
// y = y * x + p7
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(38));
// y = y * x + p6
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(37));
// y = y * x + p5
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(36));
// y = y * x + p4
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(35));
// y = y * x + p3
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(34));
// y = y * x + p2
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(33));
// y = y * x + p1
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(32));
// y = y * x + p0 ; p0 = 0
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(31));
//calculate ln(2) * n
h->uni_vmulps(vmm_src, vmm_src, table_val(3));
h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);

// get vmm_mask = src > max logf
h->uni_vmovups(vmm_mask, vmm_aux2);
if (isa == avx512_common) {
// y = (x < max log f) ? soft_relu(x) : x
h->vcmpps(k_mask, vmm_mask, table_val(25), _cmp_nle_us);
h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
} else {
// y = (x < max log f) ? soft_relu(x) : x
h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(25));
h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
}
h->uni_vmovups(vmm_src, vmm_aux1);

// tanh(ln(1+exp(x)))
tanh_compute_vector(vmm_src);
// x*tanh(ln(1+exp(x)))
h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
}

template <cpu_isa_t isa>
void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
Expand Down Expand Up @@ -725,6 +836,63 @@ void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
}

template <cpu_isa_t isa>
void jit_uni_eltwise_injector_f32<isa>::mish_prepare_table() {
const unsigned int cvals[] = {
0x3f800000, // [0] 1.0f
0x3f000000, // [1] 0.5f
0x3fb8aa3b, // [2] log2ef = 1.44269502f
0x3f317218, // [3] ln2f = 0.69314718f
0x0000007f, // [4] 0x7f
// exp(x) polynom
0x3f800001, // [5] p0 = 1.0000001f
0x3efffe85, // [6] p2 = 0.4999887f
0x3e2aaa3e, // [7] p3 = 0.16666505f
0x3d2bb1b1, // [8] p4 = 0.041917507f
0x3c091ec1, // [9] p5 = 0.008369149f
0x42b17218, //[10] logf(FLT_MAX)
0xc2aeac50, //[11] logf(FLT_MIN)
// tanh(x) constants,
0x80000000, //[12] mask to extract sign
0x39ddb3d7, //[13] arg below which tanh(x) = x
0x3f0c9f54, //[14] arg below which pol approx is valid
0x41102cb4, //[15] arg after which tanh(x) = 1
0xc0000000, //[16] -2.0f
0x7fffffff, //[17] mask to make positive
// tanh pol approx
0x3f7fffff, //[18] p0
0xbeaaa9cf, //[19] p1
0x3e085f1f, //[20] p2
0xbd572bda, //[21] p3
0x3c84fd08, //[22] p4
// gelu approx constants
0x3d372713, //[23] 0.044715
0x3f4c4229, //[24] sqrt(2/pi)
// TODO: update values [24] and [25] from comments as they are more precise
0x42b0c0a5, //[25] max logf = 88.3762589f //0x42b17218, //[24] logf(FLT_MAX)
0xc1766666, //[26] min logf = -14.5f //0xc2aeac50, //[25] logf(FLT_MIN)
//
0xbf800000, //[27] is required for sign changing
0x42fc0000, //[28] 126
0x807fffff, //[29] and with (to get 0.5 * mantissa)
0x3f000000, //[30] or with (to get 0.5 * mantissa)
// ln(1 + x) polynomial
0xb2b4637d, //[31] p0 = 0.0000000244f
0x3f7fff8e, //[32] p1 = 0.9999976971f
0xbf001759, //[33] p2 = -0.5002478215f
0x3ea70608, //[34] p3 = 0.3272714505f
0xbea3d7bf, //[35] p4 = -0.3153830071f
0xbe361d04, //[36] p5 = -0.1701777461f
0xbfa8f1e6, //[37] p6 = -1.3254635147f
0xbfe1e812, //[38] p7 = -1.7971917960f
0xbfc4d30e, //[39] p8 = -1.5652673123f
};

for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
}
}

template <cpu_isa_t isa>
int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
switch (alg_) {
Expand All @@ -742,6 +910,7 @@ int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
case alg_kind::eltwise_gelu: return 5;
case alg_kind::eltwise_clamp: return 0;
case alg_kind::eltwise_swish: return 4;
case alg_kind::eltwise_mish: return 5;
default: assert(!"unsupported eltwise algorithm");
}

Expand Down Expand Up @@ -771,6 +940,7 @@ void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
case eltwise_gelu: gelu_compute_vector(Vmm(idx)); break;
case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
case eltwise_swish: swish_compute_vector(Vmm(idx)); break;
case eltwise_mish: mish_compute_vector(Vmm(idx)); break;
default: assert(!"unsupported eltwise algorithm");
}
}
Expand Down Expand Up @@ -812,6 +982,7 @@ void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
case eltwise_square: break;
case eltwise_clamp: clamp_prepare_table(); break;
case eltwise_mish: mish_prepare_table(); break;
default: assert(!"unsupported eltwise algorithm");
}
}
Expand Down Expand Up @@ -1124,7 +1295,8 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish));
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish,
eltwise_mish));

preamble();

Expand Down Expand Up @@ -1289,7 +1461,8 @@ status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init() {
&& utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
eltwise_logistic, eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish)
eltwise_logistic, eltwise_exp, eltwise_gelu, eltwise_clamp,
eltwise_swish, eltwise_mish)
&& memory_desc_wrapper(src_pd()).is_dense(true)
&& IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/jit_uni_eltwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ struct jit_uni_eltwise_injector_f32 {
assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish));
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish,
eltwise_mish));
}

// note that eltwise.scale is ignored
Expand Down Expand Up @@ -120,6 +121,7 @@ struct jit_uni_eltwise_injector_f32 {
void gelu_compute_vector(const Vmm &vmm_src);
void clamp_compute_vector(const Vmm &vmm_src);
void swish_compute_vector(const Vmm &vmm_src);
void mish_compute_vector(const Vmm &vmm_src);

void relu_prepare_table();
void elu_prepare_table();
Expand All @@ -129,6 +131,7 @@ struct jit_uni_eltwise_injector_f32 {
void linear_prepare_table();
void bounded_relu_prepare_table();
void clamp_prepare_table();
void mish_prepare_table();
};

struct jit_uni_eltwise_kernel_f32;
Expand Down
7 changes: 6 additions & 1 deletion src/cpu/ref_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish));
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish,
eltwise_mish));
}

ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
Expand All @@ -61,6 +62,7 @@ float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
case eltwise_clamp: return clamp_fwd(s, alpha_, beta_);
case eltwise_not: return not_fwd(s);
case eltwise_swish: return swish_fwd(s, alpha_);
case eltwise_mish: return mish_fwd(s);
default: assert(!"unknown eltwise alg_kind");
}

Expand Down Expand Up @@ -96,6 +98,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() const {
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
case eltwise_not: d = not_fwd(s); break;
case eltwise_swish: d = swish_fwd(s, alpha); break;
case eltwise_mish: d = mish_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
};
Expand Down Expand Up @@ -205,6 +208,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_generic() const {
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
case eltwise_not: d = not_fwd(s); break;
case eltwise_swish: d = swish_fwd(s, alpha); break;
case eltwise_mish: d = mish_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
Expand Down Expand Up @@ -299,6 +303,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_dense() const {
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
case eltwise_not: d = not_fwd(s); break;
case eltwise_swish: d = swish_fwd(s, alpha); break;
case eltwise_mish: d = mish_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
Expand Down