From 5e191192ff2fe46fc23a258a666d4a6fc3a165ad Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 3 Mar 2025 18:11:33 +0800 Subject: [PATCH] [relax] Fix tree attention for Qwen2-1.5 models Fix the compilation error for Qwen2-1.5 models in the tree attention implementation for vulkan backend. --- python/tvm/relax/frontend/nn/llm/tree_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 36a6e2dab84a..3a666fb291e0 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -425,8 +425,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) if T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x + b_idx: T.int32(is_size_var=True) = batch_idx[0] + LH_start: T.int32(is_size_var=True) = tile_id[0] * tile_x q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] @@ -1049,8 +1049,8 @@ def tree_attn_paged_kv( batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) if T.tvm_thread_invariant(batch_idx[0] < batch_size): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x + b_idx: T.int32(is_size_var=True) = batch_idx[0] + LH_start: T.int32(is_size_var=True) = tile_id[0] * tile_x q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx]