Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 12, 2024
1 parent 3fbf7e3 commit 1546854
Show file tree
Hide file tree
Showing 17 changed files with 84 additions and 22 deletions.
26 changes: 14 additions & 12 deletions tests/integration/function_wrapper/test_single_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ def test_example_func(proj_path):

def test_example_func_dry_run(proj_path):
script = example_func(dry_run=True)
assert " ".join(script) == " ".join([
"stage",
"add",
"-n",
"example_func",
"--force",
"--params",
"params.yaml:example_func",
"--outs",
"test.txt",
"zntrack run test_single_function.example_func",
])
assert " ".join(script) == " ".join(
[
"stage",
"add",
"-n",
"example_func",
"--force",
"--params",
"params.yaml:example_func",
"--outs",
"test.txt",
"zntrack run test_single_function.example_func",
]
)


@zntrack.nodify(outs=[pathlib.Path("test.txt")], params={"text": "Lorem Ipsum"})
Expand Down
22 changes: 12 additions & 10 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,18 @@ def test_list_groups(proj_path, runner):
"ParamsToOuts",
"ParamsToOuts_1",
],
"nested": [{
"GRP1": [
"ParamsToOuts -> nested_GRP1_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1",
],
"GRP2": [
"ParamsToOuts -> nested_GRP2_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1",
],
}],
"nested": [
{
"GRP1": [
"ParamsToOuts -> nested_GRP1_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1",
],
"GRP2": [
"ParamsToOuts -> nested_GRP2_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1",
],
}
],
}

groups, _ = utils.cli.get_groups(remote=proj_path, rev=None)
Expand Down
2 changes: 2 additions & 0 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _import_from_tempfile(package_and_module: str, remote, rev):
If the module could not be found.
FileNotFoundError
If the file could not be found.
"""
file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py")
fs = dvc.api.DVCFileSystem(url=remote, rev=rev)
Expand Down Expand Up @@ -93,6 +94,7 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T:
-------
Node
The loaded node.
"""
if isinstance(name, Node):
name = name.name
Expand Down
4 changes: 4 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class NodeStatus:
The temporary path used for loading the data.
This is only set within the context manager 'use_tmp_path'.
If neither 'remote' nor 'rev' are set, tmp_path will not be used.
"""

loaded: bool
Expand Down Expand Up @@ -182,6 +183,7 @@ class Node(zninit.ZnInit, znflow.Node):
information about the state of the Node.
nwd : pathlib.Path
the node working directory.
"""

_state: NodeStatus = None
Expand Down Expand Up @@ -215,6 +217,7 @@ def convert_notebook(cls, nb_name: str = None):
----------
nb_name: str
Notebook name when not using config.nb_name (this is not recommended)
"""
# TODO this should not be a class method, but a function.
jupyter_class_to_file(nb_name=nb_name, module_name=cls.__name__)
Expand Down Expand Up @@ -302,6 +305,7 @@ def load(self, lazy: bool = None, results: bool = True) -> None:
Whether to load the node lazily. If None, the value from the config is used.
results : bool, default = True
Whether to load the results. If False, only the parameters are loaded.
"""
from zntrack.fields.field import Field, FieldGroup

