Skip to content

Commit b37d035

Browse files
committed
[Feature]:Add device assert
1 parent 50e789d commit b37d035

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

src/tl_templates/cuda/debug.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,10 @@ __device__ void debug_print_buffer_value<int16_t>(const char *msg,
257257
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
258258
threadIdx.z, buf_name, index, (int32_t)var);
259259
}
260+
261+
__device__ void device_assert(bool cond, const char *msg) {
262+
if (!cond) {
263+
printf("Device assert failed: %s\n", msg);
264+
assert(0);
265+
}
266+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# type: ignore
2+
3+
import tilelang
4+
import tilelang.testing
5+
import tilelang.language as T
6+
import os
7+
8+
os.environ["TILELANG_DEBUG"] = "1"
9+
10+
11+
def test_device_assert():
12+
13+
@T.prim_func
14+
def program():
15+
with T.Kernel(threads=128):
16+
tid = T.get_thread_binding()
17+
T.device_assert(tid != 0, "Assertion Trigger !")
18+
19+
jit_kernel = tilelang.compile(program, target="cuda")
20+
print(jit_kernel.kernel_source)
21+
profiler = jit_kernel.get_profiler()
22+
profiler.run_once()
23+
24+
25+
if __name__ == "__main__":
26+
27+
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
cumsum, # noqa: F401
6565
finalize_reducer, # noqa: F401
6666
)
67-
from .print import print # noqa: F401
67+
from .print import print, device_assert # noqa: F401
6868
from .customize import (
6969
atomic_max, # noqa: F401
7070
atomic_min, # noqa: F401

tilelang/language/print.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
22
This module provides macros and utilities for debugging TileLang (tl) programs.
3-
It includes functionality to print variables, print values in buffers, and conditionally execute debug prints.
3+
It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert.
44
"""
55

6+
import os
67
from tvm import tir
78
from typing import Any
89
from tilelang.language.kernel import get_thread_bindings
@@ -133,6 +134,18 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
133134
buffer[coords])
134135

135136

137+
@macro
138+
def device_assert(condition: tir.PrimExpr, msg: str = ""):
139+
"""
140+
Device-side assert emulation.
141+
In debug mode (TILELANG_DEBUG=1), performs Python-side conditional check
142+
and injects a printf + abort on the device if desired.
143+
"""
144+
debug_mode = os.environ.get("TILELANG_DEBUG", "0") not in ("0", "", "false", "False")
145+
if debug_mode:
146+
tir.call_extern("void", "device_assert", condition, msg)
147+
148+
136149
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
137150
"""
138151
A generic print function that handles both TIR buffers and primitive expressions.

0 commit comments

Comments
 (0)