Skip to content

Commit

Permalink
Revert "[KVCACHE] Improved schedule for prefill attention (#17432)"
Browse files Browse the repository at this point in the history
This reverts commit 79abc03.
  • Loading branch information
MasterJH5574 authored Oct 14, 2024
1 parent 43f6c08 commit 16cdb7c
Showing 1 changed file with 11 additions and 49 deletions.
60 changes: 11 additions & 49 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,12 +925,8 @@ def _attention_decode(

THREAD_LIMIT = 512
TILE_SIZE_PER_BDX = 2
if target.kind.name == "opencl" and (
("android" in str(target.host)) or ("adreno" in str(target.attrs))
):
# Keeping lower thread limit for this kernel on adreno target
# to avoid register spill
THREAD_LIMIT = 256
if target.kind.name == "opencl" and "android" in str(target.host):
THREAD_LIMIT = 256 if H_kv < 8 else 512
TILE_SIZE_PER_BDX = 1
max_num_threads_per_block = get_max_num_threads_per_block(target)
thread_limit = min(max_num_threads_per_block, THREAD_LIMIT)
Expand Down Expand Up @@ -1574,11 +1570,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any],

bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = (
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand All @@ -1588,12 +1580,6 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any],
tile_z = 8
num_warps = 2

if target.kind.name == "opencl" and (
("android" in str(target.host)) or ("adreno" in str(target.attrs))
):
LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes
NUM_BLKS = group_size * 8

# fmt: off
@T.prim_func
def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
Expand Down Expand Up @@ -1722,6 +1708,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
for lz, ly in T.grid(tile_z, tile_y):
with T.block("K_load"):
i, j = T.axis.remap("SS", [lz, ly])
T.reads()
T.writes()
cur_L = L_kv_start + i
if cur_L < kv_chunk_len[0]:
K_smem[i, j] = T.if_then_else(
Expand Down Expand Up @@ -1836,14 +1824,6 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
# fmt: on
# pylint: enable=line-too-long,too-many-branches
sch = tir.Schedule(batch_prefill_ragged_kv)
get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps]

def get_vecsize(extent):
return min(LOAD_VEC, (extent & ~(extent - 1)))

def getxy_vecsize(x, y, t):
assert (x * y) % t == 0
return min(get_vecsize(y), get_vecsize(x * y // t))

def get_tile_size(x, y, t):
cnt = (x * y) // t
Expand All @@ -1857,37 +1837,26 @@ def get_tile_size(x, y, t):

def apply_to_qkv_load(sch: tir.Schedule, block):
loop_x, loop_y = sch.get_loops(block)[-2:]
x_extent, y_extent = get_extent(loop_x, loop_y)
vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps)
yo, yv = sch.split(loop_y, [None, vec_size])
yo_extent = y_extent // vec_size
tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps))
xo, xi = sch.split(loop_x, [tile_x, None])
yo, yi = sch.split(yo, [tile_y, None])
sch.reorder(xi, yi, xo, yo)
t = sch.fuse(xi, yi)
ty, tx = sch.split(t, [num_warps, bdx])
loop = sch.fuse(loop_x, loop_y)
_, ty, tx, vec = sch.split(
loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(yv)
sch.vectorize(vec)

def apply_to_so_ewise(sch: tir.Schedule, block, tile):
loop_x, loop_y = sch.get_loops(block)[-2:]
xo, xi = sch.split(loop_x, factors=[None, tile[0]])
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
sch.unroll(xi)
yiv_extent = get_vecsize(tile[1])
yio, yiv = sch.split(yi, [None, yiv_extent])
sch.unroll(yio)
sch.vectorize(yiv)
t = sch.fuse(xo, yo)
ty, tx = sch.split(t, factors=[None, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

def apply_to_gemm( # pylint: disable=unused-argument
sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False
sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False
):
loop_x, loop_y, loop_z = sch.get_loops(block)[-3:]
xo, xi = sch.split(loop_x, factors=[None, tile[0]])
Expand All @@ -1903,12 +1872,6 @@ def apply_to_gemm( # pylint: disable=unused-argument
sch.reorder(ko, xi, yi, ki)
else:
sch.reorder(ko, ki, xi, yi)
yiv_extent = get_vecsize(tile[1])
yio, yiv = sch.split(yi, [None, yiv_extent])
sch.unroll(yio)
sch.vectorize(yiv)
sch.unroll(xi)
sch.unroll(ki)
sch.decompose_reduction(block, ty)

def apply_to_md(sch, block):
Expand All @@ -1917,7 +1880,6 @@ def apply_to_md(sch, block):
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i))
tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps)
tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps)
apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
Expand Down

0 comments on commit 16cdb7c

Please sign in to comment.