Skip to content

Commit 49c8571

Browse files
Fix various issues under int64_t static and dynamic shape. (#1218)
* Fix various issues under int64_t static and dynamic shape. * Resolve reviewed issues. * Add unit test. * fix --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent e805f8e commit 49c8571

File tree

5 files changed

+88
-18
lines changed

5 files changed

+88
-18
lines changed

src/transform/inject_assumes.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "tvm/node/structural_hash.h"
77
#include "tvm/tir/builtin.h"
88
#include "tvm/tir/expr.h"
9+
#include "tvm/tir/op.h"
910
#include "tvm/tir/stmt.h"
1011
#include "tvm/tir/stmt_functor.h"
1112
#include "tvm/tir/transform.h"
@@ -62,7 +63,8 @@ class AssumeInjector : public tvm::tir::StmtExprMutator {
6263
Stmt build(Stmt body) {
6364
auto analyzer = arith::Analyzer{};
6465
for (const auto &e : items) {
65-
auto simplified = analyzer.Simplify(GT(e.expr, 0));
66+
auto simplified =
67+
analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype)));
6668
std::stringstream ss;
6769
ss << "Buffer shape should be greater than 0: shape `" << e.expr
6870
<< "` from buffer ";
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import tilelang
2+
import tilelang.language as T
3+
4+
5+
@tilelang.jit
6+
def fill_symbolic(value: float, dtype="bfloat16"):
7+
n = T.symbolic("n", "int64")
8+
block_n = 512
9+
10+
@T.prim_func
11+
def main(x: T.Tensor[n, dtype]):
12+
# Initialize Kernel Context
13+
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
14+
# Doesn't yet work with int64-shaped global tensor
15+
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
16+
for i in T.Parallel(block_n):
17+
x[bx * block_n + i] = value
18+
19+
return main
20+
21+
22+
def run_fill_symbolic(n: int):
23+
import torch
24+
25+
x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
26+
fill_symbolic(1.0)(x)
27+
assert x.min() == 1.0 and x.max() == 1.0
28+
29+
30+
def test_fill_symbolic():
31+
# Requires 8GB VRAM
32+
run_fill_symbolic(2**32)
33+
34+
35+
@tilelang.jit
36+
def fill_static(n: int, value: float, dtype="bfloat16"):
37+
block_n = 512
38+
39+
@T.prim_func
40+
def main(x: T.Tensor[n, dtype]):
41+
# Initialize Kernel Context
42+
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
43+
# Doesn't yet work with int64-shaped global tensor
44+
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
45+
for i in T.Parallel(block_n):
46+
x[bx * block_n + i] = value
47+
48+
return main
49+
50+
51+
def run_fill_static(n: int):
52+
import torch
53+
54+
x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
55+
fill_static(n, 1.0)(x)
56+
assert x.min() == 1.0 and x.max() == 1.0
57+
58+
59+
def test_fill_static():
60+
# Requires 8GB VRAM
61+
run_fill_static(2**32)
62+
63+
64+
if __name__ == "__main__":
65+
test_fill_symbolic()
66+
test_fill_static()

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ cdef class CythonKernelWrapper:
267267
# Add dynamic dimension values to kernel arguments
268268
for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
269269
if ref_id == 0:
270-
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
270+
call_args.append(ctypes.c_int64(tensor_list[buffer_idx].shape[shape_idx]))
271271
else:
272-
call_args.append(tensor_list[buffer_idx].stride(shape_idx))
272+
call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx)))
273273

274274
# Add CUDA stream to kernel arguments
275275
call_args.append(ctypes.c_void_p(stream))

tilelang/jit/adapter/nvrtc/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ def create_dispatch_func(self, code, function_informations):
313313
raise ValueError(
314314
f"Parameter {param} is not in the buffer map of the primary function.")
315315
# Add dynamic symbols as integer arguments
316-
for dyn_sym in dynamic_symbolic_set:
316+
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
317317
if dyn_sym not in [arg["name"] for arg in function_args]:
318-
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
318+
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
319319

320320
function_args.append(self.get_stream_type())
321321

tilelang/jit/adapter/wrapper.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ def create_dispatch_func(self, code, function_informations):
220220
raise ValueError(
221221
f"Parameter {param} is not in the buffer map of the primary function.")
222222
# Add dynamic symbols as integer arguments
223-
for dyn_sym in dynamic_symbolic_set:
223+
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
224224
if dyn_sym not in [arg["name"] for arg in function_args]:
225-
function_args.append({"name": dyn_sym, "type": "int"})
225+
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
226226

227227
function_args.append(self.get_stream_type())
228228

@@ -405,28 +405,30 @@ def parse_source_information(self):
405405

406406
def get_dynamic_symbolic_set(self, prim_func):
407407
# Determine the set of dynamic symbols used in the function
408-
dynamic_symbolic_set: list[str] = []
408+
dynamic_symbolic_set: dict[str, str] = {}
409409

410-
def unique_push_back(name: str):
410+
def unique_push_back(name: str, dtype: str):
411411
if name not in dynamic_symbolic_set:
412-
dynamic_symbolic_set.append(name)
412+
dynamic_symbolic_set[name] = dtype
413+
else:
414+
assert dtype == dynamic_symbolic_set[name]
413415

414416
for param in prim_func.params:
415417
if param in prim_func.buffer_map:
416418
buffer = prim_func.buffer_map[param]
417419
for dim in buffer.shape:
418420
if isinstance(dim, tvm.tir.Var):
419-
unique_push_back(dim.name)
421+
unique_push_back(dim.name, str(dim.dtype))
420422

421423
# Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape.
422424
for param in prim_func.params:
423425
if param in prim_func.buffer_map:
424426
buffer = prim_func.buffer_map[param]
425427
for stride in buffer.strides:
426428
if isinstance(stride, tvm.tir.Var):
427-
unique_push_back(stride.name)
429+
unique_push_back(stride.name, str(stride.dtype))
428430

429-
return dynamic_symbolic_set
431+
return list(dynamic_symbolic_set.items())
430432

431433
def get_init_func(self):
432434
# Initialize an empty string for the CUDA function call
@@ -665,8 +667,8 @@ def create_call_func(self, code, function_informations):
665667
raise ValueError(
666668
f"Parameter {param} is not in the buffer map of the primary function.")
667669
# Add dynamic symbols as integer arguments
668-
for dyn_sym in dynamic_symbolic_set:
669-
function_args.append({"name": dyn_sym, "type": "int"})
670+
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
671+
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
670672
# Format the function arguments for declaration
671673
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
672674

@@ -715,14 +717,14 @@ def parse_source_information(self):
715717

716718
def get_dynamic_symbolic_set(self, prim_func):
717719
# Determine the set of dynamic symbols used in the function
718-
dynamic_symbolic_set: list[str] = []
720+
dynamic_symbolic_set: dict[str, str] = {}
719721
for param in prim_func.params:
720722
if param in prim_func.buffer_map:
721723
buffer = prim_func.buffer_map[param]
722724
for dim in buffer.shape:
723725
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
724-
dynamic_symbolic_set.append(dim.name)
725-
return dynamic_symbolic_set
726+
dynamic_symbolic_set[dim.name] = str(dim.dtype)
727+
return list(dynamic_symbolic_set.items())
726728

727729
def get_cpu_init_func(self):
728730
# Provide init() and get_last_error() for CPU backend

0 commit comments

Comments
 (0)