Skip to content

Commit

Permalink
628 reconstruct node from dvcyaml (#629)
Browse files Browse the repository at this point in the history
* add from_rev

* add test

* test from_rev for experiments

* test module / wrong rev

* add docstrings

* try loading from temporary file

* remove typehint

* add uuid

* bugfix

* revert changes

* update readme
  • Loading branch information
PythonFZ authored May 26, 2023
1 parent 644ed74 commit a8329ee
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 1 deletion.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ repos:
rev: 'v0.0.269'
hooks:
- id: ruff
args: ['--fix']
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
hooks:
Expand Down
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,26 @@ Node object.

```python
hello_world.load()
print(hello_world.random_numer)
print(hello_world.random_number)
```

> ## Tip
>
> You can easily load this Node directly from a repository.
>
> ```python
> import zntrack
>
> node = zntrack.from_rev(
> "HelloWorld",
> remote="https://github.com/PythonFZ/ZnTrackExamples.git",
> rev="b9316bf",
> )
> ```
>
> Try accessing the `max_number` parameter and `random_number` output. All Nodes
> from this and many other repositories can be loaded like this.
An overview of all the ZnTrack features as well as more detailed examples can be
found in the [ZnTrack Documentation](https://zntrack.readthedocs.io/en/latest/).
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/test_from_rev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dvc.scm
import pytest

import zntrack


def test_module_not_installed():
with pytest.raises(ModuleNotFoundError):
zntrack.from_rev(
"ASEMD",
remote="https://github.com/IPSProjects/IPS-Water",
rev="ca0eef0ccfcbfb72a82136849a9ca35eac8b7629",
)


def test_commit_not_found():
with pytest.raises(dvc.scm.RevError):
zntrack.from_rev(
"ASEMD",
remote="https://github.com/IPSProjects/IPS-Water",
rev="this-does-not-exist",
)


def test_import_from_remote():
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="b9316bf",
)
assert node.max_number == 512
assert node.random_number == 64
assert node.name == "HelloWorld"
6 changes: 6 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def test_WriteIO_no_name(tmp_path_2, assert_before_exp):
assert exp2["WriteIO"].inputs == "Lorem Ipsum"
assert exp2["WriteIO"].outputs == "Lorem Ipsum"

assert zntrack.from_rev("WriteIO", rev=exp1.name).inputs == "Hello World"
assert zntrack.from_rev("WriteIO", rev=exp1.name).outputs == "Hello World"

assert zntrack.from_rev("WriteIO", rev=exp2.name).inputs == "Lorem Ipsum"
assert zntrack.from_rev("WriteIO", rev=exp2.name).outputs == "Lorem Ipsum"


def test_project_remove_graph(proj_path):
with zntrack.Project() as project:
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def test_AddNumbers_remove_params(proj_path):
add_numbers.load()


def test_znrack_from_rev(proj_path):
with zntrack.Project() as project:
add_numbers = AddNumbers(a=1, b=2)

assert not add_numbers.state.loaded

project.run()

node = zntrack.from_rev(add_numbers.name)
assert node.a == 1
assert node.b == 2
assert node.c == 3
assert node.state.loaded


@pytest.mark.parametrize("eager", [True, False])
def test_AddNumbers_named(proj_path, eager):
with zntrack.Project() as project:
Expand Down
2 changes: 2 additions & 0 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import importlib.metadata

from zntrack import exceptions, tools
from zntrack.core.load import from_rev
from zntrack.core.node import Node
from zntrack.core.nodify import NodeConfig, nodify
from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn
Expand All @@ -29,4 +30,5 @@
"NodeConfig",
"tools",
"exceptions",
"from_rev",
]
117 changes: 117 additions & 0 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Load a node from a dvc stage."""

import contextlib
import importlib
import importlib.util
import pathlib
import sys
import tempfile
import typing
import uuid

import dvc.api
import dvc.repo
import dvc.stage

from zntrack.core.node import Node

T = typing.TypeVar("T", bound=Node)


def _get_stage(name, remote, rev) -> dvc.stage.PipelineStage:
"""Get a stage from a dvc.Repo."""
with dvc.repo.Repo.open(url=remote, rev=rev) as repo:
for stage in repo.index.stages:
with contextlib.suppress(AttributeError):
# non pipeline stage don't have name
if stage.name == name:
return stage

raise ValueError(
f"Stage {name} not found in {remote}" + (f"/tree/{rev}" if rev else "")
)


def _import_from_tempfile(package_and_module: str, remote, rev):
"""Create a temporary file to import from.
Parameters
----------
package_and_module : str
The package and module to import, e.g. "zntrack.core.node.Node".
remote : str
The remote to load the module from.
rev : str
The revision to load the module from.
Returns
-------
ModuleType
The imported module.
Raises
------
ModuleNotFoundError
If the module could not be found.
FileNotFoundError
If the file could not be found.
"""
file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py")
fs = dvc.api.DVCFileSystem(url=remote, rev=rev)
with tempfile.NamedTemporaryFile(suffix=".py") as temp_file, fs.open(file) as f:
temp_file.write(f.read())
temp_file.flush()

# we use a random uuid to avoid name clashes
ref_module = f"{uuid.uuid4()}.{package_and_module}"

spec = importlib.util.spec_from_file_location(ref_module, temp_file.name)
module = importlib.util.module_from_spec(spec)
sys.modules[ref_module] = module
spec.loader.exec_module(module)
return module


def from_rev(name, remote=".", rev=None, **kwargs) -> T:
"""Load a ZnTrack Node by its name.
Parameters
----------
name : str
The name of the node.
remote : str, optional
The remote to load the node from. Defaults to workspace.
rev : str, optional
The revision to load the node from. Defaults to HEAD.
**kwargs
Additional keyword arguments to pass to the node's constructor.
Returns
-------
Node
The loaded node.
"""
stage = _get_stage(name, remote, rev)

cmd = stage.cmd
run_str = cmd.split()[2]
name = cmd.split()[4]

package_and_module, cls_name = run_str.rsplit(".", 1)
module = None
try:
module = importlib.import_module(package_and_module)
except ModuleNotFoundError:
with contextlib.suppress(FileNotFoundError, ModuleNotFoundError):
module = _import_from_tempfile(package_and_module, remote, rev)

if module is None:
module_name = package_and_module.split(".")[0]
raise ModuleNotFoundError(
f"No module named '{module_name}'. The package might be available via 'pip"
f" install {module_name}' or from the remote via 'pip install git+{remote}'."
)

cls = getattr(module, cls_name)

return cls.from_rev(name, remote, rev, **kwargs)

0 comments on commit a8329ee

Please sign in to comment.