diff --git a/README.md b/README.md index 9104d548..4cb27166 100644 --- a/README.md +++ b/README.md @@ -52,16 +52,18 @@ class HelloWorld(Node): if __name__ == "__main__": # Write the computational graph - HelloWorld(max_number=512).write_graph() + with zntrack.Project() as project: + hello_world = HelloWorld(max_number=512) + project.run() ``` This will create a [DVC](https://dvc.org) stage ``HelloWorld``. The workflow is defined in ``dvc.yaml`` and the parameters are stored in ``params.yaml``. -You can run the workflow with ``dvc repro``. +This will run the workflow with ``dvc repro`` automatically. Once the graph is executed, the results, i.e. the random number can be accessed directly by the Node object. ```python -hello_world = HelloWorld.load() +hello_world.load() print(hello_world.random_numer) ``` An overview of all the ZnTrack features as well as more detailed examples can be found in the [ZnTrack Documentation](https://zntrack.readthedocs.io/en/latest/). @@ -81,7 +83,9 @@ def write_text(cfg: NodeConfig): cfg.params.text ) # build the DVC graph -write_text() +with zntrack.Project() as project: + write_text() +project.run() ```` The ``cfg`` dataclass passed to the function provides access to all configured files diff --git a/pyproject.toml b/pyproject.toml index 3c335968..2e0c70ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ZnTrack" -version = "0.6.0a4" +version = "0.6.0a5" description = "Create, Run and Benchmark DVC Pipelines in Python" authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index d2efdce6..f3d1cbfe 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -106,9 +106,11 @@ def test_automatic_node_names_True(tmp_path_2): with zntrack.Project(automatic_node_names=True) as project: node = WriteIO(inputs="Hello World") node2 = WriteIO(inputs="Lorem Ipsum") + node3 = WriteIO(inputs="Lorem Ipsum") assert node.name == "WriteIO" assert node2.name == "WriteIO_1" + assert node3.name == "WriteIO_2" project.run() project.load() diff --git a/tests/test_zntrack.py b/tests/test_zntrack.py index 3c1e2e59..eab76eef 100644 --- a/tests/test_zntrack.py +++ b/tests/test_zntrack.py @@ -4,4 +4,4 @@ def test_version(): """Test 'ZnTrack' version.""" - assert __version__ == "0.6.0a4" + assert __version__ == "0.6.0a5" diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index bd884fbf..5d907a1c 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -6,6 +6,7 @@ import json import logging import pathlib +import shutil import git import yaml @@ -80,6 +81,7 @@ def __post_init__(self): pathlib.Path("zntrack.json").unlink(missing_ok=True) pathlib.Path("dvc.yaml").unlink(missing_ok=True) pathlib.Path("params.yaml").unlink(missing_ok=True) + shutil.rmtree("nodes", ignore_errors=True) def __enter__(self, *args, **kwargs): """Enter the graph context.""" @@ -98,11 +100,12 @@ def update_node_names(self): for node_uuid in self.graph.get_sorted_nodes(): node: Node = self.graph.nodes[node_uuid]["value"] if self.automatic_node_names: - idx = 1 - while node.name in node_names: + if node.name in node_names: + idx = 1 + while f"{node.name}_{idx}" in node_names: + idx += 1 node.name = f"{node.name}_{idx}" log.debug(f"Updating {node.name = }") - idx += 1 elif node.name in node_names: raise exceptions.DuplicateNodeNameError(node)