Skip to content

Commit

Permalink
[TVMScript] Support starred indices in for-loop
Browse files Browse the repository at this point in the history
An extension of apache#15404, which
allowed starred expressions in the rhs of
assignments (e.g. `T.decl_buffer(shape=[*dim, 128])`), this PR also
enables starred expressions in the lhs of
assignments (e.g. `for *spatial,reduction in T.grid(*A.shape)`).
  • Loading branch information
Lunderberg committed Jul 31, 2023
1 parent 619bb1d commit 62bac84
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
2 changes: 2 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
return vars
elif isinstance(target, doc.Name):
return {target.id}
elif isinstance(target, doc.Starred):
return self._duplicate_lhs_check(target.value)
else:
self.report_error(target, "Invalid type in assign statement")
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> A
res : Any
The bound value.
"""
if isinstance(value, (list, tuple)):
if isinstance(value, (list, tuple, tvm.ir.Array)):
for i, v in enumerate(value):
bind_for_value(self, node, f"{var_name}_{i}", v)
return value
Expand Down Expand Up @@ -255,7 +255,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
for index in lhs.slice.elts:
indices.append(self.eval_expr(index))
else:
indices = [self.eval_expr(lhs.slice)]
indices = self.eval_expr(lhs.slice)
T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
else:
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
Expand Down
62 changes: 62 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,67 @@ def non_starred(a: T.handle) -> None:
tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_starred_shape_expression():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, *dims], "int32")
for i, j, k in T.grid(*A.shape):
A[i, j, k] = T.int32(1)

@T.prim_func(private=True)
def non_starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
A[i, j, k] = T.int32(1)

tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_dynamic_for_loop():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, *dims], "int32")
for iters in T.grid(*A.shape):
A[iters] = T.int32(1)

@T.prim_func(private=True)
def non_starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
A[i, j, k] = T.int32(1)

tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_starred_for_loop():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle, b: T.handle):
A = T.match_buffer(a, [*dims, 128], "int32")
B = T.match_buffer(a, dims, "int32")
for *spatial, reduction in T.grid(*A.shape):
with T.block("reduce"):
with T.init():
B[spatial] = T.int32(0)
B[spatial] = B[spatial] + A[(*spatial, reduction)]

@T.prim_func(private=True)
def non_starred(a: T.handle, b: T.handle):
A = T.match_buffer(a, [128, 128, 128], "int32")
B = T.match_buffer(a, [128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
with T.block("reduce"):
with T.init():
B[i, j] = T.int32(0)
B[i, j] = B[i, j] + A[i, j, k]

tvm.ir.assert_structural_equal(starred, non_starred)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 62bac84

Please sign in to comment.