-
Notifications
You must be signed in to change notification settings - Fork 331
[Feature]:Add device assert #1116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdded CUDA device-side assertion helpers in the runtime templates, exposed a Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Source as TileLang source
participant Macro as device_assert macro
participant Compiler as TIR compiler
participant CUDA as CUDA toolchain/runtime
participant Device as CUDA device
Source->>Macro: device_assert(condition, msg)
Macro->>Macro: check TILELANG_DEBUG && CUDA available
alt debug && cuda available
Macro->>Compiler: emit tir.call_extern("device_assert..." , condition, msg)
Compiler->>CUDA: compile & launch kernel
CUDA->>Device: execute kernel
Device->>Device: device_assert(cond,msg)
alt cond false
Note right of Device: print msg (if provided)\nand trigger assert(0)
Device-->>CUDA: kernel failure/report
else cond true
Device-->>CUDA: continue normally
end
else skip emission
Macro-->>Source: no extern emitted
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/language/print.py (1)
60-62: Undefined variables used in else-branch.
iandcoordsare out of scope here; this will fail IR generation whenconditionis False. Remove the else-branch or guard properly.- else: - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + # else: no-op when condition is Falsesrc/tl_templates/cuda/debug.h (1)
6-8: Include <assert.h> for device assert.Be explicit about the assert dependency to avoid build surprises across NVCC/NVRTC versions.
#ifndef __CUDACC_RTC__ #include <cstdio> #endif +#include <assert.h>
🧹 Nitpick comments (1)
src/tl_templates/cuda/debug.h (1)
261-266: Consider C linkage to match extern call; keep device abort.If TVM emits a plain
device_assert(...)call without a C++ prototype in scope, C++ name mangling can bite. Wrap inextern "C"to provide a stable symbol.-__device__ void device_assert(bool cond, const char *msg) { +extern "C" __device__ void device_assert(bool cond, const char *msg) { if (!cond) { printf("Device assert failed: %s\n", msg); assert(0); } }If your kernel prologue already includes this header and provides a matching prototype, keeping C++ linkage is fine—please confirm.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/tl_templates/cuda/debug.h(1 hunks)testing/python/debug/test_device_assert.py(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/print.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/tl_templates/cuda/debug.h (1)
tilelang/language/print.py (1)
device_assert(138-146)
testing/python/debug/test_device_assert.py (3)
tilelang/language/print.py (1)
device_assert(138-146)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/jit/kernel.py (2)
kernel_source(461-462)get_profiler(367-383)
tilelang/language/print.py (1)
tilelang/language/tir/op.py (1)
call_extern(173-195)
tilelang/language/__init__.py (1)
tilelang/language/print.py (1)
device_assert(138-146)
🪛 Ruff (0.14.1)
tilelang/language/__init__.py
67-67: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (1)
tilelang/language/__init__.py (1)
69-81: Verify new re-exports resolve at import time.Double‑check all of these names exist in
tilelang.language.customizeacross targets to avoidImportErroron import.
| finalize_reducer, # noqa: F401 | ||
| ) | ||
| from .print import print # noqa: F401 | ||
| from .print import print, device_assert # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Export OK; drop unused noqa to satisfy Ruff (RUF100).
Ruff flags # noqa: F401 here as unused. Either remove it or enable the rule in config. Minimal fix below.
-from .print import print, device_assert # noqa: F401
+from .print import print, device_assert📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .print import print, device_assert # noqa: F401 | |
| from .print import print, device_assert |
🧰 Tools
🪛 Ruff (0.14.1)
67-67: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In tilelang/language/__init__.py around line 67, the trailing comment "# noqa:
F401" is flagged by Ruff as unused; remove the unnecessary noqa from the import
line so the export remains but the linter warning is resolved (alternatively, if
intentional, enable the rule in the Ruff config), ensuring the import still
exposes print and device_assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tilelang/language/print.py (3)
1-4: Module docstring updated appropriately.The addition of "assert" functionality to the docstring is correct. Minor suggestion: consider "assertions" for consistency with "print variables" and "print values" phrasing.
137-139: Move imports to top of file.While the CUDA availability check pattern is correct, the import should be placed at the top of the file with other imports (after line 11) rather than mid-file. This follows Python conventions and matches the suggestion from past reviews.
Apply this diff to move the imports to the top:
from tilelang.language.kernel import get_thread_bindings from tilelang.language import copy, macro, serial, alloc_shared from tilelang.language.utils import index_to_coordinates +from tilelang.utils.target import check_cuda_availability + +_IS_CUDA_AVAILABLE = check_cuda_availability()And remove from lines 137-139:
-from tilelang.utils.target import check_cuda_availability - -_IS_CUDA_AVAILABLE = check_cuda_availability() -
142-151: Implementation correctly addresses past review concerns.The
device_assertmacro properly:
- Gates the extern call to CUDA-available targets using
_IS_CUDA_AVAILABLE- Checks debug mode via
TILELANG_DEBUGenvironment variable- Uses the established availability pattern from the codebase
The logic correctly prevents undefined extern symbols on non-CUDA backends, resolving the critical issue raised in previous reviews.
Optional: The environment variable check is case-sensitive for the string values. Consider using
.lower()for more robust parsing:debug_mode = os.environ.get("TILELANG_DEBUG", "0").lower() not in ("0", "", "false")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
testing/python/debug/test_device_assert.py(1 hunks)tilelang/language/print.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- testing/python/debug/test_device_assert.py
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/print.py (2)
tilelang/utils/target.py (1)
check_cuda_availability(28-38)tilelang/language/tir/op.py (1)
call_extern(173-195)
🔇 Additional comments (1)
tilelang/language/print.py (1)
6-6: LGTM: Import placed correctly.The
osimport is properly positioned and necessary for the environment variable check indevice_assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
testing/python/debug/test_device_assert.py (1)
34-34: Fix misleading assertion message.The message "Assertion Trigger !" is misleading since the condition
tid == tidis always true and will never trigger the assertion. Use a message that reflects the expected behavior.Apply this diff:
- T.device_assert(tid == tid, "Assertion Trigger !") + T.device_assert(tid == tid, "Should not trigger")
🧹 Nitpick comments (2)
testing/python/debug/test_device_assert.py (2)
10-21: Consider cleaning up environment variable for robustness.While this manual test function is clearly marked as not for CI, it would be more robust to restore
TILELANG_DEBUGin a finally block to avoid state leakage if this function is accidentally invoked.Apply this diff to add cleanup:
def _manual_device_assert_triggered(): - os.environ["TILELANG_DEBUG"] = "1" + prev = os.environ.get("TILELANG_DEBUG") + os.environ["TILELANG_DEBUG"] = "1" - @T.prim_func - def program(): - with T.Kernel(threads=128): - tid = T.get_thread_binding() - T.device_assert(tid > 0, "Assertion Trigger !") - - jit_kernel = tilelang.compile(program, target="cuda") - profiler = jit_kernel.get_profiler() - profiler.run_once() + try: + @T.prim_func + def program(): + with T.Kernel(threads=128): + tid = T.get_thread_binding() + T.device_assert(tid > 0, "Assertion Trigger !") + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + finally: + if prev is None: + os.environ.pop("TILELANG_DEBUG", None) + else: + os.environ["TILELANG_DEBUG"] = prev
24-43: Consider adding a test that verifies assertion failure handling.Currently, only the non-triggering case is tested. Consider adding a test that deliberately triggers an assertion and verifies the expected failure behavior (e.g., checking that an exception is raised or the kernel aborts as expected), guarded by a CUDA availability check.
Example structure:
def test_device_assert_trigger(): # Only run if CUDA is available and we can capture the failure if not is_cuda_available(): pytest.skip("CUDA not available") prev = os.environ.get("TILELANG_DEBUG") os.environ["TILELANG_DEBUG"] = "1" try: @T.prim_func def program(): with T.Kernel(threads=128): tid = T.get_thread_binding() T.device_assert(tid < 0, "Expected to trigger") jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() # Expect this to raise an exception with pytest.raises(Exception): profiler.run_once() finally: if prev is None: os.environ.pop("TILELANG_DEBUG", None) else: os.environ["TILELANG_DEBUG"] = prev
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/debug/test_device_assert.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/debug/test_device_assert.py (3)
tilelang/language/print.py (1)
device_assert(143-150)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/jit/kernel.py (1)
get_profiler(367-383)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (1)
testing/python/debug/test_device_assert.py (1)
24-43: Verify CUDA availability before testing device assertions.The test doesn't check whether CUDA is available before running. Since
device_assertonly emits the actual device-side call when_IS_CUDA_AVAILABLEis true (from the implementation intilelang/language/print.py), this test silently skips the assertion logic if CUDA isn't available on the system, reducing test coverage.Consider adding a CUDA availability check (following the pattern used in other test files like
test_tilelang_language_parallel.py) or verifying that the compiled kernel source contains the expected assertion code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
testing/python/debug/test_device_assert.py (2)
10-21: Consider restoring the environment variable for consistency.While this is a manual test (not run in CI), the environment variable is not restored after use. For consistency and to prevent state leakage if run in a test suite, consider wrapping the logic in a try/finally block similar to
test_device_assert_no_trigger().Apply this diff to add cleanup:
def _manual_device_assert_triggered(): - os.environ["TILELANG_DEBUG"] = "1" + prev = os.environ.get("TILELANG_DEBUG") + os.environ["TILELANG_DEBUG"] = "1" - @T.prim_func - def program(): - with T.Kernel(threads=128): - tid = T.get_thread_binding() - T.device_assert(tid > 0, "Assertion Trigger !") - - jit_kernel = tilelang.compile(program, target="cuda") - profiler = jit_kernel.get_profiler() - profiler.run_once() + try: + @T.prim_func + def program(): + with T.Kernel(threads=128): + tid = T.get_thread_binding() + T.device_assert(tid > 0, "Assertion Trigger !") + + jit_kernel = tilelang.compile(program, target="cuda") + profiler = jit_kernel.get_profiler() + profiler.run_once() + finally: + if prev is None: + os.environ.pop("TILELANG_DEBUG", None) + else: + os.environ["TILELANG_DEBUG"] = prev
24-43: Consider expanding test coverage.The current test only covers the scenario where
TILELANG_DEBUG="1"and the assertion doesn't trigger. Consider adding tests for:
- Debug mode disabled: Verify that
device_assertis a no-op whenTILELANG_DEBUGis unset or "0"- Message parameter: Test the assertion with a non-empty message to exercise the
device_assert_with_msgpathExample tests:
def test_device_assert_debug_disabled(): """Verify device_assert is no-op when debug is disabled.""" if not tilelang.is_cuda_available(): pytest.skip("CUDA not available") prev = os.environ.get("TILELANG_DEBUG") os.environ["TILELANG_DEBUG"] = "0" try: @T.prim_func def program(): with T.Kernel(threads=128): tid = T.get_thread_binding() # Even a false condition should be no-op when debug is off T.device_assert(tid < 0, "This should not trigger") jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() profiler.run_once() # Should succeed without assertion finally: if prev is None: os.environ.pop("TILELANG_DEBUG", None) else: os.environ["TILELANG_DEBUG"] = prev def test_device_assert_with_message(): """Verify device_assert works with a message.""" if not tilelang.is_cuda_available(): pytest.skip("CUDA not available") prev = os.environ.get("TILELANG_DEBUG") os.environ["TILELANG_DEBUG"] = "1" try: @T.prim_func def program(): with T.Kernel(threads=128): tid = T.get_thread_binding() T.device_assert(tid >= 0, "Thread ID should be non-negative") jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() profiler.run_once() finally: if prev is None: os.environ.pop("TILELANG_DEBUG", None) else: os.environ["TILELANG_DEBUG"] = prev
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/tl_templates/cuda/debug.h(1 hunks)testing/python/debug/test_device_assert.py(1 hunks)tilelang/language/print.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/tl_templates/cuda/debug.h
- tilelang/language/print.py
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/debug/test_device_assert.py (4)
tilelang/language/print.py (1)
device_assert(144-155)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/jit/kernel.py (1)
get_profiler(367-383)tilelang/env.py (1)
get(175-178)
🔇 Additional comments (3)
testing/python/debug/test_device_assert.py (3)
1-6: LGTM!The imports are appropriate for testing the device_assert feature.
46-47: LGTM!The main guard follows the standard TileLang testing pattern.
24-43: Remove the review comment — this test is intentionally designed to be run manually, not in CI.The file already documents this with a TODO comment: "It intentionally triggers a device-side assert so we can't include this in CI. Please run manually when you want to verify that device_assert actually traps on GPU."
Additionally, the review suggestion references
tilelang.is_cuda_available(), which doesn't exist as a public API. The correct pattern (used elsewhere) would betorch.cuda.is_available(), but this is unnecessary here since the test is not meant for automated CI runs.Likely an incorrect or invalid review comment.
src/tl_templates/cuda/debug.h
Outdated
| threadIdx.z, buf_name, index, (int32_t)var); | ||
| } | ||
|
|
||
| __device__ void device_assert(bool cond) { assert(cond); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we use TL_DEVICE?
| if prev is None: | ||
| os.environ.pop("TILELANG_DEBUG", None) | ||
| else: | ||
| os.environ["TILELANG_DEBUG"] = prev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need TILELANG_DEBUG and likely I didn't find any configs to exclude this file in ci?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tilelang/language/__init__.py (1)
67-67: Remove the unusednoqadirective.Static analysis correctly flags this
noqadirective as unnecessary. The import is valid without it.Apply this diff:
-from .print import print, device_assert # noqa: F401 +from .print import print, device_assert
🧹 Nitpick comments (1)
testing/python/debug/test_device_assert.py (1)
22-32: Add CUDA availability guard to prevent test failures on non-CUDA systems.The test will fail on systems without CUDA. Add a skip decorator to make the test conditional on CUDA availability.
Add this import at the top of the file:
import tilelang import tilelang.testing import tilelang.language as T +import pytest +from tilelang.utils.target import check_cuda_availabilityThen add a decorator to the test:
+@pytest.mark.skipif(not check_cuda_availability(), reason="CUDA not available") def test_device_assert_no_trigger():
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/tl_templates/cuda/debug.h(1 hunks)testing/python/debug/test_device_assert.py(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/print.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/language/print.py (2)
tilelang/utils/target.py (1)
check_cuda_availability(28-38)tilelang/language/tir/op.py (1)
call_extern(173-195)
src/tl_templates/cuda/debug.h (1)
tilelang/language/print.py (1)
device_assert(143-153)
testing/python/debug/test_device_assert.py (3)
tilelang/language/print.py (1)
device_assert(143-153)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/jit/kernel.py (1)
get_profiler(367-383)
tilelang/language/__init__.py (1)
tilelang/language/print.py (1)
device_assert(143-153)
🪛 Ruff (0.14.1)
tilelang/language/__init__.py
67-67: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (3)
src/tl_templates/cuda/debug.h (1)
261-268: LGTM! Device assertion helpers are correctly implemented.The implementation is clean and consistent with the file's existing patterns. Both functions use the
TL_DEVICEmacro appropriately, and the logic is straightforward:device_assertwraps the nativeassert, whiledevice_assert_with_msgprints a diagnostic message before triggering the assertion.tilelang/language/print.py (1)
136-141: LGTM! Follows established availability-check pattern.The imports and module-level
_IS_CUDA_AVAILABLEflag follow the same pattern used elsewhere in the codebase (e.g.,_IS_HIP_AVAILABLEinbuiltin.py).testing/python/debug/test_device_assert.py (1)
1-20: LGTM! Manual test function is appropriately documented.The
_manual_device_assert_triggered()function is correctly marked for manual testing only with a clear TODO comment explaining why it can't be included in CI. The underscore prefix prevents pytest from collecting it automatically.
* update * update
Summary by CodeRabbit
New Features
Tests
Documentation