Skip to content

Commit

Permalink
add prepare_dvc_script for nodify and Node (#246)
Browse files Browse the repository at this point in the history
* add `prepare_dvc_script` for `nodify` and `Node`

* add tests

* fix test

* add docstrings

* all dotdict on all NodeConfig attributes

* add NodeConfig dotdict test

- test NodeConfig.convert_fields_to_dotdict()
  • Loading branch information
PythonFZ authored Mar 2, 2022
1 parent d709a62 commit d5067f5
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 117 deletions.
10 changes: 10 additions & 0 deletions tests/integration_tests/function_wrapper/test_single_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,13 @@ def nodea():
@nodify(outs=[25, "str"])
def nodeb():
pass


@nodify(outs={"data": pathlib.Path("data.txt")})
def example_func_with_outs_dict(cfg: NodeConfig):
cfg.outs.data.write_text("Hello World")


def test_example_func_with_outs_dict(proj_path):
example_func_with_outs_dict(run=True)
assert pathlib.Path("data.txt").read_text() == "Hello World"
48 changes: 48 additions & 0 deletions tests/unit_tests/core/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pathlib

import pytest

from zntrack.core.functions.decorator import NodeConfig


def test_not_supported_outs():
with pytest.raises(ValueError):
NodeConfig(outs=25)
with pytest.raises(ValueError):
NodeConfig(outs=[25])
with pytest.raises(ValueError):
NodeConfig(outs={"path": 25})

assert NodeConfig(outs="path").outs == "path"


def test_not_supported_params():
with pytest.raises(ValueError):
NodeConfig(params=25)
with pytest.raises(ValueError):
NodeConfig(params=[25])


def test_supported_params():
assert NodeConfig(params={"name": "John"}).params["name"] == "John"


def test_NodeConfig_convert_fields_to_dotdict():
cfg = NodeConfig(outs="file")
assert cfg.outs == "file"

cfg = NodeConfig(outs={"file": "file.txt"})
cfg.convert_fields_to_dotdict()
assert cfg.outs.file == "file.txt"
assert cfg.outs["file"] == "file.txt"

cfg = NodeConfig(outs={"files": {"data": "datafile.txt"}})
cfg.convert_fields_to_dotdict()
assert cfg.outs.files.data == "datafile.txt"

cfg = NodeConfig(
outs={"files": {"file": "datafile.txt", "path": pathlib.Path("data.txt")}}
)
cfg.convert_fields_to_dotdict()
assert cfg.outs.files.file == "datafile.txt"
assert cfg.outs.files.path == pathlib.Path("data.txt")
64 changes: 64 additions & 0 deletions tests/unit_tests/core/test_dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
filter_ZnTrackOption,
handle_deps,
handle_dvc,
prepare_dvc_script,
)


Expand Down Expand Up @@ -131,3 +132,66 @@ def test_descriptor_list_filter():
zntrack_type=utils.ZnTypes.PARAMS,
return_with_type=True,
) == {"params": {"param1": 1, "param2": 2}}


def test_prepare_dvc_script():
dvc_run_option = DVCRunOptions(
no_commit=False,
external=True,
always_changed=True,
no_run_cache=False,
no_exec=True,
force=True,
)

script = prepare_dvc_script(
node_name="node01",
dvc_run_option=dvc_run_option,
custom_args=["--deps", "file.txt"],
nb_name=None,
module="src.file",
func_or_cls="MyNode",
call_args=".load().run_and_save()",
)

assert script == [
"dvc",
"run",
"-n",
"node01",
"--external",
"--always-changed",
"--no-exec",
"--force",
"--deps",
"file.txt",
f'{utils.get_python_interpreter()} -c "from src.file import MyNode;'
' MyNode.load().run_and_save()" ',
]

script = prepare_dvc_script(
node_name="node01",
dvc_run_option=dvc_run_option,
custom_args=["--deps", "file.txt"],
nb_name="notebook.ipynb",
module="src.file",
func_or_cls="MyNode",
call_args=".load().run_and_save()",
)

assert script == [
"dvc",
"run",
"-n",
"node01",
"--external",
"--always-changed",
"--no-exec",
"--force",
"--deps",
"file.txt",
"--deps",
"src/file.py",
f'{utils.get_python_interpreter()} -c "from src.file import MyNode;'
' MyNode.load().run_and_save()" ',
]
4 changes: 4 additions & 0 deletions tests/unit_tests/utlis/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def test_check_type():
assert not utils.check_type([None], str, allow_none=True)
assert utils.check_type([None], str, allow_none=True, allow_iterable=True)

