Skip to content

Commit

Permalink
[Relay][Hashing] Structural hash - incorporate the var type into its …
Browse files Browse the repository at this point in the history
…hash (apache#3267)

Currently, the BindVar function does not take Var type into account. This causes
two same graph structures with different var shapes to have same hash.
Structural hash is used for keeping track of which operators we have
already compiled. Because of this, two operators with different shapes end up
pointing to same compiled code. The failure is encountered at runtime, where the
expected input shape asserts are not met.
  • Loading branch information
anijain2305 authored and Wei Chen committed Jun 26, 2019
1 parent bd4ead2 commit d9f4f5d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ class RelayHashHandler:
size_t BindVar(const NodeRef& var) {
size_t hash = std::hash<int>()(var_counter++);
CHECK_EQ(hash_map_.count(var), 0);
if (auto var_node = var.as<VarNode>()) {
hash = Combine(hash, TypeHash(var_node->type_annotation));
}
hash_map_[var] = hash;

const auto* ty_param = var.as<TypeVarNode>();
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,24 @@ def test_graph_equal():
# Check the difference in the text format.
assert not alpha_equal(z0, z3)

def test_hash_unequal():
x1 = relay.var("x1", shape=(10, 10), dtype="float32")
y1 = relay.var("y1", shape=(10, 10), dtype="float32")
func1 = relay.Function([x1, y1], relay.add(x1, y1))

# func2 is exactly same structure with same variables shapes and dtypes
x2 = relay.var("x2", shape=(10, 10), dtype="float32")
y2 = relay.var("y2", shape=(10, 10), dtype="float32")
func2 = relay.Function([x2, y2], relay.add(x2, y2))

assert ir_pass.structural_hash(func1) == ir_pass.structural_hash(func2)

# func3 is same as func1 but with different var shapes
x3 = relay.var("x3", shape=(20, 10), dtype="float32")
y3 = relay.var("y3", shape=(20, 10), dtype="float32")
func3 = relay.Function([x3, y3], relay.add(x3, y3))

assert not ir_pass.structural_hash(func1) == ir_pass.structural_hash(func3)

if __name__ == "__main__":
test_tensor_type_alpha_equal()
Expand All @@ -617,3 +634,4 @@ def test_graph_equal():
test_op_alpha_equal()
test_var_alpha_equal()
test_graph_equal()
test_hash_unequal()

0 comments on commit d9f4f5d

Please sign in to comment.