Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions tilelang/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Callable, List, Literal, Any, Optional, Union
from tqdm import tqdm
import logging
import functools
from dataclasses import dataclass
import concurrent.futures
import torch
Expand Down Expand Up @@ -63,6 +64,7 @@ class JITContext:
atol: float
max_mismatched_ratio: float
skip_check: bool
manual_check_prog: Callable
cache_input_tensors: bool
kernel: tilelang.JITKernel
supply_type: tilelang.TensorSupplyType
Expand Down Expand Up @@ -106,6 +108,7 @@ class CompileArgs:
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
"""
Expand All @@ -118,6 +121,7 @@ class CompileArgs:
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'

Expand Down Expand Up @@ -164,6 +168,7 @@ def set_compile_args(self,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto'):
"""Set compilation arguments for the auto-tuner.
Expand All @@ -177,6 +182,7 @@ def set_compile_args(self,
atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors.
target: Target platform.

Expand All @@ -192,6 +198,7 @@ def set_compile_args(self,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
target=target)

Expand Down Expand Up @@ -234,6 +241,7 @@ def _compile(*config_arg):
atol=compile_args.atol,
max_mismatched_ratio=compile_args.max_mismatched_ratio,
skip_check=compile_args.skip_check,
manual_check_prog=compile_args.manual_check_prog,
cache_input_tensors=compile_args.cache_input_tensors,
kernel=kernel,
supply_type=compile_args.supply_type,
Expand All @@ -248,6 +256,7 @@ def target_fn(jit_context: JITContext):
kernel = jit_context.kernel
supply_type = jit_context.supply_type
skip_check = jit_context.skip_check
manual_check_prog = jit_context.manual_check_prog
cache_input_tensors = jit_context.cache_input_tensors
ref_prog = jit_context.ref_prog
supply_prog = jit_context.supply_prog
Expand Down Expand Up @@ -293,12 +302,18 @@ def func():
self.jit_input_tensors = jit_input_tensors_supply()

if (not skip_check) and (ref_prog is not None):
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
Expand All @@ -325,9 +340,14 @@ def func():
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = []
future_to_index = {}

def device_wrapper(func, device, *config_arg):
torch.cuda.set_device(device)
return func(*config_arg)

for i, config_arg in enumerate(config_args):
future = pool.submit(
self.jit_compile,
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
*config_arg,
)
futures.append(future)
Expand Down Expand Up @@ -357,7 +377,9 @@ def func():
# Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_context)
benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = benchmark_executor.submit(target_fn, jit_context)
future = benchmark_executor.submit(
functools.partial(device_wrapper, target_fn, torch.cuda.current_device()),
jit_context)
latency, ref_latency = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
logger.info(
Expand Down Expand Up @@ -436,6 +458,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs.
Expand All @@ -449,6 +472,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors for each compilation.
target: Target platform ('auto', 'cuda', or 'hip').

Expand Down Expand Up @@ -477,6 +501,7 @@ def decorator(*args, **kwargs) -> float:
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
kernel=kernel,
supply_type=supply_type,
Expand Down
31 changes: 31 additions & 0 deletions tilelang/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,37 @@ def assert_allclose(
ref_name="ref",
)

def manual_assert_close(
self,
reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None,
manual_check_prog: Callable = None,
):
"""Validates kernel output against a reference implementation.

Args:
reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison
max_mismatched_ratio: Maximum allowed ratio of mismatched elements
"""
ins = self._get_inputs() if input_tensors is None else input_tensors
ref_outs = reference_program(*ins)
torch.cuda.synchronize()
lib_outs = self.func(*ins)
torch.cuda.synchronize()

if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs]
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
elif ref_outs is None:
ref_outs = []
assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !"
torch.set_printoptions(edgeitems=torch.inf)
manual_check_prog(lib_outs, ref_outs)

def assert_consistent(self, repeat=10):
"""Checks for kernel consistency across multiple runs.

Expand Down