Skip to content

Commit 6ac309c

Browse files
authored
[Feature]:Add device assert (tile-ai#1116)
* update * update
1 parent 37fa1e1 commit 6ac309c

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

src/tl_templates/cuda/debug.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,12 @@ __device__ void debug_print_buffer_value<int16_t>(const char *msg,
257257
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
258258
threadIdx.z, buf_name, index, (int32_t)var);
259259
}
260+
261+
TL_DEVICE void device_assert(bool cond) { assert(cond); }
262+
263+
TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
264+
if (!cond) {
265+
printf("Device assert failed: %s\n", msg);
266+
assert(0);
267+
}
268+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# type: ignore
2+
import tilelang
3+
import tilelang.testing
4+
import tilelang.language as T
5+
6+
7+
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
8+
# Please run manually when you want to verify that device_assert actually traps on GPU.
9+
def _manual_device_assert_triggered():
10+
11+
@T.prim_func
12+
def program():
13+
with T.Kernel(threads=128):
14+
tid = T.get_thread_binding()
15+
T.device_assert(tid > 0, "Assertion Trigger !")
16+
17+
jit_kernel = tilelang.compile(program, target="cuda")
18+
profiler = jit_kernel.get_profiler()
19+
profiler.run_once()
20+
21+
22+
def test_device_assert_no_trigger():
23+
24+
@T.prim_func
25+
def program():
26+
with T.Kernel(threads=128):
27+
tid = T.get_thread_binding()
28+
T.device_assert(tid == tid)
29+
30+
jit_kernel = tilelang.compile(program, target="cuda")
31+
profiler = jit_kernel.get_profiler()
32+
profiler.run_once()
33+
34+
35+
if __name__ == "__main__":
36+
_manual_device_assert_triggered()

tilelang/language/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
cumsum, # noqa: F401
6565
finalize_reducer, # noqa: F401
6666
)
67-
from .print import print # noqa: F401
67+
from .print import print, device_assert # noqa: F401
6868
from .customize import (
6969
atomic_max, # noqa: F401
7070
atomic_min, # noqa: F401

tilelang/language/print.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
This module provides macros and utilities for debugging TileLang (tl) programs.
3-
It includes functionality to print variables, print values in buffers, and conditionally execute debug prints.
3+
It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert.
44
"""
55

66
from tvm import tir
@@ -133,6 +133,27 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
133133
buffer[coords])
134134

135135

136+
from tilelang.utils.target import check_cuda_availability
137+
import warnings
138+
139+
_IS_CUDA_AVAILABLE = check_cuda_availability()
140+
141+
142+
@macro
143+
def device_assert(condition: tir.PrimExpr, msg: str = ""):
144+
"""
145+
Device-side assert emulation.
146+
Emits a device-side assert call on CUDA targets when CUDA is available.
147+
The assert is always enabled and cannot be disabled at runtime.
148+
"""
149+
if _IS_CUDA_AVAILABLE:
150+
if msg == "":
151+
tir.call_extern("void", "device_assert", condition)
152+
else:
153+
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
154+
tir.call_extern("void", "device_assert_with_msg", condition, msg)
155+
156+
136157
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
137158
"""
138159
A generic print function that handles both TIR buffers and primitive expressions.

0 commit comments

Comments
 (0)