Skip to content

Commit

Permalink
add save_plots and introduce new ZnTypes.PLOTS (#269)
Browse files Browse the repository at this point in the history
* add `save_plots` and introduce new `ZnTypes.PLOTS`

* Update test_zn_plots.py

* Update setup.py

* Update __init__.py

* Update test_zn_plots.py
  • Loading branch information
PythonFZ authored Apr 20, 2022
1 parent b40c936 commit e3f9dfd
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setuptools.setup(
name="zntrack",
version="0.4.0",
version="0.4.1",
author="zincwarecode",
author_email="zincwarecode@gmail.com",
description="A Python package for parameter and data version control with DVC",
Expand Down
8 changes: 8 additions & 0 deletions tests/integration_tests/test_zn_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def test_write_plots_type_error(proj_path):
WritePlotsWrongData().run_and_save()


def test_save_plots(proj_path):
write_plots = WritePlots()
write_plots.run()
write_plots.save_plots()

assert pathlib.Path("nodes/WritePlots/plots.csv").exists()


class WriteTwoPlots(Node):
plots_a: pd.DataFrame = zn.plots()
plots_b: pd.DataFrame = zn.plots()
Expand Down
2 changes: 1 addition & 1 deletion zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"getdeps",
]

__version__ = "0.4.0"
__version__ = "0.4.1"

logger = logging.getLogger(__name__)
logger.setLevel(config.log_level)
Expand Down
10 changes: 10 additions & 0 deletions zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def __init_subclass__(cls, **kwargs):
signature = inspect.Signature(parameters=signature_params)
setattr(cls, "__signature__", signature)

def save_plots(self):
"""Save the zn.plots
Similar to DVC Live this can be used to save the plots during a run
for live output.
"""
for option in self._descriptor_list:
if option.zn_type is utils.ZnTypes.PLOTS:
option.save(instance=self)

def save(self, results: bool = False):
"""Save Class state to files
Expand Down
5 changes: 1 addition & 4 deletions zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,7 @@ def write_graph(
value = getattr(self, option.name)
custom_args += handle_dvc(value, option.dvc_args)
# Handle Zn Options
elif option.zn_type in [
utils.ZnTypes.RESULTS,
utils.ZnTypes.METADATA,
]:
elif option.zn_type in utils.VALUE_DVC_TRACKED:
zn_options_set.add(
(
f"--{option.dvc_args}",
Expand Down
3 changes: 2 additions & 1 deletion zntrack/utils/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ class ZnTypes(enum.Enum):
PARAMS = enum.auto()
ITERABLE = enum.auto()
RESULTS = enum.auto()
PLOTS = enum.auto()


FILE_DVC_TRACKED = [ZnTypes.DVC]
# if the getattr(instance, self.name) is an affected file,
# e.g. the dvc.<outs> is a file / list of files
VALUE_DVC_TRACKED = [ZnTypes.RESULTS, ZnTypes.METADATA]
VALUE_DVC_TRACKED = [ZnTypes.RESULTS, ZnTypes.METADATA, ZnTypes.PLOTS]


# if the internal file,
Expand Down
2 changes: 1 addition & 1 deletion zntrack/zn/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class plots(PlotsModifyOption):
dvc_option = utils.DVCOptions.PLOTS_NO_CACHE.value
zn_type = utils.ZnTypes.RESULTS
zn_type = utils.ZnTypes.PLOTS

def get_filename(self, instance) -> pathlib.Path:
"""Overwrite filename to csv"""
Expand Down

0 comments on commit e3f9dfd

Please sign in to comment.