diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index ae0537f0d9af..9b16fc2fbfee 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -1237,7 +1237,7 @@ def merge_state_inplace( def _attention_sequence_prefill( - batch_size, h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 + h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 ): # pylint: disable=line-too-long LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -1264,6 +1264,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches var_output: T.handle, # [total_len, h_q, d] var_lse: T.handle # [total_len, h_q] ): + batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) kv_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)