Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transform/inject_assumes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "tvm/node/structural_hash.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/op.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h"
#include "tvm/tir/transform.h"
Expand Down Expand Up @@ -62,7 +63,8 @@ class AssumeInjector : public tvm::tir::StmtExprMutator {
Stmt build(Stmt body) {
auto analyzer = arith::Analyzer{};
for (const auto &e : items) {
auto simplified = analyzer.Simplify(GT(e.expr, 0));
auto simplified =
analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype)));
std::stringstream ss;
ss << "Buffer shape should be greater than 0: shape `" << e.expr
<< "` from buffer ";
Expand Down
66 changes: 66 additions & 0 deletions testing/python/language/test_tilelang_language_int64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import tilelang
import tilelang.language as T


@tilelang.jit
def fill_symbolic(value: float, dtype="bfloat16"):
n = T.symbolic("n", "int64")
block_n = 512

@T.prim_func
def main(x: T.Tensor[n, dtype]):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
# Doesn't yet work with int64-shaped global tensor
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
for i in T.Parallel(block_n):
x[bx * block_n + i] = value

return main


def run_fill_symbolic(n: int):
import torch

x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
fill_symbolic(1.0)(x)
assert x.min() == 1.0 and x.max() == 1.0


def test_fill_symbolic():
# Requires 8GB VRAM
run_fill_symbolic(2**32)


@tilelang.jit
def fill_static(n: int, value: float, dtype="bfloat16"):
block_n = 512

@T.prim_func
def main(x: T.Tensor[n, dtype]):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
# Doesn't yet work with int64-shaped global tensor
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
for i in T.Parallel(block_n):
x[bx * block_n + i] = value

return main


def run_fill_static(n: int):
import torch

x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
fill_static(n, 1.0)(x)
assert x.min() == 1.0 and x.max() == 1.0


def test_fill_static():
# Requires 8GB VRAM
run_fill_static(2**32)
Comment on lines +22 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prevent OOM and missing-CUDA crashes in the int64 fill tests

run_fill_symbolic / run_fill_static unconditionally create a torch.zeros buffer on CUDA and the test invokes them with n = 2**32. On hosts without a CUDA-capable GPU this raises immediately because torch.cuda.is_available() is false, and even when a GPU exists the allocation needs ~8.6 GiB for bfloat16 data, which will throw torch.cuda.OutOfMemoryError on the default CI machines long before the kernel is exercised.(debuglab.net)

Please gate these helpers on CUDA availability and skip when the requested tensor cannot fit into the active device before trying to allocate it. One option is:

@@
-import tilelang
-import tilelang.language as T
+import pytest
+import torch
+import tilelang
+import tilelang.language as T
@@
-def run_fill_symbolic(n: int):
-    import torch
-
-    x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
+def run_fill_symbolic(n: int):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA device required for int64 fill tests")
+    elem_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()
+    total_mem = torch.cuda.get_device_properties(0).total_memory
+    if n * elem_bytes > total_mem:
+        pytest.skip(
+            f"Requires ~{n * elem_bytes / (1 << 30):.1f} GiB,"
+            f" but only {total_mem / (1 << 30):.1f} GiB available"
+        )
+    device = torch.device("cuda")
+    x = torch.zeros(n, dtype=torch.bfloat16, device=device)
@@
-def run_fill_static(n: int):
-    import torch
-
-    x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
+def run_fill_static(n: int):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA device required for int64 fill tests")
+    elem_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()
+    total_mem = torch.cuda.get_device_properties(0).total_memory
+    if n * elem_bytes > total_mem:
+        pytest.skip(
+            f"Requires ~{n * elem_bytes / (1 << 30):.1f} GiB,"
+            f" but only {total_mem / (1 << 30):.1f} GiB available"
+        )
+    device = torch.device("cuda")
+    x = torch.zeros(n, dtype=torch.bfloat16, device=device)

This keeps the int64 coverage when the hardware can handle it and lets the suite pass everywhere else.

🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_int64.py lines 22-61: The helpers
unconditionally allocate a large CUDA tensor (n=2**32) which fails on machines
without CUDA or without enough free VRAM; modify run_fill_symbolic and
run_fill_static to first check torch.cuda.is_available() and skip (or return)
when CUDA is absent, then compute required_bytes = n * element_size (bfloat16 ->
2) and query the device free memory (torch.cuda.mem_get_info or
torch.cuda.get_device_properties/free mem API) and skip when required_bytes >
free_bytes; perform the allocation only after these checks and use pytest.skip
with a clear message so tests are gated safely.



if __name__ == "__main__":
test_fill_symbolic()
test_fill_static()
4 changes: 2 additions & 2 deletions tilelang/jit/adapter/cython/cython_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ cdef class CythonKernelWrapper:
# Add dynamic dimension values to kernel arguments
for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
if ref_id == 0:
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
call_args.append(ctypes.c_int64(tensor_list[buffer_idx].shape[shape_idx]))
else:
call_args.append(tensor_list[buffer_idx].stride(shape_idx))
call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx)))

# Add CUDA stream to kernel arguments
call_args.append(ctypes.c_void_p(stream))
Expand Down
4 changes: 2 additions & 2 deletions tilelang/jit/adapter/nvrtc/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ def create_dispatch_func(self, code, function_informations):
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})

function_args.append(self.get_stream_type())

Expand Down
28 changes: 15 additions & 13 deletions tilelang/jit/adapter/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def create_dispatch_func(self, code, function_informations):
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})

function_args.append(self.get_stream_type())

Expand Down Expand Up @@ -405,28 +405,30 @@ def parse_source_information(self):

def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: list[str] = []
dynamic_symbolic_set: dict[str, str] = {}

def unique_push_back(name: str):
def unique_push_back(name: str, dtype: str):
if name not in dynamic_symbolic_set:
dynamic_symbolic_set.append(name)
dynamic_symbolic_set[name] = dtype
else:
assert dtype == dynamic_symbolic_set[name]

for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var):
unique_push_back(dim.name)
unique_push_back(dim.name, str(dim.dtype))

# Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape.
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for stride in buffer.strides:
if isinstance(stride, tvm.tir.Var):
unique_push_back(stride.name)
unique_push_back(stride.name, str(stride.dtype))

return dynamic_symbolic_set
return list(dynamic_symbolic_set.items())

def get_init_func(self):
# Initialize an empty string for the CUDA function call
Expand Down Expand Up @@ -665,8 +667,8 @@ def create_call_func(self, code, function_informations):
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])

Expand Down Expand Up @@ -715,14 +717,14 @@ def parse_source_information(self):

def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: list[str] = []
dynamic_symbolic_set: dict[str, str] = {}
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
dynamic_symbolic_set[dim.name] = str(dim.dtype)
return list(dynamic_symbolic_set.items())

def get_cpu_init_func(self):
# Provide init() and get_last_error() for CPU backend
Expand Down
Loading