Skip to content

Commit

Permalink
update znflow; use znflow groups
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Sep 21, 2023
1 parent 9ab4d52 commit 9104287
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 195 deletions.
248 changes: 125 additions & 123 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ typer = "^0.7.0"
dot4dict = "^0.1.1"
zninit = "^0.1.9"
znjson = "^0.2.2"
znflow = "^0.1.13"
znflow = "^0.1.14"


[tool.poetry.urls]
Expand Down
19 changes: 12 additions & 7 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def test_group_nodes(tmp_path_2):
assert node_5 in group_3
assert node_6 in group_3

assert group_1.name == "Group1"
assert group_2.name == "Group2"
assert group_3.name == "NamedGrp"

assert node_1.name == "Group1_ParamsToOuts"
assert node_2.name == "Group1_ParamsToOuts_1"
assert node_3.name == "Group2_ParamsToOuts"
Expand Down Expand Up @@ -373,13 +377,14 @@ def test_groups_nwd_zn_nodes(tmp_path_2):
assert node_3.result == "Lorem Ipsum"


def test_test_reopening_groups(proj_path):
with zntrack.Project(automatic_node_names=True) as project:
with project.group("GroupA"):
node_1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
with pytest.raises(ValueError):
with project.group("GroupA"):
node_2 = zntrack.examples.ParamsToOuts(params="Dolor Sit")
# This is allowed now
# def test_test_reopening_groups(proj_path):
# with zntrack.Project(automatic_node_names=True) as project:
# with project.group("GroupA"):
# node_1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
# with pytest.raises(ValueError):
# with project.group("GroupA"):
# node_2 = zntrack.examples.ParamsToOuts(params="Dolor Sit")


# def test_reopening_groups(proj_path):
Expand Down
18 changes: 14 additions & 4 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from zntrack import exceptions
from zntrack.notebooks.jupyter import jupyter_class_to_file
from zntrack.utils import NodeStatusResults, config, file_io, module_handler
from zntrack.utils import NodeName, NodeStatusResults, config, file_io, module_handler

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,13 +104,23 @@ def __get__(self, instance, owner=None):
if instance is None:
return self
if getattr(instance, "_name_") is None:
instance._name_ = instance.__class__.__name__
return getattr(instance, "_name_")
return instance.__class__.__name__
return str(getattr(instance, "_name_"))

def __set__(self, instance, value):
if value is None:
return
instance._name_ = value
if isinstance(value, NodeName):
value.update_suffix(instance._graph_.project, instance)
instance._name_ = value
elif isinstance(getattr(instance, "_name_"), NodeName):
instance._name_.name = value
instance._name_.suffix = 0
instance._name_.update_suffix(instance._graph_.project, instance)
else:
# This should only happen if an instance is loaded.
instance._name_ = value
# name is set from auto_init, can we use in_construction?


class Node(zninit.ZnInit, znflow.Node):
Expand Down
101 changes: 41 additions & 60 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
import git
import yaml
import znflow
from znflow.base import empty, get_graph
from znflow.handler import UpdateConnectors

from zntrack import exceptions
from zntrack.core.node import Node, get_dvc_cmd
from zntrack.utils import config, run_dvc_cmd
from zntrack.utils import NodeName, config, run_dvc_cmd

log = logging.getLogger(__name__)

Expand All @@ -46,14 +45,16 @@ class ZnTrackGraph(znflow.DiGraph):

project: Project = None

def add_node(self, node_for_adding, **attr):
def add_node(self, node_for_adding: Node, **attr):
"""Rename Nodes if required."""
value = super().add_node(node_for_adding, **attr)
self.project.update_node_names(check=False)
# this is called in __new__ and therefore,
# the name might not be set correctly.
# update node names only works, if name is not set.
return value
if self.active_group is not None:
name = NodeName(self.active_group, node_for_adding.name)
else:
name = NodeName(None, node_for_adding.name)

node_for_adding.name = name

super().add_node(node_for_adding, **attr)


@dataclasses.dataclass
Expand Down Expand Up @@ -82,14 +83,16 @@ class Project:
overwrite existing nodes.
"""

graph: znflow.DiGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False)
graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False)
initialize: bool = True
remove_existing_graph: bool = False
automatic_node_names: bool = False
git_only_repo: bool = True
force: bool = False

_groups: list = dataclasses.field(default_factory=list, init=False, repr=False)
_groups: dict[str, NodeGroup] = dataclasses.field(
default_factory=dict, init=False, repr=False
)

def __post_init__(self):
"""Initialize the Project.
Expand Down Expand Up @@ -120,25 +123,16 @@ def __enter__(self, *args, **kwargs):
def __exit__(self, *args, **kwargs):
"""Exit the graph context."""
self.graph.__exit__(*args, **kwargs)
self.update_node_names()

