Skip to content

Commit

Permalink
Scaffolding DLight
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jun 22, 2023
1 parent cbdee2f commit d1aa184
Show file tree
Hide file tree
Showing 10 changed files with 2,994 additions and 0 deletions.
267 changes: 267 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# pylint: disable=missing-docstring
from typing import List, Optional

from tvm import IRModule
from tvm import meta_schedule as ms
from tvm import tir
from tvm.dlight import ScheduleGenerator, ScheduleRule, auto_inline_consumers


class Decode(ScheduleRule):
def __init__(self):
...

def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
pass

def clone(self) -> ScheduleRule:
return Decode()

def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals
sch = tir.Schedule(mod)
try:
decode = sch.get_block("decode")
except: # pylint: disable=bare-except
return None
len_tx: int = 8
len_ty: int = 8
len_yi: int = 1
len_yc: int = 8

# Step 1. Tile the decoding
i, j = sch.get_loops(decode)
by, ty, yi, yc = sch.split( # pylint: disable=invalid-name
i, factors=[None, len_ty, len_yi, len_yc]
)
bx, tx = sch.split(j, factors=[None, len_tx]) # pylint: disable=invalid-name
sch.reorder(by, bx, ty, tx, yi, yc)
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.unroll(yc)
# Step 2. Cache results in shared memory
rb = sch.cache_write(decode, 0, "shared") # pylint: disable=invalid-name
consumers = sch.get_consumers(rb)
if consumers:
(consumer,) = consumers
auto_inline_consumers(sch, consumer)
sch.compute_inline(rb)
rb = consumer # pylint: disable=invalid-name
# Step 3. Schedule the shared memory write back
sch.reverse_compute_at(rb, bx, preserve_unit_loops=True)
loop = sch.fuse(*sch.get_loops(rb)[-2:])
_, ty, tx, vec = sch.split( # pylint: disable=invalid-name
loop, factors=[None, len_ty, len_tx, 4]
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)
sch.storage_align(decode, buffer_index=0, axis=0, factor=32, offset=1)
return sch


class DecodeGemv(ScheduleRule):
def __init__(self):
...

def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
pass

def clone(self) -> ScheduleRule:
return DecodeGemv()

def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals
sch = tir.Schedule(mod)
try:
gemv = sch.get_block("matmul")
decode = sch.get_block("decode")
except: # pylint: disable=bare-except
return None
len_vx: int = 2
len_tx: int = 64
len_km: int = 2
len_ki: int = 1 * 8
# Step 1. Schedule GEMV
# [b=1, i=1, j, k]
# split j => [b=1, i=1, (bx, vx, tx), k]
# fuse (b, i, bx) => [bx, vx, tx, (k)]
# split k => [bx, vx, tx, (ko, k_m, ki * 8)]
rb = sch.cache_write(gemv, 0, "local") # pylint: disable=invalid-name
b, i, j, k = sch.get_loops(gemv) # pylint: disable=invalid-name
assert sch.get(b).extent.value == 1
assert sch.get(i).extent.value == 1
bx, vx, tx = sch.split(j, [None, len_vx, len_tx]) # pylint: disable=invalid-name
bx = sch.fuse(b, i, bx) # pylint: disable=invalid-name
k_o, k_m, k_i = sch.split(k, [None, len_km, len_ki])
sch.bind(bx, thread_axis="blockIdx.x")
sch.bind(vx, thread_axis="vthread.x")
sch.bind(tx, thread_axis="threadIdx.x")
sch.reorder(bx, vx, tx, k_o, k_m, k_i)
sch.unroll(k_i)
# Step 2. Schedule decode: move to under threadIdx.x and fetch separately for each thread
sch.compute_at(decode, k_m, preserve_unit_loops=True)
sch.set_scope(decode, 0, "local")
_, unroll = sch.split(sch.get_loops(decode)[-2], [None, 8])
sch.unroll(unroll)

# Step 3. Cooperative fetch GEMV
def cooperative_fetch(block, tx): # pylint: disable=invalid-name
block = sch.cache_read(block, 0, "shared")
sch.compute_at(block, tx, preserve_unit_loops=True)
loop = sch.fuse(*sch.get_loops(block)[-2:])
len_vector = sch.sample_categorical(
[1, 2, 3, 4],
probs=[0.25, 0.25, 0.25, 0.25],
)
_, tx, vec = sch.split(loop, [None, len_tx, len_vector])
sch.bind(tx, thread_axis="threadIdx.x")
sch.vectorize(vec)
sch.storage_align(block, buffer_index=0, axis=-2, factor=32, offset=8)

cooperative_fetch(gemv, k_o)
# Step 4. Schedule epilogue
auto_inline_consumers(sch, rb)
sch.reverse_compute_at(rb, tx, preserve_unit_loops=True)
# Step 5. Postprocess: decompose reduction
sch.decompose_reduction(gemv, k_o)
return [sch]


class Normalization(ScheduleRule):
def __init__(self):
...

def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
pass

def clone(self) -> ScheduleRule:
return Normalization()

def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals
sch = tir.Schedule(mod)
b_reduce = None
for name in ["Ared_temp", "A_red_temp"]:
try:
b_reduce = sch.get_block(name)
except:
continue
else:
break
if b_reduce is None:
return None
len_tx: int = 256
unroll_depth: int = 256

(b_spatial,) = sch.get_consumers(b_reduce)
loops = sch.get_loops(b_spatial)
bx = sch.fuse(*loops[:-1]) # pylint: disable=invalid-name
_, tx = sch.split(loops[-1], [None, len_tx]) # pylint: disable=invalid-name
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")

