Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 336 additions & 0 deletions examples/convolution/example_convolution_autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import torch
import argparse
import itertools
import tilelang as tl
from tilelang.autotuner import *
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization


def check_hopper():
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)


def ref_program(stride, padding, dilation):

def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C

return main


def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
if with_roller:
arch = CUDA("cuda")
carve_template = ConvTemplate(
N=N,
C=C,
H=H,
W=W,
F=F,
K=K,
S=S,
D=D,
P=P,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)

func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs


def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):

def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
is_hopper = check_hopper()

@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)

kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})

T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)

T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])

return main

autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P,
with_roller)).set_compile_args(
out_idx=[2],
target="auto",
).set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_prog,
skip_check=False,
)
return autotuner.run(warmup=3, rep=20)


def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}


def convolution(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
thread_num,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
Comment on lines +233 to +234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dtype and accum_dtype parameters of the convolution function (defined on lines 228-229) are shadowed by these hardcoded assignments. This means any values passed for dtype and accum_dtype when calling this function will be ignored, and it will always use "float16" and "float" respectively.

Was this intentional? If these types should be configurable, consider removing these hardcoded lines and using the function parameters directly.

Suggested change
dtype = "float16"
accum_dtype = "float"
# dtype = "float16" # Consider removing or using the function parameter
# accum_dtype = "float" # Consider removing or using the function parameter

is_hopper = check_hopper()

@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)

kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})

T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)

T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])

return main
Comment on lines +104 to +280
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The convolution function (lines 214-280) is nearly identical to the inner kernel function defined within get_best_config (lines 104-162). This significant code duplication can make maintenance harder, as changes might need to be applied in two places.

Could these be refactored into a single, shared function to define the kernel logic? This shared function could then be called by both get_best_config (for autotuning) and by the main execution path when use_autotune is False.



def main(n: int = 128,
c: int = 128,
h: int = 64,
w: int = 64,
f: int = 128,
k: int = 3,
s: int = 1,
d: int = 1,
p: int = 1,
use_autotune: bool = False,
with_roller: bool = True):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
use_autotune = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use_autotune variable is hardcoded to True here. This overrides the use_autotune parameter of the main function (line 292) and the corresponding command-line argument --use_autotune (lines 325-328), making them ineffective.

Should this line be removed to allow the parameter and CLI flag to control whether autotuning is used?

Suggested change
use_autotune = True
# use_autotune = True # This overrides the function parameter and CLI argument

if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller)
print(result.config)
kernel = result.kernel
else:
config = get_heuristic_config()
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_heuristic_config() function (lines 177-211) returns a configuration dictionary that includes the key "enable_rasteration". However, the convolution function (defined on lines 214-229) does not accept enable_rasteration as a parameter. When **config is used here, it will attempt to pass enable_rasteration as a keyword argument, leading to a TypeError.

How should this be handled? Options include:

  1. Modifying convolution to accept enable_rasteration (e.g., enable_rasteration=None), even if it doesn't use it.
  2. Filtering "enable_rasteration" out of the config dictionary before passing it to convolution.
  3. Removing "enable_rasteration" from the dictionary returned by get_heuristic_config() if it's not relevant for the heuristic path.


profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_prog)
profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2)
print(f"TileLang latency: {tilelang_latency}")
print(f"Ref latency: {ref_latency}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument('--n', type=int, default=128, help='n')
parser.add_argument('--c', type=int, default=128, help='c')
parser.add_argument('--h', type=int, default=64, help='h')
parser.add_argument('--w', type=int, default=64, help='w')
parser.add_argument('--f', type=int, default=128, help='f')
parser.add_argument('--k', type=int, default=3, help='k')
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune,
args.with_roller)
6 changes: 5 additions & 1 deletion examples/convolution/test_example_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import tilelang.testing

import example_convolution
import example_convolution_autotune


@tilelang.testing.requires_cuda
def test_example_convolution():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The @tilelang.testing.requires_cuda decorator seems to have been removed from this test function. The example_convolution.main([]) function (from example_convolution.py) initializes tensors on CUDA (e.g., torch.randn(...).cuda().half()).

Without this decorator, the test might fail in environments where CUDA is not available or not correctly configured. Was the removal intentional? If example_convolution.main still requires CUDA, this decorator should likely be reinstated to ensure test stability and correct skipping on non-CUDA environments.

Suggested change
def test_example_convolution():
@tilelang.testing.requires_cuda
def test_example_convolution():

example_convolution.main([])


def test_example_convolution_autotune():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new test test_example_convolution_autotune calls example_convolution_autotune.main(). The example_convolution_autotune.py script extensively uses CUDA operations (e.g., torch.cuda.is_available(), torch.cuda.get_device_properties(0), and the underlying TileLang compilation for CUDA).

This test function should also be decorated with @tilelang.testing.requires_cuda to ensure it's skipped appropriately on non-CUDA environments and to clearly indicate its CUDA dependency.

Suggested change
def test_example_convolution_autotune():
@tilelang.testing.requires_cuda
def test_example_convolution_autotune():

example_convolution_autotune.main()


if __name__ == "__main__":
tilelang.testing.main()