From c8d46b6686caa186c79496c624577a531d5a99c2 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 15 Apr 2019 15:13:27 -0700 Subject: [PATCH 1/3] save --- tests/python/relay/test_backend_interpreter.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index da794e25ab56..bf7d96194bf2 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -50,6 +50,12 @@ def test_tuple_value(): np.testing.assert_allclose(tv[2].asnumpy(), 3) +def test_tuple_getitem(): + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) + check_eval(func, [], 1) + + def test_id(): x = relay.var('x', 'float32') ident = relay.Function([x], x) @@ -181,3 +187,5 @@ def test_kwargs_params(): test_kwargs_params() test_ref() test_tensor_value() + test_tuple_value() + test_tuple_getitem() From 348a2175bac138e26eb0c758c77a97984f8852ab Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 16 Apr 2019 18:20:17 -0700 Subject: [PATCH 2/3] fix --- src/relay/pass/fuse_ops.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 6de9c2d65f90..198b697d0017 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -865,9 +865,17 @@ class FuseMutator : private ExprMutator { } Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { + // If the function has no call, it is not a primitive function. + struct HasCallVisitor : ExprVisitor { + bool HasCall = false; + void VisitExpr_(const CallNode* op) final { + HasCall = true; + } + } visitor; + visitor(body); const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.HasCall)); return CallNode::make(func, ginfo.arguments, Attrs()); } From dd4b3a6e91273680d52953c340e8b4047f0cd833 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Tue, 16 Apr 2019 21:19:13 -0700 Subject: [PATCH 3/3] Update fuse_ops.cc --- src/relay/pass/fuse_ops.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 198b697d0017..12e3174dcade 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -867,15 +867,15 @@ class FuseMutator : private ExprMutator { Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { - bool HasCall = false; + bool has_call = false; void VisitExpr_(const CallNode* op) final { - HasCall = true; + has_call = true; } } visitor; visitor(body); const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.HasCall)); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.has_call)); return CallNode::make(func, ginfo.arguments, Attrs()); }