Skip to content

Commit

Permalink
[TIR] Allow starred expressions in TIR script (#15404)
Browse files Browse the repository at this point in the history
Small change in the evaluator to allow it to handle starred expressions
(i.e. list/tuple splicing).
  • Loading branch information
Krzysztof Parzyszek authored Jul 25, 2023
1 parent 9ff74fb commit 304aa1e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ def _visit(self, node: doc.AST) -> Any:
return node
if isinstance(node, doc.Lambda):
return self._eval_lambda(node)
if isinstance(node, doc.Starred):
value = self._visit(node.value)
return doc.Starred(
value=value,
ctx=node.ctx,
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
)

fields = {}
for field in node.__class__._FIELDS: # pylint: disable=protected-access
attr = getattr(node, field)
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,23 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32"
tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)


def test_tir_starred_expression():
dims = (128, 128)

@T.prim_func(private=True)
def starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, *dims], "int32")
for i, j, k in T.grid(128, *dims):
A[i, j, k] = T.int32(1)

@T.prim_func(private=True)
def non_starred(a: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], "int32")
for i, j, k in T.grid(128, 128, 128):
A[i, j, k] = T.int32(1)

tvm.ir.assert_structural_equal(starred, non_starred)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 304aa1e

Please sign in to comment.