Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Profile only the largest-possible alignment by default #10036

Merged
merged 15 commits into from
Jan 26, 2022
46 changes: 31 additions & 15 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def select_gemm_kernel(
arg1_dtype,
use_3xtf32,
batched,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
Expand All @@ -126,10 +126,10 @@ def select_gemm_kernel(
arg1_dtype,
use_3xtf32,
batched=batched,
profile_all=profile_all,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
Expand All @@ -146,7 +146,7 @@ def handle_batch_matmul(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for batch_matmul op workload."""
Expand All @@ -165,7 +165,7 @@ def handle_batch_matmul(
arg1_dtype,
use_3xtf32,
True,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand All @@ -191,7 +191,7 @@ def handle_dense(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for dense op workload."""
Expand All @@ -210,7 +210,7 @@ def handle_dense(
arg1_dtype,
use_3xtf32,
False,
profile_all,
find_first_valid,
use_multiprocessing,
)

Expand All @@ -237,7 +237,8 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
Expand All @@ -257,10 +258,11 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
if not find_first_valid:
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", name)
Expand All @@ -272,7 +274,13 @@ def handle_conv2d(


def tune_cutlass_kernels(
mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"
mod,
sm,
use_3xtf32=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
tmp_dir="./tmp",
):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Expand All @@ -286,7 +294,14 @@ def tune_cutlass_kernels(
An integer specifying the compute capability. For example, 75 for Turing and
80 or 86 for Ampere.

profile_all : bool
use_3xtf32 : bool
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.

profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.

find_first_valid : bool
Whether or not profile all candidate kernels, or stop profiling after
the first applicable kernel is found.

Expand Down Expand Up @@ -342,7 +357,8 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)
)
Expand All @@ -357,7 +373,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand All @@ -372,7 +388,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
find_first_valid,
use_multiprocessing,
)
)
Expand Down
36 changes: 17 additions & 19 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
import re
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
Expand Down Expand Up @@ -168,14 +167,6 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
)
return {"name": name, "opdef": opdef}

def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(match.groups()[0])
return all([dim % align == 0 for dim in [C, K]])

def select_op(
self,
d_shape,
Expand All @@ -187,7 +178,8 @@ def select_op(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
):
"""
Expand Down Expand Up @@ -216,12 +208,16 @@ def select_op(
return self.cache[workload]

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32
out_dtype,
data_dtype,
weight_dtype,
enumerate_conv2d_operators,
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
)

ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

if profile_all:
if not find_first_valid:
self.engine.compile_all(ops, use_multiprocessing)

args = (
Expand All @@ -232,7 +228,7 @@ def select_op(
for op in ops:
out = self.engine.evaluate(op, args.split(" "))
op["runtime"] = out
if out < float("inf") and not profile_all:
if out < float("inf") and find_first_valid:
self.cache[workload] = op
return op

Expand All @@ -252,11 +248,12 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32=True,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
Expand All @@ -269,8 +266,9 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)

name, opdef = create_conv2d_operator_with_epilogue(
Expand Down
50 changes: 27 additions & 23 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""GEMM kernel generator and profiler for CUTLASS."""
import re
from .gemm_operation import GemmOperation, EmitGemmInstance
from .gemm_profiler import GemmProfilerEmitter
from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
Expand Down Expand Up @@ -63,8 +62,9 @@ def create_gemm_operator_with_epilogue(
swizzling_functor,
)

return op.procedural_name(), EmitGemmInstance().emit(
op, no_beta_scaling=no_beta_scaling, batched=batched
return (
op.procedural_name(),
EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, batched=batched),
)


Expand Down Expand Up @@ -150,26 +150,22 @@ def __init__(self, sm, cutlass_path, binary_path):
self.sm = sm
self.cache = {}

def check_align(self, op_name, M, N, K):
"""Filter out kernels that cannot be supported."""
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(match.groups()[0])
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return all([dim % align == 0 for dim in [M, N, K]])

def get_default(
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32
out_dtype,
arg0_dtype,
arg1_dtype,
enumerate_gemm_operators,
lambda align: align == 1, # Only request align1 kernels
use_3xtf32,
profile_all_alignments=True, # To include all align1 kernels
)

default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]

if arg0_dtype == "float32":
Expand Down Expand Up @@ -200,7 +196,8 @@ def select_op(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
):
"""
Expand All @@ -211,22 +208,27 @@ def select_op(
op = self.cache[(M, N, K)]
return op

# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
arg0_dtype,
arg1_dtype,
enumerate_gemm_operators,
use_3xtf32=use_3xtf32,
lambda align: all([dim % align == 0 for dim in [M, N, K]]),
use_3xtf32,
profile_all_alignments=profile_all_alignments,
)
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))

if profile_all:
if not find_first_valid:
self.engine.compile_all(ops, use_multiprocessing)

for op in ops:
out = self.engine.evaluate(op, [M, N, K])
op["runtime"] = out
if out < float("inf") and not profile_all:
if out < float("inf") and find_first_valid:
self.cache[(M, N, K)] = op
return op

Expand All @@ -244,12 +246,13 @@ def profile(
arg0_dtype,
arg1_dtype,
use_3xtf32=True,
profile_all=True,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
batched=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
Expand All @@ -260,7 +263,8 @@ def profile(
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=profile_all,
profile_all_alignments=profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
)

Expand Down
Loading