diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 561a3fbe..b2dc38a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,6 +40,7 @@ repos: rev: 'v0.0.269' hooks: - id: ruff + args: ['--fix'] - repo: https://github.com/executablebooks/mdformat rev: 0.7.16 hooks: diff --git a/README.md b/README.md index 0442194a..66ec185d 100644 --- a/README.md +++ b/README.md @@ -73,9 +73,26 @@ Node object. ```python hello_world.load() -print(hello_world.random_numer) +print(hello_world.random_number) ``` +> ## Tip +> +> You can easily load this Node directly from a repository. +> +> ```python +> import zntrack +> +> node = zntrack.from_rev( +> "HelloWorld", +> remote="https://github.com/PythonFZ/ZnTrackExamples.git", +> rev="b9316bf", +> ) +> ``` +> +> Try accessing the `max_number` parameter and `random_number` output. All Nodes +> from this and many other repositories can be loaded like this. + An overview of all the ZnTrack features as well as more detailed examples can be found in the [ZnTrack Documentation](https://zntrack.readthedocs.io/en/latest/). diff --git a/tests/integration/test_from_rev.py b/tests/integration/test_from_rev.py new file mode 100644 index 00000000..eb0719b2 --- /dev/null +++ b/tests/integration/test_from_rev.py @@ -0,0 +1,33 @@ +import dvc.scm +import pytest + +import zntrack + + +def test_module_not_installed(): + with pytest.raises(ModuleNotFoundError): + zntrack.from_rev( + "ASEMD", + remote="https://github.com/IPSProjects/IPS-Water", + rev="ca0eef0ccfcbfb72a82136849a9ca35eac8b7629", + ) + + +def test_commit_not_found(): + with pytest.raises(dvc.scm.RevError): + zntrack.from_rev( + "ASEMD", + remote="https://github.com/IPSProjects/IPS-Water", + rev="this-does-not-exist", + ) + + +def test_import_from_remote(): + node = zntrack.from_rev( + "HelloWorld", + remote="https://github.com/PythonFZ/ZnTrackExamples.git", + rev="b9316bf", + ) + assert node.max_number == 512 + assert node.random_number == 64 + assert node.name == "HelloWorld" diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index 12ec5b71..7172b067 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -72,6 +72,12 @@ def test_WriteIO_no_name(tmp_path_2, assert_before_exp): assert exp2["WriteIO"].inputs == "Lorem Ipsum" assert exp2["WriteIO"].outputs == "Lorem Ipsum" + assert zntrack.from_rev("WriteIO", rev=exp1.name).inputs == "Hello World" + assert zntrack.from_rev("WriteIO", rev=exp1.name).outputs == "Hello World" + + assert zntrack.from_rev("WriteIO", rev=exp2.name).inputs == "Lorem Ipsum" + assert zntrack.from_rev("WriteIO", rev=exp2.name).outputs == "Lorem Ipsum" + def test_project_remove_graph(proj_path): with zntrack.Project() as project: diff --git a/tests/integration/test_single_node.py b/tests/integration/test_single_node.py index 9b9fafc2..f15c4dbd 100644 --- a/tests/integration/test_single_node.py +++ b/tests/integration/test_single_node.py @@ -46,6 +46,21 @@ def test_AddNumbers_remove_params(proj_path): add_numbers.load() +def test_znrack_from_rev(proj_path): + with zntrack.Project() as project: + add_numbers = AddNumbers(a=1, b=2) + + assert not add_numbers.state.loaded + + project.run() + + node = zntrack.from_rev(add_numbers.name) + assert node.a == 1 + assert node.b == 2 + assert node.c == 3 + assert node.state.loaded + + @pytest.mark.parametrize("eager", [True, False]) def test_AddNumbers_named(proj_path, eager): with zntrack.Project() as project: diff --git a/zntrack/__init__.py b/zntrack/__init__.py index cff88f76..ea203a5b 100644 --- a/zntrack/__init__.py +++ b/zntrack/__init__.py @@ -5,6 +5,7 @@ import importlib.metadata from zntrack import exceptions, tools +from zntrack.core.load import from_rev from zntrack.core.node import Node from zntrack.core.nodify import NodeConfig, nodify from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn @@ -29,4 +30,5 @@ "NodeConfig", "tools", "exceptions", + "from_rev", ] diff --git a/zntrack/core/load.py b/zntrack/core/load.py new file mode 100644 index 00000000..7f5c1e30 --- /dev/null +++ b/zntrack/core/load.py @@ -0,0 +1,117 @@ +"""Load a node from a dvc stage.""" + +import contextlib +import importlib +import importlib.util +import pathlib +import sys +import tempfile +import typing +import uuid + +import dvc.api +import dvc.repo +import dvc.stage + +from zntrack.core.node import Node + +T = typing.TypeVar("T", bound=Node) + + +def _get_stage(name, remote, rev) -> dvc.stage.PipelineStage: + """Get a stage from a dvc.Repo.""" + with dvc.repo.Repo.open(url=remote, rev=rev) as repo: + for stage in repo.index.stages: + with contextlib.suppress(AttributeError): + # non pipeline stage don't have name + if stage.name == name: + return stage + + raise ValueError( + f"Stage {name} not found in {remote}" + (f"/tree/{rev}" if rev else "") + ) + + +def _import_from_tempfile(package_and_module: str, remote, rev): + """Create a temporary file to import from. + + Parameters + ---------- + package_and_module : str + The package and module to import, e.g. "zntrack.core.node.Node". + remote : str + The remote to load the module from. + rev : str + The revision to load the module from. + + Returns + ------- + ModuleType + The imported module. + + Raises + ------ + ModuleNotFoundError + If the module could not be found. + FileNotFoundError + If the file could not be found. + """ + file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py") + fs = dvc.api.DVCFileSystem(url=remote, rev=rev) + with tempfile.NamedTemporaryFile(suffix=".py") as temp_file, fs.open(file) as f: + temp_file.write(f.read()) + temp_file.flush() + + # we use a random uuid to avoid name clashes + ref_module = f"{uuid.uuid4()}.{package_and_module}" + + spec = importlib.util.spec_from_file_location(ref_module, temp_file.name) + module = importlib.util.module_from_spec(spec) + sys.modules[ref_module] = module + spec.loader.exec_module(module) + return module + + +def from_rev(name, remote=".", rev=None, **kwargs) -> T: + """Load a ZnTrack Node by its name. + + Parameters + ---------- + name : str + The name of the node. + remote : str, optional + The remote to load the node from. Defaults to workspace. + rev : str, optional + The revision to load the node from. Defaults to HEAD. + **kwargs + Additional keyword arguments to pass to the node's constructor. + + Returns + ------- + Node + The loaded node. + """ + stage = _get_stage(name, remote, rev) + + cmd = stage.cmd + run_str = cmd.split()[2] + name = cmd.split()[4] + + package_and_module, cls_name = run_str.rsplit(".", 1) + module = None + try: + module = importlib.import_module(package_and_module) + except ModuleNotFoundError: + with contextlib.suppress(FileNotFoundError, ModuleNotFoundError): + module = _import_from_tempfile(package_and_module, remote, rev) + + if module is None: + module_name = package_and_module.split(".")[0] + raise ModuleNotFoundError( + f"No module named '{module_name}'. The package might be available via 'pip" + f" install {module_name}' or from the remote via 'pip install git+{remote}'." + ) + + cls = getattr(module, cls_name) + + return cls.from_rev(name, remote, rev, **kwargs)