Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Jul 5, 2022
1 parent 7ef44da commit 849deff
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 31 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class CuSparseDnMatDescriptor {

PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N);
if (batch_size > 1) {
#if CUDA_VERSION >= 11030
#if CUDA_VERSION >= 11070
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDnMatSetStridedBatch(
descriptor_, batch_size, M * N);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
DenseTensor* dkey,
DenseTensor* dvalue) {
PD_THROW(
"Only support 'fused_attention' CPU backward kernel of SparseTensor now");
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
}

} // namespace sparse
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
PD_THROW("Only support 'fused_attention' CPU kernel of SparseTensor now");
PD_THROW(
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
}

} // namespace sparse
Expand Down
23 changes: 13 additions & 10 deletions paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,30 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
T* dx_values,
int M,
int total_row_num,
float scale) {
float scale,
int batch_nnz) {
// dx = (dout - sum(dout * out)) * out
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= total_row_num) return;

int cur_batch = row / M;
int crow_idx = cur_batch * (M + 1) + (row % M);
int cur_batch_offset = 0;
for (int i = 1; i < cur_batch + 1; ++i) {
cur_batch_offset += static_cast<int>(out_crows[i * (M + 1) - 1]);
}
int row_first = cur_batch_offset + static_cast<int>(out_crows[crow_idx]);
int row_first = cur_batch * batch_nnz + static_cast<int>(out_crows[crow_idx]);
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return;

int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) *
Expand Down Expand Up @@ -88,11 +85,16 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
auto q_rank = q_dim.size();

int total_row_num = 1;
int batch_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
if (i < q_rank - 2) {
batch_num *= q_dim[i];
}
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];
int batch_nnz = softmax.nnz() / batch_num;

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
Expand All @@ -104,7 +106,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
d_sdd_result.mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N));
std::sqrt(N),
batch_nnz);

/* Step3: Forward: query{Dense} * key'{Dense} -> sdd_result{SparseCsr} */
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
Expand Down
47 changes: 36 additions & 11 deletions paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,16 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
int M,
int total_row_num,
float scale,
int num_heads) {
int num_heads,
int batch_nnz) {
// out = exp(x-x_max) / sum(exp(x-x_max))
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= total_row_num) return;

int cur_batch = row / M;
int cur_row = row % M;
int crow_idx = cur_batch * (M + 1) + cur_row;
int cur_batch_offset = 0;
for (int i = 1; i < cur_batch + 1; ++i) {
cur_batch_offset += static_cast<int>(x_crows[i * (M + 1) - 1]);
}
int row_first = cur_batch_offset + static_cast<int>(x_crows[crow_idx]);
int row_first = cur_batch * batch_nnz + static_cast<int>(x_crows[crow_idx]);
int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
if (row_nnz == 0) return;

Expand All @@ -81,7 +78,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) {
bool mask = false;
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

int col_idx = static_cast<int>(x_cols[row_first + idx]);
Expand All @@ -106,7 +103,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
Expand All @@ -118,7 +115,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
Expand All @@ -145,8 +142,12 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
auto q_rank = q_dim.size();

int total_row_num = 1;
int batch_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
if (i < q_rank - 2) {
batch_num *= q_dim[i];
}
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];
Expand All @@ -161,6 +162,27 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
4,
phi::errors::InvalidArgument(" 'value' must be 4D Tensor"));

PADDLE_ENFORCE_EQ(
sparse_mask.dims().size(),
3,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[0],
q_dim[0] * q_dim[1],
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[1],
M,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[2],
M,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));

PADDLE_ENFORCE_EQ(
key_padding_mask.dims().size(),
2,
Expand Down Expand Up @@ -215,6 +237,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);

int batch_nnz = sdd_result.nnz() / batch_num;

VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] {
AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
Expand All @@ -226,7 +250,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
M,
total_row_num,
std::sqrt(N),
q_dim[1]);
q_dim[1],
batch_nnz);
});

/* Step3: DSD Matmul, reuse */
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_cuda_version():


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
not core.is_compiled_with_cuda() or get_cuda_version() < 11070,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
)
class TestSparseAttentionAPI1(unittest.TestCase):
Expand Down
9 changes: 4 additions & 5 deletions python/paddle/incubate/sparse/nn/functional/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,18 @@ def attention(query,
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
result = softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The shape of the three parameters are:
The dimensions of these three parameters are: [batch_size, num_heads, seq_len, head_dim].
The shape of the three parameters are: `[batch_size, num_heads, seq_len, head_dim]`, and
``d`` represents ``head_dim`` .
Args:
query(DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64.
key(DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64.
value(DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64.
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. shape of `crows` is
[batch_size, num_heads, seq_len + 1], shape of `cols` is [batch_size, num_heads, nnz].
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.
Expand Down

1 comment on commit 849deff

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.