Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DLIGHT][GPU] Improve matmul schedule for adreno #17430

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 61 additions & 47 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.tir import IterVar, PrimExpr, Var
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
from tvm.script import tir as T

from ..base import analysis, BlockInfo, IterInfo
from .base import GPUScheduleRule
Expand Down Expand Up @@ -945,14 +946,14 @@ def get_configs(self, target: Target) -> Config:
):
return Matmul.Config(
block_size_x=32,
block_size_y=8,
block_size_y=4,
vthread_x=1,
vthread_y=1,
micro_size_x=8,
micro_size_y=2,
micro_size_k=16,
vector_size=8,
unroll=4,
unroll=16,
use_shared=False,
storage_align=False,
inner_x=True,
Expand Down Expand Up @@ -1147,7 +1148,7 @@ def get_max_factor(n, factors):
if not (
isinstance(sch.get(n).extent, tir.IntImm)
and isinstance(sch.get(mb).extent, tir.IntImm)
and isinstance(sch.get(ms).extent, tir.Var)
and not isinstance(sch.get(ms).extent, tir.IntImm)
):
return None

Expand All @@ -1157,6 +1158,7 @@ def get_max_factor(n, factors):
config.vector_size,
config.unroll,
)

VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize)
dequant_block = None
matmul_block = reduction_block
Expand All @@ -1169,61 +1171,73 @@ def get_max_factor(n, factors):
elif blk is not matmul_block:
sch.compute_inline(blk)

m = sch.fuse(mb, ms)

sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1])

rmat_block, wmat_block = (
block = sch.reindex(reduction_block, ("read", 0))
sch.pad_einsum(reduction_block, [1, Unroll_M, 1, 1])
sch.compute_inline(block)
trans_block, matmul_reindex = (
sch.get_producers(matmul_block)[0],
sch.get_consumers(matmul_block)[0],
)
mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M])
no, ni, nv = sch.split(n, [None, Threads_X, VecSize])
k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8])
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)

sch.compute_at(rmat_block, k0)
if dequant_block is not None:
sch.compute_at(dequant_block, k3)
sch.reverse_compute_at(wmat_block, mi)
sch.set_scope(rmat_block, 0, "shared")
sch.set_scope(matmul_block, 0, "local")
if epilogue_block is not None:
sch.compute_inline(matmul_reindex)
matmul_reindex = epilogue_block

if dequant_block is not None:
sch.set_scope(dequant_block, 0, "local")
sch.transform_layout(
trans_block,
("write", 0),
T.index_map(lambda i0, i1, i2: (i0, i1 // Unroll_M, i2, i1 % Unroll_M)),
)

sch.bind(mo, "blockIdx.y")
sch.bind(no, "blockIdx.x")
sch.bind(mi, "threadIdx.y")
sch.bind(ni, "threadIdx.x")
sch.vectorize(sch.get_loops(matmul_block)[-1])
# transpose block schedules
# sch.set_scope(trans_block, 0, "global.texture-1d")
tb, tn, tk = sch.get_loops(trans_block)
tbx, ttx = sch.split(tk, [None, Threads_X])
tby, tty, tc = sch.split(tn, [None, Threads_Y, Unroll_M])
sch.bind(tb, "blockIdx.z")
sch.bind(tby, "blockIdx.y")
sch.bind(tbx, "blockIdx.x")
sch.bind(tty, "threadIdx.y")
sch.bind(ttx, "threadIdx.x")
sch.reorder(tb, tby, tbx, tty, ttx, tc)
sch.vectorize(tc)

mb, ms, n, k = sch.get_loops(matmul_block)
m = sch.fuse(mb, ms)
bx, tx, vec = sch.split(n, [None, Threads_X, VecSize])
by, ty, unr = sch.split(m, [None, Threads_Y, Unroll_M])
k1, k2, k3 = sch.split(k, [None, 4, 8])
sch.reorder(bx, by, tx, ty, k1, k2, k3, unr, vec)
sch.set_scope(matmul_block, 0, "local")
if dequant_block is not None:
sch.vectorize(sch.get_loops(dequant_block)[-1])
sch.compute_at(dequant_block, k3)
sch.set_scope(dequant_block, 0, "local")
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)

# Co-operative Memory Fetch
ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize])
sch.bind(ro, "threadIdx.x")
sch.vectorize(rv)
inp = sch.cache_read(matmul_block, read_buffer_index=0, storage_scope="local")
sch.compute_at(inp, k3, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(inp)[-1])

wv = sch.get_loops(wmat_block)[-1]
sch.vectorize(wv)
sch.unroll(unr)
sch.unroll(k3)

# Scale and Quant Cache
if dequant_block is not None:
qb = sch.cache_read(dequant_block, 0, "local")
sb = sch.cache_read(dequant_block, 1, "local")
sch.compute_at(sb, k1)
sch.compute_at(qb, k2)
sch.set_scope(sb, 0, "local")
sch.set_scope(qb, 0, "local")
sch.vectorize(sch.get_loops(qb)[-1])
sch.vectorize(sch.get_loops(sb)[-1])
Aq_local = sch.cache_read(dequant_block, read_buffer_index=0, storage_scope="local")
sch.compute_at(Aq_local, k2, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(Aq_local)[-1])
As_local = sch.cache_read(dequant_block, read_buffer_index=1, storage_scope="local")
sch.compute_at(As_local, k1, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(As_local)[-1])
sch.vectorize(sch.get_loops(dequant_block)[-1])

if epilogue_block is not None:
sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True)
sch.set_scope(wmat_block, 0, "local")
sch.compute_inline(wmat_block)
sch.vectorize(sch.get_loops(epilogue_block)[-1])
sch.reverse_compute_at(matmul_reindex, ty)
o_ur, o_vec = sch.get_loops(matmul_reindex)[-2:]
sch.vectorize(o_vec)
sch.unroll(o_ur)
sch.decompose_reduction(matmul_block, k1)

sch.decompose_reduction(matmul_block, k0)
return sch
Loading
Loading