Skip to content

Commit

Permalink
Fix nan in fused multi transformer (PaddlePaddle#44093)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Jul 6, 2022
1 parent 54a9daf commit d7f4599
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
5 changes: 4 additions & 1 deletion paddle/fluid/distributed/store/tcp_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ void MasterDaemon::CloseControlFd() {
void MasterDaemon::StopByControlFd() {
VLOG(4) << ("begin to run StopByControlFd");
if (_control_fd[1] != -1) {
::write(_control_fd[1], "\0", 1);
PADDLE_ENFORCE_NE(::write(_control_fd[1], "\0", 1),
-1,
platform::errors::Fatal(
"failed to write control pipe errno:%d", errno));
// close the write end of the pipe
::close(_control_fd[1]);
_control_fd[1] = -1;
Expand Down
66 changes: 60 additions & 6 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,52 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c;
}

template <>
inline __device__ uint32_t mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 tmp_res;
tmp_res.x = tmp.x * b;
tmp_res.y = tmp.y * b;
uint32_t res = float2_to_half2(tmp_res);
return res;
}

template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
return res;
}

template <>
inline __device__ uint4 mul(uint4 a, float b) {
uint4 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
res.z = mul<uint32_t, uint32_t, float>(a.z, b);
res.w = mul<uint32_t, uint32_t, float>(a.w, b);
return res;
}

template <>
inline __device__ float2 mul(float2 a, float b) {
float2 res;
res.x = a.x * b;
res.y = a.y * b;
return res;
}

template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
res.x = a.x * b;
res.y = a.y * b;
res.z = a.z * b;
res.w = a.w * b;
return res;
}

inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
Expand Down Expand Up @@ -445,11 +491,15 @@ inline __device__ Float8_ cast_to_float(uint4 u) {
}

template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) {
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(q[0], k[0]);
inline __device__ float qk_dot_(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
K_vec inv_q = mul<K_vec, K_vec, float>(q[0], inv_sqrt_dh);
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
inv_q = mul<K_vec, K_vec, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}

float qk = sum(qk_vec);
Expand All @@ -463,8 +513,10 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) {
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) {
return qk_dot_<THREADS_PER_KEY>(q, k);
static inline __device__ float dot(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
return qk_dot_<THREADS_PER_KEY>(q, k, inv_sqrt_dh);
}
};

Expand Down Expand Up @@ -706,7 +758,9 @@ __global__ void masked_multihead_attention_kernel(
}
}

float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k) * params.inv_sqrt_dh;
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);

// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
Expand Down

0 comments on commit d7f4599

Please sign in to comment.