diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index ecbb0b495..2e71b8b73 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -13,6 +13,7 @@ from typing import Callable, List, Literal, Any, Optional, Union from tqdm import tqdm import logging +import functools from dataclasses import dataclass import concurrent.futures import torch @@ -63,6 +64,7 @@ class JITContext: atol: float max_mismatched_ratio: float skip_check: bool + manual_check_prog: Callable cache_input_tensors: bool kernel: tilelang.JITKernel supply_type: tilelang.TensorSupplyType @@ -106,6 +108,7 @@ class CompileArgs: atol: float = 1e-2 max_mismatched_ratio: float = 0.01 skip_check: bool = False + manual_check_prog: Callable = None cache_input_tensors: bool = True target: Literal['auto', 'cuda', 'hip'] = 'auto' """ @@ -118,6 +121,7 @@ class CompileArgs: atol: float = 1e-2 max_mismatched_ratio: float = 0.01 skip_check: bool = False + manual_check_prog: Callable = None cache_input_tensors: bool = True target: Literal['auto', 'cuda', 'hip'] = 'auto' @@ -164,6 +168,7 @@ def set_compile_args(self, atol: float = 1e-2, max_mismatched_ratio: float = 0.01, skip_check: bool = False, + manual_check_prog: Callable = None, cache_input_tensors: bool = True, target: Literal['auto', 'cuda', 'hip'] = 'auto'): """Set compilation arguments for the auto-tuner. @@ -177,6 +182,7 @@ def set_compile_args(self, atol: Absolute tolerance for validation. max_mismatched_ratio: Maximum allowed mismatch ratio. skip_check: Whether to skip validation. + manual_check_prog: Manual check program for validation. cache_input_tensors: Whether to cache input tensors. target: Target platform. @@ -192,6 +198,7 @@ def set_compile_args(self, atol=atol, max_mismatched_ratio=max_mismatched_ratio, skip_check=skip_check, + manual_check_prog=manual_check_prog, cache_input_tensors=cache_input_tensors, target=target) @@ -234,6 +241,7 @@ def _compile(*config_arg): atol=compile_args.atol, max_mismatched_ratio=compile_args.max_mismatched_ratio, skip_check=compile_args.skip_check, + manual_check_prog=compile_args.manual_check_prog, cache_input_tensors=compile_args.cache_input_tensors, kernel=kernel, supply_type=compile_args.supply_type, @@ -248,6 +256,7 @@ def target_fn(jit_context: JITContext): kernel = jit_context.kernel supply_type = jit_context.supply_type skip_check = jit_context.skip_check + manual_check_prog = jit_context.manual_check_prog cache_input_tensors = jit_context.cache_input_tensors ref_prog = jit_context.ref_prog supply_prog = jit_context.supply_prog @@ -293,12 +302,18 @@ def func(): self.jit_input_tensors = jit_input_tensors_supply() if (not skip_check) and (ref_prog is not None): - profiler.assert_allclose( - ref_prog, - input_tensors=self.jit_input_tensors, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio) + if manual_check_prog is not None: + profiler.manual_assert_close( + ref_prog, + input_tensors=self.jit_input_tensors, + manual_check_prog=manual_check_prog) + else: + profiler.assert_allclose( + ref_prog, + input_tensors=self.jit_input_tensors, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio) latency = profiler.do_bench( warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) if self.ref_latency_cache is None and ref_prog is not None: @@ -325,9 +340,14 @@ def func(): pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) futures = [] future_to_index = {} + + def device_wrapper(func, device, *config_arg): + torch.cuda.set_device(device) + return func(*config_arg) + for i, config_arg in enumerate(config_args): future = pool.submit( - self.jit_compile, + functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()), *config_arg, ) futures.append(future) @@ -357,7 +377,9 @@ def func(): # Because tma init may behave strangely with one thread # latency, ref_latency = target_fn(jit_context) benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - future = benchmark_executor.submit(target_fn, jit_context) + future = benchmark_executor.submit( + functools.partial(device_wrapper, target_fn, torch.cuda.current_device()), + jit_context) latency, ref_latency = future.result(timeout=timeout) except concurrent.futures.TimeoutError: logger.info( @@ -436,6 +458,7 @@ def jit(out_idx: Optional[List[int]] = None, atol: float = 1e-2, max_mismatched_ratio: float = 0.01, skip_check: bool = False, + manual_check_prog: Callable = None, cache_input_tensors: bool = True, target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable: """Just-In-Time compilation decorator for tilelang programs. @@ -449,6 +472,7 @@ def jit(out_idx: Optional[List[int]] = None, atol: Absolute tolerance for output validation. max_mismatched_ratio: Maximum allowed ratio of mismatched elements. skip_check: Whether to skip validation checks. + manual_check_prog: Manual check program for validation. cache_input_tensors: Whether to cache input tensors for each compilation. target: Target platform ('auto', 'cuda', or 'hip'). @@ -477,6 +501,7 @@ def decorator(*args, **kwargs) -> float: atol=atol, max_mismatched_ratio=max_mismatched_ratio, skip_check=skip_check, + manual_check_prog=manual_check_prog, cache_input_tensors=cache_input_tensors, kernel=kernel, supply_type=supply_type, diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 304f68c80..2c7e95c13 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -123,6 +123,37 @@ def assert_allclose( ref_name="ref", ) + def manual_assert_close( + self, + reference_program: Callable, + input_tensors: Optional[List[torch.Tensor]] = None, + manual_check_prog: Callable = None, + ): + """Validates kernel output against a reference implementation. + + Args: + reference_program: Reference implementation to compare against + input_tensors: Optional pre-generated input tensors + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + max_mismatched_ratio: Maximum allowed ratio of mismatched elements + """ + ins = self._get_inputs() if input_tensors is None else input_tensors + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + elif ref_outs is None: + ref_outs = [] + assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !" + torch.set_printoptions(edgeitems=torch.inf) + manual_check_prog(lib_outs, ref_outs) + def assert_consistent(self, repeat=10): """Checks for kernel consistency across multiple runs.