From 1803cb50bc7b81f730e60f2564a603c49cbe8180 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Thu, 22 Sep 2022 16:01:11 +0200 Subject: [PATCH] fix `zn.Nodes` issue with `zn.plots` (#338) * wrap in try/except * fis issue with `zn.Nodes` by introducing `hash_only` keyword * add and fix tests, add documentation --- tests/integration_tests/test_zn_nodes2.py | 70 +++++++++++++++++++++++ tests/unit_tests/core/test_core_base.py | 46 ++++++++++++++- tests/unit_tests/core/test_dvcgraph.py | 20 ++++++- zntrack/core/base.py | 23 ++++++-- zntrack/zn/nodes.py | 22 +++---- zntrack/zn/plots.py | 1 - 6 files changed, 158 insertions(+), 24 deletions(-) diff --git a/tests/integration_tests/test_zn_nodes2.py b/tests/integration_tests/test_zn_nodes2.py index c4ffeae6..6be24b7f 100644 --- a/tests/integration_tests/test_zn_nodes2.py +++ b/tests/integration_tests/test_zn_nodes2.py @@ -1,5 +1,7 @@ import dataclasses +import pytest +import yaml import znjson from zntrack import zn @@ -25,6 +27,24 @@ def run(self): self.outs = self.params1.param1 + self.params2.param2 +class NodeWithPlots(Node): + _hash = zn.Hash() + plots = zn.plots() + factor: float = zn.params() + + def run(self): + pass + + +class ExampleUsesPlots(Node): + node_with_plots: NodeWithPlots = zn.Nodes() + param: int = zn.params() + out = zn.outs() + + def run(self): + self.out = self.node_with_plots.factor * self.param + + def test_ExampleNode(proj_path): ExampleNode( params1=NodeViaParams(param1="Hello", param2="World"), @@ -152,3 +172,53 @@ def test_DataclassNode(proj_path): assert node.node.parameter.value == 10 znjson.deregister(ParameterConverter) + + +@pytest.mark.parametrize("node_name", ("ExampleUsesPlots", "test12")) +def test_ExampleUsesPlots(proj_path, node_name): + node = ExampleUsesPlots( + node_with_plots=NodeWithPlots(factor=2.5), param=2.0, name=node_name + ) + assert node.node_with_plots._is_attribute is True + assert node.node_with_plots.node_name == f"{node_name}-node_with_plots" + assert len(node.node_with_plots._descriptor_list) == 2 + + node.write_graph() + ExampleUsesPlots[node_name].run_and_save() + + assert ExampleUsesPlots[node_name].out == 2.5 * 2.0 + + # Just checking if changing the parameters works as well + with open("params.yaml", "r") as file: + parameters = yaml.safe_load(file) + parameters[f"{node_name}-node_with_plots"]["factor"] = 1.0 + with open("params.yaml", "a") as file: + yaml.safe_dump(parameters, file) + + assert ExampleUsesPlots[node_name].node_with_plots.factor == 1.0 + + +class NodeAsDataClass(Node): + _hash = zn.Hash() + param1 = zn.params() + param2 = zn.params() + param3 = zn.params() + + +class UseNodeAsDataClass(Node): + params: NodeAsDataClass = zn.Nodes() + output = zn.outs() + + def run(self): + self.output = self.params.param1 + self.params.param2 + self.params.param3 + + +def test_UseNodeAsDataClass(proj_path): + node = UseNodeAsDataClass(params=NodeAsDataClass(param1=1, param2=10, param3=100)) + node.write_graph(run=True) + + node = UseNodeAsDataClass.load() + assert node.output == 111 + assert node.params.param1 == 1 + assert node.params.param2 == 10 + assert node.params.param3 == 100 diff --git a/tests/unit_tests/core/test_core_base.py b/tests/unit_tests/core/test_core_base.py index 9386828c..bd445300 100644 --- a/tests/unit_tests/core/test_core_base.py +++ b/tests/unit_tests/core/test_core_base.py @@ -5,7 +5,7 @@ import pytest import yaml -from zntrack import dvc, zn +from zntrack import dvc, utils, zn from zntrack.core.base import LoadViaGetItem, Node, update_dependency_options @@ -31,6 +31,15 @@ def run(self): self.zn_outs = "outs" +class ExampleHashNode(Node): + hash = zn.Hash() + # None of these are tested, they should be ignored + params = zn.params(10) + zn_outs = zn.outs() + dvc_outs = dvc.outs("file.txt") + deps = dvc.deps("deps.inp") + + @pytest.mark.parametrize("run", (True, False)) def test_save(run): zntrack_mock = mock_open(read_data="{}") @@ -56,8 +65,11 @@ def pathlib_open(*args, **kwargs): assert zn_outs_mock().write.mock_calls == [ call(json.dumps({"zn_outs": "outs"}, indent=4)) ] + assert not zntrack_mock().write.called + assert not params_mock().write.called else: example.save() + assert not zn_outs_mock().write.called assert zntrack_mock().write.mock_calls == [ call(json.dumps({})), # clear everything first call( @@ -79,6 +91,38 @@ def pathlib_open(*args, **kwargs): ] +def test_save_only_hash(): + zntrack_mock = mock_open(read_data="{}") + params_mock = mock_open(read_data="{}") + zn_outs_mock = mock_open(read_data="{}") + hash_mock = mock_open(read_data="{}") + + example = ExampleFullNode() + + with pytest.raises(utils.exceptions.DescriptorMissing): + example.save(hash_only=True) + + def pathlib_open(*args, **kwargs): + if args[0] == pathlib.Path("zntrack.json"): + return zntrack_mock(*args, **kwargs) + elif args[0] == pathlib.Path("params.yaml"): + return params_mock(*args, **kwargs) + elif args[0] == pathlib.Path("nodes/ExampleFullNode/outs.json"): + return zn_outs_mock(*args, **kwargs) + elif args[0] == pathlib.Path("nodes/ExampleHashNode/hash.json"): + return hash_mock(*args, **kwargs) + else: + raise ValueError(args) + + example = ExampleHashNode() + with patch.object(pathlib.Path, "open", pathlib_open): + example.save(hash_only=True) + assert not params_mock().write.called + assert not zntrack_mock().write.called + assert not zn_outs_mock().write.called + assert hash_mock().write.called + + def test__load(): zntrack_mock = mock_open( read_data=json.dumps( diff --git a/tests/unit_tests/core/test_dvcgraph.py b/tests/unit_tests/core/test_dvcgraph.py index d6fc321e..350969f3 100644 --- a/tests/unit_tests/core/test_dvcgraph.py +++ b/tests/unit_tests/core/test_dvcgraph.py @@ -107,17 +107,35 @@ def test_affected_files(): class ExampleClassWithParams(Node): - is_loaded = False param1 = zn.params(default=1) param2 = zn.params(default=2) +class ExampleClassDifferentTypes(Node): + _is_attribute = True + _hash = zn.Hash() + param = zn.params(1) + outs = dvc.outs("file.txt") + metrics = zn.metrics() + plots = zn.plots() + + def test__descriptor_list(): example = ExampleClassWithParams() assert len(example._descriptor_list) == 2 +def test_descriptor_list_attr(): + """test the descriptor list if _is_attribute=True""" + example = ExampleClassDifferentTypes() + + assert len(example._descriptor_list) == 2 + + example._is_attribute = False + assert len(example._descriptor_list) == 5 + + def test_descriptor_list_filter(): example = ExampleClassWithParams() diff --git a/zntrack/core/base.py b/zntrack/core/base.py index 6880bcf2..db681511 100644 --- a/zntrack/core/base.py +++ b/zntrack/core/base.py @@ -229,7 +229,7 @@ def save_plots(self): if option.zn_type is utils.ZnTypes.PLOTS: option.save(instance=self) - def save(self, results: bool = False): + def save(self, results: bool = False, hash_only: bool = False): """Save Class state to files Parameters @@ -239,12 +239,25 @@ def save(self, results: bool = False): By default, this function saves e.g. parameters from zn.params / dvc.