Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce build(nodes=list) #654

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ def test_automatic_node_names_True(tmp_path_2):

def test_group_nodes(tmp_path_2):
with zntrack.Project(automatic_node_names=True) as project:
with project.group():
with project.group() as group_1:
node_1 = WriteIO(inputs="Lorem Ipsum")
node_2 = WriteIO(inputs="Dolor Sit")
with project.group():
with project.group() as group_2:
node_3 = WriteIO(inputs="Amet Consectetur")
node_4 = WriteIO(inputs="Adipiscing Elit")
with project.group(name="NamedGrp"):
with project.group(name="NamedGrp") as group_3:
node_5 = WriteIO(inputs="Sed Do", name="NodeA")
node_6 = WriteIO(inputs="Eiusmod Tempor", name="NodeB")

Expand All @@ -189,6 +189,18 @@ def test_group_nodes(tmp_path_2):

project.run()

assert node_1 in group_1
assert node_2 in group_1
assert node_3 not in group_1
assert node_4 not in group_1
assert len(group_1) == 2
assert group_1.name == "Group1"

assert node_3 in group_2
assert node_4 in group_2
assert node_5 in group_3
assert node_6 in group_3

assert node_1.name == "Group1_WriteIO"
assert node_2.name == "Group1_WriteIO_1"
assert node_3.name == "Group2_WriteIO"
Expand All @@ -202,3 +214,65 @@ def test_group_nodes(tmp_path_2):
assert node9.name == "NodeC"

assert WriteIO.from_rev(name="NamedGrp_NodeA").inputs == "Sed Do"


def test_build_certain_nodes(tmp_path_2):
# TODO support passing groups to project.build
with zntrack.Project(automatic_node_names=True) as project:
node_1 = WriteIO(inputs="Lorem Ipsum")
node_2 = WriteIO(inputs="Dolor Sit")
project.build(nodes=[node_1, node_2])
project.repro()

assert zntrack.from_rev(node_1).outputs == "Lorem Ipsum"
assert zntrack.from_rev(node_2).outputs == "Dolor Sit"

node_1.inputs = "ABC"
node_2.inputs = "DEF"

project.build(nodes=[node_1])
project.repro()

assert zntrack.from_rev(node_1).outputs == "ABC"
assert zntrack.from_rev(node_2).outputs == "Dolor Sit"

project.run(nodes=[node_2])

assert zntrack.from_rev(node_1).outputs == "ABC"
assert zntrack.from_rev(node_2).outputs == "DEF"


def test_build_groups(tmp_path_2):
with zntrack.Project(automatic_node_names=True) as project:
with project.group() as group_1:
node_1 = WriteIO(inputs="Lorem Ipsum")
node_2 = WriteIO(inputs="Dolor Sit")
with project.group() as group_2:
node_3 = WriteIO(inputs="Amet Consectetur")
node_4 = WriteIO(inputs="Adipiscing Elit")

project.run(nodes=[group_1])

assert zntrack.from_rev(node_1).outputs == "Lorem Ipsum"
assert zntrack.from_rev(node_2).outputs == "Dolor Sit"

with pytest.raises(ValueError):
zntrack.from_rev(node_3)
with pytest.raises(ValueError):
zntrack.from_rev(node_4)

node_2.inputs = "DEF"

project.run(nodes=[group_2, node_2])

assert zntrack.from_rev(node_1).outputs == "Lorem Ipsum"

assert zntrack.from_rev(node_2).outputs == "DEF"
assert zntrack.from_rev(node_3).outputs == "Amet Consectetur"
assert zntrack.from_rev(node_4).outputs == "Adipiscing Elit"

with pytest.raises(TypeError):
project.run(nodes=42)

with pytest.raises(ValueError):
project.run(nodes=[42])
4 changes: 3 additions & 1 deletion zntrack/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T:

Parameters
----------
name : str
name : str|Node
The name of the node.
remote : str, optional
The remote to load the node from. Defaults to workspace.
Expand All @@ -91,6 +91,8 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T:
Node
The loaded node.
"""
if isinstance(name, Node):
name = name.name
stage = _get_stage(name, remote, rev)

cmd = stage.cmd
Expand Down
46 changes: 43 additions & 3 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,17 @@ def group(self, name: str = None):
self._groups.append(name)

existing_nodes = self.graph.get_sorted_nodes()

group = NodeGroup(name=name, nodes=[])

try:
yield
yield group
finally:
for node_uuid in self.graph.get_sorted_nodes():
node: Node = self.graph.nodes[node_uuid]["value"]
if node_uuid not in existing_nodes:
node.name = f"{name}_{node.name}"
group.nodes.append(node)

def run(
self,
Expand All @@ -163,6 +167,7 @@ def run(
optional: dict = None,
save: bool = True,
environment: dict = None,
nodes: list = None,
):
"""Run the Project Graph.

Expand All @@ -182,6 +187,8 @@ def run(
Possible arg_names are e.g. 'always_changed: True'
environment : dict, default = None
A dictionary of environment variables for all nodes.
nodes : list, default = None
A list of node names to run. If None, run all nodes.
"""
if not save and not eager:
raise ValueError("Save can only be false if eager is True")
Expand All @@ -191,8 +198,23 @@ def run(
if optional is None:
optional = {}

node_names = None
if nodes is not None:
node_names = []
for node in nodes:
if isinstance(node, str):
node_names.append(node)
elif isinstance(node, Node):
node_names.append(node.name)
elif isinstance(node, NodeGroup):
node_names.extend([x.name for x in node.nodes])
else:
raise ValueError(f"Unknown node type {type(node)}")

for node_uuid in self.graph.get_sorted_nodes():
node: Node = self.graph.nodes[node_uuid]["value"]
if node_names is not None and node.name not in node_names:
continue
if node._external_:
continue
if eager:
Expand All @@ -213,9 +235,11 @@ def run(
self.repro()
# TODO should we load the nodes here? Maybe, if lazy loading is implemented.

def build(self, environment: dict = None, optional: dict = None) -> None:
def build(
self, environment: dict = None, optional: dict = None, nodes: list = None
) -> None:
"""Build the project graph without running it."""
self.run(repro=False, environment=environment, optional=optional)
self.run(repro=False, environment=environment, optional=optional, nodes=nodes)

def repro(self) -> None:
"""Run dvc repro."""
Expand Down Expand Up @@ -385,3 +409,19 @@ def queue(self, name: str):
node = self.graph.nodes[node_uuid]["value"]
node.state.rev = name
active_branch.checkout()


@dataclasses.dataclass
class NodeGroup:
"""A group of nodes."""

name: str
nodes: list[Node]

def __contains__(self, item: Node) -> bool:
"""Check if the Node is in the group."""
return item in self.nodes

def __len__(self) -> int:
"""Get the number of nodes in the group."""
return len(self.nodes)