Skip to content

Commit

Permalink
[TVMScript] Support starred indices in for-loop (#15442)
Browse files Browse the repository at this point in the history
* [TVMScript] Support starred indices in for-loop

An extension of #15404, which
allowed starred expressions in the rhs of
assignments (e.g. `T.decl_buffer(shape=[*dim, 128])`), this PR also
enables starred expressions in the lhs of
assignments (e.g. `for *spatial,reduction in T.grid(*A.shape)`).

* Fix single-argument indices

* Updated test case for T.grid()
  • Loading branch information
Lunderberg authored Aug 2, 2023
1 parent bab295e commit 1b7175b
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 6 deletions.
3 changes: 3 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,9 @@ def buffer_store(
"""
from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel

if not isinstance(indices, (list, tuple, ir.Array)):
indices = [indices]

expr_indices = []
for index in indices:
if isinstance(index, slice):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
return vars
elif isinstance(target, doc.Name):
return {target.id}
elif isinstance(target, doc.Starred):
return self._duplicate_lhs_check(target.value)
else:
self.report_error(target, "Invalid type in assign statement")
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> A
res : Any
The bound value.
"""
if isinstance(value, (list, tuple)):
if isinstance(value, (list, tuple, tvm.ir.Array)):
for i, v in enumerate(value):
bind_for_value(self, node, f"{var_name}_{i}", v)
return value
Expand Down Expand Up @@ -255,7 +255,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
for index in lhs.slice.elts:
indices.append(self.eval_expr(index))
else:
indices = [self.eval_expr(lhs.slice)]
indices = self.eval_expr(lhs.slice)
T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
else:
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ def invalid_loop_var() -> None:


def test_inconsistent_grid():
def inconsistent_grid() -> None:
for i in T.grid(16, 16): # error
T.evaluate(1.0)
def inconsistent_grid(A: T.Buffer(16)) -> None:
for i in T.grid(16, 16): # valid, i is a tuple (iter0, iter1)
T.evaluate(A[i]) # error

check_error(inconsistent_grid, 2)
check_error(inconsistent_grid, 3)


def test_invalid_match_buffer_region():
Expand Down
62 changes: 62 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,67 @@ def non_starred(a: T.handle) -> None:
tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_starred_shape_expression():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, *dims], "int32")
for i, j, k in T.grid(*A.shape):
A[i, j, k] = T.int32(1)

@T.prim_func(private=True)
def non_starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
A[i, j, k] = T.int32(1)

tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_dynamic_for_loop():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, *dims], "int32")
for iters in T.grid(*A.shape):
A[iters] = T.int32(1)

@T.prim_func(private=True)
def non_starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
A[i, j, k] = T.int32(1)

tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_starred_for_loop():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle, b: T.handle):
A = T.match_buffer(a, [*dims, 128], "int32")
B = T.match_buffer(a, dims, "int32")
for *spatial, reduction in T.grid(*A.shape):
with T.block("reduce"):
with T.init():
B[spatial] = T.int32(0)
B[spatial] = B[spatial] + A[(*spatial, reduction)]

@T.prim_func(private=True)
def non_starred(a: T.handle, b: T.handle):
A = T.match_buffer(a, [128, 128, 128], "int32")
B = T.match_buffer(a, [128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
with T.block("reduce"):
with T.init():
B[i, j] = T.int32(0)
B[i, j] = B[i, j] + A[i, j, k]

tvm.ir.assert_structural_equal(starred, non_starred)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1b7175b

Please sign in to comment.