Skip to content

Commit 2ea45ae

Browse files
committed
[Bugfix] Fix layout inference for free fragment buffer (tile-ai#443)
* [Enhancement] Improve layout inference accuracy in ParallelOp (tile-ai#441) * Added logic to use non-replicated buffers as source buffers for more accurate layout inference. * Enhanced comments to clarify the rationale behind buffer selection in layout inference process. * [Enhancement] Add error handling macros and refactor loop partitioning logic * Introduced TILELANG_CHECK macro for improved error handling in CUDA and HIP code, providing detailed error messages for kernel launches. * Enhanced loop partitioning logic to handle fragment buffers more effectively, ensuring correct replication based on thread extent. * Added logging for thread range in PlanLoopPartition to aid in debugging and performance analysis. * Updated pass configuration management to streamline vectorization control in the optimization process. * lint fix * remove debug print
1 parent 734c7fb commit 2ea45ae

File tree

8 files changed

+141
-16
lines changed

8 files changed

+141
-16
lines changed

src/tl_templates/cuda/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ using int4_t = int4;
2525
#define TL_DEVICE_NOINLINE __noinline__ __device__
2626
#define TL_PATCH
2727

28+
#define TILELANG_CHECK(stmt) \
29+
do { \
30+
cudaError_t __err = (stmt); \
31+
if (__err != cudaSuccess) { \
32+
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
33+
__LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \
34+
return -1; \
35+
} \
36+
} while (0)
37+
2838
// abs function for bfloat_t and half_t since there is no implicit convertion
2939
// method
3040
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {

src/tl_templates/hip/common.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
#define ushort unsigned short
2525

2626
#define TL_DEVICE __forceinline__ __device__
27+
#define TL_DEVICE_NOINLINE __noinline__ __device__
28+
29+
#define TILELANG_CHECK(stmt) \
30+
do { \
31+
hipError_t __err = (stmt); \
32+
if (__err != hipSuccess) { \
33+
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
34+
__LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \
35+
return -1; \
36+
} \
37+
} while (0)
2738

2839
#define half _Float16
2940
#define __float2half_rn(x) half(x)

src/transform/loop_partition.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,31 @@ class LoopPartitioner : public StmtExprVisitor {
141141
PrimExpr thd = FloorMod(access_idx, num_thread);
142142
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
143143
FloorMod(flattened, vectorize_size);
144-
return Fragment(loop_vars_, {idx}, {thd}, {});
144+
auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
145+
if (has_fragment_) {
146+
// for fragment buffer, we don't need to replicate the loop layout
147+
auto thread_extent = *as_const_int(fragment->ThreadExtent());
148+
auto num_thread_fragment = num_thread / thread_extent;
149+
fragment = fragment->Replicate(num_thread_fragment);
150+
}
151+
return fragment;
145152
}
146153

147154
private:
155+
void VisitExpr_(const BufferLoadNode *op) final {
156+
if (op->buffer.scope() == "local.fragment") {
157+
has_fragment_ = true;
158+
}
159+
StmtExprVisitor::VisitExpr_(op);
160+
}
161+
162+
void VisitStmt_(const BufferStoreNode *op) final {
163+
if (op->buffer.scope() == "local.fragment") {
164+
has_fragment_ = true;
165+
}
166+
StmtExprVisitor::VisitStmt_(op);
167+
}
168+
148169
void VisitStmt_(const ForNode *node) final {
149170
if (node->kind == ForKind::kParallel) {
150171
body_ = node->body;
@@ -157,6 +178,7 @@ class LoopPartitioner : public StmtExprVisitor {
157178

158179
Stmt body_;
159180
PrimExpr flattened = 0;
181+
bool has_fragment_ = false;
160182
Array<IterVar> loop_vars_;
161183
};
162184

tilelang/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _load_tile_lang_lib():
9696
language, # noqa: F401
9797
engine, # noqa: F401
9898
)
99+
from .transform import PassConfigKey # noqa: F401
99100

100101
from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
101102

tilelang/engine/phase.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
from tvm import tir, IRModule
22
from tvm.target import Target
33
import tilelang
4+
from tilelang.transform import PassContext
5+
from typing import Optional
46

57

6-
def allow_tma_and_warp_specialized(target: Target) -> bool:
8+
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
9+
target: Optional[Target] = None) -> bool:
10+
if pass_ctx is None:
11+
pass_ctx = tilelang.transform.get_pass_context()
712
if target.arch not in {"sm_90"}:
813
return False
9-
cur_pass_ctx = tilelang.transform.get_pass_context()
10-
disable_tma_lower = cur_pass_ctx.config.get("tl.disable_tma_lower", False)
11-
disable_warp_specialized = cur_pass_ctx.config.get("tl.disable_warp_specialized", False)
14+
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
15+
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
16+
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
1217
return not (disable_tma_lower and disable_warp_specialized)
1318

1419

20+
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
21+
if pass_ctx is None:
22+
pass_ctx = tilelang.transform.get_pass_context()
23+
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
24+
return not disable_vectorize
25+
26+
1527
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
1628
# Bind the target device information to the module
1729
mod = tir.transform.BindTarget(target)(mod)
@@ -38,8 +50,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
3850

3951

4052
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
53+
pass_ctx = tilelang.transform.get_pass_context()
4154
# which may be introduced by the LegalizeSafeMemoryAccess
42-
if allow_tma_and_warp_specialized(target):
55+
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
4356
mod = tilelang.transform.IfStmtBinding()(mod)
4457
mod = tilelang.transform.MultiVersionBuffer()(mod)
4558
mod = tilelang.transform.WarpSpecialized()(mod)
@@ -57,11 +70,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
5770
mod = tilelang.transform.PipelinePlanning()(mod)
5871
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
5972
mod = tilelang.transform.MergeIfStmt()(mod)
73+
6074
mod = tir.transform.LowerOpaqueBlock()(mod)
6175
mod = tilelang.transform.FlattenBuffer()(mod)
6276
mod = tir.transform.NarrowDataType(32)(mod)
6377
mod = tir.transform.Simplify()(mod)
64-
mod = tilelang.transform.VectorizeLoop()(mod)
78+
79+
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
80+
6581
mod = tir.transform.StorageRewrite()(mod)
6682
mod = tir.transform.UnrollLoop()(mod)
6783
mod = tir.transform.RenormalizeSplitPattern()(mod)

tilelang/jit/adapter/wrapper.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
PREDEF_HOST_FUNC = """
4343
extern "C" int call({}) {{
4444
{}
45-
return 0;
45+
\treturn 0;
4646
}}
4747
"""
4848

@@ -193,7 +193,7 @@ def legalize_c(p):
193193
p = int(p)
194194
return str(p).replace("//", "/")
195195

196-
_call_str = """"""
196+
kernel_launch_code = """"""
197197
desc_name_map: Dict[str, str] = {}
198198
for function_name, function_info in function_informations.items():
199199
block_info = function_info["block_info"]
@@ -218,14 +218,18 @@ def legalize_c(p):
218218
grid_str = "dim3({}, {}, {})".format(
219219
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
220220
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
221-
_call_str += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(function_name, grid_str,
222-
block_str, smem_str,
223-
call_args)
221+
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
222+
function_name, grid_str, block_str, smem_str, call_args)
223+
kernel_launch_code += "\tcudaError_t err = cudaGetLastError();\n"
224+
kernel_launch_code += "\tif (err != cudaSuccess) {{\n"
225+
kernel_launch_code += f"\t\tsnprintf(error_buf, ERROR_BUF_SIZE, \"{function_name}: %s - %s\", cudaGetErrorName(err), cudaGetErrorString(err));\n"
226+
kernel_launch_code += "\t\treturn -1;\n"
227+
kernel_launch_code += "\t}}\n"
224228

225-
_call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str
229+
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
226230

227231
# Wrap the kernel dispatch logic in an external C function
228-
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
232+
host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
229233
return host_func
230234

231235
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:

tilelang/transform/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
from . import _ffi_api
55
from .simplify import Simplify, simplify_prim_func # noqa: F401
6+
from .pass_config import PassConfigKey # noqa: F401
7+
from tilelang import tvm as tvm # noqa: F401
8+
from tvm.ir.transform import PassContext # noqa: F401
69

710

811
def get_pass_context():
912
"""Get the current pass context"""
10-
from tilelang import tvm as tvm
11-
return tvm.transform.PassContext.current()
13+
return PassContext.current()
1214

1315

1416
def ClusterPlanning():

tilelang/transform/pass_config.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# TODO: Add more documentation for each pass config
2+
3+
from enum import Enum
4+
5+
6+
class PassConfigKey(str, Enum):
7+
"""Pass configuration keys for TileLang compiler."""
8+
# TileLang specific configs
9+
TL_SIMPLIFY = "tl.Simplify"
10+
"""Enable/disable TileLang simplification passes. Default: True"""
11+
12+
TL_DYNAMIC_ALIGNMENT = "tl.dynamic_alignment"
13+
"""Memory alignment requirement for dynamic shapes. Default: 16"""
14+
15+
TL_DISABLE_DYNAMIC_TAIL_SPLIT = "tl.disable_dynamic_tail_split"
16+
"""Disable dynamic tail splitting optimization. Default: False"""
17+
18+
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
19+
"""Disable warp specialization optimization. Default: False"""
20+
21+
TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
22+
"""Bitwidth for configuration indices. Default: 32"""
23+
24+
TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower"
25+
"""Disable TMA (Tensor Memory Access) lowering. Default: False"""
26+
27+
# TIR related configs
28+
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
29+
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
30+
31+
TIR_DISABLE_CSE = "tir.disable_cse_tir"
32+
"""Disable TIR Common Subexpression Elimination. Default: False"""
33+
34+
TIR_SIMPLIFY = "tir.Simplify"
35+
"""Enable/disable TIR simplification passes. Default: True"""
36+
37+
TIR_DISABLE_STORAGE_REWRITE = "tir.disable_storage_rewrite"
38+
"""Disable storage rewrite optimization. Default: False"""
39+
40+
TIR_DISABLE_VECTORIZE = "tir.disable_vectorize"
41+
"""Disable vectorization optimization. Default: False"""
42+
43+
TIR_USE_ASYNC_COPY = "tir.use_async_copy"
44+
"""Enable asynchronous memory copy operations. Default: True"""
45+
46+
TIR_ENABLE_DEBUG = "tir.enable_debug"
47+
"""Enable debug information in generated code. Default: False"""
48+
49+
TIR_MERGE_STATIC_SMEM = "tir.merge_static_smem"
50+
"""Merge static shared memory allocations. Default: True"""
51+
52+
TIR_ADD_LOWER_PASS = "tir.add_lower_pass"
53+
"""Additional lowering passes to be applied. Default: None"""
54+
55+
TIR_NOALIAS = "tir.noalias"
56+
"""Enable pointer non-aliasing assumptions. Default: True"""
57+
58+
CUDA_KERNELS_OUTPUT_DIR = "cuda.kernels_output_dir"
59+
"""Output directory for generated CUDA kernels. Default: empty string"""

0 commit comments

Comments
 (0)