Skip to content

Commit

Permalink
[Fixbug] Fix graph metadata hash (#428)
Browse files Browse the repository at this point in the history
Same graphs should have the same hash. Graph node traversal used in
generating hash is non-deterministic, leading to different hashes for
the same graph.

This can prevent graph runs from using the fast path, since dispatch
tables cannot be found for the current (different) hash!
  • Loading branch information
KTong821 authored Mar 4, 2024
1 parent 06cf139 commit b7c9026
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 6 additions & 5 deletions python/hidet/drivers/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Set, Dict
from typing import List, Dict
import os
import json
import shutil
Expand All @@ -34,15 +34,15 @@ def get_graph_weights(graph):
Get the weights of the graph. All constant tensors used by the operators in the graph, or returned directly by the
graph, are considered as weights.
"""
weights: Set[Tensor] = set()
weights: List[Tensor] = []
for node in graph.nodes:
for x in node.inputs:
if x.storage is not None:
weights.add(x)
weights.append(x)
for y in graph.outputs:
if y.storage is not None:
weights.add(y)
return list(weights)
weights.append(y)
return weights


def get_graph_intermediates(graph):
Expand Down Expand Up @@ -145,6 +145,7 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD
lines.append(str(node.task))
lines.append(str(graph))
lines.append(str(space))

graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16]

return GraphMetaData(
Expand Down
8 changes: 5 additions & 3 deletions python/hidet/graph/impl/graph_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Dict, Set, Optional, Union
from typing import List, Tuple, Dict, Optional, Union
from collections import defaultdict
import hidet.option
from hidet.graph.tensor import Tensor
Expand Down Expand Up @@ -39,10 +39,12 @@ def graph_analyze(
stop_tensors: List[Tensor] = stop_tensors or []

# find out all nodes
all_nodes: Set[Operator] = set()
# use dict for ordered set behaviour
# ordering needed for deterministic node ordering
all_nodes: Dict[Operator, bool] = {}

def find_all_nodes(u: Operator):
all_nodes.add(u)
all_nodes[u] = True
for x in u.inputs:
if x.op is None or x in stop_tensors:
continue
Expand Down

0 comments on commit b7c9026

Please sign in to comment.