Skip to content

Commit

Permalink
Bugfixes (#582)
Browse files Browse the repository at this point in the history
* automatic_node_name bugfix

* bump version

* update README.md
  • Loading branch information
PythonFZ authored Apr 10, 2023
1 parent 45534d0 commit db2d94f
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,18 @@ class HelloWorld(Node):

if __name__ == "__main__":
# Write the computational graph
HelloWorld(max_number=512).write_graph()
with zntrack.Project() as project:
hello_world = HelloWorld(max_number=512)
project.run()
```

This will create a [DVC](https://dvc.org) stage ``HelloWorld``.
The workflow is defined in ``dvc.yaml`` and the parameters are stored in ``params.yaml``.

You can run the workflow with ``dvc repro``.
This will run the workflow with ``dvc repro`` automatically.
Once the graph is executed, the results, i.e. the random number can be accessed directly by the Node object.
```python
hello_world = HelloWorld.load()
hello_world.load()
print(hello_world.random_numer)
```
An overview of all the ZnTrack features as well as more detailed examples can be found in the [ZnTrack Documentation](https://zntrack.readthedocs.io/en/latest/).
Expand All @@ -81,7 +83,9 @@ def write_text(cfg: NodeConfig):
cfg.params.text
)
# build the DVC graph
write_text()
with zntrack.Project() as project:
write_text()
project.run()
````

The ``cfg`` dataclass passed to the function provides access to all configured files
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ZnTrack"
version = "0.6.0a4"
version = "0.6.0a5"
description = "Create, Run and Benchmark DVC Pipelines in Python"
authors = ["zincwarecode <zincwarecode@gmail.com>"]
license = "Apache-2.0"
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def test_automatic_node_names_True(tmp_path_2):
with zntrack.Project(automatic_node_names=True) as project:
node = WriteIO(inputs="Hello World")
node2 = WriteIO(inputs="Lorem Ipsum")
node3 = WriteIO(inputs="Lorem Ipsum")

assert node.name == "WriteIO"
assert node2.name == "WriteIO_1"
assert node3.name == "WriteIO_2"

project.run()
project.load()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zntrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

def test_version():
"""Test 'ZnTrack' version."""
assert __version__ == "0.6.0a4"
assert __version__ == "0.6.0a5"
9 changes: 6 additions & 3 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
import pathlib
import shutil

import git
import yaml
Expand Down Expand Up @@ -80,6 +81,7 @@ def __post_init__(self):
pathlib.Path("zntrack.json").unlink(missing_ok=True)
pathlib.Path("dvc.yaml").unlink(missing_ok=True)
pathlib.Path("params.yaml").unlink(missing_ok=True)
shutil.rmtree("nodes", ignore_errors=True)

def __enter__(self, *args, **kwargs):
"""Enter the graph context."""
Expand All @@ -98,11 +100,12 @@ def update_node_names(self):
for node_uuid in self.graph.get_sorted_nodes():
node: Node = self.graph.nodes[node_uuid]["value"]
if self.automatic_node_names:
idx = 1
while node.name in node_names:
if node.name in node_names:
idx = 1
while f"{node.name}_{idx}" in node_names:
idx += 1
node.name = f"{node.name}_{idx}"
log.debug(f"Updating {node.name = }")
idx += 1

elif node.name in node_names:
raise exceptions.DuplicateNodeNameError(node)
Expand Down

0 comments on commit db2d94f

Please sign in to comment.