diff --git a/tests/integration_tests/test_zn_nodes2.py b/tests/integration_tests/test_zn_nodes2.py
index c4ffeae6..6be24b7f 100644
--- a/tests/integration_tests/test_zn_nodes2.py
+++ b/tests/integration_tests/test_zn_nodes2.py
@@ -1,5 +1,7 @@
import dataclasses
+import pytest
+import yaml
import znjson
from zntrack import zn
@@ -25,6 +27,24 @@ def run(self):
self.outs = self.params1.param1 + self.params2.param2
+class NodeWithPlots(Node):
+ _hash = zn.Hash()
+ plots = zn.plots()
+ factor: float = zn.params()
+
+ def run(self):
+ pass
+
+
+class ExampleUsesPlots(Node):
+ node_with_plots: NodeWithPlots = zn.Nodes()
+ param: int = zn.params()
+ out = zn.outs()
+
+ def run(self):
+ self.out = self.node_with_plots.factor * self.param
+
+
def test_ExampleNode(proj_path):
ExampleNode(
params1=NodeViaParams(param1="Hello", param2="World"),
@@ -152,3 +172,53 @@ def test_DataclassNode(proj_path):
assert node.node.parameter.value == 10
znjson.deregister(ParameterConverter)
+
+
+@pytest.mark.parametrize("node_name", ("ExampleUsesPlots", "test12"))
+def test_ExampleUsesPlots(proj_path, node_name):
+ node = ExampleUsesPlots(
+ node_with_plots=NodeWithPlots(factor=2.5), param=2.0, name=node_name
+ )
+ assert node.node_with_plots._is_attribute is True
+ assert node.node_with_plots.node_name == f"{node_name}-node_with_plots"
+ assert len(node.node_with_plots._descriptor_list) == 2
+
+ node.write_graph()
+ ExampleUsesPlots[node_name].run_and_save()
+
+ assert ExampleUsesPlots[node_name].out == 2.5 * 2.0
+
+ # Just checking if changing the parameters works as well
+ with open("params.yaml", "r") as file:
+ parameters = yaml.safe_load(file)
+ parameters[f"{node_name}-node_with_plots"]["factor"] = 1.0
+ with open("params.yaml", "a") as file:
+ yaml.safe_dump(parameters, file)
+
+ assert ExampleUsesPlots[node_name].node_with_plots.factor == 1.0
+
+
+class NodeAsDataClass(Node):
+ _hash = zn.Hash()
+ param1 = zn.params()
+ param2 = zn.params()
+ param3 = zn.params()
+
+
+class UseNodeAsDataClass(Node):
+ params: NodeAsDataClass = zn.Nodes()
+ output = zn.outs()
+
+ def run(self):
+ self.output = self.params.param1 + self.params.param2 + self.params.param3
+
+
+def test_UseNodeAsDataClass(proj_path):
+ node = UseNodeAsDataClass(params=NodeAsDataClass(param1=1, param2=10, param3=100))
+ node.write_graph(run=True)
+
+ node = UseNodeAsDataClass.load()
+ assert node.output == 111
+ assert node.params.param1 == 1
+ assert node.params.param2 == 10
+ assert node.params.param3 == 100
diff --git a/tests/unit_tests/core/test_core_base.py b/tests/unit_tests/core/test_core_base.py
index 9386828c..bd445300 100644
--- a/tests/unit_tests/core/test_core_base.py
+++ b/tests/unit_tests/core/test_core_base.py
@@ -5,7 +5,7 @@
import pytest
import yaml
-from zntrack import dvc, zn
+from zntrack import dvc, utils, zn
from zntrack.core.base import LoadViaGetItem, Node, update_dependency_options
@@ -31,6 +31,15 @@ def run(self):
self.zn_outs = "outs"
+class ExampleHashNode(Node):
+ hash = zn.Hash()
+ # None of these are tested, they should be ignored
+ params = zn.params(10)
+ zn_outs = zn.outs()
+ dvc_outs = dvc.outs("file.txt")
+ deps = dvc.deps("deps.inp")
+
+
@pytest.mark.parametrize("run", (True, False))
def test_save(run):
zntrack_mock = mock_open(read_data="{}")
@@ -56,8 +65,11 @@ def pathlib_open(*args, **kwargs):
assert zn_outs_mock().write.mock_calls == [
call(json.dumps({"zn_outs": "outs"}, indent=4))
]
+ assert not zntrack_mock().write.called
+ assert not params_mock().write.called
else:
example.save()
+ assert not zn_outs_mock().write.called
assert zntrack_mock().write.mock_calls == [
call(json.dumps({})), # clear everything first
call(
@@ -79,6 +91,38 @@ def pathlib_open(*args, **kwargs):
]
+def test_save_only_hash():
+ zntrack_mock = mock_open(read_data="{}")
+ params_mock = mock_open(read_data="{}")
+ zn_outs_mock = mock_open(read_data="{}")
+ hash_mock = mock_open(read_data="{}")
+
+ example = ExampleFullNode()
+
+ with pytest.raises(utils.exceptions.DescriptorMissing):
+ example.save(hash_only=True)
+
+ def pathlib_open(*args, **kwargs):
+ if args[0] == pathlib.Path("zntrack.json"):
+ return zntrack_mock(*args, **kwargs)
+ elif args[0] == pathlib.Path("params.yaml"):
+ return params_mock(*args, **kwargs)
+ elif args[0] == pathlib.Path("nodes/ExampleFullNode/outs.json"):
+ return zn_outs_mock(*args, **kwargs)
+ elif args[0] == pathlib.Path("nodes/ExampleHashNode/hash.json"):
+ return hash_mock(*args, **kwargs)
+ else:
+ raise ValueError(args)
+
+ example = ExampleHashNode()
+ with patch.object(pathlib.Path, "open", pathlib_open):
+ example.save(hash_only=True)
+ assert not params_mock().write.called
+ assert not zntrack_mock().write.called
+ assert not zn_outs_mock().write.called
+ assert hash_mock().write.called
+
+
def test__load():
zntrack_mock = mock_open(
read_data=json.dumps(
diff --git a/tests/unit_tests/core/test_dvcgraph.py b/tests/unit_tests/core/test_dvcgraph.py
index d6fc321e..350969f3 100644
--- a/tests/unit_tests/core/test_dvcgraph.py
+++ b/tests/unit_tests/core/test_dvcgraph.py
@@ -107,17 +107,35 @@ def test_affected_files():
class ExampleClassWithParams(Node):
- is_loaded = False
param1 = zn.params(default=1)
param2 = zn.params(default=2)
+class ExampleClassDifferentTypes(Node):
+ _is_attribute = True
+ _hash = zn.Hash()
+ param = zn.params(1)
+ outs = dvc.outs("file.txt")
+ metrics = zn.metrics()
+ plots = zn.plots()
+
+
def test__descriptor_list():
example = ExampleClassWithParams()
assert len(example._descriptor_list) == 2
+def test_descriptor_list_attr():
+ """test the descriptor list if _is_attribute=True"""
+ example = ExampleClassDifferentTypes()
+
+ assert len(example._descriptor_list) == 2
+
+ example._is_attribute = False
+ assert len(example._descriptor_list) == 5
+
+
def test_descriptor_list_filter():
example = ExampleClassWithParams()
diff --git a/zntrack/core/base.py b/zntrack/core/base.py
index 6880bcf2..db681511 100644
--- a/zntrack/core/base.py
+++ b/zntrack/core/base.py
@@ -229,7 +229,7 @@ def save_plots(self):
if option.zn_type is utils.ZnTypes.PLOTS:
option.save(instance=self)
- def save(self, results: bool = False):
+ def save(self, results: bool = False, hash_only: bool = False):
"""Save Class state to files
Parameters
@@ -239,12 +239,25 @@ def save(self, results: bool = False):
By default, this function saves e.g. parameters from zn.params / dvc.