Skip to content

Commit

Permalink
update for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
bukejiyu committed Apr 9, 2024
1 parent 3262a2f commit 26a0f9b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
7 changes: 6 additions & 1 deletion paddle/phi/kernels/fusion/cpu/fused_layer_norm_avx_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,12 @@ void LayerNormFunc(const T* x_data,
if (norm_bias_data) {
vbeta = _mm512_maskz_loadu_ps(mask, norm_bias_data + col);
}
__m512 vy = (vx - vmean) * vgamma * vvar + vbeta;
// (vx - vmean) * vgamma * vvar + vbeta
__m512 vy;
vx = _mm512_mask_sub_ps(vx, mask, vx, vmean);
vx = _mm512_mask_mul_ps(vx, mask, vx, vgamma);
vx = _mm512_mask_mul_ps(vx, mask, vx, vvar);
vy = _mm512_mask_add_ps(vy, mask, vx, vbeta);
_mm512_mask_storeu_ps(py + col, mask, vy);
}
}
Expand Down
12 changes: 8 additions & 4 deletions paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ void softmax_sum_max(float* AB,
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
vx = vexp(vx * vrefac - vmax);
vx = _mm512_mask_mul_ps(vx, mask, vx, vrefac);
vx = _mm512_mask_sub_ps(vx, mask, vx, vmax);
vx = vexp(vx);

_mm512_mask_storeu_ps(buf + off, mask, vx);

Expand All @@ -275,8 +277,7 @@ void softmax_sum_max(float* AB,
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
vx = vx * vrsum;

vx = _mm512_mask_mul_ps(vx, mask, vx, vrsum);
_mm512_mask_storeu_ps(buf + off, mask, vx);
}
}
Expand All @@ -301,7 +302,10 @@ void update_out_blk(float* output,
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
__m512 vout = _mm512_maskz_loadu_ps(mask, outbuf + off);
__m512 vabc = _mm512_maskz_loadu_ps(mask, buf + off);
__m512 vupt = vout * merr * vfac + vabc;
vout = _mm512_mask_mul_ps(vout, mask, vout, merr);
vout = _mm512_mask_mul_ps(vout, mask, vout, vfac);
__m512 vupt;
vupt = _mm512_mask_add_ps(vupt, mask, vout, vabc);
_mm512_mask_storeu_ps(outbuf + off, mask, vupt);
}
pre_sum[i] = sum[i];
Expand Down
4 changes: 0 additions & 4 deletions test/legacy_test/test_fused_layernorm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,6 @@ def setUp(self):
self.epsilon = 1e-5
self.residual_alpha = np.random.uniform(low=0.1, high=1.1, size=[1])

self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
self.place = paddle.CPUPlace()

def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
Expand Down

0 comments on commit 26a0f9b

Please sign in to comment.