Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

628 reconstruct node from dvcyaml #629

Merged
merged 12 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ repos:
rev: 'v0.0.269'
hooks:
- id: ruff
args: ['--fix']
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
hooks:
Expand Down
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,26 @@ Node object.

```python
hello_world.load()
print(hello_world.random_numer)
print(hello_world.random_number)
```

> ## Tip
>
> You can easily load this Node directly from a repository.
>
> ```python
> import zntrack
>
> node = zntrack.from_rev(
> "HelloWorld",
> remote="https://github.com/PythonFZ/ZnTrackExamples.git",
> rev="b9316bf",
> )
> ```
>
> Try accessing the `max_number` parameter and `random_number` output. All Nodes
> from this and many other repositories can be loaded like this.

An overview of all the ZnTrack features as well as more detailed examples can be
found in the [ZnTrack Documentation](https://zntrack.readthedocs.io/en/latest/).

Expand Down
33 changes: 33 additions & 0 deletions tests/integration/test_from_rev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dvc.scm
import pytest

import zntrack


def test_module_not_installed():
with pytest.raises(ModuleNotFoundError):
zntrack.from_rev(
"ASEMD",
remote="https://github.com/IPSProjects/IPS-Water",
rev="ca0eef0ccfcbfb72a82136849a9ca35eac8b7629",
)


def test_commit_not_found():
with pytest.raises(dvc.scm.RevError):
zntrack.from_rev(
"ASEMD",
remote="https://github.com/IPSProjects/IPS-Water",
rev="this-does-not-exist",
)


def test_import_from_remote():
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="b9316bf",
)
assert node.max_number == 512
assert node.random_number == 64
assert node.name == "HelloWorld"
6 changes: 6 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def test_WriteIO_no_name(tmp_path_2, assert_before_exp):
assert exp2["WriteIO"].inputs == "Lorem Ipsum"
assert exp2["WriteIO"].outputs == "Lorem Ipsum"

assert zntrack.from_rev("WriteIO", rev=exp1.name).inputs == "Hello World"
assert zntrack.from_rev("WriteIO", rev=exp1.name).outputs == "Hello World"

assert zntrack.from_rev("WriteIO", rev=exp2.name).inputs == "Lorem Ipsum"
assert zntrack.from_rev("WriteIO", rev=exp2.name).outputs == "Lorem Ipsum"


def test_project_remove_graph(proj_path):
with zntrack.Project() as project:
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def test_AddNumbers_remove_params(proj_path):
add_numbers.load()


def test_znrack_from_rev(proj_path):
with zntrack.Project() as project:
add_numbers = AddNumbers(a=1, b=2)

assert not add_numbers.state.loaded

project.run()

node = zntrack.from_rev(add_numbers.name)
assert node.a == 1
assert node.b == 2
assert node.c == 3
assert node.state.loaded


@pytest.mark.parametrize("eager", [True, False])
def test_AddNumbers_named(proj_path, eager):
with zntrack.Project() as project:
Expand Down
2 changes: 2 additions & 0 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import importlib.metadata

from zntrack import exceptions, tools
from zntrack.core.load import from_rev
from zntrack.core.node import Node
from zntrack.core.nodify import NodeConfig, nodify
from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn
Expand All @@ -29,4 +30,5 @@
"NodeConfig",
"tools",
"exceptions",
"from_rev",
]
117 changes: 117 additions & 0 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Load a node from a dvc stage."""

import contextlib
import importlib
import importlib.util
import pathlib
import sys
import tempfile
import typing
import uuid

import dvc.api
import dvc.repo
import dvc.stage

from zntrack.core.node import Node

T = typing.TypeVar("T", bound=Node)


def _get_stage(name, remote, rev) -> dvc.stage.PipelineStage:
"""Get a stage from a dvc.Repo."""
with dvc.repo.Repo.open(url=remote, rev=rev) as repo:
for stage in repo.index.stages:
with contextlib.suppress(AttributeError):
# non pipeline stage don't have name
if stage.name == name:
return stage

raise ValueError(
f"Stage {name} not found in {remote}" + (f"/tree/{rev}" if rev else "")
)


def _import_from_tempfile(package_and_module: str, remote, rev):
"""Create a temporary file to import from.

Parameters
----------
package_and_module : str
The package and module to import, e.g. "zntrack.core.node.Node".
remote : str
The remote to load the module from.
rev : str
The revision to load the module from.

Returns
-------
ModuleType
The imported module.

Raises
------
ModuleNotFoundError
If the module could not be found.
FileNotFoundError
If the file could not be found.
"""
file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py")
fs = dvc.api.DVCFileSystem(url=remote, rev=rev)
with tempfile.NamedTemporaryFile(suffix=".py") as temp_file, fs.open(file) as f:
temp_file.write(f.read())
temp_file.flush()

# we use a random uuid to avoid name clashes
ref_module = f"{uuid.uuid4()}.{package_and_module}"

spec = importlib.util.spec_from_file_location(ref_module, temp_file.name)
module = importlib.util.module_from_spec(spec)
sys.modules[ref_module] = module
spec.loader.exec_module(module)
return module


def from_rev(name, remote=".", rev=None, **kwargs) -> T:
"""Load a ZnTrack Node by its name.

Parameters
----------
name : str
The name of the node.
remote : str, optional
The remote to load the node from. Defaults to workspace.
rev : str, optional
The revision to load the node from. Defaults to HEAD.
**kwargs
Additional keyword arguments to pass to the node's constructor.

Returns
-------
Node
The loaded node.
"""
stage = _get_stage(name, remote, rev)

cmd = stage.cmd
run_str = cmd.split()[2]
name = cmd.split()[4]

package_and_module, cls_name = run_str.rsplit(".", 1)
module = None
try:
module = importlib.import_module(package_and_module)
except ModuleNotFoundError:
with contextlib.suppress(FileNotFoundError, ModuleNotFoundError):
module = _import_from_tempfile(package_and_module, remote, rev)

if module is None:
module_name = package_and_module.split(".")[0]
raise ModuleNotFoundError(
f"No module named '{module_name}'. The package might be available via 'pip"
f" install {module_name}' or from the remote via 'pip install git+{remote}'."
)

cls = getattr(module, cls_name)

return cls.from_rev(name, remote, rev, **kwargs)