diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 8809a1b0729e..876be9395cab 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1750,7 +1750,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_EQ(total_seq_length, qkv_data->shape[0]); + CHECK_LE(total_seq_length, qkv_data->shape[0]); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. @@ -1762,12 +1762,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qkv_data->dtype); NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); + + NDArray qkv_data_view = qkv_data; + NDArray o_data_view = o_data; + if (total_seq_length != qkv_data->shape[0]) { + qkv_data_view = qkv_data.CreateView( + {total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, qkv_data->dtype); + o_data_view = + o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_}, qkv_data->dtype); + } // Part 2. Split fused qkv and apply rotary embedding to q/k data. if (!rope_ext_factors_.defined()) { - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data, static_cast(rope_mode_ == RoPEMode::kNormal)); } else { - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data, rope_ext_factors_.value()); } @@ -1776,7 +1785,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } // Part 4: perform attention - AttentionInternal(layer_id, q_data, k_data, v_data, o_data, attn_score_scaling_factor); + AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, attn_score_scaling_factor); // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set. if (!append_before_attn_) { f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);