From 7f0055b5a4447592442fb0915ffcdee4076779dc Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Fri, 6 Nov 2020 18:53:43 +0800 Subject: [PATCH] [Relay][TF] Keep node name in span --- python/tvm/relay/expr.py | 14 ++++++++++---- python/tvm/relay/expr_functor.py | 4 ++-- python/tvm/relay/frontend/tensorflow.py | 21 +++++++++++++++++++-- src/printer/relay_text_printer.cc | 14 +++++++++++++- src/printer/text_printer.h | 1 + src/relay/ir/expr.cc | 8 ++++---- src/relay/transforms/fuse_ops.cc | 2 +- tests/python/relay/test_ir_text_printer.py | 18 ++++++++++++++++++ 8 files changed, 68 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 6d304648fa1c..7b6e4b4ccf80 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -185,10 +185,13 @@ class Tuple(ExprWithOp): ---------- fields : List[tvm.relay.Expr] The fields in the tuple. + + span: Optional[tvm.relay.Span] + Span that points to original source code """ - def __init__(self, fields): - self.__init_handle_by_constructor__(_ffi_api.Tuple, fields) + def __init__(self, fields, span=None): + self.__init_handle_by_constructor__(_ffi_api.Tuple, fields, span) def __getitem__(self, index): if index >= len(self): @@ -251,12 +254,15 @@ class Call(ExprWithOp): type_args: Optional[List[tvm.relay.Type]] The additional type arguments, this is only used in advanced usecase of template functions. + + span: Optional[tvm.relay.Span] + Span that points to original source code """ - def __init__(self, op, args, attrs=None, type_args=None): + def __init__(self, op, args, attrs=None, type_args=None, span=None): if not type_args: type_args = [] - self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args) + self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span) @tvm._ffi.register_object("relay.Let") diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 0a37e4d4393c..40a116ab0b43 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -213,7 +213,7 @@ def visit_let(self, let): def visit_call(self, call): new_fn = self.visit(call.op) new_args = [self.visit(arg) for arg in call.args] - return Call(new_fn, new_args, call.attrs) + return Call(new_fn, new_args, call.attrs, call.type_args, call.span) def visit_var(self, var): return var @@ -225,7 +225,7 @@ def visit_if(self, ite): return If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch)) def visit_tuple(self, tup): - return Tuple([self.visit(field) for field in tup.fields]) + return Tuple([self.visit(field) for field in tup.fields], tup.span) def visit_tuple_getitem(self, op): tuple_value = self.visit(op.tuple_value) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c6079b4535c4..d46470d73af1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -3401,7 +3401,7 @@ def _partition_call_operator(self, inputs, attr): return ret def _convert_operator( - self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None + self, op_name, node_name, inputs, attrs, identity_list=None, convert_map=None ): """Convert from Tensorflow operator to relay operator. The converter must specify conversions explicitly for incompatible name, and @@ -3440,6 +3440,23 @@ def _convert_operator( sym = self._partition_call_operator(inputs, attrs) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) + + sym = self._set_span(sym, node_name) + + return sym + + @staticmethod + def _set_span(sym, node_name): + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call): + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call): + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) return sym def _licm_construct(self, loop_name, node_name): @@ -3576,7 +3593,7 @@ def _backtrack_construct(self, node_name): actual_input = self._licm_construct(plname, iname) inputs[i] = actual_input - op = self._convert_operator(node.op, inputs, attr, self._graph) + op = self._convert_operator(node.op, node.name, inputs, attr) if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 4132ab14ff29..da4f8cadfb3d 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -489,7 +489,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { // don't print as a call if it's a 0-arity cons return doc; } else { - return doc << "(" << Doc::Concat(args) << ")"; + doc << "(" << Doc::Concat(args) << ")"; + if (op->span.defined()) { + doc << " /* " << PrintSpan(op->span) << " */"; + } + return doc; } } @@ -840,6 +844,14 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { return docs; } +Doc RelayTextPrinter::PrintSpan(const Span& span) { + Doc doc; + const auto* span_node = span.as(); + ICHECK(span_node); + doc << span_node->source_name->name; + return doc; +} + TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { auto text = AsText(node, false, nullptr); return text; diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index e519969d6a4b..9a24fe65b4b1 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -74,6 +74,7 @@ class RelayTextPrinter : public ExprFunctor, Doc PrintFinal(const ObjectRef& node); std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); std::vector PrintFuncAttrs(const Attrs& attrs); + Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f2e0b363eb2b..89d1f1ab0f11 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -73,8 +73,8 @@ Tuple::Tuple(tvm::Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields) { - return Tuple(fields); +TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { + return Tuple(fields, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -121,8 +121,8 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") - .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { - return Call(op, args, attrs, type_args); + .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args, Span span) { + return Call(op, args, attrs, type_args, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 8023305f3f64..29f3bfa0a17e 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -870,7 +870,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(call)->FindRoot(); Array new_args = GetNewArguments(call->args, ret_group); - auto new_call = Call(call->op, new_args, call->attrs, call->type_args); + auto new_call = Call(call->op, new_args, call->attrs, call->type_args, call->span); if (ret_group->root_ref == call) { // This is the root of the group diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 6c2f7166f446..4a3569aca2ec 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -250,6 +250,24 @@ def test_null_attribute(): assert "TestAttribute=(nullptr)" in txt +def test_span(): + x = relay.var("x", shape=(3, 2)) + y = relay.var("y") + one = relay.const(10e10, dtype="float32") + z = relay.add(x, one) + z = relay.Call( + z.op, z.args, z.attrs, z.type_args, relay.Span(relay.SourceName("Add0"), 0, 0, 0, 0) + ) + z = relay.add(z, z) + z = relay.Call( + z.op, z.args, z.attrs, z.type_args, relay.Span(relay.SourceName("Add1"), 0, 0, 0, 0) + ) + f = relay.Function([x, y], z) + txt = astext(f) + assert "Add0" in txt + assert "Add1" in txt + + if __name__ == "__main__": import sys