Skip to content

Commit 8e5cba3

Browse files
Fix paddle.incubate.nn.functional.fused_rotary_position_embedding illegal memory error (#74347)
* error 700:fused_rotary_position_embedding test=develop * error 700:fused_rotary_position_embedding test=develop * error 700:fused_rotary_position_embedding test=develop
1 parent 5b8c1e0 commit 8e5cba3

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,36 @@ void FusedRopeKernel(const Context& dev_ctx,
7575
outs_data[0] = out_q->data<T>();
7676
int num_inputs = 1;
7777

78-
if (out_k && out_k->numel() > 0) {
78+
if (k) {
79+
auto k_dims = k->dims();
80+
auto k_batch_size = time_major ? k_dims[1] : k_dims[0];
81+
PADDLE_ENFORCE_LE(
82+
batch_size,
83+
k_batch_size,
84+
common::errors::InvalidArgument("The batch_size of q (%d) must be less "
85+
"than or equal to k's (%d).",
86+
batch_size,
87+
k_batch_size));
88+
89+
dev_ctx.template Alloc<T>(out_k);
7990
ins_data[num_inputs] = k->data<T>();
8091
outs_data[num_inputs] = out_k->data<T>();
8192
inputs_num_heads[num_inputs] = k->dims()[2];
8293
num_inputs++;
8394
}
84-
if (out_v && out_v->numel() > 0) {
95+
96+
if (v) {
97+
auto v_dims = v->dims();
98+
auto v_batch_size = time_major ? v_dims[1] : v_dims[0];
99+
PADDLE_ENFORCE_LE(
100+
batch_size,
101+
v_batch_size,
102+
common::errors::InvalidArgument("The batch_size of q (%d) must be less "
103+
"than or equal to v's (%d).",
104+
batch_size,
105+
v_batch_size));
106+
107+
dev_ctx.template Alloc<T>(out_v);
85108
ins_data[num_inputs] = v->data<T>();
86109
outs_data[num_inputs] = out_v->data<T>();
87110
inputs_num_heads[num_inputs] = v->dims()[2];

0 commit comments

Comments
 (0)