Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine zn.deps and zn.nodes to zntrack.deps #719

Merged
merged 19 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/integration/test_zntrack_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Tests for 'zntrack.deps'-field which can be used as both `zntrack.zn.deps` and `zntrack.zn.nodes`."""

import zntrack.examples

# TODO: change the parameters, rerun and see if it updates!
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved


def test_as_deps(proj_path):
"""Test for 'zntrack.deps' acting as `zntrack.zn.deps`-like field."""
project = zntrack.Project(automatic_node_names=True)

with project:
a = zntrack.examples.ComputeRandomNumber(params_file="a.json")
b = zntrack.examples.ComputeRandomNumber(params_file="b.json")
c = zntrack.examples.SumRandomNumbers([a, b])

a.write_params(min=1, max=5, seed=42)
b.write_params(min=5, max=10, seed=42)

project.run()

a.load()
b.load()
c.load()

assert a.number == 1
assert b.number == 10
assert c.result == 11


def test_as_nodes(proj_path):
"""Test for 'zntrack.deps' acting as `zntrack.zn.nodes`-like field."""
project = zntrack.Project(automatic_node_names=True)

a = zntrack.examples.ComputeRandomNumber(params_file="a.json")
b = zntrack.examples.ComputeRandomNumber(params_file="b.json")

with project:
c = zntrack.examples.SumRandomNumbers([a, b])

a.write_params(min=1, max=5, seed=42)
b.write_params(min=5, max=10, seed=42)

project.run()

# TODO: good error messages when someone tries to load a node that is not on the graph
# a.load()
# b.load()
# assert a.number == 1
# assert b.number == 10

c.load()
assert c.result == 11


def test_mixed(proj_path):
project = zntrack.Project(automatic_node_names=True)

a = zntrack.examples.ComputeRandomNumber(params_file="a.json")

with project:
b = zntrack.examples.ComputeRandomNumber(params_file="b.json")
c = zntrack.examples.SumRandomNumbers([a, b])

a.write_params(min=1, max=5, seed=42)
b.write_params(min=5, max=10, seed=42)

project.run()

b.load()
c.load()

assert b.number == 10
assert c.result == 11
44 changes: 36 additions & 8 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import importlib
import importlib.util
import json
import pathlib
import sys
import tempfile
Expand All @@ -14,6 +15,7 @@
import dvc.stage

from zntrack.core.node import Node
from zntrack.utils import config

T = typing.TypeVar("T", bound=Node)

Expand Down Expand Up @@ -93,14 +95,40 @@
"""
if isinstance(name, Node):
name = name.name
stage = _get_stage(name, remote, rev)

cmd = stage.cmd
run_str = cmd.split()[2]
name = cmd.split()[4]

package_and_module, cls_name = run_str.rsplit(".", 1)
module = None
if "+" in name:
fs = dvc.api.DVCFileSystem(url=remote, rev=rev)

components = name.split("+")

Check warning on line 101 in zntrack/core/load.py

View check run for this annotation

Codecov / codecov/patch

zntrack/core/load.py#L101

Added line #L101 was not covered by tests

if len(components) == 3:
parent, attribute, key = components
else:
parent, attribute = components
key = None

Check warning on line 107 in zntrack/core/load.py

View check run for this annotation

Codecov / codecov/patch

zntrack/core/load.py#L107

Added line #L107 was not covered by tests

with fs.open(config.files.zntrack) as fs:
zntrack_config = json.load(fs)
data = zntrack_config[parent][attribute]

Check warning on line 111 in zntrack/core/load.py

View check run for this annotation

Codecov / codecov/patch

zntrack/core/load.py#L111

Added line #L111 was not covered by tests
if key is not None:
try:
data = data[int(key)]
except (ValueError, KeyError):
data = data[key]

Check warning on line 116 in zntrack/core/load.py

View check run for this annotation

Codecov / codecov/patch

zntrack/core/load.py#L114-L116

Added lines #L114 - L116 were not covered by tests
assert (
data["_type"] == "zntrack.Node"
), f"Expected zntrack.Node, got {data['_type']}"
package_and_module = data["value"]["module"]
cls_name = data["value"]["cls"]
module = None

Check warning on line 122 in zntrack/core/load.py

View check run for this annotation

Codecov / codecov/patch

zntrack/core/load.py#L120-L122

Added lines #L120 - L122 were not covered by tests
else:
stage = _get_stage(name, remote, rev)

cmd = stage.cmd
run_str = cmd.split()[2]
name = cmd.split()[4]

package_and_module, cls_name = run_str.rsplit(".", 1)
module = None
try:
module = importlib.import_module(package_and_module)
except ModuleNotFoundError:
Expand Down
6 changes: 4 additions & 2 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,12 @@ def convert_notebook(cls, nb_name: str = None):
@property
def _init_descriptors_(self):
from zntrack import fields
from zntrack.fields.dependency import Dependency

return [
fields.zn.Params,
fields.zn.Dependency,
Dependency,
fields.meta.Text,
fields.meta.Environment,
fields.dvc.DVCOption,
Expand Down Expand Up @@ -367,7 +369,7 @@ def get_dvc_cmd(

@dataclasses.dataclass
class NodeIdentifier:
"""All information that uniquly identifies a node."""
"""All information that uniquely identifies a node."""

module: str
cls: str
Expand All @@ -378,7 +380,7 @@ class NodeIdentifier:
@classmethod
def from_node(cls, node: Node):
"""Create a _NodeIdentifier from a Node object."""
# TODO module and cls are not needed (from_rev can handle name, rev, remote only)
# TODO module and cls are only required for `zn.nodes`
return cls(
module=module_handler(node),
cls=node.__class__.__name__,
Expand Down
41 changes: 41 additions & 0 deletions zntrack/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

These nodes are primarily used for testing and demonstration purposes.
"""
import json
import pathlib
import random
import typing as t

import pandas as pd

import zntrack
Expand Down Expand Up @@ -120,3 +125,39 @@ class WriteDVCOuts(zntrack.Node):
def run(self):
"""Write an output file."""
self.outs.write_text(str(self.params))


class ComputeRandomNumber(zntrack.Node):
"""Compute a random number."""

params_file = zntrack.params_path()

number = zntrack.outs()

def _post_init_(self):
self.params_file = pathlib.Path(self.params_file)

def run(self):
"""Compute a random number."""
self.number = self.get_random_number()

def get_random_number(self):
"""Compute a random number."""
params = json.loads(self.params_file.read_text())
random.seed(params["seed"])
return random.randint(params["min"], params["max"])

def write_params(self, min, max, seed):
"""Write params to file."""
self.params_file.write_text(json.dumps({"min": min, "max": max, "seed": seed}))


class SumRandomNumbers(zntrack.Node):
"""Sum a list of random numbers."""

numbers: t.List[ComputeRandomNumber] = zntrack.deps()
result: int = zntrack.outs()

def run(self):
"""Sum a list of random numbers."""
self.result = sum(x.get_random_number() for x in self.numbers)
Loading
Loading