From a6c697b90d4d821f5f15c33db1c035fc0ad767f9 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Tue, 10 Sep 2024 22:40:52 -0400 Subject: [PATCH] [Relax][KV Cache] Refactor `_attention_sequence_prefill` function to handle dynamic `batch_size` in TIR This PR removes `batch_size` from the function signature, instead mapping it within the function body. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index ae0537f0d9af..9b16fc2fbfee 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -1237,7 +1237,7 @@ def merge_state_inplace( def _attention_sequence_prefill( - batch_size, h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 + h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 ): # pylint: disable=line-too-long LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -1264,6 +1264,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches var_output: T.handle, # [total_len, h_q, d] var_lse: T.handle # [total_len, h_q] ): + batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) kv_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)