Skip to content

Commit

Permalink
Inductor: Allow small sizes of m for mixed mm autotuning (#127663)
Browse files Browse the repository at this point in the history
For mixed mm with small sizes of m, such as in the example provided in #127056, being able to set BLOCK_M to 16 leads to better performance. This PR introduces kernel configs that are specific to mixed mm by extending the mm configs with two configs that work well for the example provided in #127056.
I am excluding configs with (BLOCK_M=16, BLOCK_K=16, BLOCK_N=64) because triton crashes when this config is used.

For the example in #127056:
- Without my changes, skip_triton is evaluated to true which disables autotuning. On my machine I achieve 146GB/s.
- If autotuning is enabled, but BLOCK_M>=32, I achieve 614 GB/s.
- With the changes in this PR (i.e. autotuning enabled and BLOCK_M=16), I achieve 772 GB/s.

Pull Request resolved: #127663
Approved by: https://github.com/Chillee
  • Loading branch information
AlnisM authored and pytorchmergebot committed Jun 3, 2024
1 parent 7c3740d commit d8d0bf2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
15 changes: 12 additions & 3 deletions torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .mm_common import (
addmm_epilogue,
int8_mm_configs,
mixed_mm_configs,
mm_args,
mm_configs,
mm_grid,
Expand Down Expand Up @@ -407,15 +408,23 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype):

# can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
skip_triton = (
mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous()
mat1.layout.dtype != torch.float32
and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed())
) or _is_sm7x_or_older_gpu(layout.device.index)

if inductor_config.force_mixed_mm:
choices = []
if not skip_triton:
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
for config in mixed_mm_configs(m, n, k):
# skipping this config because triton crashes on it
# See: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
if (
config.kwargs["BLOCK_M"] == 16
and config.kwargs["BLOCK_K"] == 16
and config.kwargs["BLOCK_N"] == 64
):
continue
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down
32 changes: 27 additions & 5 deletions torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,10 @@ def filtered_configs(
n: int,
k: int,
configs: List[Tuple[int, int, int, int, int]],
has_int8_tensor=False,
):
"""Heuristic to shrink configs when they are bigger than the input size"""

# According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424
# it's safer to use at least [32, 32] block size for int8/uint8
# tensors
min_block_size = 32 if has_int8_tensor else 16
min_block_size = 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
Expand Down Expand Up @@ -166,6 +162,18 @@ def filtered_configs(
{"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
]

# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
mixed_mm_kernel_configs_small_m = [
{"config": (16, 128, 256, 3, 4), "cond": True},
{"config": (16, 128, 256, 5, 8), "cond": True},
]

mixed_mm_kernel_configs = (
mm_kernel_configs + mixed_mm_kernel_configs_small_m
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else mm_kernel_configs
)

# Create filtered list of configs based on cond evaluation


Expand All @@ -179,6 +187,11 @@ def filtered_configs(
for config in int8_mm_kernel_configs
if config["cond"]
)
mixed_mm_platform_configs = tuple(
cast(Tuple[int, int, int, int, int], config["config"])
for config in mixed_mm_kernel_configs
if config["cond"]
)

# On ROCm convert num_stages to 0 to enable software pipelining
if torch.version.hip:
Expand All @@ -190,6 +203,10 @@ def filtered_configs(
(config[0], config[1], config[2], 0, config[4])
for config in mm_platform_configs
)
mixed_mm_platform_configs = tuple(
(config[0], config[1], config[2], 0, config[4])
for config in mixed_mm_platform_configs
)

mm_configs = functools.partial(
filtered_configs,
Expand All @@ -201,6 +218,11 @@ def filtered_configs(
configs=int8_platform_configs,
)

mixed_mm_configs = functools.partial(
filtered_configs,
configs=mixed_mm_platform_configs,
)


def mm_grid(m, n, meta):
"""
Expand Down

1 comment on commit d8d0bf2

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #127663 on behalf of https://github.com/soulitzer due to breaks torch ao CI, see: #127924 (comment)

Please sign in to comment.