def update_node_names(self, check=True):
"""Update the node names to be unique."""
node_names = []
for node_uuid in self.graph.get_sorted_nodes():
node: Node = self.graph.nodes[node_uuid]["value"]
for node_uuid in self.graph.nodes:
node = self.graph.nodes[node_uuid]["value"]
if node._external_:
continue

if node.name in node_names:
if node._external_:
continue
if self.automatic_node_names:
idx = 1
while f"{node.name}_{idx}" in node_names:
idx += 1
node.name = f"{node.name}_{idx}"
log.debug(f"Updating {node.name = }")

elif not self.force and check:
raise exceptions.DuplicateNodeNameError(node)
raise exceptions.DuplicateNodeNameError(node)

node_names.append(node.name)

@contextlib.contextmanager
Expand All @@ -152,44 +146,30 @@ def group(self, *names: typing.List[str]):
the number of groups + 1. If more than one name is given, the groups will
be nested to 'nwd = name[0]/name[1]/.../name[-1]'
"""
if len(names) == 0:
# names = (f"Group{len(self._groups) + 1}",)
name = "Group1"
while pathlib.Path("nodes", name).exists():
name = f"Group{int(name[5:]) + 1}"
names = (name,)

@contextlib.contextmanager
def _get_group(names):
if len(names) == 0:
# names = (f"Group{len(self._groups) + 1}",)
name = "Group1"
while pathlib.Path("nodes", name).exists():
name = f"Group{int(name[5:]) + 1}"
names = (name,)

try:
grp = self._groups[names]
except KeyError:
nwd = pathlib.Path("nodes", *names)
if any(x.nwd == nwd for x in self._groups):
raise ValueError(f"Group {names} already exists.")

nwd.mkdir(parents=True, exist_ok=True)
grp = NodeGroup(name="_".join(names), nwd=nwd, nodes=[])
self._groups[names] = grp

existing_nodes = self.graph.get_sorted_nodes()

group = NodeGroup(nwd=nwd, nodes=[])
with self.graph.group(names):
yield grp
# TODO: do we even need the group object?
# self.graph.get_group() should be sufficient.
grp.nodes = [self.graph.nodes[x]["value"] for x in self.graph.get_group(names)]
# we need to update the nwd for al lnew nodes

try:
yield group
finally:
for node_uuid in self.graph.get_sorted_nodes():
node: Node = self.graph.nodes[node_uuid]["value"]
if node_uuid not in existing_nodes:
node.__dict__["nwd"] = group.nwd / node.name
node.name = f"{'_'.join(names)}_{node.name}"
group.nodes.append(node)
self._groups.append(group)

if get_graph() is not empty:
with _get_group(names) as group:
yield group
else:
with self:
with _get_group(names) as group:
yield group
for node in grp.nodes:
node.__dict__["nwd"] = grp.nwd / node._name_.get_name_without_groups()

def run(
self,
Expand Down Expand Up @@ -448,6 +428,7 @@ def queue(self, name: str):
class NodeGroup:
"""A group of nodes."""

name: tuple[str]
nwd: pathlib.Path
nodes: list[Node]

Expand Down
45 changes: 45 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Standard python init file for the utils directory."""
import dataclasses
import enum
import logging
import os
import pathlib
import shutil
import sys
import tempfile
import typing as t

import dvc.cli

Expand All @@ -18,6 +20,9 @@
"config",
]

if t.TYPE_CHECKING:
from zntrack import Node, Project


class LazyOption:
"""Indicates that the value of the field should is loaded lazily."""
Expand Down Expand Up @@ -204,3 +209,43 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory:
os.chdir(temp_dir.name)

return temp_dir


@dataclasses.dataclass
class NodeName:
"""The name of a node."""

groups: list[str]
name: str
suffix: int = 0

def __str__(self) -> str:
"""Get the node name."""
name = []
if self.groups is not None:
name.extend(self.groups)
name.append(self.name)
if self.suffix > 0:
name.append(str(self.suffix))
return "_".join(name)

def get_name_without_groups(self) -> str:
"""Get the node name without the groups."""
name = self.name
if self.suffix > 0:
name += str(self.suffix)
return name

def update_suffix(self, project: "Project", node: "Node") -> None:
"""Update the suffix."""
node_names = [x["value"].name for x in project.graph.nodes.values()]

node_names = []
for node_uuid in project.graph.nodes:
if node_uuid == node.uuid:
continue
node_names.append(project.graph.nodes[node_uuid]["value"].name)

if project.automatic_node_names:
while str(self) in node_names:
self.suffix += 1

0 comments on commit 9104287

Please sign in to comment.