diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 076339d774b4..ad16f862ac2b 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -91,7 +91,8 @@ Doc RelayTextPrinter::PrintScope(const ObjectRef& node) { } Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { - if (node->IsInstance() && !node->IsInstance()) { + if (node.defined() && node->IsInstance() && + !node->IsInstance()) { // Temporarily skip non-relay functions. // TODO(tvm-team) enhance the code to work for all functions } else if (node.as()) { @@ -105,8 +106,8 @@ Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { } Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { - bool is_non_relay_func = - node->IsInstance() && !node->IsInstance(); + bool is_non_relay_func = node.defined() && node->IsInstance() && + !node->IsInstance(); if (node.as() && !is_non_relay_func) { return PrintExpr(Downcast(node), meta, try_inline); } else if (node.as()) { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 61dbca33ca7a..2a88c0c99ae7 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -240,6 +240,15 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { assert main_def_str.strip() in mod_str +def test_null_attribute(): + x = relay.var("x") + y = relay.var("y") + z = relay.Function([x], y) + z = z.with_attr("TestAttribute", None) + txt = astext(z) + assert "TestAttribute=(nullptr)" in txt + + if __name__ == "__main__": do_print[0] = True test_lstm() @@ -262,3 +271,4 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { test_variable_name() test_call_node_order() test_unapplied_constructor() + test_null_attribute()