Skip to content

Commit

Permalink
Use scalar implementation instead of neon implementation to avoid out…
Browse files Browse the repository at this point in the history
… of range memory access in the tail conv3x3.
  • Loading branch information
Xreki committed Sep 28, 2017
1 parent 9928eb8 commit 3fefee8
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions paddle/function/neon/NeonDepthwiseConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,27 @@ limitations under the License. */
#include "neon_util.h"

namespace paddle {

namespace neon {

#if defined(__ARM_NEON__) || defined(__ARM_NEON)

template <int filterSize, int stride>
struct DepthwiseConvKernel {};

inline float32_t conv3x3(float32x4_t r0,
float32x4_t r1,
float32x4_t r2,
inline float32_t conv3x3(const float* r0,
const float* r1,
const float* r2,
float32x4_t k0,
float32x4_t k1,
float32x4_t k2) {
float32x4_t tmp;
tmp = vmulq_f32(r0, k0);
tmp = vmlaq_f32(tmp, r1, k1);
tmp = vmlaq_f32(tmp, r2, k2);
return vaddvq_f32(tmp);
float32_t tmp[12];
vst1q_f32(&(tmp[0]), k0);
vst1q_f32(&(tmp[4]), k1);
vst1q_f32(&(tmp[8]), k2);
float32_t sum0 = r0[0] * tmp[0] + r0[1] * tmp[1] + r0[2] * tmp[2];
float32_t sum1 = r1[0] * tmp[4] + r1[1] * tmp[5] + r1[2] * tmp[6];
float32_t sum2 = r2[0] * tmp[8] + r2[1] * tmp[9] + r2[2] * tmp[10];
return sum0 + sum1 + sum2;
}

inline float32_t conv4x4(float32x4_t r0,
Expand Down Expand Up @@ -136,10 +138,7 @@ struct DepthwiseConvKernel<3, 1> {
}

for (int r = 0; r < remain; r++) {
float32x4_t i0 = vld1q_f32(r0);
float32x4_t i1 = vld1q_f32(r1);
float32x4_t i2 = vld1q_f32(r2);
*outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]);
*outputData = conv3x3(r0, r1, r2, k[0], k[1], k[2]);
r0++;
r1++;
r2++;
Expand Down Expand Up @@ -243,10 +242,7 @@ struct DepthwiseConvKernel<3, 2> {
}

for (int r = 0; r < remain; r++) {
float32x4_t i0 = vld1q_f32(r0);
float32x4_t i1 = vld1q_f32(r1);
float32x4_t i2 = vld1q_f32(r2);
*outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]);
*outputData = conv3x3(r0, r1, r2, k[0], k[1], k[2]);
r0 += 2;
r1 += 2;
r2 += 2;
Expand Down

0 comments on commit 3fefee8

Please sign in to comment.