File tree Expand file tree Collapse file tree 1 file changed +25
-2
lines changed
paddle/phi/kernels/fusion/gpu Expand file tree Collapse file tree 1 file changed +25
-2
lines changed Original file line number Diff line number Diff line change @@ -75,13 +75,36 @@ void FusedRopeKernel(const Context& dev_ctx,
7575 outs_data[0 ] = out_q->data <T>();
7676 int num_inputs = 1 ;
7777
78- if (out_k && out_k->numel () > 0 ) {
78+ if (k) {
79+ auto k_dims = k->dims ();
80+ auto k_batch_size = time_major ? k_dims[1 ] : k_dims[0 ];
81+ PADDLE_ENFORCE_LE (
82+ batch_size,
83+ k_batch_size,
84+ common::errors::InvalidArgument (" The batch_size of q (%d) must be less "
85+ " than or equal to k's (%d)." ,
86+ batch_size,
87+ k_batch_size));
88+
89+ dev_ctx.template Alloc <T>(out_k);
7990 ins_data[num_inputs] = k->data <T>();
8091 outs_data[num_inputs] = out_k->data <T>();
8192 inputs_num_heads[num_inputs] = k->dims ()[2 ];
8293 num_inputs++;
8394 }
84- if (out_v && out_v->numel () > 0 ) {
95+
96+ if (v) {
97+ auto v_dims = v->dims ();
98+ auto v_batch_size = time_major ? v_dims[1 ] : v_dims[0 ];
99+ PADDLE_ENFORCE_LE (
100+ batch_size,
101+ v_batch_size,
102+ common::errors::InvalidArgument (" The batch_size of q (%d) must be less "
103+ " than or equal to v's (%d)." ,
104+ batch_size,
105+ v_batch_size));
106+
107+ dev_ctx.template Alloc <T>(out_v);
85108 ins_data[num_inputs] = v->data <T>();
86109 outs_data[num_inputs] = out_v->data <T>();
87110 inputs_num_heads[num_inputs] = v->dims ()[2 ];
You can’t perform that action at this time.
0 commit comments