assert utils.check_type({"key": "val"}, str, allow_dict=True)
assert not utils.check_type({"key": "val"}, str)
assert utils.check_type({"a": {"b": "c"}}, str, allow_dict=True)


def test_python_interpreter():
assert utils.get_python_interpreter() in ["python", "python3"]
Expand Down
82 changes: 65 additions & 17 deletions zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,54 @@ def filter_ZnTrackOption(
return {x.name: getattr(cls, x.name) for x in data}


def prepare_dvc_script(
node_name,
dvc_run_option: DVCRunOptions,
custom_args: list,
nb_name,
module,
func_or_cls,
call_args,
) -> list:
"""Prepate the dvc cmd to be called by subprocess
Parameters
----------
node_name: str
Name of the Node
dvc_run_option: DVCRunOptions
dataclass to collect special DVC run options
custom_args: list[str]
all the params / deps / ... to be added to the script
nb_name: str|None
Notebook name for jupyter support
module: str like "src.my_module"
func_or_cls: str
The name of the Node class or function to be imported and run
call_args: str
Additional str like "(run_func=True)" or ".load().run_and_save"
Returns
-------
list[str]
The list to be passed to the subprocess call
"""
script = ["dvc", "run", "-n", node_name]
script += dvc_run_option.dvc_args
script += custom_args

if nb_name is not None:
script += ["--deps", utils.module_to_path(module).as_posix()]

import_str = f"""{utils.get_python_interpreter()} -c "from {module} import """
import_str += f"""{func_or_cls}; {func_or_cls}{call_args}" """
script += [import_str]
log.debug(f"dvc script: {' '.join([str(x) for x in script])}")
return script


class GraphWriter:
"""Write the DVC Graph
Expand Down Expand Up @@ -288,42 +336,39 @@ def write_graph(

log.debug("--- Writing new DVC file ---")

script = ["dvc", "run", "-n", self.node_name]

script += DVCRunOptions(
dvc_run_option = DVCRunOptions(
no_commit=no_commit,
external=external,
always_changed=always_changed,
no_run_cache=no_run_cache,
no_exec=no_exec,
force=force,
).dvc_args
)

# Jupyter Notebook
nb_name = utils.update_nb_name(nb_name)

if nb_name is not None:
self._module = f"{utils.config.nb_class_path}.{self.__class__.__name__}"
if notebook:
self.convert_notebook(nb_name)
script += ["--deps", utils.module_to_path(self.module).as_posix()]

custom_args = []
# Handle Parameter
params_list = filter_ZnTrackOption(
data=self._descriptor_list,
cls=self,
zntrack_type=[utils.ZnTypes.PARAMS],
)
if len(params_list) > 0:
script += [
custom_args += [
"--params",
f"{utils.Files.params}:{self.node_name}",
]
zn_options_set = set()
for option in self._descriptor_list:
if option.zntrack_type == utils.ZnTypes.DVC:
value = getattr(self, option.name)
script += handle_dvc(value, option.dvc_args)
custom_args += handle_dvc(value, option.dvc_args)
# Handle Zn Options
elif option.zntrack_type in [
utils.ZnTypes.RESULTS,
Expand All @@ -337,22 +382,25 @@ def write_graph(
)
elif option.zntrack_type == utils.ZnTypes.DEPS:
value = getattr(self, option.name)
script += handle_deps(value)
custom_args += handle_deps(value)

for pair in zn_options_set:
script += pair
custom_args += pair

script = prepare_dvc_script(
node_name=self.node_name,
dvc_run_option=dvc_run_option,
custom_args=custom_args,
nb_name=nb_name,
module=self.module,
func_or_cls=self.__class__.__name__,
call_args=f".load(name='{self.node_name}').run_and_save()",
)

# Add command to run the script
cls_name = self.__class__.__name__
script.append(
f"""{utils.get_python_interpreter()} -c "from {self.module} import """
f"""{cls_name}; {cls_name}.load(name='{self.node_name}').run_and_save()" """
)

self.save()

log.debug(f"running script: {' '.join([str(x) for x in script])}")

log.debug(
"If you are using a jupyter notebook, you may not be able to see the "
"output in real time!"
Expand Down
Loading

0 comments on commit d5067f5

Please sign in to comment.