Skip to content

Commit

Permalink
add __hash__() based on the parameters and node_name (#289)
Browse files Browse the repository at this point in the history
* add `__hash__()` based on the parameters and node_name

* add `zn.Hash`
  • Loading branch information
PythonFZ authored May 2, 2022
1 parent 270e8d3 commit bbb2691
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 4 deletions.
32 changes: 32 additions & 0 deletions tests/integration_tests/test_node_to_node_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,35 @@ def test_stacked_name_getdeps(proj_path, steps):

for step in range(steps):
assert AddOne[f"add_{step}"].output == step + 1


class ParameterNodeWithHash(Node):
param1 = zn.params()
param2 = zn.params()

_hash = zn.Hash()

def run(self):
pass


class ParamDeps(Node):
deps: ParameterNodeWithHash = zn.deps(ParameterNodeWithHash)
outs = zn.outs()

def run(self):
self.outs = self.deps.param1 + self.deps.param2


def test_ParameterNodeWithHash(proj_path):
ParameterNodeWithHash(param1=2, param2=40).write_graph()
ParamDeps().write_graph()

subprocess.check_call(["dvc", "repro"])
assert ParamDeps.load().outs == 42

# Change parameters and thereby check if the dependencies are executed
ParameterNodeWithHash(param1=10, param2=7).save()
subprocess.check_call(["dvc", "dag"])
subprocess.check_call(["dvc", "repro"])
assert ParamDeps.load().outs == 17
22 changes: 22 additions & 0 deletions tests/unit_tests/core/test_dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,25 @@ def test_ZnTrackInfo_collect():
example = ExampleClassWithParams()

assert example.zntrack.collect(zn.params) == {"param1": 1, "param2": 2}


@pytest.mark.parametrize(
("param1", "param2"),
((5, 10), ("a", "b"), ([1, 2, 3], [1, 2, 3]), ({"a": [1, 2, 3], "b": [4, 5, 6]})),
)
def test_params_hash(param1, param2):
self_1 = ExampleClassWithParams()
self_1.param1 = param1
self_1.param2 = param2

self_2 = ExampleClassWithParams()
self_2.param1 = param1
self_2.param2 = param2

assert hash(self_1) == hash(self_2)

self_3 = ExampleClassWithParams(name="CustomCls")
self_3.param1 = param1
self_3.param2 = param2

assert hash(self_1) != hash(self_3)
30 changes: 30 additions & 0 deletions tests/unit_tests/zn/test_zn_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from zntrack import zn


class ExampleNodeWithTime:
hash = zn.Hash()

def __hash__(self):
return 1234


def test_hash():
assert ExampleNodeWithTime().hash != ExampleNodeWithTime().hash

with pytest.raises(ValueError):
example = ExampleNodeWithTime()
example.hash = 1234


class ExampleNode:
hash = zn.Hash(use_time=False)

def __hash__(self):
return 1234


def test_constant_hash():
assert ExampleNode().hash == 1234
assert ExampleNode().hash == ExampleNode().hash
3 changes: 2 additions & 1 deletion zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def update_options(self, lazy=None):
if lazy is None:
lazy = utils.config.lazy
for option in self._descriptor_list:
self.__dict__[option.name] = utils.LazyOption
if option.allow_lazy:
self.__dict__[option.name] = utils.LazyOption
if not lazy:
# trigger loading the data into memory
value = getattr(self, option.name)
Expand Down
14 changes: 12 additions & 2 deletions zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import dataclasses
import json
import logging
import pathlib
import typing

from zntrack import descriptor, utils
from zntrack.core.jupyter import jupyter_class_to_file
from zntrack.core.zntrackoption import ZnTrackOption
from zntrack.descriptor import BaseDescriptorType
from zntrack.zn import params as zntrack_params
from zntrack.zn.dependencies import NodeAttribute

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -211,7 +214,7 @@ class ZnTrackInfo:
def __init__(self, parent):
self._parent = parent

def collect(self, zntrackoption: descriptor.BaseDescriptorType) -> dict:
def collect(self, zntrackoption: typing.Type[descriptor.BaseDescriptorType]) -> dict:
"""Collect the values of all ZnTrackOptions of the passed type
Parameters
Expand Down Expand Up @@ -259,8 +262,15 @@ def __init__(self, **kwargs):
if len(kwargs) > 0:
raise TypeError(f"'{kwargs}' are an invalid keyword argument")

def __hash__(self):
"""compute the hash based on the parameters and node_name"""
params_dict = self.zntrack.collect(zntrack_params)
params_dict["node_name"] = self.node_name

return hash(json.dumps(params_dict, sort_keys=True))

@property
def _descriptor_list(self) -> typing.List[ZnTrackOption]:
def _descriptor_list(self) -> typing.List[BaseDescriptorType]:
"""Get all descriptors of this instance"""
return descriptor.get_descriptors(ZnTrackOption, self=self)

Expand Down
3 changes: 3 additions & 0 deletions zntrack/core/zntrackoption.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ class ZnTrackOption(descriptor.Descriptor):
The cmd to use with DVC, e.g. dvc --outs ... would be "outs"
zn_type: utils.ZnTypes
The internal ZnType to select the correct ZnTrack behaviour
allow_lazy: bool, default=True
Allow this option to be lazy loaded.
"""

file = None
dvc_option: str = None
zn_type: utils.ZnTypes = None
allow_lazy: bool = True

def __init__(self, default_value=None, **kwargs):
"""Constructor for ZnTrackOptions
Expand Down
3 changes: 2 additions & 1 deletion zntrack/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from zntrack.core.zntrackoption import ZnTrackOption
from zntrack.zn.method import Method
from zntrack.zn.split_option import SplitZnTrackOption
from zntrack.zn.zn_hash import Hash

log = logging.getLogger(__name__)

__all__ = [Method.__name__]
__all__ = [Method.__name__, Hash.__name__]

try:
from .plots import plots
Expand Down
48 changes: 48 additions & 0 deletions zntrack/zn/zn_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import datetime

from zntrack import utils
from zntrack.core.zntrackoption import ZnTrackOption


class Hash(ZnTrackOption):
"""Special ZnTrack outs
This 'zn.Hash' can be useful if you are dealing with a Node that typically has no
outputs but is used e.g. for storing parameters. Because other Nodes don't use the
parameters of this Node but rather the outputs of this Node as a dependency,
it is important that its value changes when ever parameters or dependencies change.
TODO consider passing the parameters of such a Node to the dependent Node instead
of using this trick.
"""

zn_type = utils.ZnTypes.RESULTS
dvc_option = utils.DVCOptions.OUTS_NO_CACHE.value
allow_lazy: bool = False

def __init__(self, *, use_time: bool = True, **kwargs):
"""
Parameters
----------
use_time: bool, default = True
Add the hash of datetime.now() to provide extra salt for the hash value
to change independently of the given parameters. This is the default,
because rerunning the Node typically is associated with some changed
dependencies which are not accounted for in the parameters.
kwargs
"""
super().__init__(filename="hash", **kwargs)
self.use_time = use_time

def __get__(self, instance, owner=None) -> int:
""""""
if instance is None:
return self
if self.use_time:
return hash(instance) + hash(datetime.datetime.now())
return hash(instance)

def __set__(self, instance, value):
"""Don't allow to set the value"""
raise ValueError("Can not set value of zn.Hash")

0 comments on commit bbb2691

Please sign in to comment.