Skip to content

Commit

Permalink
[Runtime] Support KV cache with RoPE extension factor array
Browse files Browse the repository at this point in the history
This PR enhances the KV cache with the RoPE extensio factor support.
With this PR, the KV cache can support models like Phi3.5 which comes
with the extension factor.
  • Loading branch information
MasterJH5574 committed Aug 23, 2024
1 parent 8db545d commit 18fc6b4
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class AttentionKVCacheObj : public KVStateObj {
* `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`.
* \param mask The input mask data, in layout `(total_sqr_length)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
* \sa AttentionKVCache::Attention
*/
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
Expand Down
63 changes: 38 additions & 25 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
const double rotary_scale_;
/*! \brief The RoPE theta. */
const double rotary_theta_;
/*! \brief The optional RoPE extension factors for RoPE scaling. */
const Optional<NDArray> rope_ext_factors_;

/*! \brief We fix int32 to be the index dtype of auxiliary data. */
const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1));
Expand Down Expand Up @@ -988,7 +990,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, //
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs,
int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window,
RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device,
RoPEMode rope_mode, double rotary_scale, double rotary_theta,
Optional<NDArray> rope_ext_factors, DLDataType dtype, Device device,
PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill,
PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged,
Expand All @@ -1013,6 +1016,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
: rope_mode),
rotary_scale_(rotary_scale),
rotary_theta_(rotary_theta),
rope_ext_factors_(std::move(rope_ext_factors)),
f_transpose_append_(std::move(f_transpose_append)),
f_compact_copy_(std::move(f_compact_copy)),
f_attention_prefill_(std::move(f_attention_prefill)),
Expand Down Expand Up @@ -1132,6 +1136,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
preferred_host_device, copy_stream_);
}

// Right now only the "normal" RoPE mode supports the RoPE extention factors.
if (rope_ext_factors_.defined()) {
CHECK(rope_mode_ == RoPEMode::kNormal)
<< "The RoPE mode must be normal to support RoPE extension factors.";
}
}

~PagedAttentionKVCacheObj() {
Expand Down Expand Up @@ -1726,8 +1736,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_},
qkv_data->dtype);
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
if (!rope_ext_factors_.defined()) {
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
} else {
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
rope_ext_factors_.value());
}

// Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
if (append_before_attn_) {
Expand Down Expand Up @@ -2462,7 +2477,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
CHECK(args.size() == 27 || args.size() == 28)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
ShapeTuple layer_indptr_tuple = args[1];
Expand Down Expand Up @@ -2499,14 +2514,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
PackedFunc f_split_rotary = args[22];
PackedFunc f_copy_single_page = args[23];
Optional<PackedFunc> f_debug_get_kv = args[24];
PackedFunc f_compact_copy{nullptr};
PackedFunc f_attention_prefill_with_tree_mask{nullptr};
PackedFunc f_compact_copy = args[25];
PackedFunc f_attention_prefill_with_tree_mask = args[26];
Optional<NDArray> rope_ext_factors = NullOpt;

if (args.size() >= 26) {
f_compact_copy = args[25].AsObjectRef<PackedFunc>();
}
if (args.size() >= 27) {
f_attention_prefill_with_tree_mask = args[26].AsObjectRef<PackedFunc>();
if (args.size() >= 28 && args[27].IsObjectRef<NDArray>()) {
rope_ext_factors = args[27].AsObjectRef<NDArray>();
}

CHECK_EQ(cache_config.size(), 5);
Expand All @@ -2523,9 +2536,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim,
reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window,
RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window),
RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), //
init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
std::move(f_attention_prefill_ragged_begin_forward),
Expand All @@ -2539,7 +2553,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
CHECK(args.size() == 21 || args.size() == 22)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
ShapeTuple layer_indptr_tuple = args[1];
Expand Down Expand Up @@ -2570,14 +2584,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
PackedFunc f_split_rotary = args[16];
PackedFunc f_copy_single_page = args[17];
Optional<PackedFunc> f_debug_get_kv = args[18];
PackedFunc f_compact_copy{nullptr};
PackedFunc f_attention_prefill_with_tree_mask{nullptr};
PackedFunc f_compact_copy = args[19];
PackedFunc f_attention_prefill_with_tree_mask = args[20];
Optional<NDArray> rope_ext_factors = NullOpt;

if (args.size() >= 20) {
f_compact_copy = args[19].AsObjectRef<PackedFunc>();
}
if (args.size() >= 21) {
f_attention_prefill_with_tree_mask = args[20].AsObjectRef<PackedFunc>();
if (args.size() >= 22 && args[21].IsObjectRef<NDArray>()) {
rope_ext_factors = args[21].AsObjectRef<NDArray>();
}

CHECK_EQ(cache_config.size(), 5);
Expand All @@ -2594,9 +2606,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim,
reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window,
RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window),
RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), //
init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def create_kv_cache(rope_mode):
fsplit_rotary,
fcopy_single_page,
fcopy_cache,
None,
None,
None,
)
return cache

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
fcopy_cache,
fcompact_copy,
fattn_prefill_with_tree_mask,
None,
)
return cache

Expand Down

0 comments on commit 18fc6b4

Please sign in to comment.