Skip to content

Commit

Permalink
[Dlight] Enhance Decode-GEMV Rules
Browse files Browse the repository at this point in the history
This PR enhances Decode-GEMV rule with the following changes:
- Normalize the GEMV iter domain to S-R-C via transform-block-layout.
  This would help with further analysis and scheduling, in cases for
  example, when there was no spatial loop in the original reduction
  block.
- Get rid of the ad hoc iter type analysis, including the logic calling
  into a TVM packed func `tir.schedule.GetLoopIterType` using
  `tvm._ffi.get_global_func`.
- Split out the logic for two separate cases of scheduling, where the
  innermost dimension is spatial or reduction.
- Introduces `suggest_threads_per_block` to guess the threads to be
  allocated each threadblock. This helps avoid the previous case where
  dlight allocates 256 threads for a workload whose degree of parallelism
  is only 128.
- Misc improvements.

This rest of the changes are split out to separate PRs that are already
merged to main.
- [x] Pass the hints to arithmetic analyzer that shape variables should
be positive ones (apache#15210)
- [x] Eliminate unnecessary block predicate generation - should be
provable via affine analysis (apache#15193)
- [x] Shrink local memory allocation if only one element `X[threadIdx.x]`
is used (apache#15207)
  • Loading branch information
junrushao committed Jul 4, 2023
1 parent 780d6e6 commit 0d8ff66
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 142 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[tool.isort]
profile = "black"
src_paths = ["python", "tests/python"]


[tool.black]
line-length = 100
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/dlight/base/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Union

from typing_extensions import Literal

from tvm import tir
from tvm._ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import BlockRV
from typing_extensions import Literal


class IterInfo:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/dlight/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
from .fallback import Fallback
from .decode_gemv import DecodeGEMV
from .reduction import Reduction
from .fallback import Fallback
from .matmul import Matmul
from .reduction import Reduction
255 changes: 147 additions & 108 deletions python/tvm/dlight/gpu/decode_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
"""A fallback schedule rule for GPU operators."""
# pylint: disable=invalid-name
"""A rule for DecodeGEMV."""
from typing import List, Optional, Set, Tuple, Union

from typing import List, Optional, Union

from tvm import tir
from tvm._ffi import get_global_func
from tvm.arith import normalize_to_iter_sum
from tvm import arith, tir
from tvm.ir import structural_equal
from tvm.target import Target

from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial
from ..base import (
BlockInfo,
ScheduleRule,
normalize_prim_func,
try_inline_contiguous_spatial,
)
from . import utils


def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
Expand All @@ -47,13 +48,13 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:

def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
dominant_read, read_iters = None, None
tir_vars = set()
tir_vars: Set[tir.Var] = set()
for buffer_region in block.reads:
tir_vars.clear()

def _collect_tir_var(e):
if isinstance(e, tir.Var):
tir_vars.add(e)
def _collect_tir_var(expr):
if isinstance(expr, tir.Var):
tir_vars.add(expr)

for expr in buffer_region.region:
assert expr.extent == 1
Expand All @@ -68,27 +69,18 @@ def _collect_tir_var(e):


class DecodeGEMV(ScheduleRule):
def __init__(self) -> None:
super().__init__()
self.get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType")
"""A rule for DecodeGEMV."""

def apply( # pylint: disable=too-many-locals
def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
if not isinstance(func, tir.PrimFunc):
return None

if target.kind.name == "cuda":
len_tx, len_ty = 16, 16
else:
len_tx, len_ty = 8, 8

sch = tir.Schedule(func)
block_infos = try_inline_contiguous_spatial(sch, normalize_prim_func(sch))

if block_infos is None or len(block_infos) > 2:
return None

Expand All @@ -97,96 +89,143 @@ def apply( # pylint: disable=too-many-locals
block_stmt = sch.get(block)

# Step 1. Check reduction block
if not block_info.is_reduction():
if (
(not block_info.is_reduction())
or len(block_stmt.writes) != 1
or _get_reduction_expr(block_stmt) is None
):
return None
if len(block_stmt.writes) != 1:
return None
if _get_reduction_expr(block_stmt) is None:
return None

# Step 2. Sort out the spatial and reduction loops
sorted_iter_access = normalize_to_iter_sum(
_detect_dominant_read(block_stmt),
input_iters={i.var: i.dom for i in block_stmt.iter_vars},
# Step 2. Normalize the block, merge spatial and reduction iters
is_inner_reduction, c_factor = self._normalize(
sch,
block_info,
arith.normalize_to_iter_sum(
_detect_dominant_read(block_stmt),
input_iters={i.var: i.dom for i in block_stmt.iter_vars},
),
)
if sorted_iter_access.base != 0:
return None
iter_to_info = {i.var: i for i in block_info.iters}
s_loops, r_loops, c_loops = [], [], []
for split in sorted_iter_access.args:
block_var = split.source.source
block_var_info = iter_to_info[block_var]
loop_rv = block_var_info.loop_rv
is_inner_reduction = block_var_info.kind == "R"
if split.lower_factor > 1:
c_loop_factor = split.lower_factor
loop_rv, c_loop = sch.split(loop_rv, factors=[None, c_loop_factor])
c_loops.append(c_loop)
is_loop_c_reduction = is_inner_reduction
if is_inner_reduction:
r_loops.append(loop_rv)
else:
s_loops.append(loop_rv)

if len(c_loops) > 1:
return None
if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
if is_inner_reduction is None and c_factor is None:
return None
if len(s_loops) == 0 or len(r_loops) == 0:
return None

sch.reorder(*s_loops, *r_loops, *c_loops)
s = sch.fuse(*s_loops)
r = sch.fuse(*r_loops)

if is_inner_reduction:
_, tx = sch.split(r, factors=[None, len_tx * len_ty])
rf = sch.rfactor(tx, 0)
s, r, tx = sch.get_loops(rf)[:3]
sch.reorder(s, tx, r)
sch.reverse_compute_at(block, s, preserve_unit_loops=True)
sch.bind(tx, "threadIdx.x")
sch.bind(s, "blockIdx.x")
else:
sch.split(s, factors=[None, len_tx])
_, ty = sch.split(r, factors=[None, len_ty])
rf = sch.rfactor(ty, 0)
bx, tx, r, ty = sch.get_loops(rf)[:4]
sch.reorder(bx, tx, ty, r)
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
sch.bind(tx, "threadIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(bx, "blockIdx.x")

s_loops, r_loops = [], []
for loop_rv in sch.get_loops(block)[1:]:
iter_type = self.get_loop_iter_type(sch, loop_rv)
if iter_type == "S":
s_loops.append(loop_rv)
elif iter_type == "R":
r_loops.append(loop_rv)
else:
raise RuntimeError("Unknown loop type " + str(iter_type))
sch.reorder(*s_loops, *r_loops)
s_ctr = sch.fuse(*s_loops)
r_ctr = sch.fuse(*r_loops)

if c_loops and not is_loop_c_reduction:
s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
sch.reorder(s_ctr, r_ctr, inner)

# Step 3. Do the scheduling
if is_inner_reduction:
sch.bind(r_ctr, "threadIdx.x")
sch.set_scope(rf, 0, "local")
sch.decompose_reduction(rf, sch.get_loops(rf)[2])
self._sch_inner_reduction(sch, target, block, c_factor)
else:
sch.bind(s_ctr, "threadIdx.x")
sch.bind(r_ctr, "threadIdx.y")
sch.set_scope(rf, 0, "local")
sch.decompose_reduction(rf, sch.get_loops(rf)[3])

self._sch_inner_spatial(sch, target, block, c_factor)
# Step 4. Schedule epilogue
if len(block_infos) == 2:
sch.set_scope(block, 0, "local")
sch.reverse_compute_at(block_infos[1].block_rv, sch.get_loops(block)[0])

return sch

def _normalize(
self,
sch: tir.Schedule,
block_info: BlockInfo,
iter_sum: arith.IterSumExpr,
) -> Tuple[Optional[bool], Optional[int]]:
if iter_sum.base != 0:
return None, None
iter_to_info = {i.var: i for i in block_info.iters}
s_dom, r_dom, c_dom, c_factor = None, None, None, None
for split in iter_sum.args:
var = split.source.source
info = iter_to_info[var]
dom = info.dom
is_inner_reduction = info.kind == "R"
if split.lower_factor > 1:
if c_dom is not None:
return None, None
c_dom = tir.floormod(var, split.lower_factor)
var = tir.floordiv(var, split.lower_factor)
dom = tir.floordiv(dom, split.lower_factor)
if not is_inner_reduction:
c_factor = split.lower_factor
if is_inner_reduction:
if r_dom is None:
r_dom = var
else:
r_dom = r_dom * dom + var
else:
if s_dom is None:
s_dom = var
else:
s_dom = s_dom * dom + var

assert r_dom is not None
if s_dom is None:
s_dom = tir.const(1, r_dom.dtype)
if c_dom is None:
c_dom = tir.const(1, r_dom.dtype)
sch.transform_block_layout(
block_info.block_rv,
tir.IndexMap(
[i.var for i in block_info.iters],
[s_dom, r_dom, c_dom],
None,
),
)
return is_inner_reduction, c_factor

def _sch_inner_reduction(
self,
sch: tir.Schedule,
target: Target,
block: tir.schedule.BlockRV,
unroll_spatial_factor: Optional[int],
):
# pylint: disable=invalid-name
_, r, _ = sch.get_loops(block)
(len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking
target, [sch.get(r)]
)

_, tx = sch.split(r, factors=[None, len_tx])
rf = sch.rfactor(tx, 0)
s, r, tx = sch.get_loops(rf)[:3]
sch.reorder(s, tx, r)
sch.reverse_compute_at(block, s, preserve_unit_loops=True)
sch.bind(tx, "threadIdx.x")
sch.bind(s, "blockIdx.x")

_, r, *s = sch.get_loops(block)
s = sch.fuse(*s)
sch.reorder(s, r)
if unroll_spatial_factor:
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
sch.reorder(s, r, inner)
sch.bind(r, "threadIdx.x")
sch.set_scope(rf, 0, "local")
sch.decompose_reduction(rf, sch.get_loops(rf)[2])
# pylint: enable=invalid-name

def _sch_inner_spatial(
self,
sch: tir.Schedule,
_: Target,
block: tir.schedule.BlockRV,
unroll_spatial_factor: Optional[int],
):
# pylint: disable=invalid-name
s, r, _ = sch.get_loops(block)
len_tx, len_ty = 16, 16
sch.split(s, factors=[None, len_tx])
_, ty = sch.split(r, factors=[None, len_ty])
rf = sch.rfactor(ty, 0)
bx, tx, r, ty = sch.get_loops(rf)[:4]
sch.reorder(bx, tx, ty, r)
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
sch.bind(tx, "threadIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(bx, "blockIdx.x")

_, r, *s = sch.get_loops(block)
s = sch.fuse(*s)
sch.reorder(s, r)
if unroll_spatial_factor:
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
sch.reorder(s, r, inner)
sch.bind(s, "threadIdx.x")
sch.bind(r, "threadIdx.y")
sch.set_scope(rf, 0, "local")
sch.decompose_reduction(rf, sch.get_loops(rf)[3])
# pylint: enable=invalid-name
5 changes: 3 additions & 2 deletions python/tvm/dlight/gpu/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm import tir
from tvm.target import Target

from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
from ..base import ScheduleRule, normalize_prim_func, try_inline
from . import utils


class Fallback(ScheduleRule):
Expand All @@ -36,7 +37,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
target: Target,
_: bool,
) -> tir.Schedule:
max_threads_per_block = analysis.get_max_threads_per_block(target)
max_threads_per_block = utils.max_threads_per_block(target)

sch = tir.Schedule(func)
block_infos = try_inline(sch, normalize_prim_func(sch))
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
# under the License.
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from enum import Enum
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple

from tvm import tir
from tvm.ir import Range
from tvm.target import Target
from tvm.tir import PrimExpr, Var, IterVar
from tvm.tir import IterVar, PrimExpr, Var
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV

Expand Down
Loading

0 comments on commit 0d8ff66

Please sign in to comment.