Skip to content

Commit

Permalink
fix zn.Nodes issue with zn.plots (#338)
Browse files Browse the repository at this point in the history
* wrap in try/except

* fis issue with `zn.Nodes` by introducing `hash_only` keyword

* add and fix tests, add documentation
  • Loading branch information
PythonFZ authored Sep 22, 2022
1 parent ef90816 commit 1803cb5
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 24 deletions.
70 changes: 70 additions & 0 deletions tests/integration_tests/test_zn_nodes2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses

import pytest
import yaml
import znjson

from zntrack import zn
Expand All @@ -25,6 +27,24 @@ def run(self):
self.outs = self.params1.param1 + self.params2.param2


class NodeWithPlots(Node):
_hash = zn.Hash()
plots = zn.plots()
factor: float = zn.params()

def run(self):
pass


class ExampleUsesPlots(Node):
node_with_plots: NodeWithPlots = zn.Nodes()
param: int = zn.params()
out = zn.outs()

def run(self):
self.out = self.node_with_plots.factor * self.param


def test_ExampleNode(proj_path):
ExampleNode(
params1=NodeViaParams(param1="Hello", param2="World"),
Expand Down Expand Up @@ -152,3 +172,53 @@ def test_DataclassNode(proj_path):
assert node.node.parameter.value == 10

znjson.deregister(ParameterConverter)


@pytest.mark.parametrize("node_name", ("ExampleUsesPlots", "test12"))
def test_ExampleUsesPlots(proj_path, node_name):
node = ExampleUsesPlots(
node_with_plots=NodeWithPlots(factor=2.5), param=2.0, name=node_name
)
assert node.node_with_plots._is_attribute is True
assert node.node_with_plots.node_name == f"{node_name}-node_with_plots"
assert len(node.node_with_plots._descriptor_list) == 2

node.write_graph()
ExampleUsesPlots[node_name].run_and_save()

assert ExampleUsesPlots[node_name].out == 2.5 * 2.0

# Just checking if changing the parameters works as well
with open("params.yaml", "r") as file:
parameters = yaml.safe_load(file)
parameters[f"{node_name}-node_with_plots"]["factor"] = 1.0
with open("params.yaml", "a") as file:
yaml.safe_dump(parameters, file)

assert ExampleUsesPlots[node_name].node_with_plots.factor == 1.0


class NodeAsDataClass(Node):
_hash = zn.Hash()
param1 = zn.params()
param2 = zn.params()
param3 = zn.params()


class UseNodeAsDataClass(Node):
params: NodeAsDataClass = zn.Nodes()
output = zn.outs()

def run(self):
self.output = self.params.param1 + self.params.param2 + self.params.param3


def test_UseNodeAsDataClass(proj_path):
node = UseNodeAsDataClass(params=NodeAsDataClass(param1=1, param2=10, param3=100))
node.write_graph(run=True)

node = UseNodeAsDataClass.load()
assert node.output == 111
assert node.params.param1 == 1
assert node.params.param2 == 10
assert node.params.param3 == 100
46 changes: 45 additions & 1 deletion tests/unit_tests/core/test_core_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import yaml

from zntrack import dvc, zn
from zntrack import dvc, utils, zn
from zntrack.core.base import LoadViaGetItem, Node, update_dependency_options


Expand All @@ -31,6 +31,15 @@ def run(self):
self.zn_outs = "outs"


class ExampleHashNode(Node):
hash = zn.Hash()
# None of these are tested, they should be ignored
params = zn.params(10)
zn_outs = zn.outs()
dvc_outs = dvc.outs("file.txt")
deps = dvc.deps("deps.inp")


@pytest.mark.parametrize("run", (True, False))
def test_save(run):
zntrack_mock = mock_open(read_data="{}")
Expand All @@ -56,8 +65,11 @@ def pathlib_open(*args, **kwargs):
assert zn_outs_mock().write.mock_calls == [
call(json.dumps({"zn_outs": "outs"}, indent=4))
]
assert not zntrack_mock().write.called
assert not params_mock().write.called
else:
example.save()
assert not zn_outs_mock().write.called
assert zntrack_mock().write.mock_calls == [
call(json.dumps({})), # clear everything first
call(
Expand All @@ -79,6 +91,38 @@ def pathlib_open(*args, **kwargs):
]


def test_save_only_hash():
zntrack_mock = mock_open(read_data="{}")
params_mock = mock_open(read_data="{}")
zn_outs_mock = mock_open(read_data="{}")
hash_mock = mock_open(read_data="{}")

example = ExampleFullNode()

with pytest.raises(utils.exceptions.DescriptorMissing):
example.save(hash_only=True)

def pathlib_open(*args, **kwargs):
if args[0] == pathlib.Path("zntrack.json"):
return zntrack_mock(*args, **kwargs)
elif args[0] == pathlib.Path("params.yaml"):
return params_mock(*args, **kwargs)
elif args[0] == pathlib.Path("nodes/ExampleFullNode/outs.json"):
return zn_outs_mock(*args, **kwargs)
elif args[0] == pathlib.Path("nodes/ExampleHashNode/hash.json"):
return hash_mock(*args, **kwargs)
else:
raise ValueError(args)

example = ExampleHashNode()
with patch.object(pathlib.Path, "open", pathlib_open):
example.save(hash_only=True)
assert not params_mock().write.called
assert not zntrack_mock().write.called
assert not zn_outs_mock().write.called
assert hash_mock().write.called


def test__load():
zntrack_mock = mock_open(
read_data=json.dumps(
Expand Down
20 changes: 19 additions & 1 deletion tests/unit_tests/core/test_dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,35 @@ def test_affected_files():


class ExampleClassWithParams(Node):
is_loaded = False
param1 = zn.params(default=1)
param2 = zn.params(default=2)


class ExampleClassDifferentTypes(Node):
_is_attribute = True
_hash = zn.Hash()
param = zn.params(1)
outs = dvc.outs("file.txt")
metrics = zn.metrics()
plots = zn.plots()


def test__descriptor_list():
example = ExampleClassWithParams()

assert len(example._descriptor_list) == 2


def test_descriptor_list_attr():
"""test the descriptor list if _is_attribute=True"""
example = ExampleClassDifferentTypes()

assert len(example._descriptor_list) == 2

example._is_attribute = False
assert len(example._descriptor_list) == 5


def test_descriptor_list_filter():
example = ExampleClassWithParams()

Expand Down
23 changes: 17 additions & 6 deletions zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def save_plots(self):
if option.zn_type is utils.ZnTypes.PLOTS:
option.save(instance=self)

def save(self, results: bool = False):
def save(self, results: bool = False, hash_only: bool = False):
"""Save Class state to files
Parameters
Expand All @@ -239,12 +239,25 @@ def save(self, results: bool = False):
By default, this function saves e.g. parameters from zn.params / dvc.<option>,
but does not save results that are stored in zn.<option>.
Set this option to True if they should be saved, e.g. in run_and_save
hash_only: bool, default = False
Only save zn.Hash and nothing else. This is required for usage as zn.Nodes
"""
if hash_only:
try:
zninit.get_descriptors(zn.Hash, self=self)[0].save(instance=self)
except IndexError as err:
raise utils.exceptions.DescriptorMissing(
"Could not find a hash descriptor. Please add zn.Hash()"
) from err
return

if not results:
# Reset everything in params.yaml and zntrack.json before saving
utils.file_io.clear_config_file(utils.Files.params, node_name=self.node_name)
utils.file_io.clear_config_file(utils.Files.zntrack, node_name=self.node_name)
# Save dvc.<option>, dvc.deps, zn.Method

for option in self._descriptor_list:
if results:
if option.zn_type in utils.VALUE_DVC_TRACKED:
Expand Down Expand Up @@ -416,17 +429,15 @@ def _handle_nodes_as_methods(self):
zn.Nodes ZnTrackOptions will require a dedicated graph to be written.
They are shown in the dvc dag and have their own parameter section.
The name is <nodename>-<attributename> for these Nodes and they only
The name is <nodename>-<attributename> for these Nodes. They only
have a single hash output to be available for DVC dependencies.
"""
for attribute, node in self.zntrack.collect(zn_nodes).items():
if node is None:
continue
node.node_name = f"{self.node_name}-{attribute}"
node._is_attribute = True
node.write_graph(
run=True,
call_args=f".load(name='{node.node_name}').save(results=True)",
call_args=f"['{node.node_name}'].save(hash_only=True)",
)

@property
Expand Down Expand Up @@ -553,7 +564,7 @@ def write_graph(
custom_args += pair

if call_args is None:
call_args = f".load(name='{self.node_name}').run_and_save()"
call_args = f"['{self.node_name}'].run_and_save()"

script = prepare_dvc_script(
node_name=self.node_name,
Expand Down
22 changes: 7 additions & 15 deletions zntrack/zn/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,10 @@


class Nodes(ZnTrackOption):
"""ZnTrack methods passing descriptor
This descriptor allows to pass a class instance that is not a ZnTrack Node as a
method that can be used later. It requires that all passed class attributes have
the same name in the __init__ and via getattr an that they are serializable.
Example
--------
>>> class HelloWorld:
>>> def __init__(self, name):
>>> self.name = name
>>>
>>> class MyNode(zntrack.Node)
>>> my_method = Method()
>>> MyNode().my_method = HelloWorld(name="Max")
"""Have a ZnTrack Node as an attribute to another (main) Node
If you want to use a method of another ZnTrack Node you can pass it
as a zn.Nodes()
"""

dvc_option = utils.DVCOptions.DEPS
Expand Down Expand Up @@ -69,6 +57,10 @@ def __get__(self, instance, owner=None):
if instance is None:
return self
value = super().__get__(instance, owner)
if value is not None:
value._is_attribute = True
value.node_name = f"{instance.node_name}-{self.name}"
# value._is_attribute = True # value can be None
value = utils.utils.load_node_dependency(value) # use value = Cls.load()
setattr(instance, self.name, value)
return value
1 change: 0 additions & 1 deletion zntrack/zn/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,5 @@ def save(self, instance):

def get_data_from_files(self, instance):
"""Load value with pd.read_csv"""

file = self.get_filename(instance)
return pd.read_csv(file, index_col=0)

0 comments on commit 1803cb5

Please sign in to comment.