Skip to content

Commit

Permalink
fix zero bug; fix num_splits (PaddlePaddle#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuizhiqing authored Feb 13, 2023
1 parent 5994ce0 commit f0edf24
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions csrc/flash_attn/flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,14 @@ bool flash_attn_bwd(
void *dq_tmp_ptr = workspace_ptr;
// nullptr out to calculate workspace size
if (out == nullptr) {
if (loop || num_splits > 1) {
*workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float);
} else {
// There are two cases no need to allocate workspace:
// 1) num_splits == 1
// 2) num_splits == 0 for auto calculation, result to num_splits == 1
// we do allocation for case 2 for simplicity
if (num_splits == 1) {
*workspace_size = 0;
} else {
*workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float);
}
return true;
}
Expand Down Expand Up @@ -462,14 +466,13 @@ bool flash_attn_bwd(
is_bf16,
num_splits);

// calculate and set params.num_splits if num_splits == 0
launch(params, stream, /*configure=*/true);

if (params.num_splits > 1) {
SetZero(dq_tmp_ptr, 4, {total_q, num_heads, head_size}, stream);
if (!loop) {
SetZero(dq_tmp_ptr, 4, {total_q, num_heads, head_size}, stream);
params.o_tmp_ptr = dq_tmp_ptr; // o_tmp stores dq_tmp in the backward pass
} else {
SetZero(dq_tmp_ptr, 4, {total_q, num_heads, head_size}, stream);
}
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static __global__ void _float2half(float *float_ptr, __half *half_ptr, size_t n)
void Float2Half(void *float_ptr, void *half_ptr, size_t n, cudaStream_t stream) {
constexpr auto kNumThreads = 1024;
auto block = (n + kNumThreads - 1) / kNumThreads;
_float2half<<<block, kNumThreads, 0, stream>>>(static_cast<float *>(float_ptr), static_cast<__half *>(float_ptr), n);
_float2half<<<block, kNumThreads, 0, stream>>>(static_cast<float *>(float_ptr), static_cast<__half *>(half_ptr), n);
}

static __global__ void _float2bfloat16(float *float_ptr, __nv_bfloat16 *bf16_ptr, size_t n) {
Expand Down

0 comments on commit f0edf24

Please sign in to comment.