-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
2,994 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
# pylint: disable=missing-docstring | ||
from typing import List, Optional | ||
|
||
from tvm import IRModule | ||
from tvm import meta_schedule as ms | ||
from tvm import tir | ||
from tvm.dlight import ScheduleGenerator, ScheduleRule, auto_inline_consumers | ||
|
||
|
||
class Decode(ScheduleRule): | ||
def __init__(self): | ||
... | ||
|
||
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: | ||
pass | ||
|
||
def clone(self) -> ScheduleRule: | ||
return Decode() | ||
|
||
def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals | ||
sch = tir.Schedule(mod) | ||
try: | ||
decode = sch.get_block("decode") | ||
except: # pylint: disable=bare-except | ||
return None | ||
len_tx: int = 8 | ||
len_ty: int = 8 | ||
len_yi: int = 1 | ||
len_yc: int = 8 | ||
|
||
# Step 1. Tile the decoding | ||
i, j = sch.get_loops(decode) | ||
by, ty, yi, yc = sch.split( # pylint: disable=invalid-name | ||
i, factors=[None, len_ty, len_yi, len_yc] | ||
) | ||
bx, tx = sch.split(j, factors=[None, len_tx]) # pylint: disable=invalid-name | ||
sch.reorder(by, bx, ty, tx, yi, yc) | ||
sch.bind(by, "blockIdx.y") | ||
sch.bind(bx, "blockIdx.x") | ||
sch.bind(ty, "threadIdx.y") | ||
sch.bind(tx, "threadIdx.x") | ||
sch.unroll(yc) | ||
# Step 2. Cache results in shared memory | ||
rb = sch.cache_write(decode, 0, "shared") # pylint: disable=invalid-name | ||
consumers = sch.get_consumers(rb) | ||
if consumers: | ||
(consumer,) = consumers | ||
auto_inline_consumers(sch, consumer) | ||
sch.compute_inline(rb) | ||
rb = consumer # pylint: disable=invalid-name | ||
# Step 3. Schedule the shared memory write back | ||
sch.reverse_compute_at(rb, bx, preserve_unit_loops=True) | ||
loop = sch.fuse(*sch.get_loops(rb)[-2:]) | ||
_, ty, tx, vec = sch.split( # pylint: disable=invalid-name | ||
loop, factors=[None, len_ty, len_tx, 4] | ||
) | ||
sch.bind(ty, "threadIdx.y") | ||
sch.bind(tx, "threadIdx.x") | ||
sch.vectorize(vec) | ||
sch.storage_align(decode, buffer_index=0, axis=0, factor=32, offset=1) | ||
return sch | ||
|
||
|
||
class DecodeGemv(ScheduleRule): | ||
def __init__(self): | ||
... | ||
|
||
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: | ||
pass | ||
|
||
def clone(self) -> ScheduleRule: | ||
return DecodeGemv() | ||
|
||
def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals | ||
sch = tir.Schedule(mod) | ||
try: | ||
gemv = sch.get_block("matmul") | ||
decode = sch.get_block("decode") | ||
except: # pylint: disable=bare-except | ||
return None | ||
len_vx: int = 2 | ||
len_tx: int = 64 | ||
len_km: int = 2 | ||
len_ki: int = 1 * 8 | ||
# Step 1. Schedule GEMV | ||
# [b=1, i=1, j, k] | ||
# split j => [b=1, i=1, (bx, vx, tx), k] | ||
# fuse (b, i, bx) => [bx, vx, tx, (k)] | ||
# split k => [bx, vx, tx, (ko, k_m, ki * 8)] | ||
rb = sch.cache_write(gemv, 0, "local") # pylint: disable=invalid-name | ||
b, i, j, k = sch.get_loops(gemv) # pylint: disable=invalid-name | ||
assert sch.get(b).extent.value == 1 | ||
assert sch.get(i).extent.value == 1 | ||
bx, vx, tx = sch.split(j, [None, len_vx, len_tx]) # pylint: disable=invalid-name | ||
bx = sch.fuse(b, i, bx) # pylint: disable=invalid-name | ||
k_o, k_m, k_i = sch.split(k, [None, len_km, len_ki]) | ||
sch.bind(bx, thread_axis="blockIdx.x") | ||
sch.bind(vx, thread_axis="vthread.x") | ||
sch.bind(tx, thread_axis="threadIdx.x") | ||
sch.reorder(bx, vx, tx, k_o, k_m, k_i) | ||
sch.unroll(k_i) | ||
# Step 2. Schedule decode: move to under threadIdx.x and fetch separately for each thread | ||
sch.compute_at(decode, k_m, preserve_unit_loops=True) | ||
sch.set_scope(decode, 0, "local") | ||
_, unroll = sch.split(sch.get_loops(decode)[-2], [None, 8]) | ||
sch.unroll(unroll) | ||
|
||
# Step 3. Cooperative fetch GEMV | ||
def cooperative_fetch(block, tx): # pylint: disable=invalid-name | ||
block = sch.cache_read(block, 0, "shared") | ||
sch.compute_at(block, tx, preserve_unit_loops=True) | ||
loop = sch.fuse(*sch.get_loops(block)[-2:]) | ||
len_vector = sch.sample_categorical( | ||
[1, 2, 3, 4], | ||
probs=[0.25, 0.25, 0.25, 0.25], | ||
) | ||
_, tx, vec = sch.split(loop, [None, len_tx, len_vector]) | ||
sch.bind(tx, thread_axis="threadIdx.x") | ||
sch.vectorize(vec) | ||
sch.storage_align(block, buffer_index=0, axis=-2, factor=32, offset=8) | ||
|
||
cooperative_fetch(gemv, k_o) | ||
# Step 4. Schedule epilogue | ||
auto_inline_consumers(sch, rb) | ||
sch.reverse_compute_at(rb, tx, preserve_unit_loops=True) | ||
# Step 5. Postprocess: decompose reduction | ||
sch.decompose_reduction(gemv, k_o) | ||
return [sch] | ||
|
||
|
||
class Normalization(ScheduleRule): | ||
def __init__(self): | ||
... | ||
|
||
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: | ||
pass | ||
|
||
def clone(self) -> ScheduleRule: | ||
return Normalization() | ||
|
||
def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals | ||
sch = tir.Schedule(mod) | ||
b_reduce = None | ||
for name in ["Ared_temp", "A_red_temp"]: | ||
try: | ||
b_reduce = sch.get_block(name) | ||
except: | ||
continue | ||
else: | ||
break | ||
if b_reduce is None: | ||
return None | ||
len_tx: int = 256 | ||
unroll_depth: int = 256 | ||
|
||
(b_spatial,) = sch.get_consumers(b_reduce) | ||
loops = sch.get_loops(b_spatial) | ||
bx = sch.fuse(*loops[:-1]) # pylint: disable=invalid-name | ||
_, tx = sch.split(loops[-1], [None, len_tx]) # pylint: disable=invalid-name | ||
sch.bind(bx, "blockIdx.x") | ||
sch.bind(tx, "threadIdx.x") | ||
|
||
for i, _ in enumerate(sch.get(b_reduce).writes): | ||
sch.set_scope(b_reduce, buffer_index=i, storage_scope="shared") | ||
sch.compute_at(b_reduce, bx, preserve_unit_loops=True) | ||
_, tx = sch.split( # pylint: disable=invalid-name | ||
sch.get_loops(b_reduce)[-1], | ||
[None, len_tx], | ||
) | ||
sch.bind(tx, "threadIdx.x") | ||
auto_inline_consumers(sch, b_spatial) | ||
sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) | ||
sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1) | ||
|
||
return sch | ||
|
||
|
||
class Softmax(ScheduleRule): | ||
def __init__(self): | ||
... | ||
|
||
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: | ||
pass | ||
|
||
def clone(self) -> ScheduleRule: | ||
return Softmax() | ||
|
||
def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint: disable=too-many-locals | ||
sch = tir.Schedule(mod) | ||
try: | ||
b_reduce_0 = sch.get_block("T_softmax_maxelem") | ||
b_reduce_1 = sch.get_block("T_softmax_expsum") | ||
b_spatial = sch.get_block("T_softmax_norm") | ||
except: # pylint: disable=bare-except | ||
return None | ||
|
||
len_tx: int = 256 | ||
unroll_depth: int = 256 | ||
|
||
sch.compute_inline(sch.get_producers(b_reduce_1)[0]) | ||
|
||
loops = sch.get_loops(b_spatial) | ||
bx = sch.fuse(*loops[:-1]) # pylint: disable=invalid-name | ||
_, tx = sch.split(loops[-1], [None, len_tx]) # pylint: disable=invalid-name | ||
sch.bind(bx, "blockIdx.x") | ||
sch.bind(tx, "threadIdx.x") | ||
|
||
sch.set_scope(b_reduce_1, buffer_index=0, storage_scope="shared") | ||
sch.compute_at(b_reduce_1, bx, preserve_unit_loops=True) | ||
_, tx = sch.split( # pylint: disable=invalid-name | ||
sch.get_loops(b_reduce_1)[-1], | ||
[None, len_tx], | ||
) | ||
sch.bind(tx, "threadIdx.x") | ||
|
||
sch.set_scope(b_reduce_0, buffer_index=0, storage_scope="shared") | ||
sch.compute_at(b_reduce_0, bx, preserve_unit_loops=True) | ||
_, tx = sch.split( # pylint: disable=invalid-name | ||
sch.get_loops(b_reduce_0)[-1], | ||
[None, len_tx], | ||
) | ||
sch.bind(tx, "threadIdx.x") | ||
|
||
sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) | ||
sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1) | ||
|
||
auto_inline_consumers(sch, b_spatial) | ||
return [sch] | ||
|
||
|
||
def main(): | ||
from tvm.dlight.testing import mod_decode, mod_decode_gemv, mod_norm, mod_softmax | ||
|
||
gen = ScheduleGenerator( | ||
rules=[ | ||
DecodeGemv(), | ||
Decode(), | ||
Normalization(), | ||
Softmax(), | ||
] | ||
) | ||
|
||
for py_mod in [ | ||
mod_decode, | ||
mod_decode_gemv, | ||
mod_norm, | ||
mod_softmax, # Needs to upstream `compute-inline` | ||
]: | ||
i = 1 | ||
while True: | ||
try: | ||
func = py_mod.Module[f"func{i}"] | ||
except: # pylint: disable=bare-except | ||
break | ||
else: | ||
print(f"Working on {py_mod}::func{i}") | ||
i += 1 | ||
mod = IRModule.from_expr(func.with_attr("global_symbol", "main")) | ||
schedules = gen.generate_design_space(mod) | ||
assert schedules is not None | ||
assert len(schedules) == 1 | ||
(sch,) = schedules | ||
sch.mod.show(black_format=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base import ScheduleGenerator, ScheduleRule | ||
from .schedule_utils import auto_inline_consumers, auto_inline_producers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import List, Optional, Union | ||
|
||
from tvm import meta_schedule as ms | ||
from tvm import tir | ||
from tvm.ir import IRModule | ||
|
||
|
||
class ScheduleRule: | ||
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: | ||
raise NotImplementedError | ||
|
||
def apply(self, mod: IRModule) -> Union[None, tir.Schedule, List[tir.Schedule]]: | ||
raise NotImplementedError | ||
|
||
def clone(self) -> "ScheduleRule": | ||
raise NotImplementedError | ||
|
||
|
||
@ms.utils.derived_object | ||
class ScheduleGenerator(ms.space_generator.PySpaceGenerator): | ||
def __init__(self, rules: List[ScheduleRule]): | ||
self.sch_rules = [] | ||
self.postprocs = [] | ||
self.mutator_probs = {} | ||
self.rules = rules | ||
|
||
def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: | ||
for rule in self.rules: | ||
rule._initialize_with_tune_context(context) # pylint: disable=protected-access | ||
|
||
def generate_design_space(self, mod: IRModule) -> Optional[List[tir.Schedule]]: | ||
for rule in self.rules: | ||
space = rule.apply(mod) | ||
if space is None: | ||
continue | ||
if isinstance(space, tir.Schedule): | ||
space = [space] | ||
return space | ||
return None | ||
|
||
def clone(self) -> "ScheduleGenerator": | ||
rules = [rule.clone() for rule in self.rules] | ||
return ScheduleGenerator(rules) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# pylint: disable=missing-docstring | ||
from tvm import tir | ||
|
||
|
||
def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): | ||
result = [] | ||
for producer in sch.get_producers(block): | ||
result.append(producer) | ||
result.extend(_collect_producers(sch, producer)) | ||
return result | ||
|
||
|
||
def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): | ||
result = [] | ||
for consumer in sch.get_consumers(block): | ||
result.append(consumer) | ||
result.extend(_collect_consumers(sch, consumer)) | ||
return result | ||
|
||
|
||
def auto_inline_producers( | ||
sch: tir.Schedule, | ||
block: tir.schedule.BlockRV, | ||
): | ||
while True: | ||
inlined_cnt = 0 | ||
producers = _collect_producers(sch, block) | ||
for producer in producers: | ||
try: | ||
sch.compute_inline(producer) | ||
inlined_cnt += 1 | ||
except: # pylint: disable=bare-except | ||
continue | ||
if inlined_cnt == 0: | ||
return | ||
|
||
|
||
def auto_inline_consumers( | ||
sch: tir.Schedule, | ||
block: tir.schedule.BlockRV, | ||
): | ||
while True: | ||
inlined_cnt = 0 | ||
consumers = _collect_consumers(sch, block) | ||
for consumer in consumers: | ||
try: | ||
sch.compute_inline(consumer) | ||
inlined_cnt += 1 | ||
except: # pylint: disable=bare-except | ||
continue | ||
for consumer in consumers: | ||
try: | ||
sch.reverse_compute_inline(consumer) | ||
inlined_cnt += 1 | ||
except: # pylint: disable=bare-except | ||
continue | ||
if inlined_cnt == 0: | ||
return |
Empty file.
Oops, something went wrong.