Skip to content

Commit

Permalink
Bugfixes (#557)
Browse files Browse the repository at this point in the history
* update version callback

* do not load results

* bump version

* fix #555

* fix #539

* sort imports

* fix #543

* fix cli

* lint

* fix detached head issue

* poetry update znflow
  • Loading branch information
PythonFZ authored Mar 28, 2023
1 parent 44deaa9 commit 4f05747
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 199 deletions.
359 changes: 181 additions & 178 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ZnTrack"
version = "0.6.0a2"
version = "0.6.0a3"
description = "Create, Run and Benchmark DVC Pipelines in Python"
authors = ["zincwarecode <zincwarecode@gmail.com>"]
license = "Apache-2.0"
Expand All @@ -19,7 +19,7 @@ typer = "^0.7.0"
dot4dict = "^0.1.1"
zninit = "^0.1.9"
znjson = "^0.2.2"
znflow = "^0.1.6"
znflow = "^0.1.10"


[tool.poetry.urls]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zntrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

def test_version():
"""Test 'ZnTrack' version."""
assert __version__ == "0.6.0a2"
assert __version__ == "0.6.0a3"
10 changes: 10 additions & 0 deletions tests/unit_tests/utils/test_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import zntrack
import znjson
import json
import pytest


@pytest.mark.parametrize("myslice", [slice(1, 2, 3), slice(1, 2), slice(1)])
def test_slice_converter(myslice):
dump = json.dumps(myslice, cls=znjson.ZnEncoder)
assert json.loads(dump, cls=znjson.ZnDecoder) == myslice
22 changes: 19 additions & 3 deletions zntrack/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""The ZnTrack CLI."""
import contextlib
import importlib.metadata
import os
import pathlib
import sys
import uuid

import git
import typer
import yaml

Expand All @@ -16,7 +18,21 @@
def version_callback(value: bool) -> None:
"""Get the installed 'ZnTrack' version."""
if value:
typer.echo(f"ZnTrack {importlib.metadata.version('zntrack')}")
path = pathlib.Path(__file__).parent.parent.parent
report = f"ZnTrack {importlib.metadata.version('zntrack')} at '{path}'"

with contextlib.suppress(git.exc.InvalidGitRepositoryError):
repo = git.Repo(path)
_ = repo.git_dir

report += " - "
with contextlib.suppress(TypeError): # detached head
report += f"{repo.active_branch.name}@"
report += f"{repo.head.object.hexsha[:7]}"
if repo.is_dirty():
report += " (dirty)"

typer.echo(report)
raise typer.Exit()


Expand Down Expand Up @@ -52,9 +68,9 @@ def run(node: str, name: str = None, hash_only: bool = False) -> None:
if getattr(cls, "is_node", False):
cls(exec_func=True)
elif issubclass(cls, Node):
node: Node = cls.from_rev(name=name)
node: Node = cls.from_rev(name=name, results=False)
if hash_only:
(node.nwd / "hash").write_text(str(uuid.uuid4))
(node.nwd / "hash").write_text(str(uuid.uuid4()))
else:
node.run()
node.save(parameter=False)
Expand Down
29 changes: 24 additions & 5 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import contextlib
import dataclasses
import functools
import importlib
import logging
import pathlib
Expand Down Expand Up @@ -47,6 +48,12 @@ class NodeStatus:
rev: str = None

def get_file_system(self) -> dvc.api.DVCFileSystem:
"""Get the file system of the Node."""
log.warning("Deprecated. Use 'state.fs' instead.")
return self.fs

@functools.cached_property
def fs(self) -> dvc.api.DVCFileSystem:
"""Get the file system of the Node."""
return dvc.api.DVCFileSystem(
url=self.remote,
Expand Down Expand Up @@ -159,16 +166,26 @@ def save(self, parameter: bool = True, results: bool = True) -> None:
def run(self) -> None:
"""Run the node's code."""

def load(self, lazy: bool = None) -> None:
"""Load the node's output from disk."""
from zntrack.fields.field import Field
def load(self, lazy: bool = None, results: bool = True) -> None:
"""Load the node's output from disk.
Attributes
----------
lazy : bool, default = None
Whether to load the node lazily. If None, the value from the config is used.
results : bool, default = True
Whether to load the results. If False, only the parameters are loaded.
"""
from zntrack.fields.field import Field, FieldGroup

kwargs = {} if lazy is None else {"lazy": lazy}
self.state.loaded = True # we assume loading will be successful.
try:
with config.updated_config(**kwargs):
# TODO: it would be much nicer not to use a global config object here.
for attr in zninit.get_descriptors(Field, self=self):
if attr.group == FieldGroup.RESULT and not results:
continue
attr.load(self)
except KeyError as err:
raise exceptions.NodeNotAvailableError(self) from err
Expand All @@ -177,7 +194,9 @@ def load(self, lazy: bool = None) -> None:
self._post_load_()

@classmethod
def from_rev(cls, name=None, remote=None, rev=None, lazy: bool = None) -> Node:
def from_rev(
cls, name=None, remote=None, rev=None, lazy: bool = None, results: bool = True
) -> Node:
"""Create a Node instance from an experiment."""
node = cls.__new__(cls)
node.name = name
Expand All @@ -198,7 +217,7 @@ def from_rev(cls, name=None, remote=None, rev=None, lazy: bool = None) -> Node:

kwargs = {} if lazy is None else {"lazy": lazy}
with config.updated_config(**kwargs):
node.load()
node.load(results=results)

return node

Expand Down
2 changes: 1 addition & 1 deletion zntrack/fields/dvc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_data(self, instance: "Node") -> any:
The value of the field from the configuration file.
"""
zntrack_dict = json.loads(
instance.state.get_file_system().read_text("zntrack.json"),
instance.state.fs.read_text("zntrack.json"),
)
return json.loads(
json.dumps(zntrack_dict[instance.name][self.name]), cls=znjson.ZnDecoder
Expand Down
4 changes: 2 additions & 2 deletions zntrack/fields/meta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def save(self, instance):

def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the file."""
dvc_dict = yaml.safe_load(instance.state.get_file_system().read_text("dvc.yaml"))
dvc_dict = yaml.safe_load(instance.state.fs.read_text("dvc.yaml"))
return dvc_dict["stages"][instance.name]["meta"].get(self.name, None)

def get_stage_add_argument(self, instance) -> typing.List[tuple]:
Expand Down Expand Up @@ -72,7 +72,7 @@ def save(self, instance):

def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the file."""
env_dict = yaml.safe_load(instance.state.get_file_system().read_text("env.yaml"))
env_dict = yaml.safe_load(instance.state.fs.read_text("env.yaml"))
return env_dict.get(instance.name, {}).get(self.name, None)

def get_stage_add_argument(self, instance) -> typing.List[tuple]:
Expand Down
12 changes: 5 additions & 7 deletions zntrack/fields/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def encode(self, obj: slice) -> dict:

def decode(self, value: dict) -> znflow.Connection:
"""Create znflow.Connection object from dict."""
return slice(*value.values())
return slice(value["start"], value["stop"], value["step"])


znjson.config.register(SliceConverter)
Expand Down Expand Up @@ -137,7 +137,7 @@ def save(self, instance: "Node"):
def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the file."""
file = self.get_files(instance)[0]
params_dict = yaml.safe_load(instance.state.get_file_system().read_text(file))
params_dict = yaml.safe_load(instance.state.fs.read_text(file))
value = params_dict[instance.name].get(self.name, None)
return json.loads(json.dumps(value), cls=znjson.ZnDecoder)

Expand Down Expand Up @@ -212,7 +212,7 @@ def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the file."""
file = self.get_files(instance)[0]
return json.loads(
instance.state.get_file_system().read_text(file.as_posix()),
instance.state.fs.read_text(file.as_posix()),
cls=znjson.ZnDecoder,
)

Expand Down Expand Up @@ -259,9 +259,7 @@ def save(self, instance: "Node"):
def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the file."""
file = self.get_files(instance)[0]
return pd.read_csv(
instance.state.get_file_system().open(file.as_posix()), index_col=0
)
return pd.read_csv(instance.state.fs.open(file.as_posix()), index_col=0)

def get_stage_add_argument(self, instance) -> typing.List[tuple]:
"""Get the dvc command for this field."""
Expand Down Expand Up @@ -348,7 +346,7 @@ def save(self, instance: "Node"):
def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the file."""
zntrack_dict = json.loads(
instance.state.get_file_system().read_text("zntrack.json"),
instance.state.fs.read_text("zntrack.json"),
)
value = zntrack_dict[instance.name][self.name]

Expand Down

0 comments on commit 4f05747

Please sign in to comment.