diff --git a/include/mkldnn.hpp b/include/mkldnn.hpp index 2d3ef0dd56c..f5996d6a52b 100644 --- a/include/mkldnn.hpp +++ b/include/mkldnn.hpp @@ -278,6 +278,7 @@ enum algorithm { eltwise_clamp = mkldnn_eltwise_clamp, eltwise_not = mkldnn_eltwise_not, eltwise_swish = mkldnn_eltwise_swish, + eltwise_mish = mkldnn_eltwise_mish, depthwise_scale_shift = mkldnn_depthwise_scale_shift, depthwise_prelu = mkldnn_depthwise_prelu, lrn_across_channels = mkldnn_lrn_across_channels, diff --git a/include/mkldnn_types.h b/include/mkldnn_types.h index 1218184e98f..67766901573 100644 --- a/include/mkldnn_types.h +++ b/include/mkldnn_types.h @@ -617,6 +617,8 @@ typedef enum { mkldnn_eltwise_not = 0xef, /** Eltwise: swish */ mkldnn_eltwise_swish = 0xff, + /** Eltwise: mish */ + mkldnn_eltwise_mish = 0x1f0, /** Max pooling */ mkldnn_pooling_max = 0x1ff, /** Average pooling include padding */ diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 9352da08e5a..8345f3bafa8 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -80,6 +80,7 @@ namespace alg_kind { const alg_kind_t eltwise_clamp = mkldnn_eltwise_clamp; const alg_kind_t eltwise_not = mkldnn_eltwise_not; const alg_kind_t eltwise_swish = mkldnn_eltwise_swish; + const alg_kind_t eltwise_mish = mkldnn_eltwise_mish; const alg_kind_t depthwise_scale_shift = mkldnn_depthwise_scale_shift; const alg_kind_t depthwise_prelu = mkldnn_depthwise_prelu; const alg_kind_t pooling_max = mkldnn_pooling_max; diff --git a/src/common/eltwise.cpp b/src/common/eltwise.cpp index 126dfd16d75..1089b5e7de6 100644 --- a/src/common/eltwise.cpp +++ b/src/common/eltwise.cpp @@ -39,7 +39,8 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, - eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish) + eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish, + eltwise_mish) && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); if (!args_ok) return invalid_arguments; diff --git a/src/common/math_utils.hpp b/src/common/math_utils.hpp index b0de3101f3a..dc0cb442ba9 100644 --- a/src/common/math_utils.hpp +++ b/src/common/math_utils.hpp @@ -294,6 +294,13 @@ inline U swish_bwd(T dd, T s, A alpha) { return dd * (v + s * alpha * v * (1 - v)); } +template ::type> +inline U mish_fwd(T s) { + float v = ::log1pf(::expf((float)s)); + return (U)(s * ::tanhf(v)); +} + template ::type> inline U scale_shift_fwd(T s_val, A w_val, A b_val) { diff --git a/src/common/mkldnn_debug.cpp b/src/common/mkldnn_debug.cpp index 71e2c794ec4..f9a8ff8325e 100644 --- a/src/common/mkldnn_debug.cpp +++ b/src/common/mkldnn_debug.cpp @@ -432,6 +432,7 @@ const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { if (v == mkldnn_eltwise_clamp) return "eltwise_clamp"; if (v == mkldnn_eltwise_not) return "eltwise_not"; if (v == mkldnn_eltwise_swish) return "eltwise_swish"; + if (v == mkldnn_eltwise_mish) return "eltwise_mish"; if (v == mkldnn_pooling_max) return "pooling_max"; if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index e251a97d6ff..cc893ed25a6 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -93,7 +93,8 @@ status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, - eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish); + eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish, + eltwise_mish); if (!known_alg) return invalid_arguments; diff --git a/src/cpu/jit_uni_eltwise.cpp b/src/cpu/jit_uni_eltwise.cpp index f54154c5c54..579d626ef88 100644 --- a/src/cpu/jit_uni_eltwise.cpp +++ b/src/cpu/jit_uni_eltwise.cpp @@ -608,6 +608,117 @@ void jit_uni_eltwise_injector_f32::swish_compute_vector( h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); } +template +void jit_uni_eltwise_injector_f32::mish_compute_vector( + const Vmm &vmm_src) { + // Save src data on stack for later usage + h->sub(h->rsp, vlen); + h->uni_vmovups(h->ptr[h->rsp], vmm_src); + + // soft_relu - ln(1+exp(x)) + // duplicate src + h->uni_vmovups(vmm_aux2, vmm_src); + + h->uni_vminps(vmm_src, vmm_src, table_val(25)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(26)); + h->uni_vmovups(vmm_aux1, vmm_src); + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); + + // keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx + // calculation fx * ln2 + h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3)); + // x = x - fx * ln2 + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); + // y = p5 + h->uni_vmovups(vmm_aux3, table_val(9)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(8)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(7)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(6)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(5)); + + // compute 2^(-n) + if (isa == avx512_common) { + h->vmulps(vmm_aux1, vmm_src, table_val(27)); + h->vcvtps2dq(vmm_aux1, vmm_aux1); + } else { + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(27)); + } + + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx + // calculate ln(1 + y) + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); + // x = y; y is free; keep x for further computations + h->uni_vmovups(vmm_src, vmm_aux3); + // frexp() + h->uni_vpsrld(vmm_src, vmm_src, 23); + h->uni_vcvtdq2ps(vmm_src, vmm_src); + // got n. where n is x = 2^n * y. y = 0.5 .. 1 + h->uni_vsubps(vmm_src, vmm_src, table_val(28)); + + h->uni_vandps(vmm_aux3, vmm_aux3, table_val(29)); + // got y. (mantisa) 0.5 < y < 1 + h->uni_vorps(vmm_aux3, vmm_aux3, table_val(30)); + // y = y - 1 + h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0)); + // y = p8 + h->uni_vmovups(vmm_aux1, table_val(39)); + // y = y * x + p7 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(38)); + // y = y * x + p6 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(37)); + // y = y * x + p5 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(36)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(35)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(34)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(33)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(32)); + // y = y * x + p0 ; p0 = 0 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(31)); + //calculate ln(2) * n + h->uni_vmulps(vmm_src, vmm_src, table_val(3)); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0); + + // get vmm_mask = src > max logf + h->uni_vmovups(vmm_mask, vmm_aux2); + if (isa == avx512_common) { + // y = (x < max log f) ? soft_relu(x) : x + h->vcmpps(k_mask, vmm_mask, table_val(25), _cmp_nle_us); + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2); + } else { + // y = (x < max log f) ? soft_relu(x) : x + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(25)); + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask); + } + h->uni_vmovups(vmm_src, vmm_aux1); + + // tanh(ln(1+exp(x))) + tanh_compute_vector(vmm_src); + // x*tanh(ln(1+exp(x))) + h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]); + h->add(h->rsp, vlen); + h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); +} + template void jit_uni_eltwise_injector_f32::relu_prepare_table() { for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); @@ -725,6 +836,63 @@ void jit_uni_eltwise_injector_f32::clamp_prepare_table() { for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_)); } +template +void jit_uni_eltwise_injector_f32::mish_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + // exp(x) polynom + 0x3f800001, // [5] p0 = 1.0000001f + 0x3efffe85, // [6] p2 = 0.4999887f + 0x3e2aaa3e, // [7] p3 = 0.16666505f + 0x3d2bb1b1, // [8] p4 = 0.041917507f + 0x3c091ec1, // [9] p5 = 0.008369149f + 0x42b17218, //[10] logf(FLT_MAX) + 0xc2aeac50, //[11] logf(FLT_MIN) + // tanh(x) constants, + 0x80000000, //[12] mask to extract sign + 0x39ddb3d7, //[13] arg below which tanh(x) = x + 0x3f0c9f54, //[14] arg below which pol approx is valid + 0x41102cb4, //[15] arg after which tanh(x) = 1 + 0xc0000000, //[16] -2.0f + 0x7fffffff, //[17] mask to make positive + // tanh pol approx + 0x3f7fffff, //[18] p0 + 0xbeaaa9cf, //[19] p1 + 0x3e085f1f, //[20] p2 + 0xbd572bda, //[21] p3 + 0x3c84fd08, //[22] p4 + // gelu approx constants + 0x3d372713, //[23] 0.044715 + 0x3f4c4229, //[24] sqrt(2/pi) + // TODO: update values [24] and [25] from comments as they are more precise + 0x42b0c0a5, //[25] max logf = 88.3762589f //0x42b17218, //[24] logf(FLT_MAX) + 0xc1766666, //[26] min logf = -14.5f //0xc2aeac50, //[25] logf(FLT_MIN) + // + 0xbf800000, //[27] is required for sign changing + 0x42fc0000, //[28] 126 + 0x807fffff, //[29] and with (to get 0.5 * mantissa) + 0x3f000000, //[30] or with (to get 0.5 * mantissa) + // ln(1 + x) polynomial + 0xb2b4637d, //[31] p0 = 0.0000000244f + 0x3f7fff8e, //[32] p1 = 0.9999976971f + 0xbf001759, //[33] p2 = -0.5002478215f + 0x3ea70608, //[34] p3 = 0.3272714505f + 0xbea3d7bf, //[35] p4 = -0.3153830071f + 0xbe361d04, //[36] p5 = -0.1701777461f + 0xbfa8f1e6, //[37] p6 = -1.3254635147f + 0xbfe1e812, //[38] p7 = -1.7971917960f + 0xbfc4d30e, //[39] p8 = -1.5652673123f + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]); + } +} + template int jit_uni_eltwise_injector_f32::aux_vecs_count(alg_kind_t alg_) { switch (alg_) { @@ -742,6 +910,7 @@ int jit_uni_eltwise_injector_f32::aux_vecs_count(alg_kind_t alg_) { case alg_kind::eltwise_gelu: return 5; case alg_kind::eltwise_clamp: return 0; case alg_kind::eltwise_swish: return 4; + case alg_kind::eltwise_mish: return 5; default: assert(!"unsupported eltwise algorithm"); } @@ -771,6 +940,7 @@ void jit_uni_eltwise_injector_f32::compute_body(size_t start_idx, case eltwise_gelu: gelu_compute_vector(Vmm(idx)); break; case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break; case eltwise_swish: swish_compute_vector(Vmm(idx)); break; + case eltwise_mish: mish_compute_vector(Vmm(idx)); break; default: assert(!"unsupported eltwise algorithm"); } } @@ -812,6 +982,7 @@ void jit_uni_eltwise_injector_f32::prepare_table(bool gen_table) { case eltwise_bounded_relu: bounded_relu_prepare_table(); break; case eltwise_square: break; case eltwise_clamp: clamp_prepare_table(); break; + case eltwise_mish: mish_prepare_table(); break; default: assert(!"unsupported eltwise algorithm"); } } @@ -1124,7 +1295,8 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32, assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, - eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish)); + eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish, + eltwise_mish)); preamble(); @@ -1289,7 +1461,8 @@ status_t jit_uni_eltwise_fwd_t::pd_t::init() { && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, - eltwise_logistic, eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish) + eltwise_logistic, eltwise_exp, eltwise_gelu, eltwise_clamp, + eltwise_swish, eltwise_mish) && memory_desc_wrapper(src_pd()).is_dense(true) && IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false), math::eltwise_fwd_preserves_zero(desc()->alg_kind, true)) diff --git a/src/cpu/jit_uni_eltwise.hpp b/src/cpu/jit_uni_eltwise.hpp index b5227134afc..d42469df833 100644 --- a/src/cpu/jit_uni_eltwise.hpp +++ b/src/cpu/jit_uni_eltwise.hpp @@ -47,7 +47,8 @@ struct jit_uni_eltwise_injector_f32 { assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, - eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish)); + eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish, + eltwise_mish)); } // note that eltwise.scale is ignored @@ -120,6 +121,7 @@ struct jit_uni_eltwise_injector_f32 { void gelu_compute_vector(const Vmm &vmm_src); void clamp_compute_vector(const Vmm &vmm_src); void swish_compute_vector(const Vmm &vmm_src); + void mish_compute_vector(const Vmm &vmm_src); void relu_prepare_table(); void elu_prepare_table(); @@ -129,6 +131,7 @@ struct jit_uni_eltwise_injector_f32 { void linear_prepare_table(); void bounded_relu_prepare_table(); void clamp_prepare_table(); + void mish_prepare_table(); }; struct jit_uni_eltwise_kernel_f32; diff --git a/src/cpu/ref_eltwise.cpp b/src/cpu/ref_eltwise.cpp index b2ade1a09d4..c892cdd167a 100644 --- a/src/cpu/ref_eltwise.cpp +++ b/src/cpu/ref_eltwise.cpp @@ -37,7 +37,8 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, - eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish)); + eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish, + eltwise_mish)); } ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( @@ -61,6 +62,7 @@ float ref_eltwise_scalar_fwd_t::compute_scalar(float s) { case eltwise_clamp: return clamp_fwd(s, alpha_, beta_); case eltwise_not: return not_fwd(s); case eltwise_swish: return swish_fwd(s, alpha_); + case eltwise_mish: return mish_fwd(s); default: assert(!"unknown eltwise alg_kind"); } @@ -96,6 +98,7 @@ void ref_eltwise_fwd_t::execute_forward_nCspBc_padded() const { case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break; case eltwise_not: d = not_fwd(s); break; case eltwise_swish: d = swish_fwd(s, alpha); break; + case eltwise_mish: d = mish_fwd(s); break; default: assert(!"unknown eltwise alg_kind"); } }; @@ -205,6 +208,7 @@ void ref_eltwise_fwd_t::execute_forward_generic() const { case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break; case eltwise_not: d = not_fwd(s); break; case eltwise_swish: d = swish_fwd(s, alpha); break; + case eltwise_mish: d = mish_fwd(s); break; default: assert(!"unknown eltwise alg_kind"); } }); @@ -299,6 +303,7 @@ void ref_eltwise_fwd_t::execute_forward_dense() const { case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break; case eltwise_not: d = not_fwd(s); break; case eltwise_swish: d = swish_fwd(s, alpha); break; + case eltwise_mish: d = mish_fwd(s); break; default: assert(!"unknown eltwise alg_kind"); } });