From f5783b56898577de08ee3c53ea0013ee009df4a2 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Mon, 18 Jul 2022 20:59:06 +0800 Subject: [PATCH 01/22] A LayerNorm_x86 class mocking LayerNorm for tests; --- src/layer/x86/layernorm_x86.cpp | 27 +++++++++++++++++++++++++++ src/layer/x86/layernorm_x86.h | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 src/layer/x86/layernorm_x86.cpp create mode 100644 src/layer/x86/layernorm_x86.h diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp new file mode 100644 index 000000000000..6b563b1ba137 --- /dev/null +++ b/src/layer/x86/layernorm_x86.cpp @@ -0,0 +1,27 @@ +#include "layernorm_x86.h" + +#include + +namespace ncnn { + +LayerNorm_x86::LayerNorm_x86() + : LayerNorm() +{ +} + +int LayerNorm_x86::load_param(const ParamDict& pd) +{ + return LayerNorm::load_param(pd); +} + +int LayerNorm_x86::load_model(const ModelBin& mb) +{ + return LayerNorm::load_model(mb); +} + +int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + return LayerNorm::forward_inplace(bottom_top_blob, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h new file mode 100644 index 000000000000..e736eaee0e7f --- /dev/null +++ b/src/layer/x86/layernorm_x86.h @@ -0,0 +1,32 @@ +#ifndef LAYER_LAYERNORM_X86_H +#define LAYER_LAYERNORM_X86_H + +#include "layernorm.h" + +namespace ncnn { + +class LayerNorm_x86 : virtual public LayerNorm +{ +public: + LayerNorm_x86(); + + virtual int load_param(const ParamDict& pd); + + virtual int load_model(const ModelBin& mb); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +public: + // param + int affine_size; + float eps; + int affine; + + // model + Mat gamma_data; + Mat beta_data; +}; + +} // namespace ncnn + +#endif // LAYER_LAYERNORM_X86_H \ No newline at end of file From 7a94b1a4cd36569959bddeb03dad3cb60dc6cf01 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Wed, 20 Jul 2022 14:30:36 +0000 Subject: [PATCH 02/22] All SIMD optimizations success wihout support_packing; Maybe there's something strange in packing layout; --- src/layer/x86/layernorm_x86.cpp | 302 +++++++++++++++++++++++++++++++- src/layer/x86/layernorm_x86.h | 16 +- 2 files changed, 299 insertions(+), 19 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 6b563b1ba137..cd69ae25d881 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -2,26 +2,316 @@ #include +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +static NCNN_FORCEINLINE float fast_sum(float* ptr, int size) +{ + float sum = 0.0f; + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + __m512 _sum = _mm512_setzero_ps(); + for (; i + 16 <= size; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _sum = _mm512_add_ps(_sum, _cur); + } + sum += _mm512_reduce_add_ps(_sum); + } +#endif // __AVX512F__ + { + __m256 _sum = _mm256_setzero_ps(); + for (; i + 8 <= size; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _sum = _mm256_add_ps(_sum, _cur); + } + sum += _sum[0] + _sum[1] + _sum[2] + _sum[3] + _sum[4] + _sum[5] + _sum[6] + _sum[7]; + } +#endif // __AVX__ + { + __m128 _sum = _mm_setzero_ps(); + for (; i + 4 <= size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _sum = _mm_add_ps(_sum, _cur); + } + sum += _sum[0] + _sum[1] + _sum[2] + _sum[3]; + } +#endif // __SSE2__ + for (; i < size; ++i, ++ptr) + { + sum += *ptr; + } + return sum; +} + +static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) +{ + float sq_sum = 0.0f; + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + __m512 _mean = _mm512_set1_ps(mean); + __m512 _sq_sum = _mm512_setzero_ps(); + for (; i + 16 <= size; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_sub_ps(_cur, _mean); + _cur = _mm512_mul_ps(_cur, _cur); + _sq_sum = _mm512_add_ps(_sq_sum, _cur); + } + sq_sum += _mm512_reduce_add_ps(_sq_sum); + } +#endif // __AVX512F__ + { + __m256 _mean = _mm256_set1_ps(mean); + __m256 _sq_sum = _mm256_setzero_ps(); + for (; i + 8 <= size; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _cur = _mm256_sub_ps(_cur, _mean); + _cur = _mm256_mul_ps(_cur, _cur); + _sq_sum = _mm256_add_ps(_sq_sum, _cur); + } + sq_sum += _sq_sum[0] + _sq_sum[1] + _sq_sum[2] + _sq_sum[3] + _sq_sum[4] + _sq_sum[5] + _sq_sum[6] + _sq_sum[7]; + } +#endif // __AVX__ + { + __m128 _mean = _mm_set1_ps(mean); + __m128 _sq_sum = _mm_setzero_ps(); + for (; i + 4 <= size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_sub_ps(_cur, _mean); + _cur = _mm_mul_ps(_cur, _cur); + _sq_sum = _mm_add_ps(_sq_sum, _cur); + } + sq_sum += _sq_sum[0] + _sq_sum[1] + _sq_sum[2] + _sq_sum[3]; + } +#endif // __SSE2__ + for (; i < size; ++i, ++ptr) + { + float tmp = *ptr - mean; + sq_sum += tmp * tmp; + } + return sq_sum / size; +} + +static void fast_fmadd(float* ptr, float a, float b, int size) +{ + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + // 512 bit FMA instructions are included in AVX512F. + __m512 _a = _mm512_set1_ps(a); + __m512 _b = _mm512_set1_ps(b); + for (; i + 16 <= size; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _mm512_storeu_ps(ptr, _cur); + } + } +#endif // __AVX512F__ + { + // 256 bit FMA instructions are not included in AVX1 + __m256 _a = _mm256_set1_ps(a); + __m256 _b = _mm256_set1_ps(b); + for (; i + 8 <= size; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); +#if __FMA__ + _cur = _mm256_fmadd_ps(_cur, _a, _b); +#else + _cur = _mm256_mul_ps(_cur, _a); + _cur = _mm256_add_ps(_cur, _b); +#endif + _mm256_storeu_ps(ptr, _cur); + } + } +#endif // __AVX__ + { + __m128 _a = _mm_set1_ps(a); + __m128 _b = _mm_set1_ps(b); + for (; i + 4 <= size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_mul_ps(_cur, _a); + _cur = _mm_add_ps(_cur, _b); + _mm_storeu_ps(ptr, _cur); + } + } +#endif // __SSE2__ + for (; i < size; ++i, ++ptr) + { + *ptr = (*ptr) * a + b; + } +} + namespace ncnn { LayerNorm_x86::LayerNorm_x86() - : LayerNorm() { +#if __SSE2__ + support_packing = false; +#endif // __SSE2__ } -int LayerNorm_x86::load_param(const ParamDict& pd) +void LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int size) const { - return LayerNorm::load_param(pd); + int i = 0; + auto gamma = static_cast(gamma_data); + auto beta = static_cast(beta_data); +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + __m512 _a = _mm512_set1_ps(a); + __m512 _b = _mm512_set1_ps(b); + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma); + __m512 _beta = _mm512_loadu_ps(beta); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); + } + } +#endif // __AVX512F__ + { + __m256 _a = _mm256_set1_ps(a); + __m256 _b = _mm256_set1_ps(b); + + for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma); + __m256 _beta = _mm256_loadu_ps(beta); +#if __FMA__ + _cur = _mm256_fmadd_ps(_cur, _a, _b); + _cur = _mm256_fmadd_ps(_cur, _gamma, _beta); +#else + _cur = _mm256_mul_ps(_cur, _a); + _cur = _mm256_add_ps(_cur, _b); + _cur = _mm256_mul_ps(_cur, _gamma); + _cur = _mm256_add_ps(_cur, _beta); +#endif + _mm256_storeu_ps(ptr, _cur); + } + } +#endif // __AVX__ + { + __m128 _a = _mm_set1_ps(a); + __m128 _b = _mm_set1_ps(b); + for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma); + __m128 _beta = _mm_loadu_ps(beta); + _cur = _mm_mul_ps(_cur, _a); + _cur = _mm_add_ps(_cur, _b); + _cur = _mm_mul_ps(_cur, _gamma); + _cur = _mm_add_ps(_cur, _beta); + _mm_storeu_ps(ptr, _cur); + } + } +#endif // __SSE2__ + for (; i < size; ++i, ++ptr, ++gamma, ++beta) + { + *ptr = ((*ptr) * a + b) * (*gamma) + (*beta); + } } -int LayerNorm_x86::load_model(const ModelBin& mb) +void LayerNorm_x86::fast_1d_layer_norm(float* ptr, int size) const { - return LayerNorm::load_model(mb); + // mean and var + float sum = fast_sum(ptr, size); + float mean = sum / size; + float var = fast_var(ptr, size, mean); + + float a = static_cast(1.0f / sqrt(var + eps)); + float b = -mean * a; + + if (affine) + { + fast_fmadd_fmadd(ptr, a, b, size); + } + else + { + fast_fmadd(ptr, a, b, size); + } } int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - return LayerNorm::forward_inplace(bottom_top_blob, opt); + int dims = bottom_top_blob.dims; + + if (dims == 1) + { + int size = bottom_top_blob.w * bottom_top_blob.elempack; + float* ptr = bottom_top_blob; + fast_1d_layer_norm(ptr, size); + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + +#pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; ++i) + { + float* ptr = bottom_top_blob.row(i); + int size = w * bottom_top_blob.elempack; + fast_1d_layer_norm(ptr, size); + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int channels = bottom_top_blob.c; + int size = w * h * bottom_top_blob.elempack; + + if (affine_size == w) + { + size = w * bottom_top_blob.elempack; + // #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + fast_1d_layer_norm(ptr, size); + } + } + } + else // if (affine_size == size) + { + // #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + fast_1d_layer_norm(ptr, size); + } + } + } + + return 0; } } // namespace ncnn diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index e736eaee0e7f..1c509a555edc 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -10,21 +10,11 @@ class LayerNorm_x86 : virtual public LayerNorm public: LayerNorm_x86(); - virtual int load_param(const ParamDict& pd); - - virtual int load_model(const ModelBin& mb); - virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; -public: - // param - int affine_size; - float eps; - int affine; - - // model - Mat gamma_data; - Mat beta_data; +protected: + void fast_1d_layer_norm(float* ptr, int size) const; + void fast_fmadd_fmadd(float* ptr, float a, float b, int size) const; }; } // namespace ncnn From b126e0f73459e1bb8e4eb7094bdc3d9c28b15c48 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Thu, 21 Jul 2022 01:52:14 +0000 Subject: [PATCH 03/22] Located error about packed layout. --- src/layer/x86/layernorm_x86.cpp | 37 ++++++++++++++++++++++----------- src/layer/x86/layernorm_x86.h | 7 +++++-- tests/test_layernorm.cpp | 4 ++-- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index cd69ae25d881..378c116a160b 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -107,7 +107,7 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) return sq_sum / size; } -static void fast_fmadd(float* ptr, float a, float b, int size) +static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int size) { int i = 0; #if __SSE2__ @@ -165,11 +165,11 @@ namespace ncnn { LayerNorm_x86::LayerNorm_x86() { #if __SSE2__ - support_packing = false; + support_packing = true; #endif // __SSE2__ } -void LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int size) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int size) const { int i = 0; auto gamma = static_cast(gamma_data); @@ -235,7 +235,7 @@ void LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int size) con } } -void LayerNorm_x86::fast_1d_layer_norm(float* ptr, int size) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int size) const { // mean and var float sum = fast_sum(ptr, size); @@ -255,13 +255,13 @@ void LayerNorm_x86::fast_1d_layer_norm(float* ptr, int size) const } } -int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blob, const Option& opt) const { int dims = bottom_top_blob.dims; if (dims == 1) { - int size = bottom_top_blob.w * bottom_top_blob.elempack; + int size = bottom_top_blob.w; float* ptr = bottom_top_blob; fast_1d_layer_norm(ptr, size); } @@ -270,12 +270,12 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons { int w = bottom_top_blob.w; int h = bottom_top_blob.h; - + int size = w; #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); - int size = w * bottom_top_blob.elempack; + fast_1d_layer_norm(ptr, size); } } @@ -285,12 +285,12 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons int w = bottom_top_blob.w; int h = bottom_top_blob.h; int channels = bottom_top_blob.c; - int size = w * h * bottom_top_blob.elempack; + int size = w * h; if (affine_size == w) { - size = w * bottom_top_blob.elempack; - // #pragma omp parallel for num_threads(opt.num_threads) + size = w; +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { for (int i = 0; i < h; i++) @@ -302,7 +302,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else // if (affine_size == size) { - // #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { float* ptr = bottom_top_blob.channel(q); @@ -314,4 +314,17 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons return 0; } +int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + if (bottom_top_blob.elempack == 1) + { + return forward_inplace_unpacked(bottom_top_blob, opt); + } + else + { + fprintf(stderr, "Packed forward not implemented!\n"); + return -1; + } +} + } // namespace ncnn diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index 1c509a555edc..9aaad8d549d4 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -13,8 +13,11 @@ class LayerNorm_x86 : virtual public LayerNorm virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; protected: - void fast_1d_layer_norm(float* ptr, int size) const; - void fast_fmadd_fmadd(float* ptr, float a, float b, int size) const; + NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int size) const; + NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float a, float b, int size) const; + + NCNN_FORCEINLINE int forward_inplace_unpacked(Mat& bottom_top_blob, const Option& opt) const; + int forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const; }; } // namespace ncnn diff --git a/tests/test_layernorm.cpp b/tests/test_layernorm.cpp index b4e3ad7fa1e6..5205de96251d 100644 --- a/tests/test_layernorm.cpp +++ b/tests/test_layernorm.cpp @@ -38,8 +38,8 @@ static int test_layernorm(const ncnn::Mat& a, int affine_size, float eps, int af static int test_layernorm_0() { return 0 - || test_layernorm(RandomMat(6, 4, 2), 6, 0.01f, 0) - || test_layernorm(RandomMat(4, 5, 6), 4, 0.01f, 0) + // || test_layernorm(RandomMat(6, 4, 2), 6, 0.01f, 0) + // || test_layernorm(RandomMat(4, 5, 6), 4, 0.01f, 0) || test_layernorm(RandomMat(3, 3, 8), 3, 0.002f, 0) || test_layernorm(RandomMat(5, 6, 12), 5, 0.02f, 0) || test_layernorm(RandomMat(6, 7, 24), 6, 0.001f, 0) From 1982605f1e3d8344beb77068e9eea829d0dc5f92 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Thu, 21 Jul 2022 05:37:00 +0000 Subject: [PATCH 04/22] All test passed; Now it supports packing layout --- src/layer/x86/layernorm_x86.cpp | 386 ++++++++++++++++++++++++++++---- src/layer/x86/layernorm_x86.h | 6 +- tests/test_layernorm.cpp | 4 +- 3 files changed, 352 insertions(+), 44 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 378c116a160b..f4b9dfbb914e 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -1,6 +1,7 @@ #include "layernorm_x86.h" #include +#include #if __SSE2__ #include @@ -9,7 +10,7 @@ #endif // __AVX__ #endif // __SSE2__ -static NCNN_FORCEINLINE float fast_sum(float* ptr, int size) +static NCNN_FORCEINLINE float fast_mean(float* ptr, int elemcount) { float sum = 0.0f; int i = 0; @@ -18,7 +19,7 @@ static NCNN_FORCEINLINE float fast_sum(float* ptr, int size) #if __AVX512F__ { __m512 _sum = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) + for (; i + 16 <= elemcount; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _sum = _mm512_add_ps(_sum, _cur); @@ -28,7 +29,7 @@ static NCNN_FORCEINLINE float fast_sum(float* ptr, int size) #endif // __AVX512F__ { __m256 _sum = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) + for (; i + 8 <= elemcount; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _sum = _mm256_add_ps(_sum, _cur); @@ -38,7 +39,7 @@ static NCNN_FORCEINLINE float fast_sum(float* ptr, int size) #endif // __AVX__ { __m128 _sum = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) + for (; i + 4 <= elemcount; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _sum = _mm_add_ps(_sum, _cur); @@ -46,14 +47,14 @@ static NCNN_FORCEINLINE float fast_sum(float* ptr, int size) sum += _sum[0] + _sum[1] + _sum[2] + _sum[3]; } #endif // __SSE2__ - for (; i < size; ++i, ++ptr) + for (; i < elemcount; ++i, ++ptr) { sum += *ptr; } - return sum; + return sum / elemcount; } -static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) +static NCNN_FORCEINLINE float fast_var(float* ptr, int elemcount, float mean) { float sq_sum = 0.0f; int i = 0; @@ -63,7 +64,7 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) { __m512 _mean = _mm512_set1_ps(mean); __m512 _sq_sum = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) + for (; i + 16 <= elemcount; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _cur = _mm512_sub_ps(_cur, _mean); @@ -76,7 +77,7 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) { __m256 _mean = _mm256_set1_ps(mean); __m256 _sq_sum = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) + for (; i + 8 <= elemcount; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _cur = _mm256_sub_ps(_cur, _mean); @@ -89,7 +90,7 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) { __m128 _mean = _mm_set1_ps(mean); __m128 _sq_sum = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) + for (; i + 4 <= elemcount; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _cur = _mm_sub_ps(_cur, _mean); @@ -99,15 +100,15 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int size, float mean) sq_sum += _sq_sum[0] + _sq_sum[1] + _sq_sum[2] + _sq_sum[3]; } #endif // __SSE2__ - for (; i < size; ++i, ++ptr) + for (; i < elemcount; ++i, ++ptr) { float tmp = *ptr - mean; sq_sum += tmp * tmp; } - return sq_sum / size; + return sq_sum / elemcount; } -static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int size) +static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int elemcount) { int i = 0; #if __SSE2__ @@ -117,7 +118,7 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int size) // 512 bit FMA instructions are included in AVX512F. __m512 _a = _mm512_set1_ps(a); __m512 _b = _mm512_set1_ps(b); - for (; i + 16 <= size; i += 16, ptr += 16) + for (; i + 16 <= elemcount; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _cur = _mm512_fmadd_ps(_cur, _a, _b); @@ -129,7 +130,7 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int size) // 256 bit FMA instructions are not included in AVX1 __m256 _a = _mm256_set1_ps(a); __m256 _b = _mm256_set1_ps(b); - for (; i + 8 <= size; i += 8, ptr += 8) + for (; i + 8 <= elemcount; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); #if __FMA__ @@ -145,7 +146,7 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int size) { __m128 _a = _mm_set1_ps(a); __m128 _b = _mm_set1_ps(b); - for (; i + 4 <= size; i += 4, ptr += 4) + for (; i + 4 <= elemcount; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _cur = _mm_mul_ps(_cur, _a); @@ -154,12 +155,147 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int size) } } #endif // __SSE2__ - for (; i < size; ++i, ++ptr) + for (; i < elemcount; ++i, ++ptr) { *ptr = (*ptr) * a + b; } } +static void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcount, int size) +{ + int i = 0; + if (elempack == 4) + { + __m128 _sum = _mm_setzero_ps(); + __m128 _elemcount = _mm_set1_ps(float(elemcount)); + for (; i < size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _sum = _mm_add_ps(_sum, _cur); + } + __m128 _mean = _mm_div_ps(_sum, _elemcount); + _mm_storeu_ps(mean, _mean); + } + else if (elempack == 8) + { + __m256 _sum = _mm256_setzero_ps(); + __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + for (; i < size; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _sum = _mm256_add_ps(_sum, _cur); + } + __m256 _mean = _mm256_div_ps(_sum, _elemcount); + _mm256_storeu_ps(mean, _mean); + } + else if (elempack == 16) + { + __m512 _sum = _mm512_setzero_ps(); + __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + for (; i < size; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _sum = _mm512_add_ps(_sum, _cur); + } + __m512 _mean = _mm512_div_ps(_sum, _elemcount); + _mm512_storeu_ps(mean, _mean); + } +} + +static void fast_var_packed(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) +{ + int i = 0; + if (elempack == 4) + { + __m128 _mean = _mm_loadu_ps(mean); + __m128 _sq_sum = _mm_setzero_ps(); + __m128 _elemcount = _mm_set1_ps(float(elemcount)); + for (; i < size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_sub_ps(_cur, _mean); + _cur = _mm_mul_ps(_cur, _cur); + _sq_sum = _mm_add_ps(_sq_sum, _cur); + } + __m128 _var = _mm_div_ps(_sq_sum, _elemcount); + _mm_storeu_ps(var, _var); + } + else if (elempack == 8) + { + __m256 _mean = _mm256_loadu_ps(mean); + __m256 _sq_sum = _mm256_setzero_ps(); + __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + for (; i < size; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _cur = _mm256_sub_ps(_cur, _mean); + _cur = _mm256_mul_ps(_cur, _cur); + _sq_sum = _mm256_add_ps(_sq_sum, _cur); + } + __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); + _mm256_storeu_ps(var, _var); + } + else if (elempack == 16) + { + __m512 _mean = _mm512_loadu_ps(mean); + __m512 _sq_sum = _mm512_setzero_ps(); + __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + for (; i < size; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_sub_ps(_cur, _mean); + _cur = _mm512_mul_ps(_cur, _cur); + _sq_sum = _mm512_add_ps(_sq_sum, _cur); + } + __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); + _mm512_storeu_ps(var, _var); + } +} + +static void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) +{ + int i = 0; + if (elempack == 4) + { + __m128 _a = _mm_loadu_ps(a); + __m128 _b = _mm_loadu_ps(b); + for (; i < size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_mul_ps(_cur, _a); + _cur = _mm_add_ps(_cur, _b); + _mm_storeu_ps(ptr, _cur); + } + } + else if (elempack == 8) + { + __m256 _a = _mm256_loadu_ps(a); + __m256 _b = _mm256_loadu_ps(b); + for (; i < size; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); +#if __FMA__ + _cur = _mm256_fmadd_ps(_cur, _a, _b); +#else + _cur = _mm256_mul_ps(_cur, _a); + _cur = _mm256_add_ps(_cur, _b); +#endif + _mm256_storeu_ps(ptr, _cur); + } + } + else if (elempack == 16) + { + __m512 _a = _mm512_loadu_ps(a); + __m512 _b = _mm512_loadu_ps(b); + for (; i < size; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _mm512_storeu_ps(ptr, _cur); + } + } +} + namespace ncnn { LayerNorm_x86::LayerNorm_x86() @@ -169,7 +305,7 @@ LayerNorm_x86::LayerNorm_x86() #endif // __SSE2__ } -void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int size) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int elemcount) const { int i = 0; auto gamma = static_cast(gamma_data); @@ -180,7 +316,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float { __m512 _a = _mm512_set1_ps(a); __m512 _b = _mm512_set1_ps(b); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) + for (; i + 16 <= elemcount; i += 16, ptr += 16, gamma += 16, beta += 16) { __m512 _cur = _mm512_loadu_ps(ptr); __m512 _gamma = _mm512_loadu_ps(gamma); @@ -195,7 +331,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float __m256 _a = _mm256_set1_ps(a); __m256 _b = _mm256_set1_ps(b); - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) + for (; i + 8 <= elemcount; i += 8, ptr += 8, gamma += 8, beta += 8) { __m256 _cur = _mm256_loadu_ps(ptr); __m256 _gamma = _mm256_loadu_ps(gamma); @@ -216,7 +352,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float { __m128 _a = _mm_set1_ps(a); __m128 _b = _mm_set1_ps(b); - for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) + for (; i + 4 <= elemcount; i += 4, ptr += 4, gamma += 4, beta += 4) { __m128 _cur = _mm_loadu_ps(ptr); __m128 _gamma = _mm_loadu_ps(gamma); @@ -229,29 +365,155 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float } } #endif // __SSE2__ - for (; i < size; ++i, ++ptr, ++gamma, ++beta) + for (; i < elemcount; ++i, ++ptr, ++gamma, ++beta) { *ptr = ((*ptr) * a + b) * (*gamma) + (*beta); } } -void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int size) const +void LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const +{ + int i = 0; + auto gamma = static_cast(gamma_data); + auto beta = static_cast(beta_data); + if (elempack == 4) + { + __m128 _a = _mm_loadu_ps(a); + __m128 _b = _mm_loadu_ps(b); + for (; i < size; i += 4, ptr += 4, ++gamma, ++beta) + { + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(*gamma); + __m128 _beta = _mm_set1_ps(*beta); + _cur = _mm_mul_ps(_cur, _a); + _cur = _mm_add_ps(_cur, _b); + _cur = _mm_mul_ps(_cur, _gamma); + _cur = _mm_add_ps(_cur, _beta); + _mm_storeu_ps(ptr, _cur); + } + } + else if (elempack == 8) + { + __m256 _a = _mm256_loadu_ps(a); + __m256 _b = _mm256_loadu_ps(b); + for (; i < size; i += 8, ptr += 8, ++gamma,++beta) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(*gamma); + __m256 _beta = _mm256_set1_ps(*beta); +#if __FMA__ + _cur = _mm256_fmadd_ps(_cur, _a, _b); + _cur = _mm256_fmadd_ps(_cur, _gamma, _beta); +#else + _cur = _mm256_mul_ps(_cur, _a); + _cur = _mm256_add_ps(_cur, _b); + _cur = _mm256_mul_ps(_cur, _gamma); + _cur = _mm256_add_ps(_cur, _beta); +#endif + _mm256_storeu_ps(ptr, _cur); + } + } + else if (elempack == 16) + { + __m512 _a = _mm512_loadu_ps(a); + __m512 _b = _mm512_loadu_ps(b); + for (; i < size; i += 16, ptr += 16, ++gamma,++beta) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(*gamma); + __m512 _beta = _mm512_set1_ps(*beta); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); + } + } +} + +void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elemcount) const { // mean and var - float sum = fast_sum(ptr, size); - float mean = sum / size; - float var = fast_var(ptr, size, mean); + float mean = fast_mean(ptr, elemcount); + float var = fast_var(ptr, elemcount, mean); float a = static_cast(1.0f / sqrt(var + eps)); float b = -mean * a; if (affine) { - fast_fmadd_fmadd(ptr, a, b, size); + fast_fmadd_fmadd(ptr, a, b, elemcount); } else { - fast_fmadd(ptr, a, b, size); + fast_fmadd(ptr, a, b, elemcount); + } +} + +void LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const +{ + float mean[16], var[16]; + fast_mean_packed(ptr, mean, elempack, elemcount, size); + fast_var_packed(ptr, var, mean, elempack, elemcount, size); + float *a = var, *b = mean; + + if (elempack == 4) + { + __m128 _a = _mm_set1_ps(1.0f); + __m128 _eps = _mm_set1_ps(eps); + __m128 _b = _mm_setzero_ps(); + __m128 _var = _mm_loadu_ps(var); + _var = _mm_add_ps(_var, _eps); + __m128 _sqrt_var = _mm_sqrt_ps(_var); + _a = _mm_div_ps(_a, _sqrt_var); + __m128 _mean_a = _mm_loadu_ps(mean); + _mean_a = _mm_mul_ps(_mean_a, _a); + _b = _mm_sub_ps(_b, _mean_a); + + _mm_storeu_ps(a, _a); + _mm_storeu_ps(b, _b); + } + else if (elempack == 8) + { + __m256 _a = _mm256_set1_ps(1.0f); + __m256 _eps = _mm256_set1_ps(eps); + __m256 _b = _mm256_setzero_ps(); + __m256 _var = _mm256_loadu_ps(var); + _var = _mm256_add_ps(_var, _eps); + __m256 _sqrt_var = _mm256_sqrt_ps(_var); + _a = _mm256_div_ps(_a, _sqrt_var); +#if __FMA__ + __m256 _mean = _mm256_loadu_ps(mean); + _b = _mm256_fnmadd_ps(_mean, _a, _b); +#else + __m256 _mean_a = _mm256_loadu_ps(mean); + _mean_a = _mm256_mul_ps(_mean_a, _a); + _b = _mm256_sub_ps(_b, _mean_a); +#endif + _mm256_storeu_ps(a, _a); + _mm256_storeu_ps(b, _b); + } + else if (elempack == 16) + { + __m512 _a = _mm512_set1_ps(1.0f); + __m512 _eps = _mm512_set1_ps(eps); + __m512 _b = _mm512_setzero_ps(); + __m512 _var = _mm512_loadu_ps(var); + _var = _mm512_add_ps(_var, _eps); + __m512 _sqrt_var = _mm512_sqrt_ps(_var); + _a = _mm512_div_ps(_a, _sqrt_var); + __m512 _mean = _mm512_loadu_ps(mean); + _b = _mm512_fnmadd_ps(_mean, _a, _b); + + _mm512_storeu_ps(a, _a); + _mm512_storeu_ps(b, _b); + } + + if (affine) + { + fast_fmadd_fmadd_packed(ptr, a, b, elempack, size); + } + else + { + fast_fmadd_packed(ptr, a, b, elempack, size); } } @@ -261,22 +523,22 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo if (dims == 1) { - int size = bottom_top_blob.w; + int elemcount = bottom_top_blob.w * bottom_top_blob.elempack; float* ptr = bottom_top_blob; - fast_1d_layer_norm(ptr, size); + fast_1d_layer_norm(ptr, elemcount); } if (dims == 2) { int w = bottom_top_blob.w; int h = bottom_top_blob.h; - int size = w; + int elemcount = w * bottom_top_blob.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); - fast_1d_layer_norm(ptr, size); + fast_1d_layer_norm(ptr, elemcount); } } @@ -285,28 +547,73 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo int w = bottom_top_blob.w; int h = bottom_top_blob.h; int channels = bottom_top_blob.c; - int size = w * h; + int elemcount = w * h * bottom_top_blob.elempack; if (affine_size == w) { - size = w; + elemcount = w * bottom_top_blob.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { for (int i = 0; i < h; i++) { float* ptr = bottom_top_blob.channel(q).row(i); - fast_1d_layer_norm(ptr, size); + fast_1d_layer_norm(ptr, elemcount); } } } - else // if (affine_size == size) + else // if (affine_elemcount == elemcount) { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { float* ptr = bottom_top_blob.channel(q); - fast_1d_layer_norm(ptr, size); + fast_1d_layer_norm(ptr, elemcount); + } + } + } + + return 0; +} + +int LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const +{ + int elempack = bottom_top_blob.elempack; + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + // Now, bottoms_top_blob.dims >= 2 + if (bottom_top_blob.dims == 2) + { +#pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; ++i) + { + float* ptr = bottom_top_blob.row(i); + fast_1d_layer_norm_packed(ptr, elempack, w, w * elempack); + } + } + else if (bottom_top_blob.dims == 3) + { + int channels = bottom_top_blob.c; + if (affine_size == w) + { +#pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; ++q) + { + for (int i = 0; i < h; ++i) + { + float* ptr = bottom_top_blob.channel(q).row(i); + fast_1d_layer_norm_packed(ptr, elempack, w, w * elempack); + } + } + } + else // if(affine_size == w * h) + { +#pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; ++q) + { + float* ptr = bottom_top_blob.channel(q); + fast_1d_layer_norm_packed(ptr, elempack, w * h, w * h * elempack); } } } @@ -316,14 +623,13 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - if (bottom_top_blob.elempack == 1) + if (bottom_top_blob.elempack == 1 || bottom_top_blob.dims == 1) { return forward_inplace_unpacked(bottom_top_blob, opt); } else { - fprintf(stderr, "Packed forward not implemented!\n"); - return -1; + return forward_inplace_packed(bottom_top_blob, opt); } } diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index 9aaad8d549d4..f521703e0ff8 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -13,8 +13,10 @@ class LayerNorm_x86 : virtual public LayerNorm virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; protected: - NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int size) const; - NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float a, float b, int size) const; + NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elemcount) const; + void fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const; + NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float a, float b, int elemcount) const; + void fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const; NCNN_FORCEINLINE int forward_inplace_unpacked(Mat& bottom_top_blob, const Option& opt) const; int forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const; diff --git a/tests/test_layernorm.cpp b/tests/test_layernorm.cpp index 5205de96251d..b4e3ad7fa1e6 100644 --- a/tests/test_layernorm.cpp +++ b/tests/test_layernorm.cpp @@ -38,8 +38,8 @@ static int test_layernorm(const ncnn::Mat& a, int affine_size, float eps, int af static int test_layernorm_0() { return 0 - // || test_layernorm(RandomMat(6, 4, 2), 6, 0.01f, 0) - // || test_layernorm(RandomMat(4, 5, 6), 4, 0.01f, 0) + || test_layernorm(RandomMat(6, 4, 2), 6, 0.01f, 0) + || test_layernorm(RandomMat(4, 5, 6), 4, 0.01f, 0) || test_layernorm(RandomMat(3, 3, 8), 3, 0.002f, 0) || test_layernorm(RandomMat(5, 6, 12), 5, 0.02f, 0) || test_layernorm(RandomMat(6, 7, 24), 6, 0.001f, 0) From 0fa868924169b1d7722791a0d570a7537a80b549 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Thu, 21 Jul 2022 06:03:53 +0000 Subject: [PATCH 05/22] Fix runtime cpu dispatch; --- src/layer/x86/layernorm_x86.cpp | 36 +++++++++++++++++++++++++-------- src/layer/x86/layernorm_x86.h | 6 +++--- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index f4b9dfbb914e..69a1bbe3a53c 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -161,7 +161,7 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int elemco } } -static void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcount, int size) { int i = 0; if (elempack == 4) @@ -178,6 +178,7 @@ static void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcoun } else if (elempack == 8) { +#if __AVX__ __m256 _sum = _mm256_setzero_ps(); __m256 _elemcount = _mm256_set1_ps(float(elemcount)); for (; i < size; i += 8, ptr += 8) @@ -187,9 +188,11 @@ static void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcoun } __m256 _mean = _mm256_div_ps(_sum, _elemcount); _mm256_storeu_ps(mean, _mean); +#endif } else if (elempack == 16) { +#if __AVX512F__ __m512 _sum = _mm512_setzero_ps(); __m512 _elemcount = _mm512_set1_ps(float(elemcount)); for (; i < size; i += 16, ptr += 16) @@ -199,10 +202,11 @@ static void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcoun } __m512 _mean = _mm512_div_ps(_sum, _elemcount); _mm512_storeu_ps(mean, _mean); +#endif } } -static void fast_var_packed(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) { int i = 0; if (elempack == 4) @@ -222,6 +226,7 @@ static void fast_var_packed(float* ptr, float* var, float* mean, int elempack, i } else if (elempack == 8) { +#if __AVX__ __m256 _mean = _mm256_loadu_ps(mean); __m256 _sq_sum = _mm256_setzero_ps(); __m256 _elemcount = _mm256_set1_ps(float(elemcount)); @@ -234,9 +239,11 @@ static void fast_var_packed(float* ptr, float* var, float* mean, int elempack, i } __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); _mm256_storeu_ps(var, _var); +#endif } else if (elempack == 16) { +#if __AVX512F__ __m512 _mean = _mm512_loadu_ps(mean); __m512 _sq_sum = _mm512_setzero_ps(); __m512 _elemcount = _mm512_set1_ps(float(elemcount)); @@ -249,10 +256,11 @@ static void fast_var_packed(float* ptr, float* var, float* mean, int elempack, i } __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); _mm512_storeu_ps(var, _var); +#endif } } -static void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) +static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) { int i = 0; if (elempack == 4) @@ -269,6 +277,7 @@ static void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int } else if (elempack == 8) { +#if __AVX__ __m256 _a = _mm256_loadu_ps(a); __m256 _b = _mm256_loadu_ps(b); for (; i < size; i += 8, ptr += 8) @@ -282,9 +291,11 @@ static void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int #endif _mm256_storeu_ps(ptr, _cur); } +#endif } else if (elempack == 16) { +#if __AVX512F__ __m512 _a = _mm512_loadu_ps(a); __m512 _b = _mm512_loadu_ps(b); for (; i < size; i += 16, ptr += 16) @@ -293,6 +304,7 @@ static void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int _cur = _mm512_fmadd_ps(_cur, _a, _b); _mm512_storeu_ps(ptr, _cur); } +#endif } } @@ -371,7 +383,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float } } -void LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const { int i = 0; auto gamma = static_cast(gamma_data); @@ -394,9 +406,10 @@ void LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int } else if (elempack == 8) { +#if __AVX__ __m256 _a = _mm256_loadu_ps(a); __m256 _b = _mm256_loadu_ps(b); - for (; i < size; i += 8, ptr += 8, ++gamma,++beta) + for (; i < size; i += 8, ptr += 8, ++gamma, ++beta) { __m256 _cur = _mm256_loadu_ps(ptr); __m256 _gamma = _mm256_set1_ps(*gamma); @@ -412,12 +425,14 @@ void LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int #endif _mm256_storeu_ps(ptr, _cur); } +#endif } else if (elempack == 16) { +#if __AVX512F__ __m512 _a = _mm512_loadu_ps(a); __m512 _b = _mm512_loadu_ps(b); - for (; i < size; i += 16, ptr += 16, ++gamma,++beta) + for (; i < size; i += 16, ptr += 16, ++gamma, ++beta) { __m512 _cur = _mm512_loadu_ps(ptr); __m512 _gamma = _mm512_set1_ps(*gamma); @@ -426,6 +441,7 @@ void LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); _mm512_storeu_ps(ptr, _cur); } +#endif } } @@ -448,7 +464,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elemcoun } } -void LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const { float mean[16], var[16]; fast_mean_packed(ptr, mean, elempack, elemcount, size); @@ -473,6 +489,7 @@ void LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elem } else if (elempack == 8) { +#if __AVX__ __m256 _a = _mm256_set1_ps(1.0f); __m256 _eps = _mm256_set1_ps(eps); __m256 _b = _mm256_setzero_ps(); @@ -490,9 +507,11 @@ void LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elem #endif _mm256_storeu_ps(a, _a); _mm256_storeu_ps(b, _b); +#endif } else if (elempack == 16) { +#if __AVX512F__ __m512 _a = _mm512_set1_ps(1.0f); __m512 _eps = _mm512_set1_ps(eps); __m512 _b = _mm512_setzero_ps(); @@ -505,6 +524,7 @@ void LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elem _mm512_storeu_ps(a, _a); _mm512_storeu_ps(b, _b); +#endif } if (affine) @@ -576,7 +596,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo return 0; } -int LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const +int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const { int elempack = bottom_top_blob.elempack; int w = bottom_top_blob.w; diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index f521703e0ff8..20423f6a8464 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -14,12 +14,12 @@ class LayerNorm_x86 : virtual public LayerNorm protected: NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elemcount) const; - void fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const; + NCNN_FORCEINLINE void fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const; NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float a, float b, int elemcount) const; - void fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const; + NCNN_FORCEINLINE void fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const; NCNN_FORCEINLINE int forward_inplace_unpacked(Mat& bottom_top_blob, const Option& opt) const; - int forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const; + NCNN_FORCEINLINE int forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const; }; } // namespace ncnn From 6a683d357e685d6aba1ce7ec53458eee0f4e81fd Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Fri, 22 Jul 2022 02:02:43 +0000 Subject: [PATCH 06/22] Use fmadd wrapper in x86_usability.h; --- src/layer/x86/layernorm_x86.cpp | 106 ++++++++++---------------------- 1 file changed, 33 insertions(+), 73 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 69a1bbe3a53c..99c9544f4d36 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -1,5 +1,5 @@ #include "layernorm_x86.h" - +#include "x86_usability.h" #include #include @@ -34,7 +34,7 @@ static NCNN_FORCEINLINE float fast_mean(float* ptr, int elemcount) __m256 _cur = _mm256_loadu_ps(ptr); _sum = _mm256_add_ps(_sum, _cur); } - sum += _sum[0] + _sum[1] + _sum[2] + _sum[3] + _sum[4] + _sum[5] + _sum[6] + _sum[7]; + sum += _mm256_reduce_add_ps(_sum); } #endif // __AVX__ { @@ -44,7 +44,7 @@ static NCNN_FORCEINLINE float fast_mean(float* ptr, int elemcount) __m128 _cur = _mm_loadu_ps(ptr); _sum = _mm_add_ps(_sum, _cur); } - sum += _sum[0] + _sum[1] + _sum[2] + _sum[3]; + sum += _mm_reduce_add_ps(_sum); } #endif // __SSE2__ for (; i < elemcount; ++i, ++ptr) @@ -84,7 +84,7 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int elemcount, float mean) _cur = _mm256_mul_ps(_cur, _cur); _sq_sum = _mm256_add_ps(_sq_sum, _cur); } - sq_sum += _sq_sum[0] + _sq_sum[1] + _sq_sum[2] + _sq_sum[3] + _sq_sum[4] + _sq_sum[5] + _sq_sum[6] + _sq_sum[7]; + sq_sum += _mm256_reduce_add_ps(_sq_sum); } #endif // __AVX__ { @@ -97,7 +97,7 @@ static NCNN_FORCEINLINE float fast_var(float* ptr, int elemcount, float mean) _cur = _mm_mul_ps(_cur, _cur); _sq_sum = _mm_add_ps(_sq_sum, _cur); } - sq_sum += _sq_sum[0] + _sq_sum[1] + _sq_sum[2] + _sq_sum[3]; + sq_sum += _mm_reduce_add_ps(_sq_sum); } #endif // __SSE2__ for (; i < elemcount; ++i, ++ptr) @@ -133,12 +133,7 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int elemco for (; i + 8 <= elemcount; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); -#if __FMA__ - _cur = _mm256_fmadd_ps(_cur, _a, _b); -#else - _cur = _mm256_mul_ps(_cur, _a); - _cur = _mm256_add_ps(_cur, _b); -#endif + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); _mm256_storeu_ps(ptr, _cur); } } @@ -149,8 +144,7 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int elemco for (; i + 4 <= elemcount; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_mul_ps(_cur, _a); - _cur = _mm_add_ps(_cur, _b); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); _mm_storeu_ps(ptr, _cur); } } @@ -218,8 +212,7 @@ static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean { __m128 _cur = _mm_loadu_ps(ptr); _cur = _mm_sub_ps(_cur, _mean); - _cur = _mm_mul_ps(_cur, _cur); - _sq_sum = _mm_add_ps(_sq_sum, _cur); + _sq_sum = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum); } __m128 _var = _mm_div_ps(_sq_sum, _elemcount); _mm_storeu_ps(var, _var); @@ -234,8 +227,7 @@ static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean { __m256 _cur = _mm256_loadu_ps(ptr); _cur = _mm256_sub_ps(_cur, _mean); - _cur = _mm256_mul_ps(_cur, _cur); - _sq_sum = _mm256_add_ps(_sq_sum, _cur); + _sq_sum = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum); } __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); _mm256_storeu_ps(var, _var); @@ -251,8 +243,7 @@ static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean { __m512 _cur = _mm512_loadu_ps(ptr); _cur = _mm512_sub_ps(_cur, _mean); - _cur = _mm512_mul_ps(_cur, _cur); - _sq_sum = _mm512_add_ps(_sq_sum, _cur); + _sq_sum = _mm512_fmadd_ps(_cur, _cur, _sq_sum); } __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); _mm512_storeu_ps(var, _var); @@ -270,8 +261,7 @@ static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, i for (; i < size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_mul_ps(_cur, _a); - _cur = _mm_add_ps(_cur, _b); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); _mm_storeu_ps(ptr, _cur); } } @@ -283,12 +273,7 @@ static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, i for (; i < size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); -#if __FMA__ - _cur = _mm256_fmadd_ps(_cur, _a, _b); -#else - _cur = _mm256_mul_ps(_cur, _a); - _cur = _mm256_add_ps(_cur, _b); -#endif + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); _mm256_storeu_ps(ptr, _cur); } #endif @@ -320,8 +305,8 @@ LayerNorm_x86::LayerNorm_x86() void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int elemcount) const { int i = 0; - auto gamma = static_cast(gamma_data); - auto beta = static_cast(beta_data); + const float* gamma = static_cast(gamma_data); + const float* beta = static_cast(beta_data); #if __SSE2__ #if __AVX__ #if __AVX512F__ @@ -348,15 +333,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float __m256 _cur = _mm256_loadu_ps(ptr); __m256 _gamma = _mm256_loadu_ps(gamma); __m256 _beta = _mm256_loadu_ps(beta); -#if __FMA__ - _cur = _mm256_fmadd_ps(_cur, _a, _b); - _cur = _mm256_fmadd_ps(_cur, _gamma, _beta); -#else - _cur = _mm256_mul_ps(_cur, _a); - _cur = _mm256_add_ps(_cur, _b); - _cur = _mm256_mul_ps(_cur, _gamma); - _cur = _mm256_add_ps(_cur, _beta); -#endif + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); _mm256_storeu_ps(ptr, _cur); } } @@ -369,10 +347,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float __m128 _cur = _mm_loadu_ps(ptr); __m128 _gamma = _mm_loadu_ps(gamma); __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_mul_ps(_cur, _a); - _cur = _mm_add_ps(_cur, _b); - _cur = _mm_mul_ps(_cur, _gamma); - _cur = _mm_add_ps(_cur, _beta); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); _mm_storeu_ps(ptr, _cur); } } @@ -386,8 +362,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const { int i = 0; - auto gamma = static_cast(gamma_data); - auto beta = static_cast(beta_data); + const float* gamma = static_cast(gamma_data); + const float* beta = static_cast(beta_data); if (elempack == 4) { __m128 _a = _mm_loadu_ps(a); @@ -397,10 +373,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* __m128 _cur = _mm_loadu_ps(ptr); __m128 _gamma = _mm_set1_ps(*gamma); __m128 _beta = _mm_set1_ps(*beta); - _cur = _mm_mul_ps(_cur, _a); - _cur = _mm_add_ps(_cur, _b); - _cur = _mm_mul_ps(_cur, _gamma); - _cur = _mm_add_ps(_cur, _beta); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); _mm_storeu_ps(ptr, _cur); } } @@ -414,15 +388,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* __m256 _cur = _mm256_loadu_ps(ptr); __m256 _gamma = _mm256_set1_ps(*gamma); __m256 _beta = _mm256_set1_ps(*beta); -#if __FMA__ - _cur = _mm256_fmadd_ps(_cur, _a, _b); - _cur = _mm256_fmadd_ps(_cur, _gamma, _beta); -#else - _cur = _mm256_mul_ps(_cur, _a); - _cur = _mm256_add_ps(_cur, _b); - _cur = _mm256_mul_ps(_cur, _gamma); - _cur = _mm256_add_ps(_cur, _beta); -#endif + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); _mm256_storeu_ps(ptr, _cur); } #endif @@ -480,9 +447,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int e _var = _mm_add_ps(_var, _eps); __m128 _sqrt_var = _mm_sqrt_ps(_var); _a = _mm_div_ps(_a, _sqrt_var); - __m128 _mean_a = _mm_loadu_ps(mean); - _mean_a = _mm_mul_ps(_mean_a, _a); - _b = _mm_sub_ps(_b, _mean_a); + __m128 _mean = _mm_loadu_ps(mean); + _b = _mm_comp_fnmadd_ps(_mean, _a, _b); _mm_storeu_ps(a, _a); _mm_storeu_ps(b, _b); @@ -497,14 +463,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int e _var = _mm256_add_ps(_var, _eps); __m256 _sqrt_var = _mm256_sqrt_ps(_var); _a = _mm256_div_ps(_a, _sqrt_var); -#if __FMA__ __m256 _mean = _mm256_loadu_ps(mean); - _b = _mm256_fnmadd_ps(_mean, _a, _b); -#else - __m256 _mean_a = _mm256_loadu_ps(mean); - _mean_a = _mm256_mul_ps(_mean_a, _a); - _b = _mm256_sub_ps(_b, _mean_a); -#endif + _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); _mm256_storeu_ps(a, _a); _mm256_storeu_ps(b, _b); #endif @@ -553,7 +513,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo int w = bottom_top_blob.w; int h = bottom_top_blob.h; int elemcount = w * bottom_top_blob.elempack; -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -572,7 +532,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo if (affine_size == w) { elemcount = w * bottom_top_blob.elempack; -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { for (int i = 0; i < h; i++) @@ -584,7 +544,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blo } else // if (affine_elemcount == elemcount) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { float* ptr = bottom_top_blob.channel(q); @@ -605,7 +565,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, // Now, bottoms_top_blob.dims >= 2 if (bottom_top_blob.dims == 2) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -617,7 +577,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, int channels = bottom_top_blob.c; if (affine_size == w) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) @@ -629,7 +589,7 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, } else // if(affine_size == w * h) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); From bf9531275a1cd2c7e332b5cc641616eda6559403 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Fri, 22 Jul 2022 05:13:48 +0000 Subject: [PATCH 07/22] Merge packed & unpacked code. --- src/layer/x86/layernorm_x86.cpp | 653 ++++++++++++++------------------ src/layer/x86/layernorm_x86.h | 9 +- 2 files changed, 292 insertions(+), 370 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 99c9544f4d36..43a7155f2707 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -10,250 +10,234 @@ #endif // __AVX__ #endif // __SSE2__ -static NCNN_FORCEINLINE float fast_mean(float* ptr, int elemcount) +static NCNN_FORCEINLINE void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcount, int size) { - float sum = 0.0f; int i = 0; + #if __SSE2__ #if __AVX__ #if __AVX512F__ + if (elempack == 16) { __m512 _sum = _mm512_setzero_ps(); - for (; i + 16 <= elemcount; i += 16, ptr += 16) + __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + for (; i < size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _sum = _mm512_add_ps(_sum, _cur); } - sum += _mm512_reduce_add_ps(_sum); + __m512 _mean = _mm512_div_ps(_sum, _elemcount); + _mm512_storeu_ps(mean, _mean); } #endif // __AVX512F__ + if (elempack == 8) { __m256 _sum = _mm256_setzero_ps(); - for (; i + 8 <= elemcount; i += 8, ptr += 8) + __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + for (; i < size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _sum = _mm256_add_ps(_sum, _cur); } - sum += _mm256_reduce_add_ps(_sum); + __m256 _mean = _mm256_div_ps(_sum, _elemcount); + _mm256_storeu_ps(mean, _mean); } #endif // __AVX__ + if (elempack == 4) { __m128 _sum = _mm_setzero_ps(); - for (; i + 4 <= elemcount; i += 4, ptr += 4) + __m128 _elemcount = _mm_set1_ps(float(elemcount)); + for (; i < size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _sum = _mm_add_ps(_sum, _cur); } - sum += _mm_reduce_add_ps(_sum); + __m128 _mean = _mm_div_ps(_sum, _elemcount); + _mm_storeu_ps(mean, _mean); } #endif // __SSE2__ - for (; i < elemcount; ++i, ++ptr) + if (elempack == 1) { - sum += *ptr; + float sum = 0.0f; + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + __m512 _sum = _mm512_setzero_ps(); + for (; i + 16 <= elemcount; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _sum = _mm512_add_ps(_sum, _cur); + } + sum += _mm512_reduce_add_ps(_sum); + } +#endif // __AVX512F__ + { + __m256 _sum = _mm256_setzero_ps(); + for (; i + 8 <= elemcount; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _sum = _mm256_add_ps(_sum, _cur); + } + sum += _mm256_reduce_add_ps(_sum); + } +#endif // __AVX__ + { + __m128 _sum = _mm_setzero_ps(); + for (; i + 4 <= elemcount; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _sum = _mm_add_ps(_sum, _cur); + } + sum += _mm_reduce_add_ps(_sum); + } +#endif // __SSE2__ + for (; i < elemcount; ++i, ++ptr) + { + sum += *ptr; + } + *mean = sum / elemcount; } - return sum / elemcount; } -static NCNN_FORCEINLINE float fast_var(float* ptr, int elemcount, float mean) +static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) { - float sq_sum = 0.0f; int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ + if (elempack == 16) { - __m512 _mean = _mm512_set1_ps(mean); + __m512 _mean = _mm512_loadu_ps(mean); __m512 _sq_sum = _mm512_setzero_ps(); - for (; i + 16 <= elemcount; i += 16, ptr += 16) + __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + for (; i < size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _cur = _mm512_sub_ps(_cur, _mean); - _cur = _mm512_mul_ps(_cur, _cur); - _sq_sum = _mm512_add_ps(_sq_sum, _cur); + _sq_sum = _mm512_fmadd_ps(_cur, _cur, _sq_sum); } - sq_sum += _mm512_reduce_add_ps(_sq_sum); + __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); + _mm512_storeu_ps(var, _var); } #endif // __AVX512F__ + if (elempack == 8) { - __m256 _mean = _mm256_set1_ps(mean); + __m256 _mean = _mm256_loadu_ps(mean); __m256 _sq_sum = _mm256_setzero_ps(); - for (; i + 8 <= elemcount; i += 8, ptr += 8) + __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + for (; i < size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _cur = _mm256_sub_ps(_cur, _mean); - _cur = _mm256_mul_ps(_cur, _cur); - _sq_sum = _mm256_add_ps(_sq_sum, _cur); + _sq_sum = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum); } - sq_sum += _mm256_reduce_add_ps(_sq_sum); + __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); + _mm256_storeu_ps(var, _var); } #endif // __AVX__ + if (elempack == 4) { - __m128 _mean = _mm_set1_ps(mean); + __m128 _mean = _mm_loadu_ps(mean); __m128 _sq_sum = _mm_setzero_ps(); - for (; i + 4 <= elemcount; i += 4, ptr += 4) + __m128 _elemcount = _mm_set1_ps(float(elemcount)); + for (; i < size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _cur = _mm_sub_ps(_cur, _mean); - _cur = _mm_mul_ps(_cur, _cur); - _sq_sum = _mm_add_ps(_sq_sum, _cur); + _sq_sum = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum); } - sq_sum += _mm_reduce_add_ps(_sq_sum); + __m128 _var = _mm_div_ps(_sq_sum, _elemcount); + _mm_storeu_ps(var, _var); } #endif // __SSE2__ - for (; i < elemcount; ++i, ++ptr) + if (elempack == 1) { - float tmp = *ptr - mean; - sq_sum += tmp * tmp; - } - return sq_sum / elemcount; -} - -static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float a, float b, int elemcount) -{ - int i = 0; + float sq_sum = 0.0f; + int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - { - // 512 bit FMA instructions are included in AVX512F. - __m512 _a = _mm512_set1_ps(a); - __m512 _b = _mm512_set1_ps(b); - for (; i + 16 <= elemcount; i += 16, ptr += 16) { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _mm512_storeu_ps(ptr, _cur); + __m512 _mean = _mm512_set1_ps(*mean); + __m512 _sq_sum = _mm512_setzero_ps(); + for (; i + 16 <= elemcount; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_sub_ps(_cur, _mean); + _cur = _mm512_mul_ps(_cur, _cur); + _sq_sum = _mm512_add_ps(_sq_sum, _cur); + } + sq_sum += _mm512_reduce_add_ps(_sq_sum); } - } #endif // __AVX512F__ - { - // 256 bit FMA instructions are not included in AVX1 - __m256 _a = _mm256_set1_ps(a); - __m256 _b = _mm256_set1_ps(b); - for (; i + 8 <= elemcount; i += 8, ptr += 8) { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _mm256_storeu_ps(ptr, _cur); + __m256 _mean = _mm256_set1_ps(*mean); + __m256 _sq_sum = _mm256_setzero_ps(); + for (; i + 8 <= elemcount; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _cur = _mm256_sub_ps(_cur, _mean); + _cur = _mm256_mul_ps(_cur, _cur); + _sq_sum = _mm256_add_ps(_sq_sum, _cur); + } + sq_sum += _mm256_reduce_add_ps(_sq_sum); } - } #endif // __AVX__ - { - __m128 _a = _mm_set1_ps(a); - __m128 _b = _mm_set1_ps(b); - for (; i + 4 <= elemcount; i += 4, ptr += 4) { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _mm_storeu_ps(ptr, _cur); + __m128 _mean = _mm_set1_ps(*mean); + __m128 _sq_sum = _mm_setzero_ps(); + for (; i + 4 <= elemcount; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_sub_ps(_cur, _mean); + _cur = _mm_mul_ps(_cur, _cur); + _sq_sum = _mm_add_ps(_sq_sum, _cur); + } + sq_sum += _mm_reduce_add_ps(_sq_sum); } - } #endif // __SSE2__ - for (; i < elemcount; ++i, ++ptr) - { - *ptr = (*ptr) * a + b; + for (; i < elemcount; ++i, ++ptr) + { + float tmp = *ptr - *mean; + sq_sum += tmp * tmp; + } + *var = sq_sum / elemcount; } } -static NCNN_FORCEINLINE void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int elemcount, int size) { int i = 0; - if (elempack == 4) - { - __m128 _sum = _mm_setzero_ps(); - __m128 _elemcount = _mm_set1_ps(float(elemcount)); - for (; i < size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _sum = _mm_add_ps(_sum, _cur); - } - __m128 _mean = _mm_div_ps(_sum, _elemcount); - _mm_storeu_ps(mean, _mean); - } - else if (elempack == 8) - { + +#if __SSE2__ #if __AVX__ - __m256 _sum = _mm256_setzero_ps(); - __m256 _elemcount = _mm256_set1_ps(float(elemcount)); - for (; i < size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _sum = _mm256_add_ps(_sum, _cur); - } - __m256 _mean = _mm256_div_ps(_sum, _elemcount); - _mm256_storeu_ps(mean, _mean); -#endif - } - else if (elempack == 16) - { #if __AVX512F__ - __m512 _sum = _mm512_setzero_ps(); - __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + if (elempack == 16) + { + __m512 _a = _mm512_loadu_ps(a); + __m512 _b = _mm512_loadu_ps(b); for (; i < size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); - _sum = _mm512_add_ps(_sum, _cur); - } - __m512 _mean = _mm512_div_ps(_sum, _elemcount); - _mm512_storeu_ps(mean, _mean); -#endif - } -} - -static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) -{ - int i = 0; - if (elempack == 4) - { - __m128 _mean = _mm_loadu_ps(mean); - __m128 _sq_sum = _mm_setzero_ps(); - __m128 _elemcount = _mm_set1_ps(float(elemcount)); - for (; i < size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_sub_ps(_cur, _mean); - _sq_sum = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _mm512_storeu_ps(ptr, _cur); } - __m128 _var = _mm_div_ps(_sq_sum, _elemcount); - _mm_storeu_ps(var, _var); } - else if (elempack == 8) +#endif // __AVX512F__ + if (elempack == 8) { -#if __AVX__ - __m256 _mean = _mm256_loadu_ps(mean); - __m256 _sq_sum = _mm256_setzero_ps(); - __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + __m256 _a = _mm256_loadu_ps(a); + __m256 _b = _mm256_loadu_ps(b); for (; i < size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_sub_ps(_cur, _mean); - _sq_sum = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum); - } - __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); - _mm256_storeu_ps(var, _var); -#endif - } - else if (elempack == 16) - { -#if __AVX512F__ - __m512 _mean = _mm512_loadu_ps(mean); - __m512 _sq_sum = _mm512_setzero_ps(); - __m512 _elemcount = _mm512_set1_ps(float(elemcount)); - for (; i < size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_sub_ps(_cur, _mean); - _sq_sum = _mm512_fmadd_ps(_cur, _cur, _sq_sum); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _mm256_storeu_ps(ptr, _cur); } - __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); - _mm512_storeu_ps(var, _var); -#endif } -} - -static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) -{ - int i = 0; +#endif // __AVX__ if (elempack == 4) { __m128 _a = _mm_loadu_ps(a); @@ -265,31 +249,51 @@ static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, i _mm_storeu_ps(ptr, _cur); } } - else if (elempack == 8) +#endif // __SSE2__ + if (elempack == 1) { +#if __SSE2__ #if __AVX__ - __m256 _a = _mm256_loadu_ps(a); - __m256 _b = _mm256_loadu_ps(b); - for (; i < size; i += 8, ptr += 8) +#if __AVX512F__ { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _mm256_storeu_ps(ptr, _cur); + // 512 bit FMA instructions are included in AVX512F. + __m512 _a = _mm512_set1_ps(*a); + __m512 _b = _mm512_set1_ps(*b); + for (; i + 16 <= elemcount; i += 16, ptr += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _mm512_storeu_ps(ptr, _cur); + } } -#endif - } - else if (elempack == 16) - { -#if __AVX512F__ - __m512 _a = _mm512_loadu_ps(a); - __m512 _b = _mm512_loadu_ps(b); - for (; i < size; i += 16, ptr += 16) +#endif // __AVX512F__ { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _mm512_storeu_ps(ptr, _cur); + // 256 bit FMA instructions are not included in AVX1 + __m256 _a = _mm256_set1_ps(*a); + __m256 _b = _mm256_set1_ps(*b); + for (; i + 8 <= elemcount; i += 8, ptr += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _mm256_storeu_ps(ptr, _cur); + } + } +#endif // __AVX__ + { + __m128 _a = _mm_set1_ps(*a); + __m128 _b = _mm_set1_ps(*b); + for (; i + 4 <= elemcount; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _mm_storeu_ps(ptr, _cur); + } + } +#endif // __SSE2__ + for (; i < elemcount; ++i, ++ptr) + { + *ptr = (*ptr) * (*a) + (*b); } -#endif } } @@ -302,68 +306,45 @@ LayerNorm_x86::LayerNorm_x86() #endif // __SSE2__ } -void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float a, float b, int elemcount) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) const { int i = 0; const float* gamma = static_cast(gamma_data); const float* beta = static_cast(beta_data); + #if __SSE2__ #if __AVX__ #if __AVX512F__ + if (elempack == 16) { - __m512 _a = _mm512_set1_ps(a); - __m512 _b = _mm512_set1_ps(b); - for (; i + 16 <= elemcount; i += 16, ptr += 16, gamma += 16, beta += 16) + __m512 _a = _mm512_loadu_ps(a); + __m512 _b = _mm512_loadu_ps(b); + for (; i < size; i += 16, ptr += 16, ++gamma, ++beta) { __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_loadu_ps(gamma); - __m512 _beta = _mm512_loadu_ps(beta); + __m512 _gamma = _mm512_set1_ps(*gamma); + __m512 _beta = _mm512_set1_ps(*beta); _cur = _mm512_fmadd_ps(_cur, _a, _b); _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); _mm512_storeu_ps(ptr, _cur); } } #endif // __AVX512F__ + if (elempack == 8) { - __m256 _a = _mm256_set1_ps(a); - __m256 _b = _mm256_set1_ps(b); - - for (; i + 8 <= elemcount; i += 8, ptr += 8, gamma += 8, beta += 8) + __m256 _a = _mm256_loadu_ps(a); + __m256 _b = _mm256_loadu_ps(b); + for (; i < size; i += 8, ptr += 8, ++gamma, ++beta) { __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_loadu_ps(gamma); - __m256 _beta = _mm256_loadu_ps(beta); + __m256 _gamma = _mm256_set1_ps(*gamma); + __m256 _beta = _mm256_set1_ps(*beta); _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); _mm256_storeu_ps(ptr, _cur); } } #endif // __AVX__ - { - __m128 _a = _mm_set1_ps(a); - __m128 _b = _mm_set1_ps(b); - for (; i + 4 <= elemcount; i += 4, ptr += 4, gamma += 4, beta += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_loadu_ps(gamma); - __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); - } - } -#endif // __SSE2__ - for (; i < elemcount; ++i, ++ptr, ++gamma, ++beta) - { - *ptr = ((*ptr) * a + b) * (*gamma) + (*beta); - } -} - -void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const -{ - int i = 0; - const float* gamma = static_cast(gamma_data); - const float* beta = static_cast(beta_data); if (elempack == 4) { __m128 _a = _mm_loadu_ps(a); @@ -378,84 +359,90 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd_packed(float* ptr, float* _mm_storeu_ps(ptr, _cur); } } - else if (elempack == 8) +#endif // __SSE2__ + if (elempack == 1) { +#if __SSE2__ #if __AVX__ - __m256 _a = _mm256_loadu_ps(a); - __m256 _b = _mm256_loadu_ps(b); - for (; i < size; i += 8, ptr += 8, ++gamma, ++beta) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_set1_ps(*gamma); - __m256 _beta = _mm256_set1_ps(*beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); - } -#endif - } - else if (elempack == 16) - { #if __AVX512F__ - __m512 _a = _mm512_loadu_ps(a); - __m512 _b = _mm512_loadu_ps(b); - for (; i < size; i += 16, ptr += 16, ++gamma, ++beta) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_set1_ps(*gamma); - __m512 _beta = _mm512_set1_ps(*beta); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + __m512 _a = _mm512_set1_ps(*a); + __m512 _b = _mm512_set1_ps(*b); + for (; i + 16 <= elemcount; i += 16, ptr += 16, gamma += 16, beta += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma); + __m512 _beta = _mm512_loadu_ps(beta); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); + } } -#endif - } -} - -void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elemcount) const -{ - // mean and var - float mean = fast_mean(ptr, elemcount); - float var = fast_var(ptr, elemcount, mean); - - float a = static_cast(1.0f / sqrt(var + eps)); - float b = -mean * a; +#endif // __AVX512F__ + { + __m256 _a = _mm256_set1_ps(*a); + __m256 _b = _mm256_set1_ps(*b); - if (affine) - { - fast_fmadd_fmadd(ptr, a, b, elemcount); - } - else - { - fast_fmadd(ptr, a, b, elemcount); + for (; i + 8 <= elemcount; i += 8, ptr += 8, gamma += 8, beta += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma); + __m256 _beta = _mm256_loadu_ps(beta); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); + _mm256_storeu_ps(ptr, _cur); + } + } +#endif // __AVX__ + { + __m128 _a = _mm_set1_ps(*a); + __m128 _b = _mm_set1_ps(*b); + for (; i + 4 <= elemcount; i += 4, ptr += 4, gamma += 4, beta += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma); + __m128 _beta = _mm_loadu_ps(beta); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); + _mm_storeu_ps(ptr, _cur); + } + } +#endif // __SSE2__ + for (; i < elemcount; ++i, ++ptr, ++gamma, ++beta) + { + *ptr = ((*ptr) * (*a) + (*b)) * (*gamma) + (*beta); + } } } -void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const +void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size) const { float mean[16], var[16]; fast_mean_packed(ptr, mean, elempack, elemcount, size); fast_var_packed(ptr, var, mean, elempack, elemcount, size); float *a = var, *b = mean; - if (elempack == 4) +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) { - __m128 _a = _mm_set1_ps(1.0f); - __m128 _eps = _mm_set1_ps(eps); - __m128 _b = _mm_setzero_ps(); - __m128 _var = _mm_loadu_ps(var); - _var = _mm_add_ps(_var, _eps); - __m128 _sqrt_var = _mm_sqrt_ps(_var); - _a = _mm_div_ps(_a, _sqrt_var); - __m128 _mean = _mm_loadu_ps(mean); - _b = _mm_comp_fnmadd_ps(_mean, _a, _b); + __m512 _a = _mm512_set1_ps(1.0f); + __m512 _eps = _mm512_set1_ps(eps); + __m512 _b = _mm512_setzero_ps(); + __m512 _var = _mm512_loadu_ps(var); + _var = _mm512_add_ps(_var, _eps); + __m512 _sqrt_var = _mm512_sqrt_ps(_var); + _a = _mm512_div_ps(_a, _sqrt_var); + __m512 _mean = _mm512_loadu_ps(mean); + _b = _mm512_fnmadd_ps(_mean, _a, _b); - _mm_storeu_ps(a, _a); - _mm_storeu_ps(b, _b); + _mm512_storeu_ps(a, _a); + _mm512_storeu_ps(b, _b); } - else if (elempack == 8) +#endif // __AVX512F__ + if (elempack == 8) { -#if __AVX__ __m256 _a = _mm256_set1_ps(1.0f); __m256 _eps = _mm256_set1_ps(eps); __m256 _b = _mm256_setzero_ps(); @@ -467,133 +454,85 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm_packed(float* ptr, int e _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); _mm256_storeu_ps(a, _a); _mm256_storeu_ps(b, _b); -#endif } - else if (elempack == 16) +#endif // __AVX__ + if (elempack == 4) { -#if __AVX512F__ - __m512 _a = _mm512_set1_ps(1.0f); - __m512 _eps = _mm512_set1_ps(eps); - __m512 _b = _mm512_setzero_ps(); - __m512 _var = _mm512_loadu_ps(var); - _var = _mm512_add_ps(_var, _eps); - __m512 _sqrt_var = _mm512_sqrt_ps(_var); - _a = _mm512_div_ps(_a, _sqrt_var); - __m512 _mean = _mm512_loadu_ps(mean); - _b = _mm512_fnmadd_ps(_mean, _a, _b); + __m128 _a = _mm_set1_ps(1.0f); + __m128 _eps = _mm_set1_ps(eps); + __m128 _b = _mm_setzero_ps(); + __m128 _var = _mm_loadu_ps(var); + _var = _mm_add_ps(_var, _eps); + __m128 _sqrt_var = _mm_sqrt_ps(_var); + _a = _mm_div_ps(_a, _sqrt_var); + __m128 _mean = _mm_loadu_ps(mean); + _b = _mm_comp_fnmadd_ps(_mean, _a, _b); - _mm512_storeu_ps(a, _a); - _mm512_storeu_ps(b, _b); -#endif + _mm_storeu_ps(a, _a); + _mm_storeu_ps(b, _b); + } +#endif // __SSE2__ + if (elempack == 1) + { + *a = static_cast(1.0f / sqrt(*var + eps)); + *b = -*mean * (*a); } if (affine) { - fast_fmadd_fmadd_packed(ptr, a, b, elempack, size); + fast_fmadd_fmadd(ptr, a, b, elempack, elemcount, size); } else { - fast_fmadd_packed(ptr, a, b, elempack, size); + fast_fmadd_packed(ptr, a, b, elempack, elemcount, size); } } -int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_unpacked(Mat& bottom_top_blob, const Option& opt) const +int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { int dims = bottom_top_blob.dims; + int elempack = bottom_top_blob.elempack; + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int channels = bottom_top_blob.c; if (dims == 1) { - int elemcount = bottom_top_blob.w * bottom_top_blob.elempack; + int elemcount = w * elempack; float* ptr = bottom_top_blob; - fast_1d_layer_norm(ptr, elemcount); - } - - if (dims == 2) - { - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - int elemcount = w * bottom_top_blob.elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; ++i) - { - float* ptr = bottom_top_blob.row(i); - - fast_1d_layer_norm(ptr, elemcount); - } - } - - if (dims == 3) - { - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - int channels = bottom_top_blob.c; - int elemcount = w * h * bottom_top_blob.elempack; - - if (affine_size == w) - { - elemcount = w * bottom_top_blob.elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - for (int i = 0; i < h; i++) - { - float* ptr = bottom_top_blob.channel(q).row(i); - fast_1d_layer_norm(ptr, elemcount); - } - } - } - else // if (affine_elemcount == elemcount) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - float* ptr = bottom_top_blob.channel(q); - fast_1d_layer_norm(ptr, elemcount); - } - } + // 1D layer norm is special. Treat them as unpacked. + fast_1d_layer_norm(ptr, 1, elemcount, elemcount); } - - return 0; -} - -int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const -{ - int elempack = bottom_top_blob.elempack; - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - - // Now, bottoms_top_blob.dims >= 2 - if (bottom_top_blob.dims == 2) + else if (dims == 2) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); - fast_1d_layer_norm_packed(ptr, elempack, w, w * elempack); + fast_1d_layer_norm(ptr, elempack, w, w * elempack); } } - else if (bottom_top_blob.dims == 3) + else if (dims == 3) { - int channels = bottom_top_blob.c; if (affine_size == w) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.channel(q).row(i); - fast_1d_layer_norm_packed(ptr, elempack, w, w * elempack); + fast_1d_layer_norm(ptr, elempack, w, w * elempack); } } } else // if(affine_size == w * h) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); - fast_1d_layer_norm_packed(ptr, elempack, w * h, w * h * elempack); + fast_1d_layer_norm(ptr, elempack, w * h, w * h * elempack); } } } @@ -601,16 +540,4 @@ int NCNN_FORCEINLINE LayerNorm_x86::forward_inplace_packed(Mat& bottom_top_blob, return 0; } -int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const -{ - if (bottom_top_blob.elempack == 1 || bottom_top_blob.dims == 1) - { - return forward_inplace_unpacked(bottom_top_blob, opt); - } - else - { - return forward_inplace_packed(bottom_top_blob, opt); - } -} - } // namespace ncnn diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index 20423f6a8464..e4b81d5ca8d3 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -13,13 +13,8 @@ class LayerNorm_x86 : virtual public LayerNorm virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; protected: - NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elemcount) const; - NCNN_FORCEINLINE void fast_1d_layer_norm_packed(float* ptr, int elempack, int elemcount, int size) const; - NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float a, float b, int elemcount) const; - NCNN_FORCEINLINE void fast_fmadd_fmadd_packed(float* ptr, float* a, float* b, int elempack, int size) const; - - NCNN_FORCEINLINE int forward_inplace_unpacked(Mat& bottom_top_blob, const Option& opt) const; - NCNN_FORCEINLINE int forward_inplace_packed(Mat& bottom_top_blob, const Option& opt) const; + NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size) const; + NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) const; }; } // namespace ncnn From af97b0531c78be926590fd11e135251776359441 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Fri, 22 Jul 2022 05:15:43 +0000 Subject: [PATCH 08/22] Func rename. --- src/layer/x86/layernorm_x86.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 43a7155f2707..0f041d4f29a8 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -10,7 +10,7 @@ #endif // __AVX__ #endif // __SSE2__ -static NCNN_FORCEINLINE void fast_mean_packed(float* ptr, float* mean, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, int elemcount, int size) { int i = 0; @@ -101,7 +101,7 @@ static NCNN_FORCEINLINE void fast_mean_packed(float* ptr, float* mean, int elemp } } -static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) { int i = 0; #if __SSE2__ @@ -207,7 +207,7 @@ static NCNN_FORCEINLINE void fast_var_packed(float* ptr, float* var, float* mean } } -static NCNN_FORCEINLINE void fast_fmadd_packed(float* ptr, float* a, float* b, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) { int i = 0; @@ -418,8 +418,8 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, floa void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size) const { float mean[16], var[16]; - fast_mean_packed(ptr, mean, elempack, elemcount, size); - fast_var_packed(ptr, var, mean, elempack, elemcount, size); + fast_mean(ptr, mean, elempack, elemcount, size); + fast_var(ptr, var, mean, elempack, elemcount, size); float *a = var, *b = mean; #if __SSE2__ @@ -484,7 +484,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elempack } else { - fast_fmadd_packed(ptr, a, b, elempack, elemcount, size); + fast_fmadd(ptr, a, b, elempack, elemcount, size); } } From a9be63a9e326508e5ce0a8e3f8c5da934c551ce2 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Fri, 22 Jul 2022 06:35:32 +0000 Subject: [PATCH 09/22] Simplify and merge more branches about packed layout; --- src/layer/x86/layernorm_x86.cpp | 343 +++++++++++++------------------- 1 file changed, 133 insertions(+), 210 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 0f041d4f29a8..44a979a98e45 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -13,87 +13,73 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, int elemcount, int size) { int i = 0; - + float sum = 0.0f; #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) + if (elempack == 16 || elempack == 1) { __m512 _sum = _mm512_setzero_ps(); - __m512 _elemcount = _mm512_set1_ps(float(elemcount)); - for (; i < size; i += 16, ptr += 16) + for (; i + 16 <= size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _sum = _mm512_add_ps(_sum, _cur); } - __m512 _mean = _mm512_div_ps(_sum, _elemcount); - _mm512_storeu_ps(mean, _mean); + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + __m512 _mean = _mm512_div_ps(_sum, _elemcount); + _mm512_storeu_ps(mean, _mean); + } + else + { + sum += _mm512_reduce_add_ps(_sum); + } } #endif // __AVX512F__ - if (elempack == 8) + if (elempack == 8 || elempack == 1) { __m256 _sum = _mm256_setzero_ps(); - __m256 _elemcount = _mm256_set1_ps(float(elemcount)); - for (; i < size; i += 8, ptr += 8) + for (; i + 8 <= size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _sum = _mm256_add_ps(_sum, _cur); } - __m256 _mean = _mm256_div_ps(_sum, _elemcount); - _mm256_storeu_ps(mean, _mean); + if (elempack == 8) + { + __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + __m256 _mean = _mm256_div_ps(_sum, _elemcount); + _mm256_storeu_ps(mean, _mean); + } + else + { + sum += _mm256_reduce_add_ps(_sum); + } } #endif // __AVX__ - if (elempack == 4) + if (elempack == 4 || elempack == 1) { __m128 _sum = _mm_setzero_ps(); - __m128 _elemcount = _mm_set1_ps(float(elemcount)); - for (; i < size; i += 4, ptr += 4) + for (; i + 4 <= size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _sum = _mm_add_ps(_sum, _cur); } - __m128 _mean = _mm_div_ps(_sum, _elemcount); - _mm_storeu_ps(mean, _mean); - } -#endif // __SSE2__ - if (elempack == 1) - { - float sum = 0.0f; - int i = 0; -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - { - __m512 _sum = _mm512_setzero_ps(); - for (; i + 16 <= elemcount; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _sum = _mm512_add_ps(_sum, _cur); - } - sum += _mm512_reduce_add_ps(_sum); - } -#endif // __AVX512F__ + if (elempack == 4) { - __m256 _sum = _mm256_setzero_ps(); - for (; i + 8 <= elemcount; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _sum = _mm256_add_ps(_sum, _cur); - } - sum += _mm256_reduce_add_ps(_sum); + __m128 _elemcount = _mm_set1_ps(float(elemcount)); + __m128 _mean = _mm_div_ps(_sum, _elemcount); + _mm_storeu_ps(mean, _mean); } -#endif // __AVX__ + else { - __m128 _sum = _mm_setzero_ps(); - for (; i + 4 <= elemcount; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _sum = _mm_add_ps(_sum, _cur); - } sum += _mm_reduce_add_ps(_sum); } + } #endif // __SSE2__ - for (; i < elemcount; ++i, ++ptr) + if (elempack == 1) + { + for (; i < size; ++i, ++ptr) { sum += *ptr; } @@ -104,103 +90,82 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) { int i = 0; + float sq_sum = 0.0f; #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) + if (elempack == 16 || elempack == 1) { - __m512 _mean = _mm512_loadu_ps(mean); + __m512 _mean = elempack == 1 ? _mm512_set1_ps(*mean) : _mm512_loadu_ps(mean); __m512 _sq_sum = _mm512_setzero_ps(); - __m512 _elemcount = _mm512_set1_ps(float(elemcount)); - for (; i < size; i += 16, ptr += 16) + for (; i + 16 <= size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _cur = _mm512_sub_ps(_cur, _mean); _sq_sum = _mm512_fmadd_ps(_cur, _cur, _sq_sum); } - __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); - _mm512_storeu_ps(var, _var); + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps(float(elemcount)); + __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); + _mm512_storeu_ps(var, _var); + } + else + { + sq_sum += _mm512_reduce_add_ps(_sq_sum); + } } #endif // __AVX512F__ - if (elempack == 8) + if (elempack == 8 || elempack == 1) { - __m256 _mean = _mm256_loadu_ps(mean); + __m256 _mean = elempack == 1 ? _mm256_set1_ps(*mean) : _mm256_loadu_ps(mean); __m256 _sq_sum = _mm256_setzero_ps(); - __m256 _elemcount = _mm256_set1_ps(float(elemcount)); - for (; i < size; i += 8, ptr += 8) + for (; i + 8 <= size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _cur = _mm256_sub_ps(_cur, _mean); _sq_sum = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum); } - __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); - _mm256_storeu_ps(var, _var); + if (elempack == 8) + { + __m256 _elemcount = _mm256_set1_ps(float(elemcount)); + __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); + _mm256_storeu_ps(var, _var); + } + else + { + sq_sum += _mm256_reduce_add_ps(_sq_sum); + } } #endif // __AVX__ - if (elempack == 4) + if (elempack == 4 || elempack == 1) { - __m128 _mean = _mm_loadu_ps(mean); + __m128 _mean = elempack == 1 ? _mm_set1_ps(*mean) : _mm_loadu_ps(mean); __m128 _sq_sum = _mm_setzero_ps(); - __m128 _elemcount = _mm_set1_ps(float(elemcount)); - for (; i < size; i += 4, ptr += 4) + for (; i + 4 <= size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _cur = _mm_sub_ps(_cur, _mean); _sq_sum = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum); } - __m128 _var = _mm_div_ps(_sq_sum, _elemcount); - _mm_storeu_ps(var, _var); - } -#endif // __SSE2__ - if (elempack == 1) - { - float sq_sum = 0.0f; - int i = 0; -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - { - __m512 _mean = _mm512_set1_ps(*mean); - __m512 _sq_sum = _mm512_setzero_ps(); - for (; i + 16 <= elemcount; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_sub_ps(_cur, _mean); - _cur = _mm512_mul_ps(_cur, _cur); - _sq_sum = _mm512_add_ps(_sq_sum, _cur); - } - sq_sum += _mm512_reduce_add_ps(_sq_sum); - } -#endif // __AVX512F__ + if (elempack == 4) { - __m256 _mean = _mm256_set1_ps(*mean); - __m256 _sq_sum = _mm256_setzero_ps(); - for (; i + 8 <= elemcount; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_sub_ps(_cur, _mean); - _cur = _mm256_mul_ps(_cur, _cur); - _sq_sum = _mm256_add_ps(_sq_sum, _cur); - } - sq_sum += _mm256_reduce_add_ps(_sq_sum); + __m128 _elemcount = _mm_set1_ps(float(elemcount)); + __m128 _var = _mm_div_ps(_sq_sum, _elemcount); + _mm_storeu_ps(var, _var); } -#endif // __AVX__ + else { - __m128 _mean = _mm_set1_ps(*mean); - __m128 _sq_sum = _mm_setzero_ps(); - for (; i + 4 <= elemcount; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_sub_ps(_cur, _mean); - _cur = _mm_mul_ps(_cur, _cur); - _sq_sum = _mm_add_ps(_sq_sum, _cur); - } sq_sum += _mm_reduce_add_ps(_sq_sum); } + } #endif // __SSE2__ - for (; i < elemcount; ++i, ++ptr) + if (elempack == 1) + { + float _mean = *mean; + for (; i < size; ++i, ++ptr) { - float tmp = *ptr - *mean; + float tmp = *ptr - _mean; sq_sum += tmp * tmp; } *var = sq_sum / elemcount; @@ -214,11 +179,11 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) + if (elempack == 16 || elempack == 1) { - __m512 _a = _mm512_loadu_ps(a); - __m512 _b = _mm512_loadu_ps(b); - for (; i < size; i += 16, ptr += 16) + __m512 _a = elempack == 1 ? _mm512_set1_ps(*a) : _mm512_loadu_ps(a); + __m512 _b = elempack == 1 ? _mm512_set1_ps(*b) : _mm512_loadu_ps(b); + for (; i + 16 <= size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); _cur = _mm512_fmadd_ps(_cur, _a, _b); @@ -226,11 +191,11 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem } } #endif // __AVX512F__ - if (elempack == 8) + if (elempack == 8 || elempack == 1) { - __m256 _a = _mm256_loadu_ps(a); - __m256 _b = _mm256_loadu_ps(b); - for (; i < size; i += 8, ptr += 8) + __m256 _a = elempack == 1 ? _mm256_set1_ps(*a) : _mm256_loadu_ps(a); + __m256 _b = elempack == 1 ? _mm256_set1_ps(*b) : _mm256_loadu_ps(b); + for (; i + 8 <= size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); @@ -238,11 +203,11 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem } } #endif // __AVX__ - if (elempack == 4) + if (elempack == 4 || elempack == 1) { - __m128 _a = _mm_loadu_ps(a); - __m128 _b = _mm_loadu_ps(b); - for (; i < size; i += 4, ptr += 4) + __m128 _a = elempack == 1 ? _mm_set1_ps(*a) : _mm_loadu_ps(a); + __m128 _b = elempack == 1 ? _mm_set1_ps(*b) : _mm_loadu_ps(b); + for (; i + 4 <= size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); _cur = _mm_comp_fmadd_ps(_cur, _a, _b); @@ -252,44 +217,6 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem #endif // __SSE2__ if (elempack == 1) { -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - { - // 512 bit FMA instructions are included in AVX512F. - __m512 _a = _mm512_set1_ps(*a); - __m512 _b = _mm512_set1_ps(*b); - for (; i + 16 <= elemcount; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _mm512_storeu_ps(ptr, _cur); - } - } -#endif // __AVX512F__ - { - // 256 bit FMA instructions are not included in AVX1 - __m256 _a = _mm256_set1_ps(*a); - __m256 _b = _mm256_set1_ps(*b); - for (; i + 8 <= elemcount; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _mm256_storeu_ps(ptr, _cur); - } - } -#endif // __AVX__ - { - __m128 _a = _mm_set1_ps(*a); - __m128 _b = _mm_set1_ps(*b); - for (; i + 4 <= elemcount; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _mm_storeu_ps(ptr, _cur); - } - } -#endif // __SSE2__ for (; i < elemcount; ++i, ++ptr) { *ptr = (*ptr) * (*a) + (*b); @@ -319,7 +246,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, floa { __m512 _a = _mm512_loadu_ps(a); __m512 _b = _mm512_loadu_ps(b); - for (; i < size; i += 16, ptr += 16, ++gamma, ++beta) + for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) { __m512 _cur = _mm512_loadu_ps(ptr); __m512 _gamma = _mm512_set1_ps(*gamma); @@ -329,12 +256,26 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, floa _mm512_storeu_ps(ptr, _cur); } } + else if (elempack == 1) + { + __m512 _a = _mm512_set1_ps(*a); + __m512 _b = _mm512_set1_ps(*b); + for (; i + 16 <= elemcount; i += 16, ptr += 16, gamma += 16, beta += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma); + __m512 _beta = _mm512_loadu_ps(beta); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); + } + } #endif // __AVX512F__ if (elempack == 8) { __m256 _a = _mm256_loadu_ps(a); __m256 _b = _mm256_loadu_ps(b); - for (; i < size; i += 8, ptr += 8, ++gamma, ++beta) + for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) { __m256 _cur = _mm256_loadu_ps(ptr); __m256 _gamma = _mm256_set1_ps(*gamma); @@ -344,12 +285,26 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, floa _mm256_storeu_ps(ptr, _cur); } } + else if (elempack == 1) + { + __m256 _a = _mm256_set1_ps(*a); + __m256 _b = _mm256_set1_ps(*b); + for (; i + 8 <= elemcount; i += 8, ptr += 8, gamma += 8, beta += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma); + __m256 _beta = _mm256_loadu_ps(beta); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); + _mm256_storeu_ps(ptr, _cur); + } + } #endif // __AVX__ if (elempack == 4) { __m128 _a = _mm_loadu_ps(a); __m128 _b = _mm_loadu_ps(b); - for (; i < size; i += 4, ptr += 4, ++gamma, ++beta) + for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) { __m128 _cur = _mm_loadu_ps(ptr); __m128 _gamma = _mm_set1_ps(*gamma); @@ -359,55 +314,23 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, floa _mm_storeu_ps(ptr, _cur); } } -#endif // __SSE2__ - if (elempack == 1) + else if (elempack == 1) { -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - { - __m512 _a = _mm512_set1_ps(*a); - __m512 _b = _mm512_set1_ps(*b); - for (; i + 16 <= elemcount; i += 16, ptr += 16, gamma += 16, beta += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_loadu_ps(gamma); - __m512 _beta = _mm512_loadu_ps(beta); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); - } - } -#endif // __AVX512F__ - { - __m256 _a = _mm256_set1_ps(*a); - __m256 _b = _mm256_set1_ps(*b); - - for (; i + 8 <= elemcount; i += 8, ptr += 8, gamma += 8, beta += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_loadu_ps(gamma); - __m256 _beta = _mm256_loadu_ps(beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); - } - } -#endif // __AVX__ + __m128 _a = _mm_set1_ps(*a); + __m128 _b = _mm_set1_ps(*b); + for (; i + 4 <= elemcount; i += 4, ptr += 4, gamma += 4, beta += 4) { - __m128 _a = _mm_set1_ps(*a); - __m128 _b = _mm_set1_ps(*b); - for (; i + 4 <= elemcount; i += 4, ptr += 4, gamma += 4, beta += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_loadu_ps(gamma); - __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); - } + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma); + __m128 _beta = _mm_loadu_ps(beta); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); + _mm_storeu_ps(ptr, _cur); } + } #endif // __SSE2__ + if (elempack == 1) + { for (; i < elemcount; ++i, ++ptr, ++gamma, ++beta) { *ptr = ((*ptr) * (*a) + (*b)) * (*gamma) + (*beta); From 976692ad04c9b7824f290a2bc71b9c6af49ced3b Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Sun, 24 Jul 2022 05:23:46 +0000 Subject: [PATCH 10/22] Code format --- src/layer/x86/layernorm_x86.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 44a979a98e45..8766ab964cdb 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -428,7 +428,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else if (dims == 2) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -439,7 +439,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons { if (affine_size == w) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) @@ -451,7 +451,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else // if(affine_size == w * h) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); From d7007c341ece453fd6077c59300e03003e0d5f48 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Mon, 25 Jul 2022 03:28:05 +0000 Subject: [PATCH 11/22] Replace some member functions with static inline functions. --- src/layer/x86/layernorm_x86.cpp | 27 +++++++++++++++------------ src/layer/x86/layernorm_x86.h | 4 ---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 8766ab964cdb..c79b902eb01e 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -233,11 +233,11 @@ LayerNorm_x86::LayerNorm_x86() #endif // __SSE2__ } -void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) const +NCNN_FORCEINLINE static void fast_fmadd_fmadd(float* ptr, float* a, float* b, const float* gamma, const float* beta, int elempack, int elemcount, int size) { int i = 0; - const float* gamma = static_cast(gamma_data); - const float* beta = static_cast(beta_data); + // const float* gamma = static_cast(gamma_data); + // const float* beta = static_cast(beta_data); #if __SSE2__ #if __AVX__ @@ -338,7 +338,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_fmadd_fmadd(float* ptr, float* a, floa } } -void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size) const +NCNN_FORCEINLINE static void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size, const float* gamma, const float* beta, int affine, float eps) { float mean[16], var[16]; fast_mean(ptr, mean, elempack, elemcount, size); @@ -403,7 +403,7 @@ void NCNN_FORCEINLINE LayerNorm_x86::fast_1d_layer_norm(float* ptr, int elempack if (affine) { - fast_fmadd_fmadd(ptr, a, b, elempack, elemcount, size); + fast_fmadd_fmadd(ptr, a, b, gamma, beta, elempack, elemcount, size); } else { @@ -419,43 +419,46 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons int h = bottom_top_blob.h; int channels = bottom_top_blob.c; + const float* gamma = static_cast(gamma_data); + const float* beta = static_cast(beta_data); + if (dims == 1) { int elemcount = w * elempack; float* ptr = bottom_top_blob; // 1D layer norm is special. Treat them as unpacked. - fast_1d_layer_norm(ptr, 1, elemcount, elemcount); + fast_1d_layer_norm(ptr, 1, elemcount, elemcount, gamma, beta, affine, eps); } else if (dims == 2) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); - fast_1d_layer_norm(ptr, elempack, w, w * elempack); + fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); } } else if (dims == 3) { if (affine_size == w) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.channel(q).row(i); - fast_1d_layer_norm(ptr, elempack, w, w * elempack); + fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); } } } else // if(affine_size == w * h) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); - fast_1d_layer_norm(ptr, elempack, w * h, w * h * elempack); + fast_1d_layer_norm(ptr, elempack, w * h, w * h * elempack, gamma, beta, affine, eps); } } } diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index e4b81d5ca8d3..98b7b92c6b74 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -11,10 +11,6 @@ class LayerNorm_x86 : virtual public LayerNorm LayerNorm_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; - -protected: - NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size) const; - NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) const; }; } // namespace ncnn From 508d143c44cd894b28c9253f6500fac39e3a260d Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Mon, 25 Jul 2022 03:28:45 +0000 Subject: [PATCH 12/22] Add copyright header --- src/layer/x86/layernorm_x86.cpp | 14 ++++++++++++++ src/layer/x86/layernorm_x86.h | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index c79b902eb01e..f34f51d3b565 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -1,3 +1,17 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + #include "layernorm_x86.h" #include "x86_usability.h" #include diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index 98b7b92c6b74..e6f902b55de3 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -1,3 +1,17 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + #ifndef LAYER_LAYERNORM_X86_H #define LAYER_LAYERNORM_X86_H From cf015d8d94468913d44f59d72a8a88658026efd0 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Mon, 25 Jul 2022 03:55:27 +0000 Subject: [PATCH 13/22] apply code-format changes --- src/layer/x86/layernorm_x86.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index f34f51d3b565..58b3b9699ae0 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -445,7 +445,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else if (dims == 2) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -456,7 +456,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons { if (affine_size == w) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) @@ -468,7 +468,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else // if(affine_size == w * h) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); From 5084955f306de1b0816065a01d6fd3d8f132c156 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Mon, 25 Jul 2022 08:27:03 +0000 Subject: [PATCH 14/22] Add more tests with 16 packed for AVX512 --- tests/test_layernorm.cpp | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/test_layernorm.cpp b/tests/test_layernorm.cpp index b4e3ad7fa1e6..a6010e54e85d 100644 --- a/tests/test_layernorm.cpp +++ b/tests/test_layernorm.cpp @@ -42,12 +42,16 @@ static int test_layernorm_0() || test_layernorm(RandomMat(4, 5, 6), 4, 0.01f, 0) || test_layernorm(RandomMat(3, 3, 8), 3, 0.002f, 0) || test_layernorm(RandomMat(5, 6, 12), 5, 0.02f, 0) + || test_layernorm(RandomMat(4, 7, 16), 4, 0.02f, 0) || test_layernorm(RandomMat(6, 7, 24), 6, 0.001f, 0) + || test_layernorm(RandomMat(5, 8, 32), 5, 0.001f, 0) || test_layernorm(RandomMat(6, 4, 2), 6, 0.01f, 1) || test_layernorm(RandomMat(4, 5, 6), 4, 0.01f, 1) || test_layernorm(RandomMat(3, 3, 8), 3, 0.002f, 1) || test_layernorm(RandomMat(5, 6, 12), 5, 0.02f, 1) - || test_layernorm(RandomMat(6, 7, 24), 6, 0.001f, 1); + || test_layernorm(RandomMat(4, 7, 16), 4, 0.02f, 1) + || test_layernorm(RandomMat(6, 7, 24), 6, 0.001f, 1) + || test_layernorm(RandomMat(5, 8, 32), 5, 0.001f, 1); } static int test_layernorm_1() @@ -57,12 +61,16 @@ static int test_layernorm_1() || test_layernorm(RandomMat(4, 5, 6), 20, 0.01f, 0) || test_layernorm(RandomMat(3, 3, 8), 9, 0.002f, 0) || test_layernorm(RandomMat(5, 6, 12), 30, 0.02f, 0) + || test_layernorm(RandomMat(4, 7, 16), 28, 0.02f, 0) || test_layernorm(RandomMat(6, 7, 24), 42, 0.001f, 0) + || test_layernorm(RandomMat(5, 8, 32), 40, 0.001f, 0) || test_layernorm(RandomMat(6, 4, 2), 24, 0.01f, 1) || test_layernorm(RandomMat(4, 5, 6), 20, 0.01f, 1) || test_layernorm(RandomMat(3, 3, 8), 9, 0.002f, 1) || test_layernorm(RandomMat(5, 6, 12), 30, 0.02f, 1) - || test_layernorm(RandomMat(6, 7, 24), 42, 0.001f, 1); + || test_layernorm(RandomMat(4, 7, 16), 28, 0.02f, 1) + || test_layernorm(RandomMat(6, 7, 24), 42, 0.001f, 1) + || test_layernorm(RandomMat(5, 8, 24), 40, 0.001f, 1); } static int test_layernorm_2() @@ -72,12 +80,16 @@ static int test_layernorm_2() || test_layernorm(RandomMat(5, 6), 5, 0.01f, 0) || test_layernorm(RandomMat(3, 8), 3, 0.002f, 0) || test_layernorm(RandomMat(6, 12), 6, 0.02f, 0) + || test_layernorm(RandomMat(4, 16), 4, 0.02f, 0) || test_layernorm(RandomMat(7, 24), 7, 0.001f, 0) + || test_layernorm(RandomMat(8, 32), 8, 0.001f, 0) || test_layernorm(RandomMat(4, 2), 4, 0.01f, 1) || test_layernorm(RandomMat(5, 6), 5, 0.01f, 1) || test_layernorm(RandomMat(3, 8), 3, 0.002f, 1) || test_layernorm(RandomMat(6, 12), 6, 0.02f, 1) - || test_layernorm(RandomMat(7, 24), 7, 0.001f, 1); + || test_layernorm(RandomMat(4, 16), 4, 0.02f, 1) + || test_layernorm(RandomMat(7, 24), 7, 0.001f, 1) + || test_layernorm(RandomMat(8, 32), 8, 0.001f, 1); } static int test_layernorm_3() @@ -87,12 +99,16 @@ static int test_layernorm_3() || test_layernorm(RandomMat(6), 6, 0.01f, 0) || test_layernorm(RandomMat(8), 8, 0.002f, 0) || test_layernorm(RandomMat(12), 12, 0.02f, 0) + || test_layernorm(RandomMat(16), 16, 0.02f, 0) || test_layernorm(RandomMat(24), 24, 0.001f, 0) + || test_layernorm(RandomMat(32), 32, 0.001f, 0) || test_layernorm(RandomMat(2), 2, 0.01f, 1) || test_layernorm(RandomMat(6), 6, 0.01f, 1) || test_layernorm(RandomMat(8), 8, 0.002f, 1) || test_layernorm(RandomMat(12), 12, 0.02f, 1) - || test_layernorm(RandomMat(24), 24, 0.001f, 1); + || test_layernorm(RandomMat(16), 16, 0.02f, 1) + || test_layernorm(RandomMat(24), 24, 0.001f, 1) + || test_layernorm(RandomMat(32), 32, 0.001f, 1); } int main() From 48fb4ea89c7a4ccab3ee86ed265a69cc2aa7321d Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Mon, 25 Jul 2022 08:30:56 +0000 Subject: [PATCH 15/22] Code format --- src/layer/x86/layernorm_x86.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index f34f51d3b565..58b3b9699ae0 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -445,7 +445,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else if (dims == 2) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -456,7 +456,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons { if (affine_size == w) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) @@ -468,7 +468,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else // if(affine_size == w * h) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); From 487568d3e4f356b8fc55770479704dabc9dd9974 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Tue, 26 Jul 2022 03:52:00 +0000 Subject: [PATCH 16/22] Copyright statement year fixed --- src/layer/x86/layernorm_x86.cpp | 2 +- src/layer/x86/layernorm_x86.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 58b3b9699ae0..8cdc10211920 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index e6f902b55de3..e62e5dddee73 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at From 23db5ab130e216363c73bea2c7344a5671be194f Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Tue, 26 Jul 2022 05:33:29 +0000 Subject: [PATCH 17/22] Fix accidentally added corelation of mean/var and SIMD ISA --- src/layer/x86/layernorm_x86.cpp | 293 ++++++++++++++++++++------------ 1 file changed, 189 insertions(+), 104 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 8cdc10211920..e0f4aae68b2a 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -31,73 +31,117 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16 || elempack == 1) + __m512 _sum_512 = _mm512_setzero_ps(); + for (; i + 16 <= size; i += 16, ptr += 16) { - __m512 _sum = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _sum = _mm512_add_ps(_sum, _cur); - } - if (elempack == 16) - { - __m512 _elemcount = _mm512_set1_ps(float(elemcount)); - __m512 _mean = _mm512_div_ps(_sum, _elemcount); - _mm512_storeu_ps(mean, _mean); - } - else - { - sum += _mm512_reduce_add_ps(_sum); - } + __m512 _cur = _mm512_loadu_ps(ptr); + _sum_512 = _mm512_add_ps(_sum_512, _cur); } #endif // __AVX512F__ - if (elempack == 8 || elempack == 1) + __m256 _sum_256 = _mm256_setzero_ps(); + for (; i + 8 <= size; i += 8, ptr += 8) { - __m256 _sum = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) + __m256 _cur = _mm256_loadu_ps(ptr); + _sum_256 = _mm256_add_ps(_sum_256, _cur); + } +#endif // __AVX__ + __m128 _sum_128 = _mm_setzero_ps(); + for (; i + 4 <= size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _sum_128 = _mm_add_ps(_sum_128, _cur); + } +#endif // __SSE2__ + for (; i < size; ++i, ++ptr) + { + sum += *ptr; + } + + if (elempack == 16) + { + __m512 _mean = _mm512_div_ps(_sum_512, _mm512_set1_ps((float)elemcount)); + _mm512_storeu_ps(mean, _mean); + } + if (elempack == 8) + { +#if __AVX512F__ { - __m256 _cur = _mm256_loadu_ps(ptr); - _sum = _mm256_add_ps(_sum, _cur); + __m256 _low = _mm512_castps512_ps256(_sum_512); + __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_512), 1)); + _sum_256 = _mm256_add_ps(_sum_256, _high); + _sum_256 = _mm256_add_ps(_sum_256, _low); } - if (elempack == 8) +#endif // __AVX512F__ + __m256 _mean = _mm256_div_ps(_sum_256, _mm256_set1_ps((float)elemcount)); + _mm256_storeu_ps(mean, _mean); + // duplicate until len is 16 + _mm256_storeu_ps(mean + 8, _mean); + } + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ { - __m256 _elemcount = _mm256_set1_ps(float(elemcount)); - __m256 _mean = _mm256_div_ps(_sum, _elemcount); - _mm256_storeu_ps(mean, _mean); + __m256 _low = _mm512_castps512_ps256(_sum_512); + __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_512), 1)); + _sum_256 = _mm256_add_ps(_sum_256, _high); + _sum_256 = _mm256_add_ps(_sum_256, _low); } - else +#endif // __AVX512F__ { - sum += _mm256_reduce_add_ps(_sum); + __m128 _low = _mm256_castps256_ps128(_sum_256); + __m128 _high = _mm256_extractf128_ps(_sum_256, 1); + _sum_128 = _mm_add_ps(_sum_128, _low); + _sum_128 = _mm_add_ps(_sum_128, _high); } - } #endif // __AVX__ - if (elempack == 4 || elempack == 1) + __m128 _mean = _mm_div_ps(_sum_128, _mm_set1_ps((float)elemcount)); + _mm_storeu_ps(mean, _mean); + // duplicate until len is 16 + _mm_storeu_ps(mean + 4, _mean); + _mm_storeu_ps(mean + 8, _mean); + _mm_storeu_ps(mean + 12, _mean); + } + if (elempack == 1) { - __m128 _sum = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + sum += _mm512_reduce_add_ps(_sum_512); +#endif // __AVX512F__ + sum += _mm256_reduce_add_ps(_sum_256); +#endif // __AVX__ + sum += _mm_reduce_add_ps(_sum_128); +#endif // __SSE2__ +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ { - __m128 _cur = _mm_loadu_ps(ptr); - _sum = _mm_add_ps(_sum, _cur); + _mm512_storeu_ps(mean, _mm512_set1_ps(sum / elemcount)); + return; } - if (elempack == 4) +#endif // __AVX512F__ { - __m128 _elemcount = _mm_set1_ps(float(elemcount)); - __m128 _mean = _mm_div_ps(_sum, _elemcount); - _mm_storeu_ps(mean, _mean); + __m256 _mean = _mm256_set1_ps(sum / elemcount); + _mm256_storeu_ps(mean, _mean); + _mm256_storeu_ps(mean + 8, _mean); + return; } - else +#endif // __AVX__ { - sum += _mm_reduce_add_ps(_sum); + __m128 _mean = _mm_set1_ps(sum / elemcount); + _mm_storeu_ps(mean, _mean); + _mm_storeu_ps(mean + 4, _mean); + _mm_storeu_ps(mean + 8, _mean); + _mm_storeu_ps(mean + 12, _mean); + return; } - } #endif // __SSE2__ - if (elempack == 1) - { - for (; i < size; ++i, ++ptr) + float _mean = sum / elemcount; + for (int i = 0; i < 16; ++i) { - sum += *ptr; + mean[i] = _mean; } - *mean = sum / elemcount; } } @@ -108,88 +152,129 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16 || elempack == 1) + __m512 _mean_512 = _mm512_loadu_ps(mean); + __m512 _sq_sum_512 = _mm512_setzero_ps(); + for (; i + 16 <= size; i += 16, ptr += 16) { - __m512 _mean = elempack == 1 ? _mm512_set1_ps(*mean) : _mm512_loadu_ps(mean); - __m512 _sq_sum = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_sub_ps(_cur, _mean); - _sq_sum = _mm512_fmadd_ps(_cur, _cur, _sq_sum); - } - if (elempack == 16) - { - __m512 _elemcount = _mm512_set1_ps(float(elemcount)); - __m512 _var = _mm512_div_ps(_sq_sum, _elemcount); - _mm512_storeu_ps(var, _var); - } - else - { - sq_sum += _mm512_reduce_add_ps(_sq_sum); - } + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_sub_ps(_cur, _mean_512); + _sq_sum_512 = _mm512_fmadd_ps(_cur, _cur, _sq_sum_512); } #endif // __AVX512F__ - if (elempack == 8 || elempack == 1) + __m256 _mean_256 = _mm256_loadu_ps(mean); + __m256 _sq_sum_256 = _mm256_setzero_ps(); + for (; i + 8 <= size; i += 8, ptr += 8) { - __m256 _mean = elempack == 1 ? _mm256_set1_ps(*mean) : _mm256_loadu_ps(mean); - __m256 _sq_sum = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) + __m256 _cur = _mm256_loadu_ps(ptr); + _cur = _mm256_sub_ps(_cur, _mean_256); + _sq_sum_256 = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum_256); + } +#endif // __AVX__ + __m128 _mean_128 = _mm_loadu_ps(mean); + __m128 _sq_sum_128 = _mm_setzero_ps(); + for (; i + 4 <= size; i += 4, ptr += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_sub_ps(_cur, _mean_128); + _sq_sum_128 = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum_128); + } +#endif // __SSE2__ + float _mean = *mean; + for (; i < size; ++i, ++ptr) + { + float tmp = *ptr - _mean; + sq_sum += tmp * tmp; + } + + if (elempack == 16) + { + __m512 _var = _mm512_div_ps(_sq_sum_512, _mm512_set1_ps((float)elemcount)); + _mm512_storeu_ps(var, _var); + } + if (elempack == 8) + { +#if __AVX512F__ { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_sub_ps(_cur, _mean); - _sq_sum = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum); + __m256 _low = _mm512_castps512_ps256(_sq_sum_512); + __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sq_sum_512), 1)); + _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _low); + _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _high); } - if (elempack == 8) +#endif // __AVX512F__ + __m256 _var = _mm256_div_ps(_sq_sum_256, _mm256_set1_ps((float)elemcount)); + _mm256_storeu_ps(var, _var); + _mm256_storeu_ps(var + 8, _var); + } + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ { - __m256 _elemcount = _mm256_set1_ps(float(elemcount)); - __m256 _var = _mm256_div_ps(_sq_sum, _elemcount); - _mm256_storeu_ps(var, _var); + __m256 _low = _mm512_castps512_ps256(_sq_sum_512); + __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sq_sum_512), 1)); + _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _high); + _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _low); } - else +#endif // __AVX512F__ { - sq_sum += _mm256_reduce_add_ps(_sq_sum); + __m128 _low = _mm256_castps256_ps128(_sq_sum_256); + __m128 _high = _mm256_extractf128_ps(_sq_sum_256, 1); + _sq_sum_128 = _mm_add_ps(_sq_sum_128, _low); + _sq_sum_128 = _mm_add_ps(_sq_sum_128, _high); } - } #endif // __AVX__ - if (elempack == 4 || elempack == 1) + __m128 _var = _mm_div_ps(_sq_sum_128, _mm_set1_ps((float)elemcount)); + _mm_storeu_ps(var, _var); + _mm_storeu_ps(var + 4, _var); + _mm_storeu_ps(var + 8, _var); + _mm_storeu_ps(var + 12, _var); + } + if (elempack == 1) { - __m128 _mean = elempack == 1 ? _mm_set1_ps(*mean) : _mm_loadu_ps(mean); - __m128 _sq_sum = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + sq_sum += _mm512_reduce_add_ps(_sq_sum_512); +#endif // __AVX512F__ + sq_sum += _mm256_reduce_add_ps(_sq_sum_256); +#endif // __AVX__ + sq_sum += _mm_reduce_add_ps(_sq_sum_128); +#endif // __SSE2__ +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_sub_ps(_cur, _mean); - _sq_sum = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum); + _mm512_storeu_ps(var, _mm512_set1_ps(sq_sum / elemcount)); + return; } - if (elempack == 4) +#endif // __AVX512F__ { - __m128 _elemcount = _mm_set1_ps(float(elemcount)); - __m128 _var = _mm_div_ps(_sq_sum, _elemcount); - _mm_storeu_ps(var, _var); + __m256 _var = _mm256_set1_ps(sq_sum / elemcount); + _mm256_storeu_ps(var, _var); + _mm256_storeu_ps(var + 8, _var); + return; } - else +#endif // __AVX__ { - sq_sum += _mm_reduce_add_ps(_sq_sum); + __m128 _var = _mm_set1_ps(sq_sum / elemcount); + _mm_storeu_ps(var, _var); + _mm_storeu_ps(var + 4, _var); + _mm_storeu_ps(var + 8, _var); + _mm_storeu_ps(var + 12, _var); + return; } - } #endif // __SSE2__ - if (elempack == 1) - { - float _mean = *mean; - for (; i < size; ++i, ++ptr) + float _var = sq_sum / elemcount; + for (int i = 0; i < 16; ++i) { - float tmp = *ptr - _mean; - sq_sum += tmp * tmp; + var[i] = _var; } - *var = sq_sum / elemcount; } } static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) { int i = 0; - #if __SSE2__ #if __AVX__ #if __AVX512F__ @@ -445,7 +530,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else if (dims == 2) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -456,7 +541,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons { if (affine_size == w) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) @@ -468,7 +553,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else // if(affine_size == w * h) { - #pragma omp parallel for num_threads(opt.num_threads) +#pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); From 72777b4fed75f558390f7abfe67dcca55b58064a Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Tue, 26 Jul 2022 06:48:26 +0000 Subject: [PATCH 18/22] Fix accidentally added corelation of fmadd/affine_fmadd and SIMD ISA --- src/layer/x86/layernorm_x86.cpp | 315 +++++++++++++++++++------------- 1 file changed, 186 insertions(+), 129 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index e0f4aae68b2a..02cebbf676f8 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -272,16 +272,15 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e } } -static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elempack, int elemcount, int size) +static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int size) { int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16 || elempack == 1) { - __m512 _a = elempack == 1 ? _mm512_set1_ps(*a) : _mm512_loadu_ps(a); - __m512 _b = elempack == 1 ? _mm512_set1_ps(*b) : _mm512_loadu_ps(b); + __m512 _a = _mm512_loadu_ps(a); + __m512 _b = _mm512_loadu_ps(b); for (; i + 16 <= size; i += 16, ptr += 16) { __m512 _cur = _mm512_loadu_ps(ptr); @@ -290,10 +289,9 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem } } #endif // __AVX512F__ - if (elempack == 8 || elempack == 1) { - __m256 _a = elempack == 1 ? _mm256_set1_ps(*a) : _mm256_loadu_ps(a); - __m256 _b = elempack == 1 ? _mm256_set1_ps(*b) : _mm256_loadu_ps(b); + __m256 _a = _mm256_loadu_ps(a); + __m256 _b = _mm256_loadu_ps(b); for (; i + 8 <= size; i += 8, ptr += 8) { __m256 _cur = _mm256_loadu_ps(ptr); @@ -302,10 +300,9 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem } } #endif // __AVX__ - if (elempack == 4 || elempack == 1) { - __m128 _a = elempack == 1 ? _mm_set1_ps(*a) : _mm_loadu_ps(a); - __m128 _b = elempack == 1 ? _mm_set1_ps(*b) : _mm_loadu_ps(b); + __m128 _a = _mm_loadu_ps(a); + __m128 _b = _mm_loadu_ps(b); for (; i + 4 <= size; i += 4, ptr += 4) { __m128 _cur = _mm_loadu_ps(ptr); @@ -314,12 +311,9 @@ static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int elem } } #endif // __SSE2__ - if (elempack == 1) + for (; i < size; ++i, ++ptr) { - for (; i < elemcount; ++i, ++ptr) - { - *ptr = (*ptr) * (*a) + (*b); - } + *ptr = (*ptr) * (*a) + (*b); } } @@ -332,105 +326,158 @@ LayerNorm_x86::LayerNorm_x86() #endif // __SSE2__ } -NCNN_FORCEINLINE static void fast_fmadd_fmadd(float* ptr, float* a, float* b, const float* gamma, const float* beta, int elempack, int elemcount, int size) +NCNN_FORCEINLINE static void fast_fmadd_fmadd(float* ptr, float* a, float* b, const float* gamma, const float* beta, int elempack, int size) { int i = 0; - // const float* gamma = static_cast(gamma_data); - // const float* beta = static_cast(beta_data); - #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) { __m512 _a = _mm512_loadu_ps(a); __m512 _b = _mm512_loadu_ps(b); - for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) + if (elempack == 16) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_set1_ps(*gamma); - __m512 _beta = _mm512_set1_ps(*beta); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(*gamma); + __m512 _beta = _mm512_set1_ps(*beta); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); + } } - } - else if (elempack == 1) - { - __m512 _a = _mm512_set1_ps(*a); - __m512 _b = _mm512_set1_ps(*b); - for (; i + 16 <= elemcount; i += 16, ptr += 16, gamma += 16, beta += 16) + else if (elempack == 8) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_loadu_ps(gamma); - __m512 _beta = _mm512_loadu_ps(beta); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 2, beta += 2) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); + __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); + __m512 _beta_0 = _mm512_set1_ps(beta[0]); + __m512 _beta_1 = _mm512_set1_ps(beta[1]); + _gamma_0 = _mm512_mask_blend_ps(0xFF00, _gamma_0, _gamma_1); + _beta_0 = _mm512_mask_blend_ps(0xFF00, _beta_0, _beta_1); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); + _mm512_storeu_ps(ptr, _cur); + } + } + else if (elempack == 4) + { + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 4, beta += 4) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); + __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); + __m512 _gamma_2 = _mm512_set1_ps(gamma[2]); + __m512 _gamma_3 = _mm512_set1_ps(gamma[3]); + __m512 _beta_0 = _mm512_set1_ps(beta[0]); + __m512 _beta_1 = _mm512_set1_ps(beta[1]); + __m512 _beta_2 = _mm512_set1_ps(beta[2]); + __m512 _beta_3 = _mm512_set1_ps(beta[3]); + _gamma_0 = _mm512_mask_blend_ps(0x00F0, _gamma_0, _gamma_1); + _gamma_0 = _mm512_mask_blend_ps(0x0F00, _gamma_0, _gamma_2); + _gamma_0 = _mm512_mask_blend_ps(0xF000, _gamma_0, _gamma_3); + _beta_0 = _mm512_mask_blend_ps(0x00F0, _beta_0, _beta_1); + _beta_0 = _mm512_mask_blend_ps(0x0F00, _beta_0, _beta_2); + _beta_0 = _mm512_mask_blend_ps(0xF000, _beta_0, _beta_3); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); + _mm512_storeu_ps(ptr, _cur); + } + } + else if (elempack == 1) + { + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) + { + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma); + __m512 _beta = _mm512_loadu_ps(beta); + _cur = _mm512_fmadd_ps(_cur, _a, _b); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); + } } } #endif // __AVX512F__ - if (elempack == 8) { __m256 _a = _mm256_loadu_ps(a); __m256 _b = _mm256_loadu_ps(b); - for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) + if (elempack == 8) { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_set1_ps(*gamma); - __m256 _beta = _mm256_set1_ps(*beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); + for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(*gamma); + __m256 _beta = _mm256_set1_ps(*beta); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); + _mm256_storeu_ps(ptr, _cur); + } } - } - else if (elempack == 1) - { - __m256 _a = _mm256_set1_ps(*a); - __m256 _b = _mm256_set1_ps(*b); - for (; i + 8 <= elemcount; i += 8, ptr += 8, gamma += 8, beta += 8) + else if (elempack == 4) { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_loadu_ps(gamma); - __m256 _beta = _mm256_loadu_ps(beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); + for (; i + 8 <= size; i += 8, ptr += 8, gamma += 2, beta += 2) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma_0 = _mm256_set1_ps(gamma[0]); + __m256 _gamma_1 = _mm256_set1_ps(gamma[1]); + __m256 _beta_0 = _mm256_set1_ps(beta[0]); + __m256 _beta_1 = _mm256_set1_ps(beta[1]); + _gamma_0 = _mm256_blend_ps(_gamma_0, _gamma_1, 0xF0); + _beta_0 = _mm256_blend_ps(_beta_0, _beta_1, 0xF0); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma_0, _beta_0); + _mm256_storeu_ps(ptr, _cur); + } + } + else if (elempack == 1) + { + for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) + { + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma); + __m256 _beta = _mm256_loadu_ps(beta); + _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); + _mm256_storeu_ps(ptr, _cur); + } } } #endif // __AVX__ - if (elempack == 4) { __m128 _a = _mm_loadu_ps(a); __m128 _b = _mm_loadu_ps(b); - for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) + if (elempack == 4) { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_set1_ps(*gamma); - __m128 _beta = _mm_set1_ps(*beta); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); + for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) + { + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(*gamma); + __m128 _beta = _mm_set1_ps(*beta); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); + _mm_storeu_ps(ptr, _cur); + } } - } - else if (elempack == 1) - { - __m128 _a = _mm_set1_ps(*a); - __m128 _b = _mm_set1_ps(*b); - for (; i + 4 <= elemcount; i += 4, ptr += 4, gamma += 4, beta += 4) + else if (elempack == 1) { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_loadu_ps(gamma); - __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); + for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma); + __m128 _beta = _mm_loadu_ps(beta); + _cur = _mm_comp_fmadd_ps(_cur, _a, _b); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); + _mm_storeu_ps(ptr, _cur); + } } } #endif // __SSE2__ if (elempack == 1) { - for (; i < elemcount; ++i, ++ptr, ++gamma, ++beta) + for (; i < size; ++i, ++ptr, ++gamma, ++beta) { *ptr = ((*ptr) * (*a) + (*b)) * (*gamma) + (*beta); } @@ -444,69 +491,79 @@ NCNN_FORCEINLINE static void fast_1d_layer_norm(float* ptr, int elempack, int el fast_var(ptr, var, mean, elempack, elemcount, size); float *a = var, *b = mean; + do { #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) - { - __m512 _a = _mm512_set1_ps(1.0f); - __m512 _eps = _mm512_set1_ps(eps); - __m512 _b = _mm512_setzero_ps(); - __m512 _var = _mm512_loadu_ps(var); - _var = _mm512_add_ps(_var, _eps); - __m512 _sqrt_var = _mm512_sqrt_ps(_var); - _a = _mm512_div_ps(_a, _sqrt_var); - __m512 _mean = _mm512_loadu_ps(mean); - _b = _mm512_fnmadd_ps(_mean, _a, _b); + { + __m512 _a = _mm512_set1_ps(1.0f); + __m512 _eps = _mm512_set1_ps(eps); + __m512 _b = _mm512_setzero_ps(); + __m512 _var = _mm512_loadu_ps(var); + _var = _mm512_add_ps(_var, _eps); + __m512 _sqrt_var = _mm512_sqrt_ps(_var); + _a = _mm512_div_ps(_a, _sqrt_var); + __m512 _mean = _mm512_loadu_ps(mean); + _b = _mm512_fnmadd_ps(_mean, _a, _b); - _mm512_storeu_ps(a, _a); - _mm512_storeu_ps(b, _b); - } + _mm512_storeu_ps(a, _a); + _mm512_storeu_ps(b, _b); + break; + } #endif // __AVX512F__ - if (elempack == 8) - { - __m256 _a = _mm256_set1_ps(1.0f); - __m256 _eps = _mm256_set1_ps(eps); - __m256 _b = _mm256_setzero_ps(); - __m256 _var = _mm256_loadu_ps(var); - _var = _mm256_add_ps(_var, _eps); - __m256 _sqrt_var = _mm256_sqrt_ps(_var); - _a = _mm256_div_ps(_a, _sqrt_var); - __m256 _mean = _mm256_loadu_ps(mean); - _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); - _mm256_storeu_ps(a, _a); - _mm256_storeu_ps(b, _b); - } + { + __m256 _a = _mm256_set1_ps(1.0f); + __m256 _eps = _mm256_set1_ps(eps); + __m256 _b = _mm256_setzero_ps(); + __m256 _var = _mm256_loadu_ps(var); + _var = _mm256_add_ps(_var, _eps); + __m256 _sqrt_var = _mm256_sqrt_ps(_var); + _a = _mm256_div_ps(_a, _sqrt_var); + __m256 _mean = _mm256_loadu_ps(mean); + _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); + _mm256_storeu_ps(a, _a); + _mm256_storeu_ps(a + 8, _a); + _mm256_storeu_ps(b, _b); + _mm256_storeu_ps(b + 8, _b); + break; + } #endif // __AVX__ - if (elempack == 4) - { - __m128 _a = _mm_set1_ps(1.0f); - __m128 _eps = _mm_set1_ps(eps); - __m128 _b = _mm_setzero_ps(); - __m128 _var = _mm_loadu_ps(var); - _var = _mm_add_ps(_var, _eps); - __m128 _sqrt_var = _mm_sqrt_ps(_var); - _a = _mm_div_ps(_a, _sqrt_var); - __m128 _mean = _mm_loadu_ps(mean); - _b = _mm_comp_fnmadd_ps(_mean, _a, _b); + { + __m128 _a = _mm_set1_ps(1.0f); + __m128 _eps = _mm_set1_ps(eps); + __m128 _b = _mm_setzero_ps(); + __m128 _var = _mm_loadu_ps(var); + _var = _mm_add_ps(_var, _eps); + __m128 _sqrt_var = _mm_sqrt_ps(_var); + _a = _mm_div_ps(_a, _sqrt_var); + __m128 _mean = _mm_loadu_ps(mean); + _b = _mm_comp_fnmadd_ps(_mean, _a, _b); - _mm_storeu_ps(a, _a); - _mm_storeu_ps(b, _b); - } + _mm_storeu_ps(a, _a); + _mm_storeu_ps(a + 4, _a); + _mm_storeu_ps(a + 8, _a); + _mm_storeu_ps(a + 12, _a); + _mm_storeu_ps(b, _b); + _mm_storeu_ps(b + 4, _b); + _mm_storeu_ps(b + 8, _b); + _mm_storeu_ps(b + 12, _b); + break; + } #endif // __SSE2__ - if (elempack == 1) - { - *a = static_cast(1.0f / sqrt(*var + eps)); - *b = -*mean * (*a); - } + for (int i = 0; i < 16; ++i) + { + a[i] = static_cast(1.0f / sqrt(var[i] + eps)); + b[i] = -mean[i] * (a[i]); + } + } while (0); if (affine) { - fast_fmadd_fmadd(ptr, a, b, gamma, beta, elempack, elemcount, size); + fast_fmadd_fmadd(ptr, a, b, gamma, beta, elempack, size); } else { - fast_fmadd(ptr, a, b, elempack, elemcount, size); + fast_fmadd(ptr, a, b, size); } } From b20d29848d286dc3b2e7fe06655534f57edf52e9 Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Tue, 26 Jul 2022 06:49:26 +0000 Subject: [PATCH 19/22] Fix a wrong test param --- tests/test_layernorm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_layernorm.cpp b/tests/test_layernorm.cpp index a6010e54e85d..fefb37c8a4c9 100644 --- a/tests/test_layernorm.cpp +++ b/tests/test_layernorm.cpp @@ -70,7 +70,7 @@ static int test_layernorm_1() || test_layernorm(RandomMat(5, 6, 12), 30, 0.02f, 1) || test_layernorm(RandomMat(4, 7, 16), 28, 0.02f, 1) || test_layernorm(RandomMat(6, 7, 24), 42, 0.001f, 1) - || test_layernorm(RandomMat(5, 8, 24), 40, 0.001f, 1); + || test_layernorm(RandomMat(5, 8, 32), 40, 0.001f, 1); } static int test_layernorm_2() From 4fddf9e69b81db5495b465b1fce6aeeb20a9920a Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Tue, 26 Jul 2022 07:01:14 +0000 Subject: [PATCH 20/22] Fix runtime dispatch --- src/layer/x86/layernorm_x86.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 02cebbf676f8..9b9190dc4c5d 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -56,12 +56,14 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in { sum += *ptr; } - +#if __AVX512F__ if (elempack == 16) { __m512 _mean = _mm512_div_ps(_sum_512, _mm512_set1_ps((float)elemcount)); _mm512_storeu_ps(mean, _mean); } +#endif // __AVX512F__ +#if __AVX__ if (elempack == 8) { #if __AVX512F__ @@ -77,6 +79,8 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in // duplicate until len is 16 _mm256_storeu_ps(mean + 8, _mean); } +#endif // __AVX__ +#if __SSE2__ if (elempack == 4) { #if __AVX__ @@ -102,6 +106,7 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in _mm_storeu_ps(mean + 8, _mean); _mm_storeu_ps(mean + 12, _mean); } +#endif // __SSE2__ if (elempack == 1) { #if __SSE2__ @@ -186,11 +191,14 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e sq_sum += tmp * tmp; } +#if __AVX512F__ if (elempack == 16) { __m512 _var = _mm512_div_ps(_sq_sum_512, _mm512_set1_ps((float)elemcount)); _mm512_storeu_ps(var, _var); } +#endif // __AVX512F__ +#if __AVX__ if (elempack == 8) { #if __AVX512F__ @@ -205,6 +213,8 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e _mm256_storeu_ps(var, _var); _mm256_storeu_ps(var + 8, _var); } +#endif // __AVX__ +#if __SSE2__ if (elempack == 4) { #if __AVX__ @@ -229,6 +239,7 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e _mm_storeu_ps(var + 8, _var); _mm_storeu_ps(var + 12, _var); } +#endif // __SSE2__ if (elempack == 1) { #if __SSE2__ @@ -587,7 +598,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else if (dims == 2) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) { float* ptr = bottom_top_blob.row(i); @@ -598,7 +609,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons { if (affine_size == w) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { for (int i = 0; i < h; ++i) @@ -610,7 +621,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } else // if(affine_size == w * h) { -#pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) { float* ptr = bottom_top_blob.channel(q); From 2555b3e8bf965cf51731e33b599a2211ef541c8f Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Tue, 26 Jul 2022 07:02:49 +0000 Subject: [PATCH 21/22] apply code-format changes --- src/layer/x86/layernorm_x86.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 9b9190dc4c5d..9ea2e2556337 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -502,7 +502,8 @@ NCNN_FORCEINLINE static void fast_1d_layer_norm(float* ptr, int elempack, int el fast_var(ptr, var, mean, elempack, elemcount, size); float *a = var, *b = mean; - do { + do + { #if __SSE2__ #if __AVX__ #if __AVX512F__ From 1b118f783a08636516f57bcd990410a86e325265 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 29 Jul 2022 15:26:08 +0800 Subject: [PATCH 22/22] no store duplicates --- src/layer/x86/layernorm_x86.cpp | 561 +++++++++++++++----------------- src/layer/x86/layernorm_x86.h | 2 +- 2 files changed, 260 insertions(+), 303 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 9ea2e2556337..3f6a66a5ec03 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -24,10 +24,18 @@ #endif // __AVX__ #endif // __SSE2__ +namespace ncnn { + +LayerNorm_x86::LayerNorm_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, int elemcount, int size) { int i = 0; - float sum = 0.0f; #if __SSE2__ #if __AVX__ #if __AVX512F__ @@ -52,10 +60,14 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in _sum_128 = _mm_add_ps(_sum_128, _cur); } #endif // __SSE2__ + float sum = 0.0f; for (; i < size; ++i, ++ptr) { sum += *ptr; } + +#if __SSE2__ +#if __AVX__ #if __AVX512F__ if (elempack == 16) { @@ -63,7 +75,7 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in _mm512_storeu_ps(mean, _mean); } #endif // __AVX512F__ -#if __AVX__ + if (elempack == 8) { #if __AVX512F__ @@ -76,11 +88,9 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in #endif // __AVX512F__ __m256 _mean = _mm256_div_ps(_sum_256, _mm256_set1_ps((float)elemcount)); _mm256_storeu_ps(mean, _mean); - // duplicate until len is 16 - _mm256_storeu_ps(mean + 8, _mean); } #endif // __AVX__ -#if __SSE2__ + if (elempack == 4) { #if __AVX__ @@ -101,12 +111,9 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in #endif // __AVX__ __m128 _mean = _mm_div_ps(_sum_128, _mm_set1_ps((float)elemcount)); _mm_storeu_ps(mean, _mean); - // duplicate until len is 16 - _mm_storeu_ps(mean + 4, _mean); - _mm_storeu_ps(mean + 8, _mean); - _mm_storeu_ps(mean + 12, _mean); } #endif // __SSE2__ + if (elempack == 1) { #if __SSE2__ @@ -118,46 +125,27 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in #endif // __AVX__ sum += _mm_reduce_add_ps(_sum_128); #endif // __SSE2__ + mean[0] = sum / elemcount; + } +} + +static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, const float* mean, int elempack, int elemcount, int size) +{ + const float _mean = mean[0]; #if __SSE2__ + __m128 _mean_128 = (elempack == 4) ? _mm_loadu_ps(mean) : _mm_set1_ps(_mean); #if __AVX__ + __m256 _mean_256 = (elempack == 8) ? _mm256_loadu_ps(mean) : _mm256_insertf128_ps(_mm256_castps128_ps256(_mean_128), _mean_128, 1); #if __AVX512F__ - { - _mm512_storeu_ps(mean, _mm512_set1_ps(sum / elemcount)); - return; - } + __m512 _mean_512 = (elempack == 16) ? _mm512_loadu_ps(mean) : _mm512_insertf32x8(_mm512_castps256_ps512(_mean_256), _mean_256, 1); #endif // __AVX512F__ - { - __m256 _mean = _mm256_set1_ps(sum / elemcount); - _mm256_storeu_ps(mean, _mean); - _mm256_storeu_ps(mean + 8, _mean); - return; - } #endif // __AVX__ - { - __m128 _mean = _mm_set1_ps(sum / elemcount); - _mm_storeu_ps(mean, _mean); - _mm_storeu_ps(mean + 4, _mean); - _mm_storeu_ps(mean + 8, _mean); - _mm_storeu_ps(mean + 12, _mean); - return; - } #endif // __SSE2__ - float _mean = sum / elemcount; - for (int i = 0; i < 16; ++i) - { - mean[i] = _mean; - } - } -} -static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int elempack, int elemcount, int size) -{ int i = 0; - float sq_sum = 0.0f; #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _mean_512 = _mm512_loadu_ps(mean); __m512 _sq_sum_512 = _mm512_setzero_ps(); for (; i + 16 <= size; i += 16, ptr += 16) { @@ -166,7 +154,6 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e _sq_sum_512 = _mm512_fmadd_ps(_cur, _cur, _sq_sum_512); } #endif // __AVX512F__ - __m256 _mean_256 = _mm256_loadu_ps(mean); __m256 _sq_sum_256 = _mm256_setzero_ps(); for (; i + 8 <= size; i += 8, ptr += 8) { @@ -175,7 +162,6 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e _sq_sum_256 = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum_256); } #endif // __AVX__ - __m128 _mean_128 = _mm_loadu_ps(mean); __m128 _sq_sum_128 = _mm_setzero_ps(); for (; i + 4 <= size; i += 4, ptr += 4) { @@ -184,13 +170,15 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e _sq_sum_128 = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum_128); } #endif // __SSE2__ - float _mean = *mean; + float sq_sum = 0.0f; for (; i < size; ++i, ++ptr) { float tmp = *ptr - _mean; sq_sum += tmp * tmp; } +#if __SSE2__ +#if __AVX__ #if __AVX512F__ if (elempack == 16) { @@ -198,7 +186,7 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e _mm512_storeu_ps(var, _var); } #endif // __AVX512F__ -#if __AVX__ + if (elempack == 8) { #if __AVX512F__ @@ -211,10 +199,9 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e #endif // __AVX512F__ __m256 _var = _mm256_div_ps(_sq_sum_256, _mm256_set1_ps((float)elemcount)); _mm256_storeu_ps(var, _var); - _mm256_storeu_ps(var + 8, _var); } #endif // __AVX__ -#if __SSE2__ + if (elempack == 4) { #if __AVX__ @@ -235,11 +222,9 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e #endif // __AVX__ __m128 _var = _mm_div_ps(_sq_sum_128, _mm_set1_ps((float)elemcount)); _mm_storeu_ps(var, _var); - _mm_storeu_ps(var + 4, _var); - _mm_storeu_ps(var + 8, _var); - _mm_storeu_ps(var + 12, _var); } #endif // __SSE2__ + if (elempack == 1) { #if __SSE2__ @@ -251,323 +236,293 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, float* mean, int e #endif // __AVX__ sq_sum += _mm_reduce_add_ps(_sq_sum_128); #endif // __SSE2__ + var[0] = sq_sum / elemcount; + } +} + +static NCNN_FORCEINLINE void fast_fmadd(float* ptr, const float* a, const float* b, int elempack, int size) +{ + const float _a = a[0]; + const float _b = b[0]; #if __SSE2__ + __m128 _a_128 = (elempack == 4) ? _mm_loadu_ps(a) : _mm_set1_ps(_a); + __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps(b) : _mm_set1_ps(_b); #if __AVX__ + __m256 _a_256 = (elempack == 8) ? _mm256_loadu_ps(a) : _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); + __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps(b) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - { - _mm512_storeu_ps(var, _mm512_set1_ps(sq_sum / elemcount)); - return; - } + __m512 _a_512 = (elempack == 16) ? _mm512_loadu_ps(a) : _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); + __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps(b) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); #endif // __AVX512F__ - { - __m256 _var = _mm256_set1_ps(sq_sum / elemcount); - _mm256_storeu_ps(var, _var); - _mm256_storeu_ps(var + 8, _var); - return; - } #endif // __AVX__ - { - __m128 _var = _mm_set1_ps(sq_sum / elemcount); - _mm_storeu_ps(var, _var); - _mm_storeu_ps(var + 4, _var); - _mm_storeu_ps(var + 8, _var); - _mm_storeu_ps(var + 12, _var); - return; - } #endif // __SSE2__ - float _var = sq_sum / elemcount; - for (int i = 0; i < 16; ++i) - { - var[i] = _var; - } - } -} -static NCNN_FORCEINLINE void fast_fmadd(float* ptr, float* a, float* b, int size) -{ int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ + for (; i + 16 <= size; i += 16, ptr += 16) { - __m512 _a = _mm512_loadu_ps(a); - __m512 _b = _mm512_loadu_ps(b); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _mm512_storeu_ps(ptr, _cur); - } + __m512 _cur = _mm512_loadu_ps(ptr); + _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); + _mm512_storeu_ps(ptr, _cur); } #endif // __AVX512F__ + for (; i + 8 <= size; i += 8, ptr += 8) { - __m256 _a = _mm256_loadu_ps(a); - __m256 _b = _mm256_loadu_ps(b); - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _mm256_storeu_ps(ptr, _cur); - } + __m256 _cur = _mm256_loadu_ps(ptr); + _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); + _mm256_storeu_ps(ptr, _cur); } #endif // __AVX__ + for (; i + 4 <= size; i += 4, ptr += 4) { - __m128 _a = _mm_loadu_ps(a); - __m128 _b = _mm_loadu_ps(b); - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _mm_storeu_ps(ptr, _cur); - } + __m128 _cur = _mm_loadu_ps(ptr); + _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); + _mm_storeu_ps(ptr, _cur); } #endif // __SSE2__ for (; i < size; ++i, ++ptr) { - *ptr = (*ptr) * (*a) + (*b); + *ptr = (*ptr) * _a + _b; } } -namespace ncnn { - -LayerNorm_x86::LayerNorm_x86() +static NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, const float* a, const float* b, const float* gamma, const float* beta, int elempack, int size) { -#if __SSE2__ - support_packing = true; -#endif // __SSE2__ -} - -NCNN_FORCEINLINE static void fast_fmadd_fmadd(float* ptr, float* a, float* b, const float* gamma, const float* beta, int elempack, int size) -{ - int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ + if (elempack == 16) { - __m512 _a = _mm512_loadu_ps(a); - __m512 _b = _mm512_loadu_ps(b); - if (elempack == 16) + int i = 0; + __m512 _a_512 = _mm512_loadu_ps(a); + __m512 _b_512 = _mm512_loadu_ps(b); + for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) { - for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_set1_ps(*gamma); - __m512 _beta = _mm512_set1_ps(*beta); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); - } - } - else if (elempack == 8) - { - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 2, beta += 2) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); - __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); - __m512 _beta_0 = _mm512_set1_ps(beta[0]); - __m512 _beta_1 = _mm512_set1_ps(beta[1]); - _gamma_0 = _mm512_mask_blend_ps(0xFF00, _gamma_0, _gamma_1); - _beta_0 = _mm512_mask_blend_ps(0xFF00, _beta_0, _beta_1); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm512_storeu_ps(ptr, _cur); - } + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(*gamma); + __m512 _beta = _mm512_set1_ps(*beta); + _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); } - else if (elempack == 4) + } +#endif // __AVX512F__ + + if (elempack == 8) + { + int i = 0; + __m256 _a_256 = _mm256_loadu_ps(a); + __m256 _b_256 = _mm256_loadu_ps(b); +#if __AVX512F__ + __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); + __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 2, beta += 2) { - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 4, beta += 4) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); - __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); - __m512 _gamma_2 = _mm512_set1_ps(gamma[2]); - __m512 _gamma_3 = _mm512_set1_ps(gamma[3]); - __m512 _beta_0 = _mm512_set1_ps(beta[0]); - __m512 _beta_1 = _mm512_set1_ps(beta[1]); - __m512 _beta_2 = _mm512_set1_ps(beta[2]); - __m512 _beta_3 = _mm512_set1_ps(beta[3]); - _gamma_0 = _mm512_mask_blend_ps(0x00F0, _gamma_0, _gamma_1); - _gamma_0 = _mm512_mask_blend_ps(0x0F00, _gamma_0, _gamma_2); - _gamma_0 = _mm512_mask_blend_ps(0xF000, _gamma_0, _gamma_3); - _beta_0 = _mm512_mask_blend_ps(0x00F0, _beta_0, _beta_1); - _beta_0 = _mm512_mask_blend_ps(0x0F00, _beta_0, _beta_2); - _beta_0 = _mm512_mask_blend_ps(0xF000, _beta_0, _beta_3); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm512_storeu_ps(ptr, _cur); - } + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); + __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); + __m512 _beta_0 = _mm512_set1_ps(beta[0]); + __m512 _beta_1 = _mm512_set1_ps(beta[1]); + _gamma_0 = _mm512_mask_blend_ps(0xFF00, _gamma_0, _gamma_1); + _beta_0 = _mm512_mask_blend_ps(0xFF00, _beta_0, _beta_1); + _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); + _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); + _mm512_storeu_ps(ptr, _cur); } - else if (elempack == 1) +#endif // __AVX512F__ + + for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) { - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_loadu_ps(gamma); - __m512 _beta = _mm512_loadu_ps(beta); - _cur = _mm512_fmadd_ps(_cur, _a, _b); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); - } + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(*gamma); + __m256 _beta = _mm256_set1_ps(*beta); + _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); + _mm256_storeu_ps(ptr, _cur); } } -#endif // __AVX512F__ +#endif // __AVX__ + + if (elempack == 4) { - __m256 _a = _mm256_loadu_ps(a); - __m256 _b = _mm256_loadu_ps(b); - if (elempack == 8) + int i = 0; + __m128 _a_128 = _mm_loadu_ps(a); + __m128 _b_128 = _mm_loadu_ps(b); +#if __AVX__ + __m256 _a_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); + __m256 _b_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); + __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 4, beta += 4) { - for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_set1_ps(*gamma); - __m256 _beta = _mm256_set1_ps(*beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); - } + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); + __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); + __m512 _gamma_2 = _mm512_set1_ps(gamma[2]); + __m512 _gamma_3 = _mm512_set1_ps(gamma[3]); + __m512 _beta_0 = _mm512_set1_ps(beta[0]); + __m512 _beta_1 = _mm512_set1_ps(beta[1]); + __m512 _beta_2 = _mm512_set1_ps(beta[2]); + __m512 _beta_3 = _mm512_set1_ps(beta[3]); + _gamma_0 = _mm512_mask_blend_ps(0x00F0, _gamma_0, _gamma_1); + _gamma_0 = _mm512_mask_blend_ps(0x0F00, _gamma_0, _gamma_2); + _gamma_0 = _mm512_mask_blend_ps(0xF000, _gamma_0, _gamma_3); + _beta_0 = _mm512_mask_blend_ps(0x00F0, _beta_0, _beta_1); + _beta_0 = _mm512_mask_blend_ps(0x0F00, _beta_0, _beta_2); + _beta_0 = _mm512_mask_blend_ps(0xF000, _beta_0, _beta_3); + _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); + _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); + _mm512_storeu_ps(ptr, _cur); } - else if (elempack == 4) +#endif // __AVX512F__ + + for (; i + 8 <= size; i += 8, ptr += 8, gamma += 2, beta += 2) { - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 2, beta += 2) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma_0 = _mm256_set1_ps(gamma[0]); - __m256 _gamma_1 = _mm256_set1_ps(gamma[1]); - __m256 _beta_0 = _mm256_set1_ps(beta[0]); - __m256 _beta_1 = _mm256_set1_ps(beta[1]); - _gamma_0 = _mm256_blend_ps(_gamma_0, _gamma_1, 0xF0); - _beta_0 = _mm256_blend_ps(_beta_0, _beta_1, 0xF0); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm256_storeu_ps(ptr, _cur); - } + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma_0 = _mm256_set1_ps(gamma[0]); + __m256 _gamma_1 = _mm256_set1_ps(gamma[1]); + __m256 _beta_0 = _mm256_set1_ps(beta[0]); + __m256 _beta_1 = _mm256_set1_ps(beta[1]); + _gamma_0 = _mm256_blend_ps(_gamma_0, _gamma_1, 0xF0); + _beta_0 = _mm256_blend_ps(_beta_0, _beta_1, 0xF0); + _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma_0, _beta_0); + _mm256_storeu_ps(ptr, _cur); } - else if (elempack == 1) +#endif // __AVX__ + + for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) { - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_loadu_ps(gamma); - __m256 _beta = _mm256_loadu_ps(beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); - } + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(*gamma); + __m128 _beta = _mm_set1_ps(*beta); + _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); + _mm_storeu_ps(ptr, _cur); } } -#endif // __AVX__ +#endif // __SSE2__ + + if (elempack == 1) { - __m128 _a = _mm_loadu_ps(a); - __m128 _b = _mm_loadu_ps(b); - if (elempack == 4) + int i = 0; + const float _a = a[0]; + const float _b = b[0]; +#if __SSE2__ + __m128 _a_128 = _mm_set1_ps(_a); + __m128 _b_128 = _mm_set1_ps(_b); +#if __AVX__ + __m256 _a_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); + __m256 _b_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); + __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); + for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) { - for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) - { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_set1_ps(*gamma); - __m128 _beta = _mm_set1_ps(*beta); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); - } + __m512 _cur = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma); + __m512 _beta = _mm512_loadu_ps(beta); + _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); + _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); + _mm512_storeu_ps(ptr, _cur); } - else if (elempack == 1) +#endif // __AVX512F__ + + for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) { - for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_loadu_ps(gamma); - __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_comp_fmadd_ps(_cur, _a, _b); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); - } + __m256 _cur = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma); + __m256 _beta = _mm256_loadu_ps(beta); + _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); + _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); + _mm256_storeu_ps(ptr, _cur); + } +#endif // __AVX__ + + for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) + { + __m128 _cur = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma); + __m128 _beta = _mm_loadu_ps(beta); + _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); + _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); + _mm_storeu_ps(ptr, _cur); } - } #endif // __SSE2__ - if (elempack == 1) - { + for (; i < size; ++i, ++ptr, ++gamma, ++beta) { - *ptr = ((*ptr) * (*a) + (*b)) * (*gamma) + (*beta); + *ptr = ((*ptr) * _a + _b) * (*gamma) + (*beta); } } } -NCNN_FORCEINLINE static void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size, const float* gamma, const float* beta, int affine, float eps) +static NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size, const float* gamma, const float* beta, int affine, float eps) { - float mean[16], var[16]; + float mean[16] = {0.f}, var[16] = {0.f}; fast_mean(ptr, mean, elempack, elemcount, size); fast_var(ptr, var, mean, elempack, elemcount, size); float *a = var, *b = mean; - do - { #if __SSE2__ #if __AVX__ #if __AVX512F__ - { - __m512 _a = _mm512_set1_ps(1.0f); - __m512 _eps = _mm512_set1_ps(eps); - __m512 _b = _mm512_setzero_ps(); - __m512 _var = _mm512_loadu_ps(var); - _var = _mm512_add_ps(_var, _eps); - __m512 _sqrt_var = _mm512_sqrt_ps(_var); - _a = _mm512_div_ps(_a, _sqrt_var); - __m512 _mean = _mm512_loadu_ps(mean); - _b = _mm512_fnmadd_ps(_mean, _a, _b); - - _mm512_storeu_ps(a, _a); - _mm512_storeu_ps(b, _b); - break; - } + if (elempack == 16) + { + __m512 _a = _mm512_set1_ps(1.0f); + __m512 _eps = _mm512_set1_ps(eps); + __m512 _b = _mm512_setzero_ps(); + __m512 _var = _mm512_loadu_ps(var); + _var = _mm512_add_ps(_var, _eps); + __m512 _sqrt_var = _mm512_sqrt_ps(_var); + _a = _mm512_div_ps(_a, _sqrt_var); + __m512 _mean = _mm512_loadu_ps(mean); + _b = _mm512_fnmadd_ps(_mean, _a, _b); + + _mm512_storeu_ps(a, _a); + _mm512_storeu_ps(b, _b); + } #endif // __AVX512F__ - { - __m256 _a = _mm256_set1_ps(1.0f); - __m256 _eps = _mm256_set1_ps(eps); - __m256 _b = _mm256_setzero_ps(); - __m256 _var = _mm256_loadu_ps(var); - _var = _mm256_add_ps(_var, _eps); - __m256 _sqrt_var = _mm256_sqrt_ps(_var); - _a = _mm256_div_ps(_a, _sqrt_var); - __m256 _mean = _mm256_loadu_ps(mean); - _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); - _mm256_storeu_ps(a, _a); - _mm256_storeu_ps(a + 8, _a); - _mm256_storeu_ps(b, _b); - _mm256_storeu_ps(b + 8, _b); - break; - } + if (elempack == 8) + { + __m256 _a = _mm256_set1_ps(1.0f); + __m256 _eps = _mm256_set1_ps(eps); + __m256 _b = _mm256_setzero_ps(); + __m256 _var = _mm256_loadu_ps(var); + _var = _mm256_add_ps(_var, _eps); + __m256 _sqrt_var = _mm256_sqrt_ps(_var); + _a = _mm256_div_ps(_a, _sqrt_var); + __m256 _mean = _mm256_loadu_ps(mean); + _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); + + _mm256_storeu_ps(a, _a); + _mm256_storeu_ps(b, _b); + } #endif // __AVX__ - { - __m128 _a = _mm_set1_ps(1.0f); - __m128 _eps = _mm_set1_ps(eps); - __m128 _b = _mm_setzero_ps(); - __m128 _var = _mm_loadu_ps(var); - _var = _mm_add_ps(_var, _eps); - __m128 _sqrt_var = _mm_sqrt_ps(_var); - _a = _mm_div_ps(_a, _sqrt_var); - __m128 _mean = _mm_loadu_ps(mean); - _b = _mm_comp_fnmadd_ps(_mean, _a, _b); - - _mm_storeu_ps(a, _a); - _mm_storeu_ps(a + 4, _a); - _mm_storeu_ps(a + 8, _a); - _mm_storeu_ps(a + 12, _a); - _mm_storeu_ps(b, _b); - _mm_storeu_ps(b + 4, _b); - _mm_storeu_ps(b + 8, _b); - _mm_storeu_ps(b + 12, _b); - break; - } + if (elempack == 4) + { + __m128 _a = _mm_set1_ps(1.0f); + __m128 _eps = _mm_set1_ps(eps); + __m128 _b = _mm_setzero_ps(); + __m128 _var = _mm_loadu_ps(var); + _var = _mm_add_ps(_var, _eps); + __m128 _sqrt_var = _mm_sqrt_ps(_var); + _a = _mm_div_ps(_a, _sqrt_var); + __m128 _mean = _mm_loadu_ps(mean); + _b = _mm_comp_fnmadd_ps(_mean, _a, _b); + + _mm_storeu_ps(a, _a); + _mm_storeu_ps(b, _b); + } #endif // __SSE2__ - for (int i = 0; i < 16; ++i) - { - a[i] = static_cast(1.0f / sqrt(var[i] + eps)); - b[i] = -mean[i] * (a[i]); - } - } while (0); + if (elempack == 1) + { + a[0] = static_cast(1.0f / sqrt(var[0] + eps)); + b[0] = -mean[0] * (a[0]); + } if (affine) { @@ -575,7 +530,7 @@ NCNN_FORCEINLINE static void fast_1d_layer_norm(float* ptr, int elempack, int el } else { - fast_fmadd(ptr, a, b, size); + fast_fmadd(ptr, a, b, elempack, size); } } @@ -587,8 +542,8 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons int h = bottom_top_blob.h; int channels = bottom_top_blob.c; - const float* gamma = static_cast(gamma_data); - const float* beta = static_cast(beta_data); + const float* gamma = gamma_data; + const float* beta = beta_data; if (dims == 1) { @@ -597,7 +552,8 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons // 1D layer norm is special. Treat them as unpacked. fast_1d_layer_norm(ptr, 1, elemcount, elemcount, gamma, beta, affine, eps); } - else if (dims == 2) + + if (dims == 2) { #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < h; ++i) @@ -606,7 +562,8 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); } } - else if (dims == 3) + + if (dims == 3) { if (affine_size == w) { @@ -620,7 +577,7 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } } } - else // if(affine_size == w * h) + else // if (affine_size == w * h) { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; ++q) diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index e62e5dddee73..42eb551ed95d 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -29,4 +29,4 @@ class LayerNorm_x86 : virtual public LayerNorm } // namespace ncnn -#endif // LAYER_LAYERNORM_X86_H \ No newline at end of file +#endif // LAYER_LAYERNORM_X86_H