Skip to content

Commit 399af08

Browse files
authored
[BugFix] alloc_var init failed to handle complex expression (#1144)
* [Fix] init var with complex expression * fix lint error
1 parent 60567ba commit 399af08

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import tilelang.testing
4+
5+
6+
def test_var_assign() -> None:
7+
8+
@tilelang.jit(out_idx=-1)
9+
def jit_kernel():
10+
11+
@T.prim_func
12+
def test_var_assign(A: T.Tensor((2,), 'int32')):
13+
with T.Kernel(1) as _:
14+
a = T.alloc_var('int32', init=1)
15+
b = T.alloc_var('int32', init=a) # b gets value of a
16+
a = 2
17+
d = T.alloc_var('int32', init=a) # c gets new value of a
18+
A[0] = b
19+
A[1] = d
20+
21+
print(test_var_assign)
22+
return test_var_assign
23+
24+
kernel = jit_kernel()
25+
print(kernel.get_kernel_source())
26+
res = kernel()
27+
assert res[0] == 1
28+
assert res[1] == 2
29+
30+
31+
if __name__ == '__main__':
32+
tilelang.testing.main()

tilelang/language/allocate.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
"""
1616

1717
from __future__ import annotations
18+
from typing import overload
1819
from tilelang import tvm as tvm
1920
from tvm.script import tir as T
2021
from tvm.tir import PrimExpr
2122
from tvm.script.parser.tir import block_attr
23+
from tvm.tir.buffer import Buffer
24+
from tvm.tir.expr import FloatImm, IntImm
2225

2326

2427
def alloc_shared(shape, dtype, scope="shared.dyn"):
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
6770
return T.alloc_buffer(shape, dtype, scope=scope)
6871

6972

73+
@overload
74+
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer:
75+
...
76+
77+
78+
@overload
79+
def alloc_var(dtype: str,
80+
scope: str = 'local.var',
81+
*,
82+
init: PrimExpr | int | float | None = None) -> Buffer:
83+
...
84+
85+
7086
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
7187
"""Allocate a single-element variable buffer.
7288
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
8298
init (PrimExpr, optional): The optional initializer value. When provided,
8399
the generated code will initialize the variable with this value instead
84100
of defaulting to zero.
85-
101+
Examples:
102+
a = T.alloc_var('int32', 1) # var with init 1
103+
a = T.alloc_var('int32', 'local.var') # var with local.var scope
104+
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
105+
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
106+
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
86107
Returns:
87108
T.Buffer: A TVM buffer object allocated as a single-element variable
88109
"""
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
113134

114135
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
115136
if parsed_init is not None:
116-
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
137+
if isinstance(parsed_init, (int, float, IntImm, FloatImm)):
138+
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
139+
else:
140+
T.buffer_store(buffer, parsed_init, 0)
117141
return buffer
118142

119143

0 commit comments

Comments
 (0)