Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,69 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) = 0;

/*!
* \brief Compute attention with Q/K/V data.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`
* \param mask The input mask data, in layout `(total_sqr_length)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data, Optional<NDArray> mask, NDArray o_data,
double attn_score_scaling_factor) = 0;

/*!
* \brief Compute multi-head latent attention after applying weight absorption.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, qk_head_dim)`
* \param compressed_kv_data The compressed latent KV data, in layout
* `(total_length, num_kv_heads, kv_lora_rank)`
* \param k_pe_data The positional embedding part of K data, in layout
* `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim`
* equals qk_head_dim
* \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data,
NDArray k_pe_data, NDArray o_data, double attn_score_scaling_factor) = 0;

/*!
* \brief Compute multi-head latent attention in normal style.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout
* `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
* \param k_data The input K data, in layout
* `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
* \param v_data The input V data, in layout
* `(total_length, num_qo_heads, v_head_dim)`
* \param compressed_kv_data The compressed latent KV data, in layout
* `(total_length, num_kv_heads, kv_lora_rank)`
* \param k_pe_data The positional embedding part of K data, in layout
* `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim`
* equals qk_head_dim
* \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
double attn_score_scaling_factor) = 0;

/*!
* \brief Compute linear attention with Q/K/V data.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`.
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
* \sa AttentionKVCache::Attention
*/
virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
double attn_score_scaling_factor) = 0;

/************** Positions **************/

/*!
Expand Down
Loading
Loading