From cad4c87f079edaaf89fbba56977ac9212eaec1cd Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 21 Jun 2023 21:50:50 -0700 Subject: [PATCH] [Unity] Scaffolding DLight This PR introduces a new package `tvm.dlight`. DLight provides infra to implement composable default schedules to TVM IRModule. --- python/tvm/dlight/__init__.py | 19 ++++ python/tvm/dlight/base/__init__.py | 19 ++++ python/tvm/dlight/base/schedule_rule.py | 105 ++++++++++++++++++++++ python/tvm/dlight/base/transform.py | 78 ++++++++++++++++ python/tvm/dlight/gpu/__init__.py | 21 +++++ python/tvm/dlight/gpu/fallback.py | 91 +++++++++++++++++++ src/tir/schedule/analysis/analysis.cc | 12 +++ tests/python/dlight/test_schedule_rule.py | 71 +++++++++++++++ tests/scripts/unity/task_python_relax.sh | 1 + 9 files changed, 417 insertions(+) create mode 100644 python/tvm/dlight/__init__.py create mode 100644 python/tvm/dlight/base/__init__.py create mode 100644 python/tvm/dlight/base/schedule_rule.py create mode 100644 python/tvm/dlight/base/transform.py create mode 100644 python/tvm/dlight/gpu/__init__.py create mode 100644 python/tvm/dlight/gpu/fallback.py create mode 100644 tests/python/dlight/test_schedule_rule.py diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py new file mode 100644 index 000000000000..23dd17993cee --- /dev/null +++ b/python/tvm/dlight/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DLight package provides efficient schedules out-of-box for deep learning workloads.""" +from . import gpu +from .base import ApplyDefaultSchedule, ScheduleRule diff --git a/python/tvm/dlight/base/__init__.py b/python/tvm/dlight/base/__init__.py new file mode 100644 index 000000000000..6088add37ef7 --- /dev/null +++ b/python/tvm/dlight/base/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base infra""" +from .schedule_rule import ScheduleRule +from .transform import ApplyDefaultSchedule diff --git a/python/tvm/dlight/base/schedule_rule.py b/python/tvm/dlight/base/schedule_rule.py new file mode 100644 index 000000000000..3bb7e5c1a929 --- /dev/null +++ b/python/tvm/dlight/base/schedule_rule.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A lightweight wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc.""" +from typing import Callable, List, Union + +from tvm import tir +from tvm.target import Target + + +class ScheduleRule: # pylint: disable=too-few-public-methods + """A thin wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc. + + Given a PrimFunc, a target, and a tunable flag, the apply method of a ScheduleRule + returns either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. If the tunable flag is True, the ScheduleRule is + allowed to return either a Schedule or a list of Schedules, and the Schedules are allowed to + contain tunable instructions. If the tunable flag is False, the ScheduleRule is only allowed to + return a Schedule, and the Schedule is not allowed to contain tunable instructions. + """ + + def apply( + self, + func: tir.PrimFunc, + target: Target, + tunable: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + """Apply the ScheduleRule to the given PrimFunc. + + Parameters + ---------- + func : tir.PrimFunc + The PrimFunc to apply the ScheduleRule to. + target : Target + The compilation target the schedule is supposed to be built for. + tunable : bool + Whether the schedule is allowed to contain tunable instructions. + + Returns + ------- + results : Union[None, tir.Schedule, List[tir.Schedule]] + Either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. + """ + raise NotImplementedError + + @staticmethod + def from_callable( + name, + ) -> Callable[ + [ + Callable[ + [tir.PrimFunc, Target, bool], + Union[None, tir.Schedule, List[tir.Schedule]], + ], + ], + "ScheduleRule", + ]: + """Create a ScheduleRule from a callable. + + Parameters + ---------- + name : str + + Returns + ------- + decorator : Callable + A decorator that takes a callable and returns a ScheduleRule. + + Examples + -------- + .. code-block:: python + + @ScheduleRule.from_callable("MyRule") + def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] + # Do something with func and target + """ + + def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name + class _Rule(ScheduleRule): + def apply( + self, + func: tir.PrimFunc, + target: Target, + tunable: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + return f(func, target, tunable) + + _Rule.__name__ = name + return _Rule() + + return decorator diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py new file mode 100644 index 000000000000..9d536adfec83 --- /dev/null +++ b/python/tvm/dlight/base/transform.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Apply ScheduleRules onto an IRModule to generate default schedules without tuning, +or a space for MetaSchedule tuning +""" +from typing import List, Optional + +from tvm import tir +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm.target import Target + +from .schedule_rule import ScheduleRule + + +@module_pass(opt_level=0, name="ApplyDefaultSchedule") +class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods + """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" + + def __init__(self, *rules: ScheduleRule): + """Construct a new ApplyDefaultSchedule pass. + + Parameters + ---------- + *rules : ScheduleRule + The ScheduleRules to apply to all PrimFuncs in the module. + """ + self.rules = list(rules) + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + target = Target.current(allow_none=False) + updated_functions = {} + for g_var, func in mod.functions.items(): + if isinstance(func, tir.PrimFunc) and ( + not func.attrs or not func.attrs.get("tir.is_scheduled", 0) + ): + sch = _apply_rules(func, target, self.rules, tunable=False) + if sch is not None: + assert len(sch) == 1 + updated_functions[g_var] = sch[0].mod["main"].with_attr("tir.is_scheduled", 1) + for g_var, func in updated_functions.items(): + mod[g_var] = func + return mod + + +def _apply_rules( + func: tir.PrimFunc, + target: Target, + rules: List[ScheduleRule], + tunable: bool, +) -> Optional[List[tir.Schedule]]: + for rule in rules: + space = rule.apply(func, target, tunable) + if space is None: + continue + if isinstance(space, tir.Schedule): + space = [space] + return space + return None diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py new file mode 100644 index 000000000000..d5311014b043 --- /dev/null +++ b/python/tvm/dlight/gpu/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +GPU-generic schedule rules. +For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead +""" +from .fallback import Fallback diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py new file mode 100644 index 000000000000..354361323ca3 --- /dev/null +++ b/python/tvm/dlight/gpu/fallback.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""A fallback schedule rule for GPU operators.""" +from typing import Callable, List + +from tvm import tir +from tvm._ffi import get_global_func +from tvm.target import Target + +from ..base import ScheduleRule + + +def _max_threads_per_block(target: Target) -> int: + max_threads_per_block = None + for name in ["max_threads_per_block", "max_num_threads"]: + if max_threads_per_block is None: + max_threads_per_block = target.attrs.get(name, None) + if max_threads_per_block is None: + max_threads_per_block = 64 + return int(max_threads_per_block) + + +class Fallback(ScheduleRule): + """ + A fallback schedule rule for all GPU operators. It will try to inline all the blocks first, + and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + max_threads_per_block = _max_threads_per_block(target) + get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType") + + sch = tir.Schedule(func) + blocks = sch.get_child_blocks(sch.get_block(sch.mod["main"].body.block.name_hint)) + + while True: + + def _try_inline(func: Callable): + for i, block in enumerate(blocks): + try: + func(block) + except: # pylint: disable=bare-except + continue + return i + return None + + i = _try_inline(sch.compute_inline) + if i is None: + i = _try_inline(sch.reverse_compute_inline) + if i is None: + break + blocks.pop(i) + + for block in blocks: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + for loop in sch.get_loops(block): + iter_type = get_loop_iter_type(sch, loop) + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + sch.reorder(*s_loops, *r_loops, *o_loops) + bx, tx = sch.split( # pylint: disable=invalid-name + sch.fuse(*s_loops), + factors=[None, max_threads_per_block], + ) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + return sch diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 1f989ef939ec..b99193e36e16 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2104,5 +2104,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); }); +TVM_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") + .set_body_typed([](Schedule sch, LoopRV loop) -> String { + IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); + if (kind == kDataPar) { + return "S"; + } else if (kind == kCommReduce) { + return "R"; + } else { + return "O"; + } + }); + } // namespace tir } // namespace tvm diff --git a/tests/python/dlight/test_schedule_rule.py b/tests/python/dlight/test_schedule_rule.py new file mode 100644 index 000000000000..8d20169c534a --- /dev/null +++ b/tests/python/dlight/test_schedule_rule.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from tvm import dlight as dl +from tvm.ir import assert_structural_equal +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.target import Target + + +def test_fallback(): + @I.ir_module + class Before: + @T.prim_func + def main( + A: T.Buffer((1, 32, 1, 128), "float16"), + C: T.Buffer((1, 1, 4096), "float16"), + ): + B = T.alloc_buffer((1, 1, 32, 128), "float16") + for i, j, k, l in T.grid(1, 1, 32, 128): + with T.block("T_transpose"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vk, vj, vl] + for i, j, k in T.grid(1, 1, 4096): + with T.block("T_reshape"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] + + @I.ir_module + class After: + @T.prim_func + def main( + A: T.Buffer((1, 32, 1, 128), "float16"), + C: T.Buffer((1, 1, 4096), "float16"), + ): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + for i_j_k_fused_0 in T.thread_binding(4, thread="blockIdx.x"): + for i_j_k_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("T_reshape"): + vi = T.axis.spatial(1, 0) + vj = T.axis.spatial(1, 0) + vk = T.axis.spatial(4096, i_j_k_fused_0 * 1024 + i_j_k_fused_1) + T.reads(A[0, vk % 4096 // 128, 0, vk % 128]) + T.writes(C[vi, vj, vk]) + C[vi, vj, vk] = A[0, vk % 4096 // 128, 0, vk % 128] + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Fallback(), + )(Before) + assert_structural_equal(mod, After) + + +if __name__ == "__main__": + test_fallback() diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh index 8869c318fab7..b6b70ab457ec 100755 --- a/tests/scripts/unity/task_python_relax.sh +++ b/tests/scripts/unity/task_python_relax.sh @@ -30,6 +30,7 @@ make cython3 # Run Relax tests TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax +TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight # Run Relax examples # python3 ./apps/relax_examples/mlp.py