Skip to content

Commit

Permalink
Add with project.group() (#642)
Browse files Browse the repository at this point in the history
* add context manager for group

* update docstring

* update tests

* remove graph reference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Jun 15, 2023
1 parent a8329ee commit 9d5c55f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,36 @@ def test_automatic_node_names_True(tmp_path_2):
assert node.outputs == "Hello World"
assert node2.outputs == "Lorem Ipsum"
assert node3.outputs == "Dolor Sit"


def test_group_nodes(tmp_path_2):
with zntrack.Project(automatic_node_names=True) as project:
with project.group():
node_1 = WriteIO(inputs="Lorem Ipsum")
node_2 = WriteIO(inputs="Dolor Sit")
with project.group():
node_3 = WriteIO(inputs="Amet Consectetur")
node_4 = WriteIO(inputs="Adipiscing Elit")
with project.group(name="NamedGrp"):
node_5 = WriteIO(inputs="Sed Do", name="NodeA")
node_6 = WriteIO(inputs="Eiusmod Tempor", name="NodeB")

node7 = WriteIO(inputs="Hello World")
node8 = WriteIO(inputs="How are you?")
node9 = WriteIO(inputs="I'm fine, thanks!", name="NodeC")

project.run()

assert node_1.name == "Group1_WriteIO"
assert node_2.name == "Group1_WriteIO_1"
assert node_3.name == "Group2_WriteIO"
assert node_4.name == "Group2_WriteIO_1"

assert node_5.name == "NamedGrp_NodeA"
assert node_6.name == "NamedGrp_NodeB"

assert node7.name == "WriteIO"
assert node8.name == "WriteIO_1"
assert node9.name == "NodeC"

assert WriteIO.from_rev(name="NamedGrp_NodeA").inputs == "Sed Do"
25 changes: 25 additions & 0 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class Project:
automatic_node_names: bool = False
force: bool = False

_groups: list = dataclasses.field(default_factory=list, init=False, repr=False)

def __post_init__(self):
"""Initialize the Project.
Expand Down Expand Up @@ -128,6 +130,29 @@ def update_node_names(self, check=True):
raise exceptions.DuplicateNodeNameError(node)
node_names.append(node.name)

@contextlib.contextmanager
def group(self, name: str = None):
"""Group nodes together.
Parameters
----------
name : str, optional
The name of the group. If None, the group will be named 'GroupX' where X is
the number of groups + 1.
"""
if name is None:
name = f"Group{len(self._groups) + 1}"
self._groups.append(name)

existing_nodes = self.graph.get_sorted_nodes()
try:
yield
finally:
for node_uuid in self.graph.get_sorted_nodes():
node: Node = self.graph.nodes[node_uuid]["value"]
if node_uuid not in existing_nodes:
node.name = f"{name}_{node.name}"

def run(
self,
eager=False,
Expand Down

0 comments on commit 9d5c55f

Please sign in to comment.