Skip to content

Commit

Permalink
add get_hash
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 18, 2023
1 parent 5274e91 commit df96738
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
34 changes: 34 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
.. autofunction:: get_num_call_sites
.. autofunction:: get_hash
.. autoclass:: DirectPredecessorsGetter
"""

Expand Down Expand Up @@ -453,4 +455,36 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:

# }}}


# {{{ get_hash

class HashMapper(CachedWalkMapper):
"""
A mapper that generates a hash for a given DAG.
"""
def __init__(self) -> None:
super().__init__()
import hashlib
self.hash = hashlib.sha256()

def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any:
return expr

def post_visit(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None:
self.hash.update(str(hash(expr)).encode("ascii"))


def get_hash(outputs: Union[Array, DictOfNamedArrays]) -> str:
"""Returns a hash of the DAG *outputs*."""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

hm = HashMapper()
hm(outputs)

return hm.hash.hexdigest()

# }}}

# vim: fdm=marker
28 changes: 28 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,34 @@ def test_dot_visualizers():
# }}}


def test_get_hash():
from pytato.analysis import get_hash

axis_len = 5

seen_hashes = set()

for i in range(100):
rdagc1 = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)
rdagc2 = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)
rdagc3 = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)

dag1 = make_random_dag(rdagc1)
dag2 = make_random_dag(rdagc2)
dag3 = make_random_dag(rdagc3)

h1 = get_hash(dag1)
h2 = get_hash(dag2)
h3 = get_hash(dag3)

assert h1 == h2 == h3
assert h1 not in seen_hashes
seen_hashes.add(h1)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit df96738

Please sign in to comment.