diff --git a/tests/integration_tests/test_single_node.py b/tests/integration_tests/test_single_node.py index 8e49e2aa..b5506f0a 100644 --- a/tests/integration_tests/test_single_node.py +++ b/tests/integration_tests/test_single_node.py @@ -325,3 +325,28 @@ def test_load_named_nodes(proj_path): # this will run load with name=Node01, lazy=True/False assert ExampleNode01[{"name": "Node01", "lazy": True}].outputs == 42 assert ExampleNode01[{"name": "Node01", "lazy": False}].outputs == 42 + + +class NodeCustomFileName(Node): + output_std = zn.outs() + output_custom = zn.outs(filename="custom_data") + + def run(self): + self.output_std = "Hello World" + self.output_custom = "Lorem Ipsum" + + +def test_NodeCustomFileName(proj_path): + NodeCustomFileName().write_graph(run=True) + + assert NodeCustomFileName.load().output_std == "Hello World" + assert NodeCustomFileName.load().output_custom == "Lorem Ipsum" + + output_std = pathlib.Path("nodes", "NodeCustomFileName", "outs.json") + output_custom = pathlib.Path("nodes", "NodeCustomFileName", "custom_data.json") + + assert output_std.exists() + assert output_custom.exists() + # + assert json.loads(output_std.read_text())["output_std"] == "Hello World" + assert json.loads(output_custom.read_text())["output_custom"] == "Lorem Ipsum" diff --git a/zntrack/core/zntrackoption.py b/zntrack/core/zntrackoption.py index 03506442..a8fd9378 100644 --- a/zntrack/core/zntrackoption.py +++ b/zntrack/core/zntrackoption.py @@ -60,6 +60,14 @@ class ZnTrackOption(descriptor.Descriptor): def __init__(self, default_value=None, **kwargs): """Constructor for ZnTrackOptions + + Attributes + ---------- + default_value: Any + The default value of the descriptor + filename: + part of the kwargs, optional filename overwrite. + Raises ------ ValueError: If dvc_option is None and the class name is not in utils.DVCOptions @@ -70,6 +78,8 @@ def __init__(self, default_value=None, **kwargs): if self.dvc_option is None: # use the name of the class as DVCOption if registered in DVCOptions self.dvc_option = utils.DVCOptions(self.__class__.__name__).value + + self.filename = kwargs.pop("filename", self.dvc_option) super().__init__(default_value=default_value, **kwargs) @property @@ -132,7 +142,7 @@ def __get__(self, instance, owner): def get_filename(self, instance) -> pathlib.Path: """Get the name of the file this ZnTrackOption will save its values to""" if uses_node_name(self.zn_type, instance) is None: - return pathlib.Path("nodes", instance.node_name, f"{self.dvc_option}.json") + return pathlib.Path("nodes", instance.node_name, f"{self.filename}.json") return pathlib.Path(self.file) def save(self, instance):