Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,5 +342,23 @@ def swap_idx(A: T.Tensor[(2,), T.float32]):
torch.testing.assert_close(data, ref)


def test_while_loop():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_while_loop(A: T.Tensor((1,), T.int32)):
with T.Kernel(1) as _:
i = T.alloc_var(T.int32, 0)
sum = T.alloc_var(T.int32)
while i < 10:
sum += i
i += 1
A[0] = sum

ker = test_while_loop()
A = ker()
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"


if __name__ == '__main__':
tilelang.testing.main()
106 changes: 106 additions & 0 deletions tilelang/language/tir/ir.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import TypeVar, Literal
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm

_T = TypeVar('_T')

def abs(x: _T, span: Span | None=None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
def asin(x: _T) -> _T: ...
def asinh(x: _T) -> _T: ...
def atan(x: _T) -> _T: ...
def atan2(x1: _T, x2: _T) -> _T: ...
def atanh(x: _T) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
def ceil(x: _T, span: Span | None=None) -> _T: ...
def clz(x: _T) -> _T: ...
def copysign(x1: _T, x2: _T) -> _T: ...
def cos(x: _T) -> _T: ...
def cosh(x: _T) -> _T: ...
def erf(x: _T) -> _T: ...
def exp(x: _T) -> _T: ...
def exp2(x: _T) -> _T: ...
def exp10(x: _T) -> _T: ...
def floor(x: _T, span: Span | None=None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def fmod(x: _T, y: _T) -> _T: ...
def hypot(x1: _T, x2: _T) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
def isfinite(x: _T, span: Span | None=None) -> _T: ...
def isinf(x: _T, span: Span | None=None) -> _T: ...
def isnan(x: _T, span: Span | None=None) -> _T: ...
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
def ldexp(x1: _T, x2: _T) -> _T: ...
def likely(cond: _T, span: Span | None=None) -> _T: ...
def log(x: _T) -> _T: ...
def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
def popcount(x: _T) -> _T: ...
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
def ret(val: _T) -> _T: ...
def round(x: _T, span: Span | None=None) -> _T: ...
def rsqrt(x: _T) -> _T: ...
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
def sigmoid(x: _T) -> _T: ...
def sin(x: _T) -> _T: ...
def sinh(x: _T) -> _T: ...
def sqrt(x: _T) -> _T: ...
def tan(x: _T) -> _T: ...
def tanh(x: _T) -> _T: ...
def trunc(x: _T, span: Span | None=None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
def tvm_throw_last_error() -> _T: ...
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
def tvm_stack_make_shape(*args) -> _T: ...
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
def call_packed(*args, span=None) -> _T: ...
def call_cpacked(*args, span=None) -> _T: ...
def call_packed_lowered(*args, span=None) -> _T: ...
def call_cpacked_lowered(*args, span=None) -> _T: ...
def tvm_tuple(*value) -> _T: ...
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
def tvm_thread_invariant(cond: _T) -> _T: ...
def tvm_thread_allreduce(*freduce_args) -> _T: ...
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def ptx_wait_group(num: int) -> PrimExpr: ...
def ptx_commit_group() -> _T: ...
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ...
def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
def create_barriers(barrier_count: int) -> PrimExpr: ...
def assume(cond: _T=None) -> _T: ...
def undef() -> _T: ...
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
def start_profile_intrinsic(id: int) -> PrimExpr: ...
def end_profile_intrinsic(id: int) -> PrimExpr: ...
def anylist_getitem(list_handle, index) -> PrimExpr: ...
def anylist_resetitem(list_handle, index) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ...
def vscale() -> _T: ...
1 change: 1 addition & 0 deletions tilelang/language/v2/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign):
return self._emit_assign_target(node.target, rval, annot=node.annotation)

def visit_While(self, node):
node = self.generic_visit(node)
return quote1(
"for _ in __tb.ctx_while(lambda: cond):\n pass",
cond=node.test,
Expand Down
17 changes: 16 additions & 1 deletion tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,22 @@ def ctx_break(self):

def ctx_while(self, cond):
self.check_continue_break()
raise RuntimeError("while loops are not supported in TileLang builder")
cond_v = cond()
cond_v_unwrap = unwrap_cond(cond_v)
if not isinstance(cond_v_unwrap, PrimExpr):
if cond_v_unwrap:
raise RuntimeError(
f'Infinite while loop detected in TileLang\n'
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n'
)
else:
logger.warning(
'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
stack_info=True,
stacklevel=2)
with self.with_frame(tir.While(cond_v_unwrap)):
yield None

Comment on lines 293 to 311
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Fix logging call in ctx_while to avoid runtime errors

The new ctx_while logic overall makes sense (evaluate cond() once to obtain a PrimExpr, detect obviously infinite loops, and then build a tir.While frame), but the warning branch for constant‑false conditions has a subtle bug:

logger.warning(
    'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
    f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
    stack_info=True,
    stacklevel=2)

logging treats the second positional argument as *args for msg % args formatting. Since the format string contains no % placeholders, this will raise TypeError: not all arguments converted during string formatting when this branch is hit.

You can fix this by folding the condition detail into the message string and passing only keyword arguments:

-        else:
-                logger.warning(
-                    'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
-                    f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
-                    stack_info=True,
-                    stacklevel=2)
+        else:
+            logger.warning(
+                'While loop with constant false condition detected in TileLang; '
+                f'the loop body will never be executed.\n'
+                f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
+                stack_info=True,
+                stacklevel=2,
+            )

Everything else in ctx_while looks consistent with the existing builder design (use of unwrap_cond, check_continue_break, and tir.While).


I'll verify the review comment by examining the actual code and confirming the logging bug.

# Verify Python logging behavior with extra positional args
import logging

# Test the exact scenario from the review
test_logger = logging.getLogger("test")
handler = logging.StreamHandler()
test_logger.addHandler(handler)
test_logger.setLevel(logging.WARNING)

# This should demonstrate the bug
try:
    test_logger.warning(
        'First message\n',
        f'Second message\n',
        stack_info=True,
        stacklevel=2
    )
    print("No error occurred - bug description may be incorrect")
except TypeError as e:
    print(f"TypeError caught as expected: {e}")

🏁 Script executed:

#!/bin/bash
# Extract the ctx_while method to verify the exact code
fd -t f "builder.py" | head -1 | xargs -I {} sed -n '293,311p' {}

Length of output: 870


🏁 Script executed:

#!/bin/bash
# Search for ctx_while in the file to understand full context
fd -t f "builder.py" | head -1 | xargs -I {} grep -n "def ctx_while" {}

Length of output: 34


🏁 Script executed:

#!/bin/bash
# Get a broader view of the method and surrounding code
fd -t f "builder.py" | head -1 | xargs -I {} sed -n '290,315p' {}

Length of output: 1129


I need to verify the Python logging behavior. Let me run the test:

#!/usr/bin/env python3
# Verify Python logging behavior with extra positional args
import logging
import sys

# Set up logging to see the actual error
logging.basicConfig(level=logging.WARNING, stream=sys.stdout)
test_logger = logging.getLogger("test")

# Simulate the exact scenario from the code
cond_v = "some_condition"
cond_v_unwrap = False

try:
    test_logger.warning(
        'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
        f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
        stack_info=True,
        stacklevel=2
    )
    print("SUCCESS: No error occurred - the logging call worked")
except TypeError as e:
    print(f"CONFIRMED BUG: TypeError caught: {e}")
except Exception as e:
    print(f"OTHER ERROR: {type(e).__name__}: {e}")

Fix logging call in ctx_while to avoid runtime errors

The code has been verified and the issue is confirmed. The logger.warning() call passes two positional string arguments:

logger.warning(
    'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
    f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
    stack_info=True,
    stacklevel=2)

Python's logging module treats extra positional arguments as formatting arguments (via msg % args). Since the first message string contains no % placeholders, this will raise TypeError: not all arguments converted during string formatting when the constant-false condition branch is executed.

The suggested fix combines the messages into a single string and uses only keyword arguments:

            else:
-                logger.warning(
-                    'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
-                    f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
-                    stack_info=True,
-                    stacklevel=2)
+            logger.warning(
+                'While loop with constant false condition detected in TileLang; '
+                f'the loop body will never be executed.\n'
+                f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
+                stack_info=True,
+                stacklevel=2,
+            )
🧰 Tools
🪛 Ruff (0.14.4)

299-302: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 293 to 311, the logger.warning
call passes two positional string arguments which the logging module interprets
as format args and can raise a TypeError; combine the two message parts into a
single formatted string (e.g. one f-string that includes the condition and
types) and call logger.warning with that single message and the existing keyword
arguments stack_info=True and stacklevel=2 so no extra positional args are
passed.

def bind(self, name, value, annot=BaseBuilder.empty):
self.check_continue_break()
Expand Down
Loading