diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index bce3610da47b..f37b1a4c10be 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -267,6 +267,8 @@ class RelayHashHandler: hash = Combine(hash, TypeHash(func->ret_type)); hash = Combine(hash, ExprHash(func->body)); + hash = Combine(hash, AttrHash(func->attrs)); + return hash; } diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index b240daf962d5..6ef435a19388 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -313,6 +313,29 @@ def test_tuple_get_item_alpha_equal(): assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) +def test_multi_node_subgraph(): + x0 = relay.var('x0', shape=(10, 10)) + w00 = relay.var('w00', shape=(10, 10)) + w01 = relay.var('w01', shape=(10, 10)) + w02 = relay.var('w02', shape=(10, 10)) + z00 = relay.add(x0, w00) + p00 = relay.subtract(z00, w01) + q00 = relay.multiply(p00, w02) + func0 = relay.Function([x0, w00, w01, w02], q00) + func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a")) + + x1 = relay.var('x1', shape=(10, 10)) + w10 = relay.var('w10', shape=(10, 10)) + w11 = relay.var('w11', shape=(10, 10)) + w12 = relay.var('w12', shape=(10, 10)) + z10 = relay.add(x1, w10) + p10 = relay.subtract(z10, w11) + q10 = relay.multiply(p10, w12) + func1 = relay.Function([x1, w10, w11, w12], q10) + func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b")) + assert not alpha_equal(func0, func1) + + def test_function_alpha_equal(): tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((4, 5, 6), "int8") @@ -639,6 +662,7 @@ def test_tuple_match(): test_tuple_alpha_equal() test_tuple_get_item_alpha_equal() test_function_alpha_equal() + test_function_attr() test_call_alpha_equal() test_let_alpha_equal() test_if_alpha_equal()