-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR introduces a new package `tvm.dlight`. DLight provides infra to implement composable default schedules to TVM IRModule.
- Loading branch information
Showing
9 changed files
with
417 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""DLight package provides efficient schedules out-of-box for deep learning workloads.""" | ||
from . import gpu | ||
from .base import ApplyDefaultSchedule, ScheduleRule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Base infra""" | ||
from .schedule_rule import ScheduleRule | ||
from .transform import ApplyDefaultSchedule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""A lightweight wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc.""" | ||
from typing import Callable, List, Union | ||
|
||
from tvm import tir | ||
from tvm.target import Target | ||
|
||
|
||
class ScheduleRule: # pylint: disable=too-few-public-methods | ||
"""A thin wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc. | ||
Given a PrimFunc, a target, and a tunable flag, the apply method of a ScheduleRule | ||
returns either a Schedule, a list of Schedules, or None, where None means that the rule | ||
is not applicable to the given PrimFunc. If the tunable flag is True, the ScheduleRule is | ||
allowed to return either a Schedule or a list of Schedules, and the Schedules are allowed to | ||
contain tunable instructions. If the tunable flag is False, the ScheduleRule is only allowed to | ||
return a Schedule, and the Schedule is not allowed to contain tunable instructions. | ||
""" | ||
|
||
def apply( | ||
self, | ||
func: tir.PrimFunc, | ||
target: Target, | ||
tunable: bool, | ||
) -> Union[None, tir.Schedule, List[tir.Schedule]]: | ||
"""Apply the ScheduleRule to the given PrimFunc. | ||
Parameters | ||
---------- | ||
func : tir.PrimFunc | ||
The PrimFunc to apply the ScheduleRule to. | ||
target : Target | ||
The compilation target the schedule is supposed to be built for. | ||
tunable : bool | ||
Whether the schedule is allowed to contain tunable instructions. | ||
Returns | ||
------- | ||
results : Union[None, tir.Schedule, List[tir.Schedule]] | ||
Either a Schedule, a list of Schedules, or None, where None means that the rule | ||
is not applicable to the given PrimFunc. | ||
""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def from_callable( | ||
name, | ||
) -> Callable[ | ||
[ | ||
Callable[ | ||
[tir.PrimFunc, Target, bool], | ||
Union[None, tir.Schedule, List[tir.Schedule]], | ||
], | ||
], | ||
"ScheduleRule", | ||
]: | ||
"""Create a ScheduleRule from a callable. | ||
Parameters | ||
---------- | ||
name : str | ||
Returns | ||
------- | ||
decorator : Callable | ||
A decorator that takes a callable and returns a ScheduleRule. | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
@ScheduleRule.from_callable("MyRule") | ||
def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] | ||
# Do something with func and target | ||
""" | ||
|
||
def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name | ||
class _Rule(ScheduleRule): | ||
def apply( | ||
self, | ||
func: tir.PrimFunc, | ||
target: Target, | ||
tunable: bool, | ||
) -> Union[None, tir.Schedule, List[tir.Schedule]]: | ||
return f(func, target, tunable) | ||
|
||
_Rule.__name__ = name | ||
return _Rule() | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
""" | ||
Apply ScheduleRules onto an IRModule to generate default schedules without tuning, | ||
or a space for MetaSchedule tuning | ||
""" | ||
from typing import List, Optional | ||
|
||
from tvm import tir | ||
from tvm.ir import IRModule | ||
from tvm.ir.transform import PassContext, module_pass | ||
from tvm.target import Target | ||
|
||
from .schedule_rule import ScheduleRule | ||
|
||
|
||
@module_pass(opt_level=0, name="ApplyDefaultSchedule") | ||
class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods | ||
"""A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" | ||
|
||
def __init__(self, *rules: ScheduleRule): | ||
"""Construct a new ApplyDefaultSchedule pass. | ||
Parameters | ||
---------- | ||
*rules : ScheduleRule | ||
The ScheduleRules to apply to all PrimFuncs in the module. | ||
""" | ||
self.rules = list(rules) | ||
|
||
def transform_module( # pylint: disable=missing-function-docstring | ||
self, | ||
mod: IRModule, | ||
_: PassContext, | ||
) -> IRModule: | ||
target = Target.current(allow_none=False) | ||
updated_functions = {} | ||
for g_var, func in mod.functions.items(): | ||
if isinstance(func, tir.PrimFunc) and ( | ||
not func.attrs or not func.attrs.get("tir.is_scheduled", 0) | ||
): | ||
sch = _apply_rules(func, target, self.rules, tunable=False) | ||
if sch is not None: | ||
assert len(sch) == 1 | ||
updated_functions[g_var] = sch[0].mod["main"].with_attr("tir.is_scheduled", 1) | ||
for g_var, func in updated_functions.items(): | ||
mod[g_var] = func | ||
return mod | ||
|
||
|
||
def _apply_rules( | ||
func: tir.PrimFunc, | ||
target: Target, | ||
rules: List[ScheduleRule], | ||
tunable: bool, | ||
) -> Optional[List[tir.Schedule]]: | ||
for rule in rules: | ||
space = rule.apply(func, target, tunable) | ||
if space is None: | ||
continue | ||
if isinstance(space, tir.Schedule): | ||
space = [space] | ||
return space | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
""" | ||
GPU-generic schedule rules. | ||
For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead | ||
""" | ||
from .fallback import Fallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=missing-docstring | ||
"""A fallback schedule rule for GPU operators.""" | ||
from typing import Callable, List | ||
|
||
from tvm import tir | ||
from tvm._ffi import get_global_func | ||
from tvm.target import Target | ||
|
||
from ..base import ScheduleRule | ||
|
||
|
||
def _max_threads_per_block(target: Target) -> int: | ||
max_threads_per_block = None | ||
for name in ["max_threads_per_block", "max_num_threads"]: | ||
if max_threads_per_block is None: | ||
max_threads_per_block = target.attrs.get(name, None) | ||
if max_threads_per_block is None: | ||
max_threads_per_block = 64 | ||
return int(max_threads_per_block) | ||
|
||
|
||
class Fallback(ScheduleRule): | ||
""" | ||
A fallback schedule rule for all GPU operators. It will try to inline all the blocks first, | ||
and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks. | ||
""" | ||
|
||
def apply( # pylint: disable=too-many-locals,missing-docstring | ||
self, | ||
func: tir.PrimFunc, | ||
target: Target, | ||
_: bool, | ||
) -> tir.Schedule: | ||
max_threads_per_block = _max_threads_per_block(target) | ||
get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType") | ||
|
||
sch = tir.Schedule(func) | ||
blocks = sch.get_child_blocks(sch.get_block(sch.mod["main"].body.block.name_hint)) | ||
|
||
while True: | ||
|
||
def _try_inline(func: Callable): | ||
for i, block in enumerate(blocks): | ||
try: | ||
func(block) | ||
except: # pylint: disable=bare-except | ||
continue | ||
return i | ||
return None | ||
|
||
i = _try_inline(sch.compute_inline) | ||
if i is None: | ||
i = _try_inline(sch.reverse_compute_inline) | ||
if i is None: | ||
break | ||
blocks.pop(i) | ||
|
||
for block in blocks: | ||
s_loops: List[tir.schedule.LoopRV] = [] | ||
r_loops: List[tir.schedule.LoopRV] = [] | ||
o_loops: List[tir.schedule.LoopRV] = [] | ||
for loop in sch.get_loops(block): | ||
iter_type = get_loop_iter_type(sch, loop) | ||
{"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) | ||
|
||
if not s_loops: | ||
s_loops.append(sch.add_unit_loop(block)) | ||
sch.reorder(*s_loops, *r_loops, *o_loops) | ||
bx, tx = sch.split( # pylint: disable=invalid-name | ||
sch.fuse(*s_loops), | ||
factors=[None, max_threads_per_block], | ||
) | ||
sch.bind(bx, "blockIdx.x") | ||
sch.bind(tx, "threadIdx.x") | ||
return sch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.