diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 36a6e2dab84a..3a666fb291e0 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -425,8 +425,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) if T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x + b_idx: T.int32(is_size_var=True) = batch_idx[0] + LH_start: T.int32(is_size_var=True) = tile_id[0] * tile_x q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] @@ -1049,8 +1049,8 @@ def tree_attn_paged_kv( batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) if T.tvm_thread_invariant(batch_idx[0] < batch_size): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x + b_idx: T.int32(is_size_var=True) = batch_idx[0] + LH_start: T.int32(is_size_var=True) = tile_id[0] * tile_x q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx]