Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Scaffolding DLight #15141

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/tvm/dlight/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""DLight package provides efficient schedules out-of-box for deep learning workloads."""
from . import gpu
from .base import ApplyDefaultSchedule, ScheduleRule
19 changes: 19 additions & 0 deletions python/tvm/dlight/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Base infra"""
from .schedule_rule import ScheduleRule
from .transform import ApplyDefaultSchedule
105 changes: 105 additions & 0 deletions python/tvm/dlight/base/schedule_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A lightweight wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc."""
from typing import Callable, List, Union

from tvm import tir
from tvm.target import Target


class ScheduleRule: # pylint: disable=too-few-public-methods
"""A thin wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc.

Given a PrimFunc, a target, and a tunable flag, the apply method of a ScheduleRule
returns either a Schedule, a list of Schedules, or None, where None means that the rule
is not applicable to the given PrimFunc. If the tunable flag is True, the ScheduleRule is
allowed to return either a Schedule or a list of Schedules, and the Schedules are allowed to
contain tunable instructions. If the tunable flag is False, the ScheduleRule is only allowed to
return a Schedule, and the Schedule is not allowed to contain tunable instructions.
"""

def apply(
self,
func: tir.PrimFunc,
target: Target,
tunable: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
"""Apply the ScheduleRule to the given PrimFunc.

Parameters
----------
func : tir.PrimFunc
The PrimFunc to apply the ScheduleRule to.
target : Target
The compilation target the schedule is supposed to be built for.
tunable : bool
Whether the schedule is allowed to contain tunable instructions.

Returns
-------
results : Union[None, tir.Schedule, List[tir.Schedule]]
Either a Schedule, a list of Schedules, or None, where None means that the rule
is not applicable to the given PrimFunc.
"""
raise NotImplementedError

@staticmethod
def from_callable(
name,
) -> Callable[
[
Callable[
[tir.PrimFunc, Target, bool],
Union[None, tir.Schedule, List[tir.Schedule]],
],
],
"ScheduleRule",
]:
"""Create a ScheduleRule from a callable.

Parameters
----------
name : str

Returns
-------
decorator : Callable
A decorator that takes a callable and returns a ScheduleRule.

Examples
--------
.. code-block:: python

@ScheduleRule.from_callable("MyRule")
def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule]
# Do something with func and target
"""

def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name
class _Rule(ScheduleRule):
def apply(
self,
func: tir.PrimFunc,
target: Target,
tunable: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
return f(func, target, tunable)

_Rule.__name__ = name
return _Rule()

return decorator
78 changes: 78 additions & 0 deletions python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Apply ScheduleRules onto an IRModule to generate default schedules without tuning,
or a space for MetaSchedule tuning
"""
from typing import List, Optional

from tvm import tir
from tvm.ir import IRModule
from tvm.ir.transform import PassContext, module_pass
from tvm.target import Target

from .schedule_rule import ScheduleRule


@module_pass(opt_level=0, name="ApplyDefaultSchedule")
class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods
"""A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module."""

def __init__(self, *rules: ScheduleRule):
"""Construct a new ApplyDefaultSchedule pass.

Parameters
----------
*rules : ScheduleRule
The ScheduleRules to apply to all PrimFuncs in the module.
"""
self.rules = list(rules)

def transform_module( # pylint: disable=missing-function-docstring
self,
mod: IRModule,
_: PassContext,
) -> IRModule:
target = Target.current(allow_none=False)
updated_functions = {}
for g_var, func in mod.functions.items():
if isinstance(func, tir.PrimFunc) and (
not func.attrs or not func.attrs.get("tir.is_scheduled", 0)
):
sch = _apply_rules(func, target, self.rules, tunable=False)
if sch is not None:
assert len(sch) == 1
updated_functions[g_var] = sch[0].mod["main"].with_attr("tir.is_scheduled", 1)
for g_var, func in updated_functions.items():
mod[g_var] = func
return mod


def _apply_rules(
func: tir.PrimFunc,
target: Target,
rules: List[ScheduleRule],
tunable: bool,
) -> Optional[List[tir.Schedule]]:
for rule in rules:
space = rule.apply(func, target, tunable)
if space is None:
continue
if isinstance(space, tir.Schedule):
space = [space]
return space
return None
21 changes: 21 additions & 0 deletions python/tvm/dlight/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
from .fallback import Fallback
91 changes: 91 additions & 0 deletions python/tvm/dlight/gpu/fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
"""A fallback schedule rule for GPU operators."""
from typing import Callable, List

from tvm import tir
from tvm._ffi import get_global_func
from tvm.target import Target

from ..base import ScheduleRule


def _max_threads_per_block(target: Target) -> int:
max_threads_per_block = None
for name in ["max_threads_per_block", "max_num_threads"]:
if max_threads_per_block is None:
max_threads_per_block = target.attrs.get(name, None)
if max_threads_per_block is None:
max_threads_per_block = 64
return int(max_threads_per_block)


class Fallback(ScheduleRule):
"""
A fallback schedule rule for all GPU operators. It will try to inline all the blocks first,
and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks.
"""

def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> tir.Schedule:
max_threads_per_block = _max_threads_per_block(target)
get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType")

sch = tir.Schedule(func)
blocks = sch.get_child_blocks(sch.get_block(sch.mod["main"].body.block.name_hint))

while True:

def _try_inline(func: Callable):
for i, block in enumerate(blocks):
try:
func(block)
except: # pylint: disable=bare-except
continue
return i
return None

i = _try_inline(sch.compute_inline)
if i is None:
i = _try_inline(sch.reverse_compute_inline)
if i is None:
break
blocks.pop(i)

for block in blocks:
s_loops: List[tir.schedule.LoopRV] = []
r_loops: List[tir.schedule.LoopRV] = []
o_loops: List[tir.schedule.LoopRV] = []
for loop in sch.get_loops(block):
iter_type = get_loop_iter_type(sch, loop)
{"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop)

if not s_loops:
s_loops.append(sch.add_unit_loop(block))
sch.reorder(*s_loops, *r_loops, *o_loops)
bx, tx = sch.split( # pylint: disable=invalid-name
sch.fuse(*s_loops),
factors=[None, max_threads_per_block],
)
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
return sch
12 changes: 12 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2104,5 +2104,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch
return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false));
});

TVM_REGISTER_GLOBAL("tir.schedule.GetLoopIterType")
.set_body_typed([](Schedule sch, LoopRV loop) -> String {
IterVarType kind = GetLoopIterType(sch->GetSRef(loop));
if (kind == kDataPar) {
return "S";
} else if (kind == kCommReduce) {
return "R";
} else {
return "O";
}
});

} // namespace tir
} // namespace tvm
Loading