Skip to content

Commit

Permalink
[Relay][TF] Keep node name in span (#6885)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored Nov 9, 2020
1 parent b7318a7 commit eb1fa29
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 14 deletions.
14 changes: 10 additions & 4 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,13 @@ class Tuple(ExprWithOp):
----------
fields : List[tvm.relay.Expr]
The fields in the tuple.
span: Optional[tvm.relay.Span]
Span that points to original source code
"""

def __init__(self, fields):
self.__init_handle_by_constructor__(_ffi_api.Tuple, fields)
def __init__(self, fields, span=None):
self.__init_handle_by_constructor__(_ffi_api.Tuple, fields, span)

def __getitem__(self, index):
if index >= len(self):
Expand Down Expand Up @@ -251,12 +254,15 @@ class Call(ExprWithOp):
type_args: Optional[List[tvm.relay.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
span: Optional[tvm.relay.Span]
Span that points to original source code
"""

def __init__(self, op, args, attrs=None, type_args=None):
def __init__(self, op, args, attrs=None, type_args=None, span=None):
if not type_args:
type_args = []
self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args)
self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span)


@tvm._ffi.register_object("relay.Let")
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def visit_let(self, let):
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
return Call(new_fn, new_args, call.attrs, call.type_args, call.span)

def visit_var(self, var):
return var
Expand All @@ -225,7 +225,7 @@ def visit_if(self, ite):
return If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
return Tuple([self.visit(field) for field in tup.fields], tup.span)

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
Expand Down
21 changes: 19 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3402,7 +3402,7 @@ def _partition_call_operator(self, inputs, attr):
return ret

def _convert_operator(
self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None
self, op_name, node_name, inputs, attrs, identity_list=None, convert_map=None
):
"""Convert from Tensorflow operator to relay operator.
The converter must specify conversions explicitly for incompatible name, and
Expand Down Expand Up @@ -3441,6 +3441,23 @@ def _convert_operator(
sym = self._partition_call_operator(inputs, attrs)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))

sym = self._set_span(sym, node_name)

return sym

@staticmethod
def _set_span(sym, node_name):
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call):
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call):
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
sym = _expr.TupleWrapper(tuple_value, sym.size)
return sym

def _licm_construct(self, loop_name, node_name):
Expand Down Expand Up @@ -3577,7 +3594,7 @@ def _backtrack_construct(self, node_name):
actual_input = self._licm_construct(plname, iname)
inputs[i] = actual_input

op = self._convert_operator(node.op, inputs, attr, self._graph)
op = self._convert_operator(node.op, node.name, inputs, attr)

if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
Expand Down
14 changes: 13 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
// don't print as a call if it's a 0-arity cons
return doc;
} else {
return doc << "(" << Doc::Concat(args) << ")";
doc << "(" << Doc::Concat(args) << ")";
if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
}
}

Expand Down Expand Up @@ -840,6 +844,14 @@ std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
return docs;
}

Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
return doc;
}

TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
auto text = AsText(node, false, nullptr);
return text;
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
Doc PrintFinal(const ObjectRef& node);
std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
Doc PrintSpan(const Span& span);

Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);

Expand Down
8 changes: 4 additions & 4 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {

TVM_REGISTER_NODE_TYPE(TupleNode);

TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields) {
return Tuple(fields);
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields, Span span) {
return Tuple(fields, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -121,8 +121,8 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s
TVM_REGISTER_NODE_TYPE(CallNode);

TVM_REGISTER_GLOBAL("relay.ir.Call")
.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
return Call(op, args, attrs, type_args);
.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
return Call(op, args, attrs, type_args, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ class FuseMutator : private ExprMutator {
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args = GetNewArguments(call->args, ret_group);

auto new_call = Call(call->op, new_args, call->attrs, call->type_args);
auto new_call = Call(call->op, new_args, call->attrs, call->type_args, call->span);

if (ret_group->root_ref == call) {
// This is the root of the group
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,24 @@ def test_null_attribute():
assert "TestAttribute=(nullptr)" in txt


def test_span():
x = relay.var("x", shape=(3, 2))
y = relay.var("y")
one = relay.const(10e10, dtype="float32")
z = relay.add(x, one)
z = relay.Call(
z.op, z.args, z.attrs, z.type_args, relay.Span(relay.SourceName("Add0"), 0, 0, 0, 0)
)
z = relay.add(z, z)
z = relay.Call(
z.op, z.args, z.attrs, z.type_args, relay.Span(relay.SourceName("Add1"), 0, 0, 0, 0)
)
f = relay.Function([x, y], z)
txt = astext(f)
assert "Add0" in txt
assert "Add1" in txt


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit eb1fa29

Please sign in to comment.