Skip to content

Commit

Permalink
Breaking Update Add UUID output to every Node (#647)
Browse files Browse the repository at this point in the history
* rename hash to UUID

* UUID outs for every node

* update cmd

* do not allow NWD as outs

* test error message

* remove commented out code

* working mess

* clean up a bit

* test correct dvc.yaml deps

* poetry update

* use uuid property
  • Loading branch information
PythonFZ authored Jun 28, 2023
1 parent e4479cb commit aa92f6c
Show file tree
Hide file tree
Showing 9 changed files with 295 additions and 258 deletions.
6 changes: 2 additions & 4 deletions examples/docs/02_Inheritance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -433,10 +434,7 @@
"You might still want to use the other Nodes to avoid overhead though.\n",
"\n",
"In the following we will use the run method of a `NodeBase` Node and also have a dataclass Node just for storing parameters.\n",
"Internally, ZnTrack disables all outputs of the given Node.\n",
"To keep the DAG working, a `_hash = zn.Hash()` is introduced.\n",
"This value is computed from the parameters as well as the current timestamp and only serves as a file dependency for DVC.\n",
"Adding `zn.Hash()` to any Node will add an output file but won't have any additional effect."
"Internally, ZnTrack disables all outputs of the given Node except for a UUID file."
]
},
{
Expand Down
414 changes: 207 additions & 207 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8,<4.0.0"
dvc = "^2.52.0, !=2.53.0"
dvc = "^3.2.2"
pyyaml = "^6.0"
tqdm = "^4.64.0"
pandas = "^1.4.3"
Expand Down
41 changes: 9 additions & 32 deletions tests/integration/test_node_nwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,21 @@ def run(self):
self.file[0].write_text(self.text)


class OutsAsNWD(zntrack.Node):
text = zntrack.zn.params()
outs: pathlib.Path = zntrack.dvc.outs(zntrack.nwd)

def run(self):
(self.outs / "test.txt").write_text(self.text)

@property
def file(self):
return self.outs / "test.txt"


class FileToOuts(zntrack.Node):
# although, this is a file path, it has to be zn.deps
file = zntrack.zn.deps()
text = zntrack.zn.outs()

def run(self):
with open(self.file, "r") as f:
with open(self.file[0], "r") as f:
self.text = f.read()


@pytest.mark.parametrize("eager", [True, False])
def test_WriteToNWD(proj_path, eager):
with zntrack.Project() as project:
write_to_nwd = WriteToNWD(text="Hello World")
file_to_outs = FileToOuts(file=write_to_nwd.file)

project.run(eager=eager)
assert write_to_nwd.file[0].read_text() == "Hello World"
Expand All @@ -48,25 +37,13 @@ def test_WriteToNWD(proj_path, eager):
write_to_nwd.load()
assert write_to_nwd.__dict__["file"] == [pathlib.Path("$nwd$", "test.txt")]

file_to_outs.load()
assert file_to_outs.text == "Hello World"

@pytest.mark.parametrize("eager", [True, False])
def test_OutAsNWD(proj_path, eager):
with zntrack.Project() as project:
outs_as_nwd = OutsAsNWD(text="Hello World")

project.run(eager=eager)
assert (outs_as_nwd.outs / "test.txt").read_text() == "Hello World"
assert outs_as_nwd.outs == pathlib.Path("nodes", "OutsAsNWD")
if not eager:
outs_as_nwd.load()
assert outs_as_nwd.__dict__["outs"] == zntrack.nwd

def test_OutAsNWD(proj_path):
with pytest.raises(ValueError):

def test_FileToOuts(proj_path):
with zntrack.Project() as project:
write_to_nwd = OutsAsNWD(text="Hello World")
file_to_outs = FileToOuts(file=write_to_nwd.file)

project.run()
file_to_outs.load()
assert file_to_outs.text == "Hello World"
class OutsAsNWD(zntrack.Node):
text = zntrack.zn.params()
outs: pathlib.Path = zntrack.dvc.outs(zntrack.nwd)
32 changes: 32 additions & 0 deletions tests/unit_tests/test_zn_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pathlib

import yaml

import zntrack


class NodeWithOuts(zntrack.Node):
def run(self):
pass


class DependentNode(zntrack.Node):
deps = zntrack.zn.deps()

def run(self):
pass


def test_DependentNode(proj_path):
with zntrack.Project() as proj:
a = NodeWithOuts()
b = DependentNode(deps=a)

proj.run(repro=False)

dvc_yaml = yaml.safe_load((proj_path / "dvc.yaml").read_text())
assert (
dvc_yaml["stages"]["DependentNode"]["cmd"]
== "zntrack run test_zn_deps.DependentNode --name DependentNode"
)
assert dvc_yaml["stages"]["DependentNode"]["deps"] == ["nodes/NodeWithOuts/uuid"]
8 changes: 3 additions & 5 deletions zntrack/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import pathlib
import sys
import uuid

import git
import typer
Expand Down Expand Up @@ -47,7 +46,7 @@ def main(


@app.command()
def run(node: str, name: str = None, hash_only: bool = False) -> None:
def run(node: str, name: str = None, uuid_only: bool = False) -> None:
"""Execute a ZnTrack Node.
Use as 'zntrack run module.Node --name node_name'.
Expand Down Expand Up @@ -78,11 +77,10 @@ def run(node: str, name: str = None, hash_only: bool = False) -> None:
cls(exec_func=True)
elif issubclass(cls, Node):
node: Node = cls.from_rev(name=name, results=False)
if hash_only:
(node.nwd / "hash").write_text(str(uuid.uuid4()))
else:
if not uuid_only:
node.run()
node.save(parameter=False)
node.save(uuid_only=True)
else:
raise ValueError(f"Node {node} is not a ZnTrack Node.")

Expand Down
11 changes: 10 additions & 1 deletion zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,14 @@ def nwd(self) -> pathlib.Path:
nwd.mkdir(parents=True)
return nwd

def save(self, parameter: bool = True, results: bool = True) -> None:
def save(
self, parameter: bool = True, results: bool = True, uuid_only: bool = False
) -> None:
"""Save the node's output to disk."""
if uuid_only:
(self.nwd / "uuid").write_text(str(self.uuid))
return

# TODO have an option to save and run dvc commit afterwards.

# TODO: check if there is a difference in saving
Expand Down Expand Up @@ -298,9 +304,12 @@ def get_dvc_cmd(
for attr in zninit.get_descriptors(Field, self=node):
field_cmds += attr.get_stage_add_argument(node)
optionals += attr.get_optional_dvc_cmd(node)

for field_cmd in set(field_cmds):
cmd += list(field_cmd)

cmd += ["--outs", f"nodes/{node.name}/uuid"]

module = module_handler(node.__class__)
cmd += [f"zntrack run {module}.{node.__class__.__name__} --name {node.name}"]
optionals = [x for x in optionals if x] # remove empty entries []
Expand Down
6 changes: 6 additions & 0 deletions zntrack/fields/dvc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class DVCOption(Field):

def __init__(self, *args, **kwargs):
"""Create a DVCOption field."""
if node_wd.nwd in args or node_wd.nwd in kwargs.values():
raise ValueError(
"Can not set `zntrack.nwd` as value for {self}. Please use"
" `zntrack.nwd/...` to create a path relative to the node working"
" directory."
)
self.dvc_option = kwargs.pop("dvc_option")
super().__init__(*args, **kwargs)

Expand Down
33 changes: 25 additions & 8 deletions zntrack/fields/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,22 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
_default = object()


def _get_all_connections_and_instances(value) -> list["Node"]:
"""Get Nodes from Connections and CombinedConnections."""
connections = []
stack = [value]
while stack:
node = stack.pop()
if isinstance(node, znflow.CombinedConnections):
stack.extend(node.connections)
elif isinstance(node, znflow.Connection):
instance = node.instance
while isinstance(instance, znflow.Connection):
instance = instance.instance
connections.append(instance)
return connections


class Dependency(LazyField):
"""A dependency field."""

Expand Down Expand Up @@ -317,16 +333,17 @@ def get_files(self, instance) -> list:

others = []
for node in value:
if isinstance(node, znflow.CombinedConnections):
others.extend(node.connections)
if isinstance(node, (znflow.CombinedConnections, znflow.Connection)):
others.extend(_get_all_connections_and_instances(node))
else:
others.append(node)

value.extend(others)
value = others

for node in value:
if node is None:
continue
if isinstance(node, znflow.Connection):
node = node.instance
files.append(node.nwd / "uuid")
for field in zninit.get_descriptors(Field, self=node):
if field.dvc_option in ["params", "deps"]:
# We do not want to depend on parameter files or
Expand Down Expand Up @@ -467,7 +484,7 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
name,
"--force",
"--outs",
f"nodes/{name}/hash",
f"nodes/{name}/uuid",
"--params",
f"zntrack.json:{instance.name}.{self.name}",
]
Expand All @@ -480,7 +497,7 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:

_cmd += [
f"zntrack run {module}.{node.__class__.__name__} --name"
f" {name} --hash-only"
f" {name} --uuid-only"
]

cmd.append(_cmd)
Expand All @@ -490,7 +507,7 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
def get_files(self, instance: "Node") -> list:
"""Get the files affected by this field."""
return [
pathlib.Path(f"nodes/{name}/hash") for name in self.get_node_names(instance)
pathlib.Path(f"nodes/{name}/uuid") for name in self.get_node_names(instance)
]


Expand Down

0 comments on commit aa92f6c

Please sign in to comment.