Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Apr 15, 2024
1 parent 4819ce9 commit e135bd2
Showing 17 changed files with 84 additions and 22 deletions.
26 changes: 14 additions & 12 deletions tests/integration/function_wrapper/test_single_function.py
Original file line number Diff line number Diff line change
@@ -24,18 +24,20 @@ def test_example_func(proj_path):

def test_example_func_dry_run(proj_path):
script = example_func(dry_run=True)
assert " ".join(script) == " ".join([
"stage",
"add",
"-n",
"example_func",
"--force",
"--params",
"params.yaml:example_func",
"--outs",
"test.txt",
"zntrack run test_single_function.example_func",
])
assert " ".join(script) == " ".join(
[
"stage",
"add",
"-n",
"example_func",
"--force",
"--params",
"params.yaml:example_func",
"--outs",
"test.txt",
"zntrack run test_single_function.example_func",
]
)


@zntrack.nodify(outs=[pathlib.Path("test.txt")], params={"text": "Lorem Ipsum"})
22 changes: 12 additions & 10 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
@@ -141,16 +141,18 @@ def test_list_groups(proj_path, runner):
"ParamsToOuts",
"ParamsToOuts_1",
],
"nested": [{
"GRP1": [
"ParamsToOuts -> nested_GRP1_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1",
],
"GRP2": [
"ParamsToOuts -> nested_GRP2_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1",
],
}],
"nested": [
{
"GRP1": [
"ParamsToOuts -> nested_GRP1_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1",
],
"GRP2": [
"ParamsToOuts -> nested_GRP2_ParamsToOuts",
"ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1",
],
}
],
}

groups, _ = utils.cli.get_groups(remote=proj_path, rev=None)
2 changes: 2 additions & 0 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ def _import_from_tempfile(package_and_module: str, remote, rev):
If the module could not be found.
FileNotFoundError
If the file could not be found.
"""
file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py")
fs = dvc.api.DVCFileSystem(url=remote, rev=rev)
@@ -93,6 +94,7 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T:
-------
Node
The loaded node.
"""
if isinstance(name, Node):
name = name.name
4 changes: 4 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@ class NodeStatus:
The temporary path used for loading the data.
This is only set within the context manager 'use_tmp_path'.
If neither 'remote' nor 'rev' are set, tmp_path will not be used.
"""

loaded: bool
@@ -202,6 +203,7 @@ class Node(zninit.ZnInit, znflow.Node):
information about the state of the Node.
nwd : pathlib.Path
the node working directory.
"""

_state: NodeStatus = None
@@ -235,6 +237,7 @@ def convert_notebook(cls, nb_name: str = None):
----------
nb_name: str
Notebook name when not using config.nb_name (this is not recommended)
"""
# TODO this should not be a class method, but a function.
jupyter_class_to_file(nb_name=nb_name, module_name=cls.__name__)
@@ -324,6 +327,7 @@ def load(self, lazy: bool = None, results: bool = True) -> None:
Whether to load the node lazily. If None, the value from the config is used.
results : bool, default = True
Whether to load the results. If False, only the parameters are loaded.
"""
from zntrack.fields.field import Field, FieldGroup

