From f82a63d0c5fc9f27c562f3443086f4a44a2ce82f Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Tue, 14 May 2024 16:07:02 +0200 Subject: [PATCH 1/4] add "zntrack.apply" --- tests/integration/test_apply.py | 25 +++++++++++++++++++++++++ zntrack/__init__.py | 3 ++- zntrack/cli/__init__.py | 7 +++++-- zntrack/core/node.py | 8 +++++++- zntrack/examples/__init__.py | 4 ++++ zntrack/utils/apply.py | 17 +++++++++++++++++ 6 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 tests/integration/test_apply.py create mode 100644 zntrack/utils/apply.py diff --git a/tests/integration/test_apply.py b/tests/integration/test_apply.py new file mode 100644 index 00000000..ac64ffca --- /dev/null +++ b/tests/integration/test_apply.py @@ -0,0 +1,25 @@ +"""Test the apply function.""" + +import zntrack.examples + + +def test_apply(proj_path) -> None: + """Test the "zntrack.apply" function.""" + project = zntrack.Project() + + JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") + + with project: + a = zntrack.examples.ParamsToOuts(params=["a", "b"]) + b = JoinedParamsToOuts(params=["a", "b"]) + c = zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"]) + + project.run() + + a.load() + b.load() + c.load() + + assert a.outs == ["a", "b"] + assert b.outs == "a-b" + assert c.outs == "a-b-c" diff --git a/zntrack/__init__.py b/zntrack/__init__.py index fc4c6d28..83e0a499 100644 --- a/zntrack/__init__.py +++ b/zntrack/__init__.py @@ -23,7 +23,7 @@ plots_path, ) from zntrack.project import Project -from zntrack.utils import config +from zntrack.utils import apply, config from zntrack.utils.node_wd import nwd __version__ = importlib.metadata.version("zntrack") @@ -45,6 +45,7 @@ "exceptions", "from_rev", "get_nodes", + "apply", ] __all__ += [ diff --git a/zntrack/cli/__init__.py b/zntrack/cli/__init__.py index f5e8b219..08975720 100644 --- a/zntrack/cli/__init__.py +++ b/zntrack/cli/__init__.py @@ -47,7 +47,9 @@ def main( @app.command() -def run(node: str, name: str = None, meta_only: bool = False) -> None: +def run( + node: str, name: str = None, meta_only: bool = False, method: str = "run" +) -> None: """Execute a ZnTrack Node. Use as 'zntrack run module.Node --name node_name'. @@ -80,7 +82,8 @@ def run(node: str, name: str = None, meta_only: bool = False) -> None: node: Node = cls.from_rev(name=name, results=False) node.save(meta_only=True) if not meta_only: - node.run() + # dynamic version of node.run() + getattr(node, method)() node.save(parameter=False) else: raise ValueError(f"Node {node} is not a ZnTrack Node.") diff --git a/zntrack/core/node.py b/zntrack/core/node.py index e05a90fc..4b8ff292 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -443,7 +443,13 @@ def get_dvc_cmd( cmd += ["--outs", f"{(get_nwd(node) /'node-meta.json').as_posix()}"] module = module_handler(node.__class__) - cmd += [f"zntrack run {module}.{node.__class__.__name__} --name {node.name}"] + + zntrack_run = f"zntrack run {module}.{node.__class__.__name__} --name {node.name}" + if hasattr(node, "_method"): + zntrack_run += f" --method {node._method}" + + cmd += [zntrack_run] + optionals = [x for x in optionals if x] # remove empty entries [] return [cmd] + optionals diff --git a/zntrack/examples/__init__.py b/zntrack/examples/__init__.py index d5b848c1..aab5b42e 100644 --- a/zntrack/examples/__init__.py +++ b/zntrack/examples/__init__.py @@ -23,6 +23,10 @@ def run(self) -> None: """Save params to outs.""" self.outs = self.params + def join(self) -> None: + """Join the results.""" + self.outs = "-".join(self.params) + class ParamsToMetrics(zntrack.Node): """Save params to metrics.""" diff --git a/zntrack/utils/apply.py b/zntrack/utils/apply.py new file mode 100644 index 00000000..65232b86 --- /dev/null +++ b/zntrack/utils/apply.py @@ -0,0 +1,17 @@ +"""Zntrack apply module for custom "run" methods.""" + +import typing as t + +o = t.TypeVar("o") + + +def apply(obj: o, method: str) -> o: + """Return a new object like "o" which has the method string attached.""" + + class _(obj): + _method = method + + _.__module__ = obj.__module__ + _.__name__ = obj.__name__ + + return _ From 8d50706319c94c1031f9122807364ff533abcd3f Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Tue, 14 May 2024 16:08:28 +0200 Subject: [PATCH 2/4] rename class name --- zntrack/utils/apply.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/zntrack/utils/apply.py b/zntrack/utils/apply.py index 65232b86..ebb91110 100644 --- a/zntrack/utils/apply.py +++ b/zntrack/utils/apply.py @@ -8,10 +8,16 @@ def apply(obj: o, method: str) -> o: """Return a new object like "o" which has the method string attached.""" - class _(obj): + class MockInheritanceClass(obj): + """Copy of the original class with the new method attribute. + + We can not set the method directly on the original class, because + it would be used by all the other instances of the class as well. + """ + _method = method - _.__module__ = obj.__module__ - _.__name__ = obj.__name__ + MockInheritanceClass.__module__ = obj.__module__ + MockInheritanceClass.__name__ = obj.__name__ - return _ + return MockInheritanceClass From a855cafb2a2399aa018295d9c6061d5b5f12da35 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Tue, 14 May 2024 16:10:39 +0200 Subject: [PATCH 3/4] update docstrings --- zntrack/cli/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/zntrack/cli/__init__.py b/zntrack/cli/__init__.py index 08975720..fe62aeb1 100644 --- a/zntrack/cli/__init__.py +++ b/zntrack/cli/__init__.py @@ -53,6 +53,17 @@ def run( """Execute a ZnTrack Node. Use as 'zntrack run module.Node --name node_name'. + + Arguments: + --------- + node : str + The node to run. + name : str + The name of the node. + meta_only : bool + Save only the metadata. + method : str, default 'run' + The method to run on the node. """ env_file = pathlib.Path("env.yaml") if env_file.exists(): From 9fb5b9581e0bb40117ae2cb0f32609edf8c8fd76 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Tue, 14 May 2024 16:22:10 +0200 Subject: [PATCH 4/4] fix eager --- tests/integration/test_apply.py | 7 +++++-- zntrack/__init__.py | 3 ++- zntrack/project/zntrack_project.py | 5 ++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_apply.py b/tests/integration/test_apply.py index ac64ffca..fc9a953e 100644 --- a/tests/integration/test_apply.py +++ b/tests/integration/test_apply.py @@ -1,9 +1,12 @@ """Test the apply function.""" +import pytest + import zntrack.examples -def test_apply(proj_path) -> None: +@pytest.mark.parametrize("eager", [True, False]) +def test_apply(proj_path, eager) -> None: """Test the "zntrack.apply" function.""" project = zntrack.Project() @@ -14,7 +17,7 @@ def test_apply(proj_path) -> None: b = JoinedParamsToOuts(params=["a", "b"]) c = zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"]) - project.run() + project.run(eager=eager) a.load() b.load() diff --git a/zntrack/__init__.py b/zntrack/__init__.py index 83e0a499..99d4c665 100644 --- a/zntrack/__init__.py +++ b/zntrack/__init__.py @@ -23,7 +23,8 @@ plots_path, ) from zntrack.project import Project -from zntrack.utils import apply, config +from zntrack.utils import config +from zntrack.utils.apply import apply from zntrack.utils.node_wd import nwd __version__ = importlib.metadata.version("zntrack") diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index bfd824d8..e1338583 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -287,7 +287,10 @@ def run( # update connectors log.info(f"Running node {node}") self.graph._update_node_attributes(node, UpdateConnectors()) - node.run() + if hasattr(node, "_method"): + getattr(node, node._method)() + else: + node.run() if save: node.save() node.state.loaded = True