for i, _ in enumerate(sch.get(b_reduce).writes):
sch.set_scope(b_reduce, buffer_index=i, storage_scope="shared")
sch.compute_at(b_reduce, bx, preserve_unit_loops=True)
_, tx = sch.split( # pylint: disable=invalid-name
sch.get_loops(b_reduce)[-1],
[None, len_tx],
)
sch.bind(tx, "threadIdx.x")
auto_inline_consumers(sch, b_spatial)
sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)

return sch


class Softmax(ScheduleRule):
def __init__(self):
...

def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
pass

def clone(self) -> ScheduleRule:
return Softmax()

def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals
sch = tir.Schedule(mod)
try:
b_reduce_0 = sch.get_block("T_softmax_maxelem")
b_reduce_1 = sch.get_block("T_softmax_expsum")
b_spatial = sch.get_block("T_softmax_norm")
except: # pylint: disable=bare-except
return None

len_tx: int = 256
unroll_depth: int = 256

sch.compute_inline(sch.get_producers(b_reduce_1)[0])

loops = sch.get_loops(b_spatial)
bx = sch.fuse(*loops[:-1]) # pylint: disable=invalid-name
_, tx = sch.split(loops[-1], [None, len_tx]) # pylint: disable=invalid-name
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")

sch.set_scope(b_reduce_1, buffer_index=0, storage_scope="shared")
sch.compute_at(b_reduce_1, bx, preserve_unit_loops=True)
_, tx = sch.split( # pylint: disable=invalid-name
sch.get_loops(b_reduce_1)[-1],
[None, len_tx],
)
sch.bind(tx, "threadIdx.x")

sch.set_scope(b_reduce_0, buffer_index=0, storage_scope="shared")
sch.compute_at(b_reduce_0, bx, preserve_unit_loops=True)
_, tx = sch.split( # pylint: disable=invalid-name
sch.get_loops(b_reduce_0)[-1],
[None, len_tx],
)
sch.bind(tx, "threadIdx.x")

sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)

auto_inline_consumers(sch, b_spatial)
return [sch]


def main():
from tvm.dlight.testing import mod_decode, mod_decode_gemv, mod_norm, mod_softmax

gen = ScheduleGenerator(
rules=[
DecodeGemv(),
Decode(),
Normalization(),
Softmax(),
]
)

for py_mod in [
mod_decode,
mod_decode_gemv,
mod_norm,
mod_softmax, # Needs to upstream `compute-inline`
]:
i = 1
while True:
try:
func = py_mod.Module[f"func{i}"]
except: # pylint: disable=bare-except
break
else:
print(f"Working on {py_mod}::func{i}")
i += 1
mod = IRModule.from_expr(func.with_attr("global_symbol", "main"))
schedules = gen.generate_design_space(mod)
assert schedules is not None
assert len(schedules) == 1
(sch,) = schedules
sch.mod.show(black_format=False)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions python/tvm/dlight/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import ScheduleGenerator, ScheduleRule
from .schedule_utils import auto_inline_consumers, auto_inline_producers
43 changes: 43 additions & 0 deletions python/tvm/dlight/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List, Optional, Union

from tvm import meta_schedule as ms
from tvm import tir
from tvm.ir import IRModule


class ScheduleRule:
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
raise NotImplementedError

def apply(self, mod: IRModule) -> Union[None, tir.Schedule, List[tir.Schedule]]:
raise NotImplementedError

def clone(self) -> "ScheduleRule":
raise NotImplementedError


@ms.utils.derived_object
class ScheduleGenerator(ms.space_generator.PySpaceGenerator):
def __init__(self, rules: List[ScheduleRule]):
self.sch_rules = []
self.postprocs = []
self.mutator_probs = {}
self.rules = rules

def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
for rule in self.rules:
rule._initialize_with_tune_context(context) # pylint: disable=protected-access

def generate_design_space(self, mod: IRModule) -> Optional[List[tir.Schedule]]:
for rule in self.rules:
space = rule.apply(mod)
if space is None:
continue
if isinstance(space, tir.Schedule):
space = [space]
return space
return None

def clone(self) -> "ScheduleGenerator":
rules = [rule.clone() for rule in self.rules]
return ScheduleGenerator(rules)
58 changes: 58 additions & 0 deletions python/tvm/dlight/schedule_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# pylint: disable=missing-docstring
from tvm import tir


def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for producer in sch.get_producers(block):
result.append(producer)
result.extend(_collect_producers(sch, producer))
return result


def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for consumer in sch.get_consumers(block):
result.append(consumer)
result.extend(_collect_consumers(sch, consumer))
return result


def auto_inline_producers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
):
while True:
inlined_cnt = 0
producers = _collect_producers(sch, block)
for producer in producers:
try:
sch.compute_inline(producer)
inlined_cnt += 1
except: # pylint: disable=bare-except
continue
if inlined_cnt == 0:
return


def auto_inline_consumers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
):
while True:
inlined_cnt = 0
consumers = _collect_consumers(sch, block)
for consumer in consumers:
try:
sch.compute_inline(consumer)
inlined_cnt += 1
except: # pylint: disable=bare-except
continue
for consumer in consumers:
try:
sch.reverse_compute_inline(consumer)
inlined_cnt += 1
except: # pylint: disable=bare-except
continue
if inlined_cnt == 0:
return
Empty file.
Loading

0 comments on commit d1aa184

Please sign in to comment.