-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
124 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
import shutil | ||
import subprocess | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from zntrack import Node, zn | ||
|
||
|
||
@pytest.fixture | ||
def proj_path(tmp_path): | ||
shutil.copy(__file__, tmp_path) | ||
os.chdir(tmp_path) | ||
subprocess.check_call(["git", "init"]) | ||
subprocess.check_call(["dvc", "init"]) | ||
|
||
return tmp_path | ||
|
||
|
||
class WritePlots(Node): | ||
plots: pd.DataFrame = zn.plots() | ||
|
||
def run(self): | ||
self.plots = pd.DataFrame({"value": [x for x in range(100)]}) | ||
self.plots.index.name = "index" | ||
|
||
|
||
class WritePlotsNoIndex(Node): | ||
plots: pd.DataFrame = zn.plots() | ||
|
||
def run(self): | ||
self.plots = pd.DataFrame({"value": [x for x in range(100)]}) | ||
|
||
|
||
class WritePlotsWrongData(Node): | ||
plots: pd.DataFrame = zn.plots() | ||
|
||
def run(self): | ||
self.plots = {"value": [x for x in range(100)]} | ||
|
||
|
||
def test_write_plots(proj_path): | ||
WritePlots().write_graph(no_exec=False) | ||
subprocess.check_call(["dvc", "plots", "show"]) | ||
|
||
|
||
def test_load_plots(proj_path): | ||
WritePlots().write_graph(no_exec=False) | ||
|
||
df = pd.DataFrame({"value": [x for x in range(100)]}) | ||
df.index.name = "index" | ||
|
||
assert df.equals(WritePlots.load().plots) | ||
|
||
|
||
def test_write_plots_value_error(proj_path): | ||
with pytest.raises(ValueError): | ||
WritePlotsNoIndex().run_and_save() | ||
|
||
|
||
def test_write_plots_type_error(proj_path): | ||
with pytest.raises(TypeError): | ||
WritePlotsWrongData().run_and_save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import logging | ||
import pathlib | ||
|
||
import pandas as pd | ||
|
||
from zntrack.core.parameter import File, ZnTrackOption | ||
from zntrack.descriptor import Metadata | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class plots(ZnTrackOption): | ||
metadata = Metadata(dvc_option="plots_no_cache", zntrack_type="zn") | ||
|
||
def get_filename(self, instance) -> File: | ||
"""Overwrite filename to csv""" | ||
return File( | ||
path=pathlib.Path( | ||
"nodes", instance.node_name, f"{self.metadata.dvc_option}.csv" | ||
), | ||
tracked=True, | ||
) | ||
|
||
def save(self, instance): | ||
"""Save value with pd.DataFrame.to_csv""" | ||
value = self.__get__(instance, self.owner) | ||
|
||
if value is None: | ||
return | ||
|
||
if not isinstance(value, pd.DataFrame): | ||
raise TypeError( | ||
f"zn.plots() only supports <pd.DataFrame> and not {type(value)}" | ||
) | ||
|
||
if value.index.name is None: | ||
raise ValueError( | ||
"pd.DataFrame must have an index name! You can set the name via" | ||
" DataFrame.index.name = <index name>." | ||
) | ||
|
||
file = self.get_filename(instance) | ||
file.path.parent.mkdir(exist_ok=True, parents=True) | ||
value.to_csv(file.path) | ||
|
||
def load(self, instance): | ||
"""Load value with pd.read_csv""" | ||
file = self.get_filename(instance) | ||
try: | ||
instance.__dict__.update({self.name: pd.read_csv(file.path, index_col=0)}) | ||
except (FileNotFoundError, KeyError): | ||
pass |