Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/tl_templates/cuda/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,12 @@ __device__ void debug_print_buffer_value<int16_t>(const char *msg,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (int32_t)var);
}

TL_DEVICE void device_assert(bool cond) { assert(cond); }

TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
if (!cond) {
printf("Device assert failed: %s\n", msg);
assert(0);
}
}
36 changes: 36 additions & 0 deletions testing/python/debug/test_device_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# type: ignore
import tilelang
import tilelang.testing
import tilelang.language as T


# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU.
def _manual_device_assert_triggered():

@T.prim_func
def program():
with T.Kernel(threads=128):
tid = T.get_thread_binding()
T.device_assert(tid > 0, "Assertion Trigger !")

jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()


def test_device_assert_no_trigger():

@T.prim_func
def program():
with T.Kernel(threads=128):
tid = T.get_thread_binding()
T.device_assert(tid == tid)

jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()


if __name__ == "__main__":
_manual_device_assert_triggered()
2 changes: 1 addition & 1 deletion tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
cumsum, # noqa: F401
finalize_reducer, # noqa: F401
)
from .print import print # noqa: F401
from .print import print, device_assert # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Export OK; drop unused noqa to satisfy Ruff (RUF100).

Ruff flags # noqa: F401 here as unused. Either remove it or enable the rule in config. Minimal fix below.

-from .print import print, device_assert  # noqa: F401
+from .print import print, device_assert
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .print import print, device_assert # noqa: F401
from .print import print, device_assert
🧰 Tools
🪛 Ruff (0.14.1)

67-67: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In tilelang/language/__init__.py around line 67, the trailing comment "# noqa:
F401" is flagged by Ruff as unused; remove the unnecessary noqa from the import
line so the export remains but the linter warning is resolved (alternatively, if
intentional, enable the rule in the Ruff config), ensuring the import still
exposes print and device_assert.

from .customize import (
atomic_max, # noqa: F401
atomic_min, # noqa: F401
Expand Down
23 changes: 22 additions & 1 deletion tilelang/language/print.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This module provides macros and utilities for debugging TileLang (tl) programs.
It includes functionality to print variables, print values in buffers, and conditionally execute debug prints.
It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert.
"""

from tvm import tir
Expand Down Expand Up @@ -133,6 +133,27 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
buffer[coords])


from tilelang.utils.target import check_cuda_availability
import warnings

_IS_CUDA_AVAILABLE = check_cuda_availability()


@macro
def device_assert(condition: tir.PrimExpr, msg: str = ""):
"""
Device-side assert emulation.
Emits a device-side assert call on CUDA targets when CUDA is available.
The assert is always enabled and cannot be disabled at runtime.
"""
if _IS_CUDA_AVAILABLE:
if msg == "":
tir.call_extern("void", "device_assert", condition)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
tir.call_extern("void", "device_assert_with_msg", condition, msg)


def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.
Expand Down
Loading