Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking Update Add UUID output to every Node #647

Merged
merged 11 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/docs/02_Inheritance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -433,10 +434,7 @@
"You might still want to use the other Nodes to avoid overhead though.\n",
"\n",
"In the following we will use the run method of a `NodeBase` Node and also have a dataclass Node just for storing parameters.\n",
"Internally, ZnTrack disables all outputs of the given Node.\n",
"To keep the DAG working, a `_hash = zn.Hash()` is introduced.\n",
"This value is computed from the parameters as well as the current timestamp and only serves as a file dependency for DVC.\n",
"Adding `zn.Hash()` to any Node will add an output file but won't have any additional effect."
"Internally, ZnTrack disables all outputs of the given Node except for a UUID file."
]
},
{
Expand Down
414 changes: 207 additions & 207 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8,<4.0.0"
dvc = "^2.52.0, !=2.53.0"
dvc = "^3.2.2"
pyyaml = "^6.0"
tqdm = "^4.64.0"
pandas = "^1.4.3"
Expand Down
41 changes: 9 additions & 32 deletions tests/integration/test_node_nwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,21 @@ def run(self):
self.file[0].write_text(self.text)


class OutsAsNWD(zntrack.Node):
text = zntrack.zn.params()
outs: pathlib.Path = zntrack.dvc.outs(zntrack.nwd)

def run(self):
(self.outs / "test.txt").write_text(self.text)

@property
def file(self):
return self.outs / "test.txt"


class FileToOuts(zntrack.Node):
# although, this is a file path, it has to be zn.deps
file = zntrack.zn.deps()
text = zntrack.zn.outs()

def run(self):
with open(self.file, "r") as f:
with open(self.file[0], "r") as f:
self.text = f.read()


@pytest.mark.parametrize("eager", [True, False])
def test_WriteToNWD(proj_path, eager):
with zntrack.Project() as project:
write_to_nwd = WriteToNWD(text="Hello World")
file_to_outs = FileToOuts(file=write_to_nwd.file)

project.run(eager=eager)
assert write_to_nwd.file[0].read_text() == "Hello World"
Expand All @@ -48,25 +37,13 @@ def test_WriteToNWD(proj_path, eager):
write_to_nwd.load()
assert write_to_nwd.__dict__["file"] == [pathlib.Path("$nwd$", "test.txt")]

file_to_outs.load()
assert file_to_outs.text == "Hello World"

@pytest.mark.parametrize("eager", [True, False])
def test_OutAsNWD(proj_path, eager):
with zntrack.Project() as project:
outs_as_nwd = OutsAsNWD(text="Hello World")

project.run(eager=eager)
assert (outs_as_nwd.outs / "test.txt").read_text() == "Hello World"
assert outs_as_nwd.outs == pathlib.Path("nodes", "OutsAsNWD")
if not eager:
outs_as_nwd.load()
assert outs_as_nwd.__dict__["outs"] == zntrack.nwd

def test_OutAsNWD(proj_path):
with pytest.raises(ValueError):

def test_FileToOuts(proj_path):
with zntrack.Project() as project:
write_to_nwd = OutsAsNWD(text="Hello World")
file_to_outs = FileToOuts(file=write_to_nwd.file)

project.run()
file_to_outs.load()
assert file_to_outs.text == "Hello World"
class OutsAsNWD(zntrack.Node):
text = zntrack.zn.params()
outs: pathlib.Path = zntrack.dvc.outs(zntrack.nwd)
32 changes: 32 additions & 0 deletions tests/unit_tests/test_zn_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pathlib

import yaml

import zntrack


class NodeWithOuts(zntrack.Node):
def run(self):
pass


class DependentNode(zntrack.Node):
deps = zntrack.zn.deps()

def run(self):
pass


def test_DependentNode(proj_path):
with zntrack.Project() as proj:
a = NodeWithOuts()
b = DependentNode(deps=a)

proj.run(repro=False)

dvc_yaml = yaml.safe_load((proj_path / "dvc.yaml").read_text())
assert (
dvc_yaml["stages"]["DependentNode"]["cmd"]
== "zntrack run test_zn_deps.DependentNode --name DependentNode"
)
assert dvc_yaml["stages"]["DependentNode"]["deps"] == ["nodes/NodeWithOuts/uuid"]
8 changes: 3 additions & 5 deletions zntrack/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import pathlib
import sys
import uuid

import git
import typer
Expand Down Expand Up @@ -47,7 +46,7 @@ def main(


@app.command()
def run(node: str, name: str = None, hash_only: bool = False) -> None:
def run(node: str, name: str = None, uuid_only: bool = False) -> None:
"""Execute a ZnTrack Node.
Use as 'zntrack run module.Node --name node_name'.
Expand Down Expand Up @@ -78,11 +77,10 @@ def run(node: str, name: str = None, hash_only: bool = False) -> None:
cls(exec_func=True)
elif issubclass(cls, Node):
node: Node = cls.from_rev(name=name, results=False)
if hash_only:
(node.nwd / "hash").write_text(str(uuid.uuid4()))
else:
if not uuid_only:
node.run()
node.save(parameter=False)
node.save(uuid_only=True)
else:
raise ValueError(f"Node {node} is not a ZnTrack Node.")

Expand Down
11 changes: 10 additions & 1 deletion zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,14 @@ def nwd(self) -> pathlib.Path:
nwd.mkdir(parents=True)
return nwd

def save(self, parameter: bool = True, results: bool = True) -> None:
def save(
self, parameter: bool = True, results: bool = True, uuid_only: bool = False
) -> None:
"""Save the node's output to disk."""
if uuid_only:
(self.nwd / "uuid").write_text(str(self.uuid))
return

# TODO have an option to save and run dvc commit afterwards.

# TODO: check if there is a difference in saving
Expand Down Expand Up @@ -298,9 +304,12 @@ def get_dvc_cmd(
for attr in zninit.get_descriptors(Field, self=node):
field_cmds += attr.get_stage_add_argument(node)
optionals += attr.get_optional_dvc_cmd(node)

for field_cmd in set(field_cmds):
cmd += list(field_cmd)

cmd += ["--outs", f"nodes/{node.name}/uuid"]

module = module_handler(node.__class__)
cmd += [f"zntrack run {module}.{node.__class__.__name__} --name {node.name}"]
optionals = [x for x in optionals if x] # remove empty entries []
Expand Down
6 changes: 6 additions & 0 deletions zntrack/fields/dvc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class DVCOption(Field):

def __init__(self, *args, **kwargs):
"""Create a DVCOption field."""
if node_wd.nwd in args or node_wd.nwd in kwargs.values():
raise ValueError(
"Can not set `zntrack.nwd` as value for {self}. Please use"
" `zntrack.nwd/...` to create a path relative to the node working"
" directory."
)
self.dvc_option = kwargs.pop("dvc_option")
super().__init__(*args, **kwargs)

Expand Down
33 changes: 25 additions & 8 deletions zntrack/fields/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,22 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
_default = object()


def _get_all_connections_and_instances(value) -> list["Node"]:
"""Get Nodes from Connections and CombinedConnections."""
connections = []
stack = [value]
while stack:
node = stack.pop()
if isinstance(node, znflow.CombinedConnections):
stack.extend(node.connections)
elif isinstance(node, znflow.Connection):
instance = node.instance
while isinstance(instance, znflow.Connection):
instance = instance.instance
connections.append(instance)
return connections


class Dependency(LazyField):
"""A dependency field."""

Expand Down Expand Up @@ -317,16 +333,17 @@ def get_files(self, instance) -> list:

others = []
for node in value:
if isinstance(node, znflow.CombinedConnections):
others.extend(node.connections)
if isinstance(node, (znflow.CombinedConnections, znflow.Connection)):
others.extend(_get_all_connections_and_instances(node))
else:
others.append(node)

value.extend(others)
value = others

for node in value:
if node is None:
continue
if isinstance(node, znflow.Connection):
node = node.instance
files.append(node.nwd / "uuid")
for field in zninit.get_descriptors(Field, self=node):
if field.dvc_option in ["params", "deps"]:
# We do not want to depend on parameter files or
Expand Down Expand Up @@ -467,7 +484,7 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
name,
"--force",
"--outs",
f"nodes/{name}/hash",
f"nodes/{name}/uuid",
"--params",
f"zntrack.json:{instance.name}.{self.name}",
]
Expand All @@ -480,7 +497,7 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:

_cmd += [
f"zntrack run {module}.{node.__class__.__name__} --name"
f" {name} --hash-only"
f" {name} --uuid-only"
]

cmd.append(_cmd)
Expand All @@ -490,7 +507,7 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
def get_files(self, instance: "Node") -> list:
"""Get the files affected by this field."""
return [
pathlib.Path(f"nodes/{name}/hash") for name in self.get_node_names(instance)
pathlib.Path(f"nodes/{name}/uuid") for name in self.get_node_names(instance)
]


Expand Down