Skip to content

Commit

Permalink
[Relay][Text Format] Reverse CallNode Print Order (#2882)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshpoll authored and tqchen committed Mar 23, 2019
1 parent 8a7f41c commit e23913f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,13 @@ class PrettyPrinter :

Doc VisitExpr_(const CallNode* op) final {
Doc doc;
doc << Print(op->op);
// visit args first so they are lifted before the op
// this places op closer to its call site
std::vector<Doc> args;
for (Expr arg : op->args) {
args.push_back(Print(arg));
}
doc << Print(op->op);
return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, op->op) << ")";
}

Expand Down
17 changes: 16 additions & 1 deletion tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
from tvm import relay


do_print = [False]

SEMVER = "v0.0.1\n"

def show(text):
if do_print[0]:
print("---------------------------")
Expand Down Expand Up @@ -152,6 +153,19 @@ def test_densenet():
net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
net.astext()

def test_call_node_order():
x = relay.var("x")
y = relay.var("y")
assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \
("%0 = fn (%y) {\n"
" %y\n"
"}\n"
"%1 = %0(1)\n"
"%2 = fn (%x) {\n"
" %x\n"
"}\n"
"%3 = %2(%1)\n"
"%3")

if __name__ == "__main__":
do_print[0] = True
Expand All @@ -170,3 +184,4 @@ def test_densenet():
test_call_attrs()
test_let_if_scope()
test_variable_name()
test_call_node_order()

0 comments on commit e23913f

Please sign in to comment.