diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 9b16fc2fbfee..fd866ae06c16 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -925,8 +925,12 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 256 if H_kv < 8 else 512 + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + # Keeping lower thread limit for this kernel on adreno target + # to avoid register spill + THREAD_LIMIT = 256 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -1570,7 +1574,11 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = ( + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + d, + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + ) # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1580,6 +1588,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], tile_z = 8 num_warps = 2 + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes + NUM_BLKS = group_size * 8 + # fmt: off @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches @@ -1708,8 +1722,6 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for lz, ly in T.grid(tile_z, tile_y): with T.block("K_load"): i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( @@ -1824,6 +1836,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) + get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] + + def get_vecsize(extent): + return min(LOAD_VEC, (extent & ~(extent - 1))) + + def getxy_vecsize(x, y, t): + assert (x * y) % t == 0 + return min(get_vecsize(y), get_vecsize(x * y // t)) def get_tile_size(x, y, t): cnt = (x * y) // t @@ -1837,26 +1857,37 @@ def get_tile_size(x, y, t): def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) + x_extent, y_extent = get_extent(loop_x, loop_y) + vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) + yo, yv = sch.split(loop_y, [None, vec_size]) + yo_extent = y_extent // vec_size + tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) + xo, xi = sch.split(loop_x, [tile_x, None]) + yo, yi = sch.split(yo, [tile_y, None]) + sch.reorder(xi, yi, xo, yo) + t = sch.fuse(xi, yi) + ty, tx = sch.split(t, [num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) + sch.vectorize(yv) def apply_to_so_ewise(sch: tir.Schedule, block, tile): loop_x, loop_y = sch.get_loops(block)[-2:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) + sch.unroll(xi) + yiv_extent = get_vecsize(tile[1]) + yio, yiv = sch.split(yi, [None, yiv_extent]) + sch.unroll(yio) + sch.vectorize(yiv) t = sch.fuse(xo, yo) ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) @@ -1872,6 +1903,12 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.reorder(ko, xi, yi, ki) else: sch.reorder(ko, ki, xi, yi) + yiv_extent = get_vecsize(tile[1]) + yio, yiv = sch.split(yi, [None, yiv_extent]) + sch.unroll(yio) + sch.vectorize(yiv) + sch.unroll(xi) + sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block): @@ -1880,6 +1917,7 @@ def apply_to_md(sch, block): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") + sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)