diff --git a/poetry.lock b/poetry.lock index c5aa3fdb..777a55a6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5318,6 +5318,23 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "varname" +version = "0.13.0" +description = "Dark magics about variable names in python." +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "varname-0.13.0-py3-none-any.whl", hash = "sha256:d41101d69d92d167e15f2997039c0a52a73ff0c8e02256c3c8b549a11fc4b176"}, + {file = "varname-0.13.0.tar.gz", hash = "sha256:64e9052029fd4d49686ac6443a9ed182c2c149686e42acef69ebaa3a27811beb"}, +] + +[package.dependencies] +executing = ">=2.0,<3.0" + +[package.extras] +all = ["asttokens (==2.*)", "pure_eval (==0.*)"] + [[package]] name = "vine" version = "5.1.0" @@ -5696,4 +5713,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0.0" -content-hash = "47eb0b34d17c7c715566aa4626beb5a1a3e4d6f16a55f14448dd8d3b1c7e4393" +content-hash = "cd187c08029a8632406528c1ac370397729406979b98390499726e26ff411c46" diff --git a/pyproject.toml b/pyproject.toml index 066b5e72..6f57c26e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dot4dict = "^0.1" zninit = "^0.1" znjson = "^0.2" znflow = "^0.1" +varname = "^0.13" # for Python3.12 compatibliity pyzmq = "^25" diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index 65549c34..6238790c 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -552,3 +552,35 @@ def test_auto_remove(proj_path): n1 = zntrack.examples.ParamsToOuts.from_rev(n1.name) with pytest.raises(zntrack.exceptions.NodeNotAvailableError): n2 = zntrack.examples.ParamsToOuts.from_rev(n2.name) + + +def test_magic_names(proj_path): + node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") + assert node.name == "ParamsToOuts" + with pytest.raises(ValueError): + project = zntrack.Project(magic_names=True, automatic_node_names=True) + + project = zntrack.Project(magic_names=True, automatic_node_names=False) + with project: + node01 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") + node02 = zntrack.examples.ParamsToOuts(params="Dolor Sit") + node03 = zntrack.examples.ParamsToOuts(params="Test01") + assert node01.name == "node01" + assert node02.name == "node02" + assert node03.name == "node03" + + with project.group("Grp01"): + node01 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") + node02 = zntrack.examples.ParamsToOuts(params="Dolor Sit") + grp_node03 = zntrack.examples.ParamsToOuts(params="Test02") + + assert node01.name == "Grp01_node01" + assert node02.name == "Grp01_node02" + assert grp_node03.name == "Grp01_grp_node03" + + project.run() + + zntrack.from_rev(node01.name).outs == "Lorem Ipsum" + zntrack.from_rev(node02.name).outs == "Dolor Sit" + zntrack.from_rev(node03.name).outs == "Test01" + zntrack.from_rev(grp_node03.name).outs == "Test02" diff --git a/zntrack/core/node.py b/zntrack/core/node.py index f382ece9..814e8ff7 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -21,6 +21,7 @@ import znflow import zninit import znjson +from varname import VarnameException, varname from zntrack import exceptions from zntrack.notebooks.jupyter import jupyter_class_to_file @@ -161,8 +162,12 @@ def __set__(self, instance, value): if isinstance(value, NodeName): if not instance._external_: value.update_suffix(instance._graph_.project, instance) + with contextlib.suppress(VarnameException): + value.varname = varname(frame=4) instance._name_ = value elif isinstance(getattr(instance, "_name_"), NodeName): + with contextlib.suppress(VarnameException): + instance._name_.varname = varname(frame=4) instance._name_.name = value instance._name_.suffix = 0 instance._name_.update_suffix(instance._graph_.project, instance) diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index e1ea5c29..bfd824d8 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -82,6 +82,11 @@ class Project: This will require a DVC remote to be setup. force : bool, default = False overwrite existing nodes. + magic_names : bool, default = False + If True, use magic names for the nodes. This will use the variable name of the + node as the node name. E.g. `node = Node()` will result in a node name of 'node'. + If used within a group, the group name will be added to the node name. E.g. + `group.name = Grp1` and `model = Node()` will result in a name of 'Grp1_model'. """ graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False) @@ -90,6 +95,7 @@ class Project: automatic_node_names: bool = True git_only_repo: bool = True force: bool = False + magic_names: bool = False _groups: dict[str, NodeGroup] = dataclasses.field( default_factory=dict, init=False, repr=False @@ -116,6 +122,11 @@ def __post_init__(self): config.files.params.unlink(missing_ok=True) shutil.rmtree("nodes", ignore_errors=True) + if self.automatic_node_names and self.magic_names: + raise ValueError( + "automatic_node_names and magic_names can not be True at the same time" + ) + def __enter__(self, *args, **kwargs): """Enter the graph context.""" self.graph.__enter__(*args, **kwargs) diff --git a/zntrack/utils/__init__.py b/zntrack/utils/__init__.py index 62b77886..4bb823be 100644 --- a/zntrack/utils/__init__.py +++ b/zntrack/utils/__init__.py @@ -227,21 +227,28 @@ class NodeName: groups: list[str] name: str + varname: str = None suffix: int = 0 + use_varname: bool = False def __str__(self) -> str: """Get the node name.""" name = [] if self.groups is not None: name.extend(self.groups) - name.append(self.name) + if self.use_varname: + name.append(self.varname) + else: + name.append(self.name) + if self.suffix > 0 and self.use_varname: + raise ValueError("Suffixes are not supported for magic names (varnames).") if self.suffix > 0: name.append(str(self.suffix)) return "_".join(name) def get_name_without_groups(self) -> str: """Get the node name without the groups.""" - name = self.name + name = self.varname if self.use_varname else self.name if self.suffix > 0: name += f"_{self.suffix}" return name @@ -249,6 +256,7 @@ def get_name_without_groups(self) -> str: def update_suffix(self, project: "Project", node: "Node") -> None: """Update the suffix.""" node_names = [x["value"].name for x in project.graph.nodes.values()] + self.use_varname = project.magic_names node_names = [] for node_uuid in project.graph.nodes: