From bd067465dc284091fc123f7a1b801a1b69b5dc30 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Sun, 23 Jul 2023 17:37:25 +0200 Subject: [PATCH 1/4] construct test cases --- tests/integration/test_project.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index 9671621a..ae4a0016 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -276,3 +276,22 @@ def test_build_groups(tmp_path_2): with pytest.raises(ValueError): project.run(nodes=[42]) + + +def test_groups_nwd(tmp_path_2): + with zntrack.Project(automatic_node_names=True) as project: + node_1 = WriteIO(inputs="Lorem Ipsum") + with project.group() as group_1: + node_2 = WriteIO(inputs="Dolor Sit") + with project.group(name="CustomGroup") as group_2: + node_3 = WriteIO(inputs="Adipiscing Elit") + + project.build() + + assert node_1.nwd == pathlib.Path("nodes", node_1.name) + assert node_2.nwd == pathlib.Path( + "nodes", group_1.name, node_2.name.replace(f"{group_1.name}_", "") + ) + assert node_3.nwd == pathlib.Path( + "nodes", group_2.name, node_3.name.replace(f"{group_2.name}_", "") + ) From ef49a0df9c61cf5a615586e3475cabe3be94e401 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Mon, 24 Jul 2023 09:31:06 +0200 Subject: [PATCH 2/4] group based nwd --- zntrack/core/node.py | 11 ++++++++++- zntrack/project/zntrack_project.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/zntrack/core/node.py b/zntrack/core/node.py index 4e32ab6a..76381dbe 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -162,7 +162,16 @@ def state(self) -> NodeStatus: @property def nwd(self) -> pathlib.Path: """Get the node working directory.""" - nwd = pathlib.Path("nodes", znflow.get_attribute(self, "name")) + try: + nwd = self.__dict__["nwd"] + except KeyError: + try: + zntrack_config = json.loads(pathlib.Path("zntrack.json").read_text()) + nwd = pathlib.Path( + zntrack_config[znflow.get_attribute(self, "name")]["nwd"] + ) + except (FileNotFoundError, KeyError): + nwd = pathlib.Path("nodes", znflow.get_attribute(self, "name")) if not nwd.exists(): nwd.mkdir(parents=True) return nwd diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index 111e3282..f46d8bbc 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -157,6 +157,7 @@ def group(self, name: str = None): for node_uuid in self.graph.get_sorted_nodes(): node: Node = self.graph.nodes[node_uuid]["value"] if node_uuid not in existing_nodes: + node.__dict__["nwd"] = pathlib.Path("nodes", group.name, node.name) node.name = f"{name}_{node.name}" group.nodes.append(node) From ec61bd1dc71a4f8ada187602f6e8deae0d13cf96 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Mon, 24 Jul 2023 09:42:10 +0200 Subject: [PATCH 3/4] update tests; use znjson --- tests/integration/test_project.py | 24 ++++++++++++++++++++++++ zntrack/core/node.py | 13 ++++++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index ae4a0016..732697be 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -1,3 +1,4 @@ +import json import pathlib import pytest @@ -295,3 +296,26 @@ def test_groups_nwd(tmp_path_2): assert node_3.nwd == pathlib.Path( "nodes", group_2.name, node_3.name.replace(f"{group_2.name}_", "") ) + # now load the Nodes and assert as well + + assert zntrack.from_rev(node_1).nwd == pathlib.Path("nodes", node_1.name) + assert zntrack.from_rev(node_2).nwd == pathlib.Path( + "nodes", group_1.name, node_2.name.replace(f"{group_1.name}_", "") + ) + assert zntrack.from_rev(node_3).nwd == pathlib.Path( + "nodes", group_2.name, node_3.name.replace(f"{group_2.name}_", "") + ) + + with open("zntrack.json") as f: + data = json.load(f) + data[node_1.name]["nwd"]["value"] = "test" + data[node_2.name].pop("nwd") + + with open("zntrack.json", "w") as f: + json.dump(data, f) + + assert zntrack.from_rev(node_1).nwd == pathlib.Path("test") + assert zntrack.from_rev(node_2).nwd == pathlib.Path("nodes", node_2.name) + assert zntrack.from_rev(node_3).nwd == pathlib.Path( + "nodes", group_2.name, node_3.name.replace(f"{group_2.name}_", "") + ) diff --git a/zntrack/core/node.py b/zntrack/core/node.py index 76381dbe..62ddc533 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -166,10 +166,10 @@ def nwd(self) -> pathlib.Path: nwd = self.__dict__["nwd"] except KeyError: try: - zntrack_config = json.loads(pathlib.Path("zntrack.json").read_text()) - nwd = pathlib.Path( - zntrack_config[znflow.get_attribute(self, "name")]["nwd"] + zntrack_config = json.loads( + pathlib.Path("zntrack.json").read_text(), cls=znjson.ZnDecoder ) + nwd = zntrack_config[znflow.get_attribute(self, "name")]["nwd"] except (FileNotFoundError, KeyError): nwd = pathlib.Path("nodes", znflow.get_attribute(self, "name")) if not nwd.exists(): @@ -211,6 +211,13 @@ def save( f"Field {attr} has no group. Please assign a group from" f" '{FieldGroup.__module__}.{FieldGroup.__name__}'." ) + # save the nwd to zntrack.json + file_io.update_config_file( + file=pathlib.Path("zntrack.json"), + node_name=self.name, + value_name="nwd", + value=self.nwd, + ) def run(self) -> None: """Run the node's code.""" From 52e81fa2ed6ab1568ad225e6a07899ecca4b7426 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Mon, 24 Jul 2023 10:16:04 +0200 Subject: [PATCH 4/4] fix bug --- zntrack/core/node.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/zntrack/core/node.py b/zntrack/core/node.py index 62ddc533..b48c1aff 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -166,10 +166,9 @@ def nwd(self) -> pathlib.Path: nwd = self.__dict__["nwd"] except KeyError: try: - zntrack_config = json.loads( - pathlib.Path("zntrack.json").read_text(), cls=znjson.ZnDecoder - ) + zntrack_config = json.loads(pathlib.Path("zntrack.json").read_text()) nwd = zntrack_config[znflow.get_attribute(self, "name")]["nwd"] + nwd = json.loads(json.dumps(nwd), cls=znjson.ZnDecoder) except (FileNotFoundError, KeyError): nwd = pathlib.Path("nodes", znflow.get_attribute(self, "name")) if not nwd.exists():