6 changes: 6 additions & 0 deletions zntrack/core/nodify.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ class DVCRunOptions:
References
----------
https://dvc.org/doc/command-reference/run#options.
"""

no_commit: bool
@@ -51,6 +52,7 @@ def dvc_args(self) -> list:
-------
list: A list of strings for the subprocess call, e.g.:
["--no-commit", "--external"].
"""
out = []
for datacls_field in dataclasses.fields(self):
@@ -97,6 +99,7 @@ def prepare_dvc_script(
-------
list[str]
The list to be passed to the subprocess call.
"""
script = ["stage", "add", "-n", node_name]
script += dvc_run_option.dvc_args
@@ -134,6 +137,7 @@ def check_type(
accept None even if not in types.
allow_dict:
allow for {key: types}
"""
if isinstance(obj, (list, tuple, set)) and allow_iterable:
for value in obj:
@@ -254,6 +258,7 @@ def save_node_config_to_files(cfg: NodeConfig, node_name: str):
The NodeConfig object which should be serialized to zntrack.json / params.yaml
node_name: str
The name of the node, usually func.__name__.
"""
for value_name, value in dataclasses.asdict(cfg).items():
if value_name == "params":
@@ -339,6 +344,7 @@ def nodify(
References
----------
https://dvc.org/doc/command-reference/run#options
"""
cfg_ = NodeConfig(
outs=outs,
3 changes: 3 additions & 0 deletions zntrack/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ def __init__(self, arg):
----------
arg : str|Node
Custom Error message or Node that is not available.
"""
if isinstance(arg, str):
super().__init__(arg)
@@ -33,6 +34,7 @@ def __init__(self, node, field, instance):
The 'zn.nodes' field
instance : Node
The node that contains the 'zn.nodes' field
"""
msg = (
f"Can not set '{field.name}' of Node<'{instance.name}'> to"
@@ -59,6 +61,7 @@ def __init__(self, node):
----------
node: Node
The node that is already on the graph.
"""
msg = (
f"Node name '{node.name}' is already used in the graph. Please use"
1 change: 1 addition & 0 deletions zntrack/fields/dependency.py
Original file line number Diff line number Diff line change
@@ -96,6 +96,7 @@ def _get_nodes_on_off_graph(self, instance) -> t.Tuple[list, list]:
The nodes that are on the graph.
off_graph : list
The nodes that are off the graph.
"""
values = getattr(instance, self.name)
# TODO use IterableHandler?
1 change: 1 addition & 0 deletions zntrack/fields/dvc/options.py
Original file line number Diff line number Diff line change
@@ -111,6 +111,7 @@ def get_data(self, instance: "Node") -> any:
-------
any
The value of the field from the configuration file.
"""
zntrack_dict = json.loads(
instance.state.fs.read_text("zntrack.json"),
9 changes: 9 additions & 0 deletions zntrack/fields/field.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ class Field(zninit.Descriptor, abc.ABC):
----------
dvc_option : str
The dvc command option for this field.
"""

dvc_option: str = None
@@ -49,6 +50,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The Node instance to save the field for.
"""
raise NotImplementedError

@@ -70,6 +72,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
The affected files.
"""
raise NotImplementedError

@@ -83,6 +86,7 @@ def load(self, instance: "Node", lazy: bool = None):
lazy : bool, optional
Whether to load the field lazily.
This only applies to 'LazyField' classes.
"""
try:
instance.__dict__[self.name] = self.get_data(instance)
@@ -103,6 +107,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
typing.List[tuple]
The stage add argument for this field.
"""
return [
(f"--{self.dvc_option}", pathlib.Path(x).as_posix())
@@ -127,6 +132,7 @@ def get_optional_dvc_cmd(
-------
typing.List[str]
The optional dvc commands.
"""
return []

@@ -173,6 +179,7 @@ def get_value_except_lazy(self, instance):
------
DataIsLazyError
If the value is lazy.
"""
with contextlib.suppress(KeyError):
if instance.__dict__[self.name] is LazyOption:
@@ -198,6 +205,7 @@ def load(self, instance: "Node", lazy: bool = None):
The Node instance to load the field for.
lazy : bool, optional
Whether to load the field lazily, by default 'zntrack.config.lazy'.
"""
if lazy in {None, True} and config.lazy:
instance.__dict__[self.name] = LazyOption
@@ -226,6 +234,7 @@ def __init__(
----------
use_global_plots : bool
Save the plots config not in 'stages' but in 'plots' in the dvc.yaml file.
"""
super().__init__(*args, **kwargs)
self.plots_options = {}
5 changes: 5 additions & 0 deletions zntrack/fields/fields.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ def outs():
The object is serialized and deserialized by ZnTrack
and stored in the node working directory.
see https://dvc.org/doc/command-reference/stage/add#-o
"""
return Output(dvc_option="outs", use_repr=False)

@@ -49,6 +50,7 @@ def params(*args, **kwargs):
see https://dvc.org/doc/command-reference/stage/add#-p
kwargs: dict
Additional keyword arguments.
"""
return Params(*args, **kwargs)

@@ -63,6 +65,7 @@ def deps(*data):
This can either be a Node or an attribute of a Node.
It can not be an object that is not part of the Node graph.
see https://dvc.org/doc/command-reference/stage/add#-d
"""
return Dependency(*data)

@@ -132,6 +135,7 @@ def params_path(*args, **kwargs):
see https://dvc.org/doc/command-reference/stage/add#-p
kwargs: dict
Additional keyword arguments.
"""
return DVCOption(*args, dvc_option="params", **kwargs)

@@ -163,5 +167,6 @@ def plots_path(*args, dvc_option="plots", **kwargs):
The DVC option to use for this field.
kwargs: dict
Additional keyword arguments that are used for plotting.
"""
return PlotsOption(*args, dvc_option=dvc_option, **kwargs)
8 changes: 8 additions & 0 deletions zntrack/fields/zn/options.py
Original file line number Diff line number Diff line change
@@ -103,6 +103,7 @@ class Params(Field):
----------
dvc_option: str
The DVC option to use. Default is "params".
"""

dvc_option: str = "params"
@@ -115,6 +116,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
A list of file paths.
"""
return [config.files.params]

@@ -125,6 +127,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance associated with this field.
"""
file = self.get_files(instance)[0]

@@ -161,6 +164,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
list
A list of tuples containing the DVC option and the file path.
"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", f"{file}:{instance.name}")]
@@ -180,6 +184,7 @@ def __init__(self, dvc_option: str, **kwargs):
The DVC option used to specify the output file.
**kwargs
Additional arguments to pass to the parent constructor.
"""
self.dvc_option = dvc_option
super().__init__(**kwargs)
@@ -196,6 +201,7 @@ def get_files(self, instance) -> list:
-------
list
A list containing the path of the file.
"""
return [get_nwd(instance) / f"{self.name}.json"]

@@ -206,6 +212,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance.
"""
try:
value = self.get_value_except_lazy(instance)
@@ -236,6 +243,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
-------
list
A list containing the DVC command for this field.
"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", file.as_posix())]
4 changes: 4 additions & 0 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ class Project:
node as the node name. E.g. `node = Node()` will result in a node name of 'node'.
If used within a group, the group name will be added to the node name. E.g.
`group.name = Grp1` and `model = Node()` will result in a name of 'Grp1_model'.
"""

graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False)
@@ -111,6 +112,7 @@ def __post_init__(self):
remove_existing_graph : bool, default = False
If True, remove 'dvc.yaml', 'zntrack.json' and 'params.yaml'
before writing new nodes.
"""
self.graph.project = self
if self.initialize:
@@ -157,6 +159,7 @@ def group(self, *names: typing.List[str]):
The name of the group. If None, the group will be named 'GroupX' where X is
the number of groups + 1. If more than one name is given, the groups will
be nested to 'nwd = name[0]/name[1]/.../name[-1]'
"""
if not names:
name = "Group1"
@@ -248,6 +251,7 @@ def run(
auto_remove : bool, default = False
If True, remove all nodes from 'dvc.yaml' that are not in the graph.
This is the same as calling 'project.auto_remove()'
"""
if not save and not eager:
raise ValueError("Save can only be false if eager is True")
1 change: 1 addition & 0 deletions zntrack/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ def timeit(field: str):
field : str
The field to store the time in.
The value is stored as {func_name: time} or {func_name: [time1, time2, ...]}
"""

def decorator(func):
5 changes: 5 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ def __init__(self) -> None:
------
NotImplementedError:
This class is not meant to be instantiated.
"""
raise NotImplementedError("This class is not meant to be instantiated.")

@@ -60,6 +61,7 @@ def module_handler(obj) -> str:
----------
obj:
Any object that implements __module__
"""
if config.nb_name:
try:
@@ -111,6 +113,7 @@ def run_dvc_cmd(script, stdout=None):
------
DVCProcessError:
if the dvc cli command fails.
"""
dvc_short_string = " ".join(script[:5])
if len(script) > 5:
@@ -177,6 +180,7 @@ class NodeStatusResults(enum.Enum):
the Node instance has failed to run.
AVAILABLE : int
the Node instance was loaded and results are available.
"""

UNKNOWN = 0
@@ -202,6 +206,7 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory:
-------
temp_dir:
The temporary directory file. Close with temp_dir.cleanup() at the end.
"""
temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
# add ignore_cleanup_errors=True in Py3.10?
2 changes: 2 additions & 0 deletions zntrack/utils/cli.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ def check_empty(self):
Raises
------
typer.Exit: if the directory is not empty and force is false
"""
is_empty = not any(pathlib.Path(".").iterdir())
if not is_empty and not self.force:
@@ -92,6 +93,7 @@ def get_groups(remote, rev) -> (dict, list):
values. Contains "short-name -> long-name" if inside a group.
node_names: list
A list of all node names in the project.
"""
fs = DVCFileSystem(url=remote, rev=rev)
with fs.open("zntrack.json") as f:
3 changes: 3 additions & 0 deletions zntrack/utils/config.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ class Files:
Notes
-----
Currently frozen because changing the value is not tested.
"""

zntrack: Path = Path("zntrack.json")
@@ -55,6 +56,7 @@ class Config:
Use the `dvc.cli.main` function instead of subprocess
disable_operating_directory: bool, default = False
Global config to disable operating directory context manager.
"""

nb_name: str = None
@@ -86,6 +88,7 @@ def updated_config(self, **kwargs) -> None:
Yields
------
Environment with temporarily changed config.
"""
state = {}
for key, value in kwargs.items():
4 changes: 4 additions & 0 deletions zntrack/utils/file_io.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ def read_file(file: pathlib.Path) -> dict:
-------
dict:
Content of the json/yaml file
"""
if file.suffix in [".yaml", ".yml"]:
file_content = yaml.safe_load(file.read_text())
@@ -51,6 +52,7 @@ def write_file(file: pathlib.Path, value: dict, mkdir: bool = True):
Any serializable data to save
mkdir: bool
Create a parent directory if necessary
"""
if mkdir:
file.parent.mkdir(exist_ok=True, parents=True)
@@ -72,6 +74,7 @@ def clear_config_file(file: typing.Union[pathlib.Path, str], node_name: str):
The file to read from, e.g. params.yaml / zntrack.json
node_name: str
The name of the Node
"""
file = pathlib.Path(file)
try:
@@ -107,6 +110,7 @@ def update_config_file(
be {node_name: value}.
value:
The value to write to the file
"""
# Read file
if node_name is None and value_name is None:

0 comments on commit e135bd2

Please sign in to comment.