Expand Down
6 changes: 6 additions & 0 deletions zntrack/core/nodify.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DVCRunOptions:
References
----------
https://dvc.org/doc/command-reference/run#options.
"""

no_commit: bool
Expand All @@ -51,6 +52,7 @@ def dvc_args(self) -> list:
-------
list: A list of strings for the subprocess call, e.g.:
["--no-commit", "--external"].
"""
out = []
for datacls_field in dataclasses.fields(self):
Expand Down Expand Up @@ -97,6 +99,7 @@ def prepare_dvc_script(
-------
list[str]
The list to be passed to the subprocess call.
"""
script = ["stage", "add", "-n", node_name]
script += dvc_run_option.dvc_args
Expand Down Expand Up @@ -134,6 +137,7 @@ def check_type(
accept None even if not in types.
allow_dict:
allow for {key: types}
"""
if isinstance(obj, (list, tuple, set)) and allow_iterable:
for value in obj:
Expand Down Expand Up @@ -254,6 +258,7 @@ def save_node_config_to_files(cfg: NodeConfig, node_name: str):
The NodeConfig object which should be serialized to zntrack.json / params.yaml
node_name: str
The name of the node, usually func.__name__.
"""
for value_name, value in dataclasses.asdict(cfg).items():
if value_name == "params":
Expand Down Expand Up @@ -339,6 +344,7 @@ def nodify(
References
----------
https://dvc.org/doc/command-reference/run#options
"""
cfg_ = NodeConfig(
outs=outs,
Expand Down
3 changes: 3 additions & 0 deletions zntrack/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, arg):
----------
arg : str|Node
Custom Error message or Node that is not available.
"""
if isinstance(arg, str):
super().__init__(arg)
Expand All @@ -33,6 +34,7 @@ def __init__(self, node, field, instance):
The 'zn.nodes' field
instance : Node
The node that contains the 'zn.nodes' field
"""
msg = (
f"Can not set '{field.name}' of Node<'{instance.name}'> to"
Expand All @@ -59,6 +61,7 @@ def __init__(self, node):
----------
node: Node
The node that is already on the graph.
"""
msg = (
f"Node name '{node.name}' is already used in the graph. Please use"
Expand Down
1 change: 1 addition & 0 deletions zntrack/fields/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_nodes_on_off_graph(self, instance) -> t.Tuple[list, list]:
The nodes that are on the graph.
off_graph : list
The nodes that are off the graph.
"""
values = getattr(instance, self.name)
# TODO use IterableHandler?
Expand Down
1 change: 1 addition & 0 deletions zntrack/fields/dvc/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_data(self, instance: "Node") -> any:
-------
any
The value of the field from the configuration file.
"""
zntrack_dict = json.loads(
instance.state.fs.read_text("zntrack.json"),
Expand Down
9 changes: 9 additions & 0 deletions zntrack/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Field(zninit.Descriptor, abc.ABC):
----------
dvc_option : str
The dvc command option for this field.
"""

dvc_option: str = None
Expand All @@ -49,6 +50,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The Node instance to save the field for.
"""
raise NotImplementedError

Expand All @@ -70,6 +72,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
The affected files.
"""
raise NotImplementedError

Expand All @@ -83,6 +86,7 @@ def load(self, instance: "Node", lazy: bool = None):
lazy : bool, optional
Whether to load the field lazily.
This only applies to 'LazyField' classes.
"""
try:
instance.__dict__[self.name] = self.get_data(instance)
Expand All @@ -103,6 +107,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
typing.List[tuple]
The stage add argument for this field.
"""
return [
(f"--{self.dvc_option}", pathlib.Path(x).as_posix())
Expand All @@ -127,6 +132,7 @@ def get_optional_dvc_cmd(
-------
typing.List[str]
The optional dvc commands.
"""
return []

Expand Down Expand Up @@ -173,6 +179,7 @@ def get_value_except_lazy(self, instance):
------
DataIsLazyError
If the value is lazy.
"""
with contextlib.suppress(KeyError):
if instance.__dict__[self.name] is LazyOption:
Expand All @@ -198,6 +205,7 @@ def load(self, instance: "Node", lazy: bool = None):
The Node instance to load the field for.
lazy : bool, optional
Whether to load the field lazily, by default 'zntrack.config.lazy'.
"""
if lazy in {None, True} and config.lazy:
instance.__dict__[self.name] = LazyOption
Expand Down Expand Up @@ -226,6 +234,7 @@ def __init__(
----------
use_global_plots : bool
Save the plots config not in 'stages' but in 'plots' in the dvc.yaml file.
"""
super().__init__(*args, **kwargs)
self.plots_options = {}
Expand Down
5 changes: 5 additions & 0 deletions zntrack/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def outs():
The object is serialized and deserialized by ZnTrack
and stored in the node working directory.
see https://dvc.org/doc/command-reference/stage/add#-o
"""
return Output(dvc_option="outs", use_repr=False)

Expand Down Expand Up @@ -49,6 +50,7 @@ def params(*args, **kwargs):
see https://dvc.org/doc/command-reference/stage/add#-p
kwargs: dict
Additional keyword arguments.
"""
return Params(*args, **kwargs)

Expand All @@ -63,6 +65,7 @@ def deps(*data):
This can either be a Node or an attribute of a Node.
It can not be an object that is not part of the Node graph.
see https://dvc.org/doc/command-reference/stage/add#-d
"""
return Dependency(*data)

Expand Down Expand Up @@ -132,6 +135,7 @@ def params_path(*args, **kwargs):
see https://dvc.org/doc/command-reference/stage/add#-p
kwargs: dict
Additional keyword arguments.
"""
return DVCOption(*args, dvc_option="params", **kwargs)

Expand Down Expand Up @@ -163,5 +167,6 @@ def plots_path(*args, dvc_option="plots", **kwargs):
The DVC option to use for this field.
kwargs: dict
Additional keyword arguments that are used for plotting.
"""
return PlotsOption(*args, dvc_option=dvc_option, **kwargs)
8 changes: 8 additions & 0 deletions zntrack/fields/zn/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Params(Field):
----------
dvc_option: str
The DVC option to use. Default is "params".
"""

dvc_option: str = "params"
Expand All @@ -115,6 +116,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
A list of file paths.
"""
return [config.files.params]

Expand All @@ -125,6 +127,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance associated with this field.
"""
file = self.get_files(instance)[0]

Expand Down Expand Up @@ -161,6 +164,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
list
A list of tuples containing the DVC option and the file path.
"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", f"{file}:{instance.name}")]
Expand All @@ -180,6 +184,7 @@ def __init__(self, dvc_option: str, **kwargs):
The DVC option used to specify the output file.
**kwargs
Additional arguments to pass to the parent constructor.
"""
self.dvc_option = dvc_option
super().__init__(**kwargs)
Expand All @@ -196,6 +201,7 @@ def get_files(self, instance) -> list:
-------
list
A list containing the path of the file.
"""
return [get_nwd(instance) / f"{self.name}.json"]

Expand All @@ -206,6 +212,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance.
"""
try:
value = self.get_value_except_lazy(instance)
Expand Down Expand Up @@ -236,6 +243,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
-------
list
A list containing the DVC command for this field.
"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", file.as_posix())]
Expand Down
Loading

0 comments on commit 1546854

Please sign in to comment.