Skip to content

Commit

Permalink
Fix recast of relay ops without attributes (apache#8043)
Browse files Browse the repository at this point in the history
* Fix recast of ops without attributes

* fix test for pylint pass
  • Loading branch information
elvin-n authored and trevor-m committed Jun 17, 2021
1 parent b08de21 commit a39009e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/recast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def visit_call(self, call):

# If out_dtype is in the attributes, we need to update it.
orig_dtype = None
if "out_dtype" in call.attrs.keys():
if call.attrs is not None and "out_dtype" in call.attrs.keys():
new_attr_dict = {}
for attr in call.attrs.keys():
attr_value = call.attrs[attr]
Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/test_recast.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,37 @@ def expected():
assert tvm.ir.structural_equal(expected, post)


def test_recast_relu():
"""Recast a ReLU operator which does not have attributes."""

def before():
x = relay.var("x", shape=[8, 8, 8, 8])
w = relay.var("w", shape=[8, 8, 3, 3])
c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32")
r = relay.nn.relu(c)
return relay.Function([x, w], r)

def expected():
x = relay.var("x", shape=[8, 8, 8, 8])
w = relay.var("w", shape=[8, 8, 3, 3])
x_fp16 = relay.cast(x, "float16")
w_fp16 = relay.cast(w, "float16")
c = relay.nn.conv2d(x_fp16, w_fp16, padding=(1, 1), out_dtype="float16")
c_float32 = relay.cast(c, "float32")
c_float16 = relay.cast(c_float32, "float16")
r = relay.nn.relu(c_float16)
r_float32 = relay.cast(r, "float32")
return relay.Function([x, w], r_float32)

pre = before()
post = recast(pre, "float16", "float16", ops=["nn.conv2d", "nn.relu"])
expected = expected()
assert tvm.ir.structural_equal(expected, post)


if __name__ == "__main__":
test_recast_simple()
test_recast_medium()
test_recast_skip()
test_recast_concat()
test_recast_relu()

0 comments on commit a39009e

Please sign in to comment.