Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add filename kwargs
Browse files Browse the repository at this point in the history
PythonFZ committed May 2, 2022
1 parent 5bb55e5 commit 2ceb7e7
Showing 2 changed files with 36 additions and 1 deletion.
25 changes: 25 additions & 0 deletions tests/integration_tests/test_single_node.py
Original file line number Diff line number Diff line change
@@ -325,3 +325,28 @@ def test_load_named_nodes(proj_path):
# this will run load with name=Node01, lazy=True/False
assert ExampleNode01[{"name": "Node01", "lazy": True}].outputs == 42
assert ExampleNode01[{"name": "Node01", "lazy": False}].outputs == 42


class NodeCustomFileName(Node):
output_std = zn.outs()
output_custom = zn.outs(filename="custom_data")

def run(self):
self.output_std = "Hello World"
self.output_custom = "Lorem Ipsum"


def test_NodeCustomFileName(proj_path):
NodeCustomFileName().write_graph(run=True)

assert NodeCustomFileName.load().output_std == "Hello World"
assert NodeCustomFileName.load().output_custom == "Lorem Ipsum"

output_std = pathlib.Path("nodes", "NodeCustomFileName", "outs.json")
output_custom = pathlib.Path("nodes", "NodeCustomFileName", "custom_data.json")

assert output_std.exists()
assert output_custom.exists()
#
assert json.loads(output_std.read_text())["output_std"] == "Hello World"
assert json.loads(output_custom.read_text())["output_custom"] == "Lorem Ipsum"
12 changes: 11 additions & 1 deletion zntrack/core/zntrackoption.py
Original file line number Diff line number Diff line change
@@ -60,6 +60,14 @@ class ZnTrackOption(descriptor.Descriptor):

def __init__(self, default_value=None, **kwargs):
"""Constructor for ZnTrackOptions
Attributes
----------
default_value: Any
The default value of the descriptor
filename:
part of the kwargs, optional filename overwrite.
Raises
------
ValueError: If dvc_option is None and the class name is not in utils.DVCOptions
@@ -70,6 +78,8 @@ def __init__(self, default_value=None, **kwargs):
if self.dvc_option is None:
# use the name of the class as DVCOption if registered in DVCOptions
self.dvc_option = utils.DVCOptions(self.__class__.__name__).value

self.filename = kwargs.pop("filename", self.dvc_option)
super().__init__(default_value=default_value, **kwargs)

@property
@@ -132,7 +142,7 @@ def __get__(self, instance, owner):
def get_filename(self, instance) -> pathlib.Path:
"""Get the name of the file this ZnTrackOption will save its values to"""
if uses_node_name(self.zn_type, instance) is None:
return pathlib.Path("nodes", instance.node_name, f"{self.dvc_option}.json")
return pathlib.Path("nodes", instance.node_name, f"{self.filename}.json")
return pathlib.Path(self.file)

def save(self, instance):

0 comments on commit 2ceb7e7

Please sign in to comment.