Skip to content

Commit

Permalink
support for magic node names using varnames (#776)
Browse files Browse the repository at this point in the history
* support for magic node names using `varnames`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shorten line

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bugfix

* use contextlib

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove poetry lock

* do not allow automatic node names and magic node names

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Feb 22, 2024
1 parent fdb193f commit 6698c4a
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 3 deletions.
19 changes: 18 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dot4dict = "^0.1"
zninit = "^0.1"
znjson = "^0.2"
znflow = "^0.1"
varname = "^0.13"
# for Python3.12 compatibliity
pyzmq = "^25"

Expand Down
32 changes: 32 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,35 @@ def test_auto_remove(proj_path):
n1 = zntrack.examples.ParamsToOuts.from_rev(n1.name)
with pytest.raises(zntrack.exceptions.NodeNotAvailableError):
n2 = zntrack.examples.ParamsToOuts.from_rev(n2.name)


def test_magic_names(proj_path):
node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
assert node.name == "ParamsToOuts"
with pytest.raises(ValueError):
project = zntrack.Project(magic_names=True, automatic_node_names=True)

project = zntrack.Project(magic_names=True, automatic_node_names=False)
with project:
node01 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
node02 = zntrack.examples.ParamsToOuts(params="Dolor Sit")
node03 = zntrack.examples.ParamsToOuts(params="Test01")
assert node01.name == "node01"
assert node02.name == "node02"
assert node03.name == "node03"

with project.group("Grp01"):
node01 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
node02 = zntrack.examples.ParamsToOuts(params="Dolor Sit")
grp_node03 = zntrack.examples.ParamsToOuts(params="Test02")

assert node01.name == "Grp01_node01"
assert node02.name == "Grp01_node02"
assert grp_node03.name == "Grp01_grp_node03"

project.run()

zntrack.from_rev(node01.name).outs == "Lorem Ipsum"
zntrack.from_rev(node02.name).outs == "Dolor Sit"
zntrack.from_rev(node03.name).outs == "Test01"
zntrack.from_rev(grp_node03.name).outs == "Test02"
5 changes: 5 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import znflow
import zninit
import znjson
from varname import VarnameException, varname

from zntrack import exceptions
from zntrack.notebooks.jupyter import jupyter_class_to_file
Expand Down Expand Up @@ -161,8 +162,12 @@ def __set__(self, instance, value):
if isinstance(value, NodeName):
if not instance._external_:
value.update_suffix(instance._graph_.project, instance)
with contextlib.suppress(VarnameException):
value.varname = varname(frame=4)
instance._name_ = value
elif isinstance(getattr(instance, "_name_"), NodeName):
with contextlib.suppress(VarnameException):
instance._name_.varname = varname(frame=4)
instance._name_.name = value
instance._name_.suffix = 0
instance._name_.update_suffix(instance._graph_.project, instance)
Expand Down
11 changes: 11 additions & 0 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class Project:
This will require a DVC remote to be setup.
force : bool, default = False
overwrite existing nodes.
magic_names : bool, default = False
If True, use magic names for the nodes. This will use the variable name of the
node as the node name. E.g. `node = Node()` will result in a node name of 'node'.
If used within a group, the group name will be added to the node name. E.g.
`group.name = Grp1` and `model = Node()` will result in a name of 'Grp1_model'.
"""

graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False)
Expand All @@ -90,6 +95,7 @@ class Project:
automatic_node_names: bool = True
git_only_repo: bool = True
force: bool = False
magic_names: bool = False

_groups: dict[str, NodeGroup] = dataclasses.field(
default_factory=dict, init=False, repr=False
Expand All @@ -116,6 +122,11 @@ def __post_init__(self):
config.files.params.unlink(missing_ok=True)
shutil.rmtree("nodes", ignore_errors=True)

if self.automatic_node_names and self.magic_names:
raise ValueError(
"automatic_node_names and magic_names can not be True at the same time"
)

def __enter__(self, *args, **kwargs):
"""Enter the graph context."""
self.graph.__enter__(*args, **kwargs)
Expand Down
12 changes: 10 additions & 2 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,28 +227,36 @@ class NodeName:

groups: list[str]
name: str
varname: str = None
suffix: int = 0
use_varname: bool = False

def __str__(self) -> str:
"""Get the node name."""
name = []
if self.groups is not None:
name.extend(self.groups)
name.append(self.name)
if self.use_varname:
name.append(self.varname)
else:
name.append(self.name)
if self.suffix > 0 and self.use_varname:
raise ValueError("Suffixes are not supported for magic names (varnames).")
if self.suffix > 0:
name.append(str(self.suffix))
return "_".join(name)

def get_name_without_groups(self) -> str:
"""Get the node name without the groups."""
name = self.name
name = self.varname if self.use_varname else self.name
if self.suffix > 0:
name += f"_{self.suffix}"
return name

def update_suffix(self, project: "Project", node: "Node") -> None:
"""Update the suffix."""
node_names = [x["value"].name for x in project.graph.nodes.values()]
self.use_varname = project.magic_names

node_names = []
for node_uuid in project.graph.nodes:
Expand Down

0 comments on commit 6698c4a

Please sign in to comment.