diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 7a3545e80..9c352a4b5 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -27,6 +27,16 @@ using int4_t = int4; #define TL_DEVICE_NOINLINE __noinline__ __device__ #define TL_PATCH +#define TILELANG_CHECK(stmt) \ + do { \ + cudaError_t __err = (stmt); \ + if (__err != cudaSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \ + __LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + // abs function for bfloat_t and half_t since there is no implicit convertion // method TL_PATCH TL_DEVICE half_t __habs(const half_t x) { diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index 15838ff68..4ffc47e7b 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -26,6 +26,17 @@ #define ushort unsigned short #define TL_DEVICE __forceinline__ __device__ +#define TL_DEVICE_NOINLINE __noinline__ __device__ + +#define TILELANG_CHECK(stmt) \ + do { \ + hipError_t __err = (stmt); \ + if (__err != hipSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \ + __LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) #define half _Float16 #define __float2half_rn(x) half(x) diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 6efdfbf80..6a053063e 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -141,10 +141,31 @@ class LoopPartitioner : public StmtExprVisitor { PrimExpr thd = FloorMod(access_idx, num_thread); PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size + FloorMod(flattened, vectorize_size); - return Fragment(loop_vars_, {idx}, {thd}, {}); + auto fragment = Fragment(loop_vars_, {idx}, {thd}, {}); + if (has_fragment_) { + // for fragment buffer, we don't need to replicate the loop layout + auto thread_extent = *as_const_int(fragment->ThreadExtent()); + auto num_thread_fragment = num_thread / thread_extent; + fragment = fragment->Replicate(num_thread_fragment); + } + return fragment; } private: + void VisitExpr_(const BufferLoadNode *op) final { + if (op->buffer.scope() == "local.fragment") { + has_fragment_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.scope() == "local.fragment") { + has_fragment_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode *node) final { if (node->kind == ForKind::kParallel) { body_ = node->body; @@ -157,6 +178,7 @@ class LoopPartitioner : public StmtExprVisitor { Stmt body_; PrimExpr flattened = 0; + bool has_fragment_ = false; Array loop_vars_; }; diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 476a9bc7b..061a03275 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -99,6 +99,7 @@ def _load_tile_lang_lib(): language, # noqa: F401 engine, # noqa: F401 ) +from .transform import PassConfigKey # noqa: F401 from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401 diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 98f31523b..8cca5118b 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -3,17 +3,29 @@ from tvm import tir, IRModule from tvm.target import Target import tilelang +from tilelang.transform import PassContext +from typing import Optional -def allow_tma_and_warp_specialized(target: Target) -> bool: +def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, + target: Optional[Target] = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() if target.arch not in {"sm_90"}: return False - cur_pass_ctx = tilelang.transform.get_pass_context() - disable_tma_lower = cur_pass_ctx.config.get("tl.disable_tma_lower", False) - disable_warp_specialized = cur_pass_ctx.config.get("tl.disable_warp_specialized", False) + disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) + disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) + disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False) return not (disable_tma_lower and disable_warp_specialized) +def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False) + return not disable_vectorize + + def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Bind the target device information to the module mod = tir.transform.BindTarget(target)(mod) @@ -40,8 +52,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: + pass_ctx = tilelang.transform.get_pass_context() # which may be introduced by the LegalizeSafeMemoryAccess - if allow_tma_and_warp_specialized(target): + if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.WarpSpecialized()(mod) @@ -59,11 +72,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.MergeIfStmt()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.FlattenBuffer()(mod) mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.Simplify()(mod) - mod = tilelang.transform.VectorizeLoop()(mod) + + mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) + mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod) diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 991eec585..6ec9a8492 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -45,7 +45,7 @@ PREDEF_HOST_FUNC = """ extern "C" int call({}) {{ {} - return 0; +\treturn 0; }} """ @@ -196,7 +196,7 @@ def legalize_c(p): p = int(p) return str(p).replace("//", "/") - _call_str = """""" + kernel_launch_code = """""" desc_name_map: Dict[str, str] = {} for function_name, function_info in function_informations.items(): block_info = function_info["block_info"] @@ -221,14 +221,18 @@ def legalize_c(p): grid_str = "dim3({}, {}, {})".format( legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf - _call_str += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(function_name, grid_str, - block_str, smem_str, - call_args) + kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format( + function_name, grid_str, block_str, smem_str, call_args) + kernel_launch_code += "\tcudaError_t err = cudaGetLastError();\n" + kernel_launch_code += "\tif (err != cudaSuccess) {{\n" + kernel_launch_code += f"\t\tsnprintf(error_buf, ERROR_BUF_SIZE, \"{function_name}: %s - %s\", cudaGetErrorName(err), cudaGetErrorString(err));\n" + kernel_launch_code += "\t\treturn -1;\n" + kernel_launch_code += "\t}}\n" - _call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str + kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code # Wrap the kernel dispatch logic in an external C function - host_func = PREDEF_HOST_FUNC.format(def_args, _call_str) + host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code) return host_func def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index f0adb2803..f7ef54b40 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -5,12 +5,14 @@ from . import _ffi_api from .simplify import Simplify, simplify_prim_func # noqa: F401 +from .pass_config import PassConfigKey # noqa: F401 +from tilelang import tvm as tvm # noqa: F401 +from tvm.ir.transform import PassContext # noqa: F401 def get_pass_context(): """Get the current pass context""" - from tilelang import tvm as tvm - return tvm.transform.PassContext.current() + return PassContext.current() def ClusterPlanning(): diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py new file mode 100644 index 000000000..44720f664 --- /dev/null +++ b/tilelang/transform/pass_config.py @@ -0,0 +1,61 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +# TODO: Add more documentation for each pass config + +from enum import Enum + + +class PassConfigKey(str, Enum): + """Pass configuration keys for TileLang compiler.""" + # TileLang specific configs + TL_SIMPLIFY = "tl.Simplify" + """Enable/disable TileLang simplification passes. Default: True""" + + TL_DYNAMIC_ALIGNMENT = "tl.dynamic_alignment" + """Memory alignment requirement for dynamic shapes. Default: 16""" + + TL_DISABLE_DYNAMIC_TAIL_SPLIT = "tl.disable_dynamic_tail_split" + """Disable dynamic tail splitting optimization. Default: False""" + + TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized" + """Disable warp specialization optimization. Default: False""" + + TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth" + """Bitwidth for configuration indices. Default: 32""" + + TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower" + """Disable TMA (Tensor Memory Access) lowering. Default: False""" + + # TIR related configs + TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" + """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" + + TIR_DISABLE_CSE = "tir.disable_cse_tir" + """Disable TIR Common Subexpression Elimination. Default: False""" + + TIR_SIMPLIFY = "tir.Simplify" + """Enable/disable TIR simplification passes. Default: True""" + + TIR_DISABLE_STORAGE_REWRITE = "tir.disable_storage_rewrite" + """Disable storage rewrite optimization. Default: False""" + + TIR_DISABLE_VECTORIZE = "tir.disable_vectorize" + """Disable vectorization optimization. Default: False""" + + TIR_USE_ASYNC_COPY = "tir.use_async_copy" + """Enable asynchronous memory copy operations. Default: True""" + + TIR_ENABLE_DEBUG = "tir.enable_debug" + """Enable debug information in generated code. Default: False""" + + TIR_MERGE_STATIC_SMEM = "tir.merge_static_smem" + """Merge static shared memory allocations. Default: True""" + + TIR_ADD_LOWER_PASS = "tir.add_lower_pass" + """Additional lowering passes to be applied. Default: None""" + + TIR_NOALIAS = "tir.noalias" + """Enable pointer non-aliasing assumptions. Default: True""" + + CUDA_KERNELS_OUTPUT_DIR = "cuda.kernels_output_dir" + """Output directory for generated CUDA kernels. Default: empty string"""