Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add while node support in TVMScript #9004

Merged
merged 2 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.3.0 \
synr==0.4.0 \
six \
tornado
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.3.0"),
("synr", "==0.4.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,19 @@ def transform_For(self, node):
self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
return res

def transform_While(self, node):
"""While visitor
AST abstract grammar:
While(expr condition, stmt* body)
"""
condition = self.transform(node.condition)
# body
self.context.enter_scope(nodes=node.body.stmts)
body = self.parse_body(node)
self.context.exit_scope()

return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span))

def transform_With(self, node):
"""With visitor
AST abstract grammar:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
Expand Down
8 changes: 8 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
Doc VisitStmt_(const WhileNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Doc VisitStmt_(const BlockRealizeNode* op) override;
Expand Down Expand Up @@ -830,6 +831,13 @@ Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) {
return doc;
}

Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) {
Doc doc;
doc << "while " << Print(op->condition) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
return doc;
}

Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
doc << "ty." << runtime::DLDataType2String(node->dtype);
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3066,5 +3066,22 @@ def test_same_name_var():
assert out_str.find("i_") == -1


@tvm.script.tir
def while_loop(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (16,), "float32")
B = tir.match_buffer(b, (16,), "float32")
i = tir.alloc_buffer((), "int32", scope="local")
with tir.block([16]) as [vi]:
B[vi] = 0
while i[()] < 10:
for j in range(16):
B[j] += A[j]


def test_while_loop():
rt_func = tvm.script.from_source(tvm.script.asscript(while_loop, True))
tvm.ir.assert_structural_equal(while_loop, rt_func)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
2 changes: 1 addition & 1 deletion tests/scripts/task_ci_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}

python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.3.0
python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.0

# Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in
# Jenkinsfile. We expect config.cmake to be present from pack_lib().
Expand Down