Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ using int4_t = int4;
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TL_PATCH

#define TILELANG_CHECK(stmt) \
do { \
cudaError_t __err = (stmt); \
if (__err != cudaSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \
return -1; \
} \
} while (0)

// abs function for bfloat_t and half_t since there is no implicit convertion
// method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
Expand Down
11 changes: 11 additions & 0 deletions src/tl_templates/hip/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
#define ushort unsigned short

#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__

#define TILELANG_CHECK(stmt) \
do { \
hipError_t __err = (stmt); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)

#define half _Float16
#define __float2half_rn(x) half(x)
Expand Down
24 changes: 23 additions & 1 deletion src/transform/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,31 @@ class LoopPartitioner : public StmtExprVisitor {
PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
FloorMod(flattened, vectorize_size);
return Fragment(loop_vars_, {idx}, {thd}, {});
auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
if (has_fragment_) {
// for fragment buffer, we don't need to replicate the loop layout
auto thread_extent = *as_const_int(fragment->ThreadExtent());
auto num_thread_fragment = num_thread / thread_extent;
fragment = fragment->Replicate(num_thread_fragment);
}
return fragment;
}

private:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
has_fragment_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
has_fragment_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) {
body_ = node->body;
Expand All @@ -157,6 +178,7 @@ class LoopPartitioner : public StmtExprVisitor {

Stmt body_;
PrimExpr flattened = 0;
bool has_fragment_ = false;
Array<IterVar> loop_vars_;
};

Expand Down
1 change: 1 addition & 0 deletions tilelang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _load_tile_lang_lib():
language, # noqa: F401
engine, # noqa: F401
)
from .transform import PassConfigKey # noqa: F401

from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401

Expand Down
28 changes: 22 additions & 6 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@
from tvm import tir, IRModule
from tvm.target import Target
import tilelang
from tilelang.transform import PassContext
from typing import Optional


def allow_tma_and_warp_specialized(target: Target) -> bool:
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90"}:
return False
cur_pass_ctx = tilelang.transform.get_pass_context()
disable_tma_lower = cur_pass_ctx.config.get("tl.disable_tma_lower", False)
disable_warp_specialized = cur_pass_ctx.config.get("tl.disable_warp_specialized", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not (disable_tma_lower and disable_warp_specialized)


def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
return not disable_vectorize


def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod)
Expand All @@ -40,8 +52,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:


def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
pass_ctx = tilelang.transform.get_pass_context()
# which may be introduced by the LegalizeSafeMemoryAccess
if allow_tma_and_warp_specialized(target):
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
Expand All @@ -59,11 +72,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)

mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop()(mod)

mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)

mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
Expand Down
18 changes: 11 additions & 7 deletions tilelang/jit/adapter/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
PREDEF_HOST_FUNC = """
extern "C" int call({}) {{
{}
return 0;
\treturn 0;
}}
"""

Expand Down Expand Up @@ -196,7 +196,7 @@ def legalize_c(p):
p = int(p)
return str(p).replace("//", "/")

_call_str = """"""
kernel_launch_code = """"""
desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
Expand All @@ -221,14 +221,18 @@ def legalize_c(p):
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
_call_str += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(function_name, grid_str,
block_str, smem_str,
call_args)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tcudaError_t err = cudaGetLastError();\n"
kernel_launch_code += "\tif (err != cudaSuccess) {{\n"
kernel_launch_code += f"\t\tsnprintf(error_buf, ERROR_BUF_SIZE, \"{function_name}: %s - %s\", cudaGetErrorName(err), cudaGetErrorString(err));\n"
kernel_launch_code += "\t\treturn -1;\n"
kernel_launch_code += "\t}}\n"

_call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code

# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
return host_func

def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
Expand Down
6 changes: 4 additions & 2 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401
from .pass_config import PassConfigKey # noqa: F401
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401


def get_pass_context():
"""Get the current pass context"""
from tilelang import tvm as tvm
return tvm.transform.PassContext.current()
return PassContext.current()


def ClusterPlanning():
Expand Down
61 changes: 61 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
# TODO: Add more documentation for each pass config

from enum import Enum


class PassConfigKey(str, Enum):
"""Pass configuration keys for TileLang compiler."""
# TileLang specific configs
TL_SIMPLIFY = "tl.Simplify"
"""Enable/disable TileLang simplification passes. Default: True"""

TL_DYNAMIC_ALIGNMENT = "tl.dynamic_alignment"
"""Memory alignment requirement for dynamic shapes. Default: 16"""

TL_DISABLE_DYNAMIC_TAIL_SPLIT = "tl.disable_dynamic_tail_split"
"""Disable dynamic tail splitting optimization. Default: False"""

TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
"""Disable warp specialization optimization. Default: False"""

TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
"""Bitwidth for configuration indices. Default: 32"""

TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower"
"""Disable TMA (Tensor Memory Access) lowering. Default: False"""

# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""

TIR_DISABLE_CSE = "tir.disable_cse_tir"
"""Disable TIR Common Subexpression Elimination. Default: False"""

TIR_SIMPLIFY = "tir.Simplify"
"""Enable/disable TIR simplification passes. Default: True"""

TIR_DISABLE_STORAGE_REWRITE = "tir.disable_storage_rewrite"
"""Disable storage rewrite optimization. Default: False"""

TIR_DISABLE_VECTORIZE = "tir.disable_vectorize"
"""Disable vectorization optimization. Default: False"""

TIR_USE_ASYNC_COPY = "tir.use_async_copy"
"""Enable asynchronous memory copy operations. Default: True"""

TIR_ENABLE_DEBUG = "tir.enable_debug"
"""Enable debug information in generated code. Default: False"""

TIR_MERGE_STATIC_SMEM = "tir.merge_static_smem"
"""Merge static shared memory allocations. Default: True"""

TIR_ADD_LOWER_PASS = "tir.add_lower_pass"
"""Additional lowering passes to be applied. Default: None"""

TIR_NOALIAS = "tir.noalias"
"""Enable pointer non-aliasing assumptions. Default: True"""

CUDA_KERNELS_OUTPUT_DIR = "cuda.kernels_output_dir"
"""Output directory for generated CUDA kernels. Default: empty string"""