From bbb2691f452c877d942f7670a1093e09a8f9ed78 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Mon, 2 May 2022 15:41:15 +0200 Subject: [PATCH] add `__hash__()` based on the parameters and node_name (#289) * add `__hash__()` based on the parameters and node_name * add `zn.Hash` --- .../test_node_to_node_dependencies.py | 32 +++++++++++++ tests/unit_tests/core/test_dvcgraph.py | 22 +++++++++ tests/unit_tests/zn/test_zn_hash.py | 30 ++++++++++++ zntrack/core/base.py | 3 +- zntrack/core/dvcgraph.py | 14 +++++- zntrack/core/zntrackoption.py | 3 ++ zntrack/zn/__init__.py | 3 +- zntrack/zn/zn_hash.py | 48 +++++++++++++++++++ 8 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 tests/unit_tests/zn/test_zn_hash.py create mode 100644 zntrack/zn/zn_hash.py diff --git a/tests/integration_tests/test_node_to_node_dependencies.py b/tests/integration_tests/test_node_to_node_dependencies.py index 9ae2c2dc..2aa1f49d 100644 --- a/tests/integration_tests/test_node_to_node_dependencies.py +++ b/tests/integration_tests/test_node_to_node_dependencies.py @@ -339,3 +339,35 @@ def test_stacked_name_getdeps(proj_path, steps): for step in range(steps): assert AddOne[f"add_{step}"].output == step + 1 + + +class ParameterNodeWithHash(Node): + param1 = zn.params() + param2 = zn.params() + + _hash = zn.Hash() + + def run(self): + pass + + +class ParamDeps(Node): + deps: ParameterNodeWithHash = zn.deps(ParameterNodeWithHash) + outs = zn.outs() + + def run(self): + self.outs = self.deps.param1 + self.deps.param2 + + +def test_ParameterNodeWithHash(proj_path): + ParameterNodeWithHash(param1=2, param2=40).write_graph() + ParamDeps().write_graph() + + subprocess.check_call(["dvc", "repro"]) + assert ParamDeps.load().outs == 42 + + # Change parameters and thereby check if the dependencies are executed + ParameterNodeWithHash(param1=10, param2=7).save() + subprocess.check_call(["dvc", "dag"]) + subprocess.check_call(["dvc", "repro"]) + assert ParamDeps.load().outs == 17 diff --git a/tests/unit_tests/core/test_dvcgraph.py b/tests/unit_tests/core/test_dvcgraph.py index e9298ec5..cdefd030 100644 --- a/tests/unit_tests/core/test_dvcgraph.py +++ b/tests/unit_tests/core/test_dvcgraph.py @@ -210,3 +210,25 @@ def test_ZnTrackInfo_collect(): example = ExampleClassWithParams() assert example.zntrack.collect(zn.params) == {"param1": 1, "param2": 2} + + +@pytest.mark.parametrize( + ("param1", "param2"), + ((5, 10), ("a", "b"), ([1, 2, 3], [1, 2, 3]), ({"a": [1, 2, 3], "b": [4, 5, 6]})), +) +def test_params_hash(param1, param2): + self_1 = ExampleClassWithParams() + self_1.param1 = param1 + self_1.param2 = param2 + + self_2 = ExampleClassWithParams() + self_2.param1 = param1 + self_2.param2 = param2 + + assert hash(self_1) == hash(self_2) + + self_3 = ExampleClassWithParams(name="CustomCls") + self_3.param1 = param1 + self_3.param2 = param2 + + assert hash(self_1) != hash(self_3) diff --git a/tests/unit_tests/zn/test_zn_hash.py b/tests/unit_tests/zn/test_zn_hash.py new file mode 100644 index 00000000..c9d88f5e --- /dev/null +++ b/tests/unit_tests/zn/test_zn_hash.py @@ -0,0 +1,30 @@ +import pytest + +from zntrack import zn + + +class ExampleNodeWithTime: + hash = zn.Hash() + + def __hash__(self): + return 1234 + + +def test_hash(): + assert ExampleNodeWithTime().hash != ExampleNodeWithTime().hash + + with pytest.raises(ValueError): + example = ExampleNodeWithTime() + example.hash = 1234 + + +class ExampleNode: + hash = zn.Hash(use_time=False) + + def __hash__(self): + return 1234 + + +def test_constant_hash(): + assert ExampleNode().hash == 1234 + assert ExampleNode().hash == ExampleNode().hash diff --git a/zntrack/core/base.py b/zntrack/core/base.py index ee638761..69b34a28 100644 --- a/zntrack/core/base.py +++ b/zntrack/core/base.py @@ -222,7 +222,8 @@ def update_options(self, lazy=None): if lazy is None: lazy = utils.config.lazy for option in self._descriptor_list: - self.__dict__[option.name] = utils.LazyOption + if option.allow_lazy: + self.__dict__[option.name] = utils.LazyOption if not lazy: # trigger loading the data into memory value = getattr(self, option.name) diff --git a/zntrack/core/dvcgraph.py b/zntrack/core/dvcgraph.py index a2602b21..243fddbc 100644 --- a/zntrack/core/dvcgraph.py +++ b/zntrack/core/dvcgraph.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import json import logging import pathlib import typing @@ -8,6 +9,8 @@ from zntrack import descriptor, utils from zntrack.core.jupyter import jupyter_class_to_file from zntrack.core.zntrackoption import ZnTrackOption +from zntrack.descriptor import BaseDescriptorType +from zntrack.zn import params as zntrack_params from zntrack.zn.dependencies import NodeAttribute log = logging.getLogger(__name__) @@ -211,7 +214,7 @@ class ZnTrackInfo: def __init__(self, parent): self._parent = parent - def collect(self, zntrackoption: descriptor.BaseDescriptorType) -> dict: + def collect(self, zntrackoption: typing.Type[descriptor.BaseDescriptorType]) -> dict: """Collect the values of all ZnTrackOptions of the passed type Parameters @@ -259,8 +262,15 @@ def __init__(self, **kwargs): if len(kwargs) > 0: raise TypeError(f"'{kwargs}' are an invalid keyword argument") + def __hash__(self): + """compute the hash based on the parameters and node_name""" + params_dict = self.zntrack.collect(zntrack_params) + params_dict["node_name"] = self.node_name + + return hash(json.dumps(params_dict, sort_keys=True)) + @property - def _descriptor_list(self) -> typing.List[ZnTrackOption]: + def _descriptor_list(self) -> typing.List[BaseDescriptorType]: """Get all descriptors of this instance""" return descriptor.get_descriptors(ZnTrackOption, self=self) diff --git a/zntrack/core/zntrackoption.py b/zntrack/core/zntrackoption.py index 5d823548..777bbd3f 100644 --- a/zntrack/core/zntrackoption.py +++ b/zntrack/core/zntrackoption.py @@ -52,11 +52,14 @@ class ZnTrackOption(descriptor.Descriptor): The cmd to use with DVC, e.g. dvc --outs ... would be "outs" zn_type: utils.ZnTypes The internal ZnType to select the correct ZnTrack behaviour + allow_lazy: bool, default=True + Allow this option to be lazy loaded. """ file = None dvc_option: str = None zn_type: utils.ZnTypes = None + allow_lazy: bool = True def __init__(self, default_value=None, **kwargs): """Constructor for ZnTrackOptions diff --git a/zntrack/zn/__init__.py b/zntrack/zn/__init__.py index 19813f9a..dd645c6a 100644 --- a/zntrack/zn/__init__.py +++ b/zntrack/zn/__init__.py @@ -11,10 +11,11 @@ from zntrack.core.zntrackoption import ZnTrackOption from zntrack.zn.method import Method from zntrack.zn.split_option import SplitZnTrackOption +from zntrack.zn.zn_hash import Hash log = logging.getLogger(__name__) -__all__ = [Method.__name__] +__all__ = [Method.__name__, Hash.__name__] try: from .plots import plots diff --git a/zntrack/zn/zn_hash.py b/zntrack/zn/zn_hash.py new file mode 100644 index 00000000..71bd15d5 --- /dev/null +++ b/zntrack/zn/zn_hash.py @@ -0,0 +1,48 @@ +import datetime + +from zntrack import utils +from zntrack.core.zntrackoption import ZnTrackOption + + +class Hash(ZnTrackOption): + """Special ZnTrack outs + + This 'zn.Hash' can be useful if you are dealing with a Node that typically has no + outputs but is used e.g. for storing parameters. Because other Nodes don't use the + parameters of this Node but rather the outputs of this Node as a dependency, + it is important that its value changes when ever parameters or dependencies change. + + TODO consider passing the parameters of such a Node to the dependent Node instead + of using this trick. + """ + + zn_type = utils.ZnTypes.RESULTS + dvc_option = utils.DVCOptions.OUTS_NO_CACHE.value + allow_lazy: bool = False + + def __init__(self, *, use_time: bool = True, **kwargs): + """ + + Parameters + ---------- + use_time: bool, default = True + Add the hash of datetime.now() to provide extra salt for the hash value + to change independently of the given parameters. This is the default, + because rerunning the Node typically is associated with some changed + dependencies which are not accounted for in the parameters. + kwargs + """ + super().__init__(filename="hash", **kwargs) + self.use_time = use_time + + def __get__(self, instance, owner=None) -> int: + """""" + if instance is None: + return self + if self.use_time: + return hash(instance) + hash(datetime.datetime.now()) + return hash(instance) + + def __set__(self, instance, value): + """Don't allow to set the value""" + raise ValueError("Can not set value of zn.Hash")