From a39009e64c7b745846030ff8616d77b1f7099342 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Mon, 17 May 2021 19:05:32 +0300 Subject: [PATCH] Fix recast of relay ops without attributes (#8043) * Fix recast of ops without attributes * fix test for pylint pass --- python/tvm/relay/transform/recast.py | 2 +- tests/python/relay/test_recast.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 2c88f10dc8f4..c1722ab67d6b 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -74,7 +74,7 @@ def visit_call(self, call): # If out_dtype is in the attributes, we need to update it. orig_dtype = None - if "out_dtype" in call.attrs.keys(): + if call.attrs is not None and "out_dtype" in call.attrs.keys(): new_attr_dict = {} for attr in call.attrs.keys(): attr_value = call.attrs[attr] diff --git a/tests/python/relay/test_recast.py b/tests/python/relay/test_recast.py index 43def9df41ce..19803594c968 100644 --- a/tests/python/relay/test_recast.py +++ b/tests/python/relay/test_recast.py @@ -126,7 +126,37 @@ def expected(): assert tvm.ir.structural_equal(expected, post) +def test_recast_relu(): + """Recast a ReLU operator which does not have attributes.""" + + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") + r = relay.nn.relu(c) + return relay.Function([x, w], r) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + x_fp16 = relay.cast(x, "float16") + w_fp16 = relay.cast(w, "float16") + c = relay.nn.conv2d(x_fp16, w_fp16, padding=(1, 1), out_dtype="float16") + c_float32 = relay.cast(c, "float32") + c_float16 = relay.cast(c_float32, "float16") + r = relay.nn.relu(c_float16) + r_float32 = relay.cast(r, "float32") + return relay.Function([x, w], r_float32) + + pre = before() + post = recast(pre, "float16", "float16", ops=["nn.conv2d", "nn.relu"]) + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + if __name__ == "__main__": test_recast_simple() test_recast_medium() test_recast_skip() + test_recast_concat() + test_recast_relu()