Skip to content

Commit

Permalink
fix #116
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Jan 20, 2022
1 parent 70ce174 commit 653cd27
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 6 deletions.
64 changes: 64 additions & 0 deletions tests/integration_tests/test_zn_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import shutil
import subprocess

import pandas as pd
import pytest

from zntrack import Node, zn


@pytest.fixture
def proj_path(tmp_path):
shutil.copy(__file__, tmp_path)
os.chdir(tmp_path)
subprocess.check_call(["git", "init"])
subprocess.check_call(["dvc", "init"])

return tmp_path


class WritePlots(Node):
plots: pd.DataFrame = zn.plots()

def run(self):
self.plots = pd.DataFrame({"value": [x for x in range(100)]})
self.plots.index.name = "index"


class WritePlotsNoIndex(Node):
plots: pd.DataFrame = zn.plots()

def run(self):
self.plots = pd.DataFrame({"value": [x for x in range(100)]})


class WritePlotsWrongData(Node):
plots: pd.DataFrame = zn.plots()

def run(self):
self.plots = {"value": [x for x in range(100)]}


def test_write_plots(proj_path):
WritePlots().write_graph(no_exec=False)
subprocess.check_call(["dvc", "plots", "show"])


def test_load_plots(proj_path):
WritePlots().write_graph(no_exec=False)

df = pd.DataFrame({"value": [x for x in range(100)]})
df.index.name = "index"

assert df.equals(WritePlots.load().plots)


def test_write_plots_value_error(proj_path):
with pytest.raises(ValueError):
WritePlotsNoIndex().run_and_save()


def test_write_plots_type_error(proj_path):
with pytest.raises(TypeError):
WritePlotsWrongData().run_and_save()
7 changes: 1 addition & 6 deletions zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,7 @@ def write_graph(
# Handle Zn Options
elif option.metadata.zntrack_type in ["zn", "metadata"]:
zn_options_set.add(
(
f"--{option.metadata.dvc_args}",
pathlib.Path("nodes")
/ self.node_name
/ f"{option.metadata.dvc_option}.json",
)
(f"--{option.metadata.dvc_args}", option.get_filename(self).path)
)
elif option.metadata.zntrack_type == "deps":
script += handle_deps(value)
Expand Down
7 changes: 7 additions & 0 deletions zntrack/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

log = logging.getLogger(__name__)

try:
from .plots import plots

__all__ = [plots.__name__]
except ImportError:
pass


# module class definitions to be used via zn.<option>
# detailed explanations on https://dvc.org/doc/command-reference/run#options
Expand Down
52 changes: 52 additions & 0 deletions zntrack/zn/plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
import pathlib

import pandas as pd

from zntrack.core.parameter import File, ZnTrackOption
from zntrack.descriptor import Metadata

log = logging.getLogger(__name__)


class plots(ZnTrackOption):
metadata = Metadata(dvc_option="plots_no_cache", zntrack_type="zn")

def get_filename(self, instance) -> File:
"""Overwrite filename to csv"""
return File(
path=pathlib.Path(
"nodes", instance.node_name, f"{self.metadata.dvc_option}.csv"
),
tracked=True,
)

def save(self, instance):
"""Save value with pd.DataFrame.to_csv"""
value = self.__get__(instance, self.owner)

if value is None:
return

if not isinstance(value, pd.DataFrame):
raise TypeError(
f"zn.plots() only supports <pd.DataFrame> and not {type(value)}"
)

if value.index.name is None:
raise ValueError(
"pd.DataFrame must have an index name! You can set the name via"
" DataFrame.index.name = <index name>."
)

file = self.get_filename(instance)
file.path.parent.mkdir(exist_ok=True, parents=True)
value.to_csv(file.path)

def load(self, instance):
"""Load value with pd.read_csv"""
file = self.get_filename(instance)
try:
instance.__dict__.update({self.name: pd.read_csv(file.path, index_col=0)})
except (FileNotFoundError, KeyError):
pass

0 comments on commit 653cd27

Please sign in to comment.