From 5f37380055403199897c667053be799733e550e8 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 25 Nov 2020 17:00:31 -0800 Subject: [PATCH] [Frontend][Relay][Parser] fix unparsable yolo formals (#6963) * fix yolo formals * fix lint * move test to test_forward --- python/tvm/relay/frontend/darknet.py | 2 +- tests/python/frontend/darknet/test_forward.py | 15 +++++++++++++++ tests/python/relay/test_ir_text_printer.py | 5 ++--- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 87e55593e943..363812fd562b 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -40,7 +40,7 @@ def _darknet_not_support(attr, op="relay"): def _get_params_prefix(opname, layer_num): """Makes the params prefix name from opname and layer number.""" - return str(opname) + str(layer_num) + return str(opname).replace(".", "_") + str(layer_num) def _get_params_name(prefix, item): diff --git a/tests/python/frontend/darknet/test_forward.py b/tests/python/frontend/darknet/test_forward.py index 74c1a2199caa..b6dc815a9530 100644 --- a/tests/python/frontend/darknet/test_forward.py +++ b/tests/python/frontend/darknet/test_forward.py @@ -46,6 +46,17 @@ ) +def astext(program, unify_free_vars=False): + """check that program is parsable in text format""" + text = program.astext() + if isinstance(program, relay.Expr): + roundtrip_program = tvm.parser.parse_expr(text) + else: + roundtrip_program = tvm.parser.fromtext(text) + + tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True) + + def _read_memory_buffer(shape, data, dtype="float32"): length = 1 for x in shape: @@ -60,6 +71,10 @@ def _get_tvm_output(net, data, build_dtype="float32", states=None): """Compute TVM output""" dtype = "float32" mod, params = relay.frontend.from_darknet(net, data.shape, dtype) + # verify that from_darknet creates a valid, parsable relay program + mod = relay.transform.InferType()(mod) + astext(mod) + target = "llvm" shape_dict = {"data": data.shape} lib = relay.build(mod, target, params=params) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 4a3569aca2ec..72a243dbbb67 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -21,6 +21,7 @@ import numpy as np from tvm.relay import Expr from tvm.relay.analysis import free_vars +import pytest DEBUG_PRINT = False @@ -269,6 +270,4 @@ def test_span(): if __name__ == "__main__": - import sys - - pytext.argv(sys.argv) + pytest.main([__file__])