Skip to content

Commit

Permalink
[BatchNorm Optimize x86] AVX512 intrinsic (#4061)
Browse files Browse the repository at this point in the history
* Add the test samples for elempack==16

* Add the AVX512 Support for batchnorm
  • Loading branch information
LRY89757 authored Jul 21, 2022
1 parent e33c85c commit 13a9533
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
70 changes: 66 additions & 4 deletions src/layer/x86/batchnorm_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,74 @@ int BatchNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons
#if __AVX512F__
if (elempack == 16)
{
Mat tmp;
convert_packing(bottom_top_blob, tmp, 8, opt);
if (dims == 1)
{
int w = bottom_top_blob.w;

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < w; i++)
{
float* ptr = (float*)bottom_top_blob + i * 16;

__m512 _a = _mm512_loadu_ps((const float*)a_data + i * 16);
__m512 _b = _mm512_loadu_ps((const float*)b_data + i * 16);

__m512 _p = _mm512_loadu_ps(ptr);
_p = _mm512_fmadd_ps(_p, _b, _a);
_mm512_storeu_ps(ptr, _p);
}
}

if (dims == 2)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
__m512 _a = _mm512_loadu_ps((const float*)a_data + i * 16);
__m512 _b = _mm512_loadu_ps((const float*)b_data + i * 16);

float* ptr = bottom_top_blob.row(i);

for (int j = 0; j < w; j++)
{
__m512 _p = _mm512_loadu_ps(ptr);
_p = _mm512_fmadd_ps(_p, _b, _a);
_mm512_storeu_ps(ptr, _p);

ptr += 16;
}
}
}

if (dims == 3 || dims == 4)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int d = bottom_top_blob.d;
int c = bottom_top_blob.c;
int size = w * h * d;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < c; q++)
{
__m512 _a = _mm512_loadu_ps((const float*)a_data + q * 16);
__m512 _b = _mm512_loadu_ps((const float*)b_data + q * 16);

float* ptr = bottom_top_blob.channel(q);

forward_inplace(tmp, opt);
for (int i = 0; i < size; i++)
{
__m512 _p = _mm512_loadu_ps(ptr);
_p = _mm512_fmadd_ps(_p, _b, _a);
_mm512_storeu_ps(ptr, _p);

convert_packing(tmp, bottom_top_blob, 16, opt);
ptr += 16;
}
}
}

return 0;
}
Expand Down
12 changes: 9 additions & 3 deletions tests/test_batchnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ static int test_batchnorm_0()
|| test_batchnorm(RandomMat(7, 8, 9, 12), 0.f)
|| test_batchnorm(RandomMat(7, 8, 9, 12), 0.001f)
|| test_batchnorm(RandomMat(3, 4, 5, 13), 0.f)
|| test_batchnorm(RandomMat(3, 4, 5, 13), 0.001f);
|| test_batchnorm(RandomMat(3, 4, 5, 13), 0.f)
|| test_batchnorm(RandomMat(3, 4, 6, 32), 0.f)
|| test_batchnorm(RandomMat(3, 4, 5, 32), 0.001f);
}

static int test_batchnorm_1()
Expand All @@ -63,7 +65,9 @@ static int test_batchnorm_1()
|| test_batchnorm(RandomMat(7, 9, 12), 0.f)
|| test_batchnorm(RandomMat(7, 9, 12), 0.001f)
|| test_batchnorm(RandomMat(3, 5, 13), 0.f)
|| test_batchnorm(RandomMat(3, 5, 13), 0.001f);
|| test_batchnorm(RandomMat(3, 5, 13), 0.001f)
|| test_batchnorm(RandomMat(3, 5, 16), 0.001f)
|| test_batchnorm(RandomMat(3, 5, 32), 0.001f);
}

static int test_batchnorm_2()
Expand All @@ -74,7 +78,9 @@ static int test_batchnorm_2()
|| test_batchnorm(RandomMat(17, 12), 0.f)
|| test_batchnorm(RandomMat(17, 12), 0.001f)
|| test_batchnorm(RandomMat(19, 15), 0.f)
|| test_batchnorm(RandomMat(19, 15), 0.001f);
|| test_batchnorm(RandomMat(19, 15), 0.001f)
|| test_batchnorm(RandomMat(128, 16), 0.f)
|| test_batchnorm(RandomMat(16, 128), 0.001f);
}

static int test_batchnorm_3()
Expand Down

0 comments on commit 13a9533

Please sign in to comment.