Skip to content

Commit f94bdcb

Browse files
authored
[0-size Tensor No.87、214]Add 0-size Tensor support for paddle.incubate.nn.functional. fused_rotary_position_embedding 、scaled_dot_product_attention (#74323)
* fix fused rope and VLM attention * fix flash attention * add unittest case * fix bug and add unittest case
1 parent 6b610ee commit f94bdcb

File tree

8 files changed

+122
-13
lines changed

8 files changed

+122
-13
lines changed

paddle/phi/infermeta/fusion.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4508,6 +4508,20 @@ void VariableLengthMemoryEfficientAttentionInferMeta(
45084508
true,
45094509
common::errors::InvalidArgument(
45104510
"The seq length of Key, Value should be equal."));
4511+
if (mask) {
4512+
PADDLE_ENFORCE_EQ(
4513+
mask.dims().size(),
4514+
4,
4515+
common::errors::InvalidArgument("Mask should be a 4-D tensor"
4516+
"But received Value dimension(%s)",
4517+
mask.dims().size()));
4518+
const int64_t mask_batch_size = mask.dims()[0];
4519+
PADDLE_ENFORCE_EQ(
4520+
query_batch_size == mask_batch_size,
4521+
true,
4522+
common::errors::InvalidArgument(
4523+
"The batch size of Query, Key, Value and Mask should be equal."));
4524+
}
45114525

45124526
std::vector<int64_t> out_dims(
45134527
{query_batch_size, query_num_head, query_seq_length, value_head_size});

paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention_kernel.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ void MultiHeadAttentionVariableForwardKernel(
6767
params.causal = causal;
6868
params.pre_cache_length = pre_cache_length;
6969

70-
if (mask) {
70+
// if the mask is 0-size tensor, we don't need to set mask_ptr
71+
if (mask && mask.get().numel() > 0) {
7172
// [B, 1, S, D]
7273
auto mask_tensor = mask.get();
7374
int64_t mask_num_heads = mask_tensor.dims()[1];

paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ void FusedRopeGradKernel(const Context& dev_ctx,
3838
DenseTensor* dk,
3939
DenseTensor* dv) {
4040
int64_t numel = dout_q.numel();
41-
if (numel <= 0) return;
4241
dev_ctx.template Alloc<T>(dq);
42+
if (dout_k) dev_ctx.template Alloc<T>(dk);
43+
if (dout_v) dev_ctx.template Alloc<T>(dv);
44+
if (numel <= 0) return;
4345

4446
phi::Array<int64_t, 3> inputs_num_heads;
4547
// small size for broadcast
@@ -70,22 +72,19 @@ void FusedRopeGradKernel(const Context& dev_ctx,
7072
outs_data[0] = dq->data<T>();
7173
int num_inputs = 1;
7274

73-
if (dout_k) {
74-
dev_ctx.template Alloc<T>(dk);
75+
if (dk && dk->numel() > 0) {
7576
outs_data[num_inputs] = dk->data<T>();
7677
ins_data[num_inputs] = dout_k->data<T>();
7778
inputs_num_heads[num_inputs] = dk->dims()[2];
7879
num_inputs++;
7980
}
8081

81-
if (dout_v) {
82-
dev_ctx.template Alloc<T>(dv);
82+
if (dv && dv->numel() > 0) {
8383
outs_data[num_inputs] = dv->data<T>();
8484
ins_data[num_inputs] = dout_v->data<T>();
8585
inputs_num_heads[num_inputs] = dv->dims()[2];
8686
num_inputs++;
8787
}
88-
8988
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
9089
MPType div_c = static_cast<MPType>(1.0f / head_dim);
9190

paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ void FusedRopeKernel(const Context& dev_ctx,
3838
DenseTensor* out_k,
3939
DenseTensor* out_v) {
4040
int64_t numel = q.numel();
41-
if (numel <= 0) return;
4241
dev_ctx.template Alloc<T>(out_q);
42+
if (k) dev_ctx.template Alloc<T>(out_k);
43+
if (v) dev_ctx.template Alloc<T>(out_v);
44+
if (numel <= 0) return;
4345

4446
phi::Array<int64_t, 3> inputs_num_heads;
4547

@@ -73,16 +75,13 @@ void FusedRopeKernel(const Context& dev_ctx,
7375
outs_data[0] = out_q->data<T>();
7476
int num_inputs = 1;
7577

76-
if (k) {
77-
dev_ctx.template Alloc<T>(out_k);
78+
if (out_k && out_k->numel() > 0) {
7879
ins_data[num_inputs] = k->data<T>();
7980
outs_data[num_inputs] = out_k->data<T>();
8081
inputs_num_heads[num_inputs] = k->dims()[2];
8182
num_inputs++;
8283
}
83-
84-
if (v) {
85-
dev_ctx.template Alloc<T>(out_v);
84+
if (out_v && out_v->numel() > 0) {
8685
ins_data[num_inputs] = v->data<T>();
8786
outs_data[num_inputs] = out_v->data<T>();
8887
inputs_num_heads[num_inputs] = v->dims()[2];

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,18 @@ void FlashAttnGradKernel(const Context& dev_ctx,
927927
if (dv) {
928928
dev_ctx.template Alloc<T>(dv);
929929
}
930+
if (dout.numel() == 0) {
931+
if (dq)
932+
Full<T, Context>(
933+
dev_ctx, phi::IntArray(common::vectorize(dq->dims())), 0, dq);
934+
if (dk)
935+
Full<T, Context>(
936+
dev_ctx, phi::IntArray(common::vectorize(dk->dims())), 0, dk);
937+
if (dv)
938+
Full<T, Context>(
939+
dev_ctx, phi::IntArray(common::vectorize(dv->dims())), 0, dv);
940+
return;
941+
}
930942
FlashAttnGradBaseKernel<T, Context>(dev_ctx,
931943
q,
932944
k,

paddle/phi/kernels/gpu/flash_attn_kernel.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,31 @@ void FlashAttnKernel(const Context& dev_ctx,
633633
DenseTensor* softmax,
634634
DenseTensor* softmax_lse,
635635
DenseTensor* seed_offset) {
636+
if (q.numel() == 0 || k.numel() == 0 || v.numel() == 0) {
637+
if (out) {
638+
Full<T, Context>(
639+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
640+
}
641+
if (softmax) {
642+
Full<T, Context>(dev_ctx,
643+
phi::IntArray(common::vectorize(softmax->dims())),
644+
0,
645+
softmax);
646+
}
647+
if (softmax_lse) {
648+
Full<T, Context>(dev_ctx,
649+
phi::IntArray(common::vectorize(softmax_lse->dims())),
650+
0,
651+
softmax_lse);
652+
}
653+
if (seed_offset) {
654+
Full<T, Context>(dev_ctx,
655+
phi::IntArray(common::vectorize(seed_offset->dims())),
656+
0,
657+
seed_offset);
658+
}
659+
return;
660+
}
636661
FlashAttnBaseKernel<T, Context>(dev_ctx,
637662
q,
638663
k,

test/legacy_test/test_fused_rotary_position_embedding.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,5 +692,55 @@ def test_error2():
692692
self.assertRaises(AssertionError, test_error2)
693693

694694

695+
@unittest.skipIf(
696+
not core.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(),
697+
"core is not compiled with CUDA or ROCM ",
698+
)
699+
class TestFusedRotaryPositionEmbeddingZeroSize(unittest.TestCase):
700+
def setUp(self):
701+
self.dtype = "float32"
702+
self.qkv_shape = [0, 1, 8, 8]
703+
self.sin_cos_shape = [1, 1, 1, 8]
704+
705+
def init_data(self):
706+
self.q = paddle.randn(self.qkv_shape, dtype=self.dtype)
707+
self.k = paddle.randn(self.qkv_shape, dtype=self.dtype)
708+
self.v = paddle.randn(self.qkv_shape, dtype=self.dtype)
709+
self.q.stop_gradient = False
710+
self.k.stop_gradient = False
711+
self.v.stop_gradient = False
712+
self.sin = paddle.sin(
713+
paddle.randn(self.sin_cos_shape, dtype=self.dtype)
714+
)
715+
self.cos = paddle.cos(
716+
paddle.randn(self.sin_cos_shape, dtype=self.dtype)
717+
)
718+
719+
def _test_forward_backward(self):
720+
out_q, out_k, out_v = fused_rotary_position_embedding(
721+
self.q,
722+
self.k,
723+
self.v,
724+
sin=self.sin,
725+
cos=self.cos,
726+
use_neox_rotary_style=False,
727+
)
728+
out = out_q + out_k + out_v
729+
out.backward()
730+
np.testing.assert_allclose(
731+
self.q.shape, self.q.grad.shape, rtol=1e-05, atol=1e-06
732+
)
733+
np.testing.assert_allclose(
734+
self.k.shape, self.k.grad.shape, rtol=1e-05, atol=1e-06
735+
)
736+
np.testing.assert_allclose(
737+
self.v.shape, self.v.grad.shape, rtol=1e-05, atol=1e-06
738+
)
739+
740+
def test_zero_size(self):
741+
self.init_data()
742+
self._test_forward_backward()
743+
744+
695745
if __name__ == "__main__":
696746
unittest.main()

test/legacy_test/test_scaled_dot_product_attention.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,5 +220,14 @@ def test_3d_input(self):
220220
np.testing.assert_allclose(out.numpy(), out_ref, rtol=5e-03, atol=1e-03)
221221

222222

223+
class TestAttentionWithBoolMaskZeroSize(TestAttentionWithBoolMask):
224+
def setUp(self):
225+
self.place = paddle.CUDAPlace(0)
226+
self.shape = (0, 1, 8, 8)
227+
self.dtype = 'float32'
228+
self.dropout = 0.0
229+
self.causal = False
230+
231+
223232
if __name__ == '__main__':
224233
unittest.main()

0 commit comments

Comments
 (0)