-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create DVC
experiment on live.end
.
#366
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,12 @@ | |
import os | ||
import shutil | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Union | ||
from typing import Any, Dict, List, Optional, Set, Union | ||
|
||
from ruamel.yaml.representer import RepresenterError | ||
|
||
from . import env | ||
from .dvc import make_checkpoint | ||
from .dvc import get_dvc_repo, make_checkpoint, make_dvcyaml, random_exp_name | ||
from .error import ( | ||
InvalidDataTypeError, | ||
InvalidParameterTypeError, | ||
|
@@ -40,96 +40,138 @@ def __init__( | |
dir: str = "dvclive", # noqa pylint: disable=redefined-builtin | ||
resume: bool = False, | ||
report: Optional[str] = "auto", | ||
save_dvc_exp: bool = False, | ||
): | ||
self.summary: Dict[str, Any] = {} | ||
|
||
self._dir: str = dir | ||
self._resume: bool = resume or env2bool(env.DVCLIVE_RESUME) | ||
self._ended: bool = False | ||
self.studio_url = os.getenv(env.STUDIO_REPO_URL, None) | ||
self.studio_token = os.getenv(env.STUDIO_TOKEN, None) | ||
self.rev = None | ||
|
||
if report == "auto": | ||
if self.studio_url and self.studio_token: | ||
report = "studio" | ||
elif env2bool("CI") and matplotlib_installed(): | ||
report = "md" | ||
else: | ||
report = "html" | ||
else: | ||
if report not in {None, "html", "md"}: | ||
raise ValueError( | ||
"`report` can only be `None`, `auto`, `html` or `md`" | ||
) | ||
|
||
self.report_mode: Optional[str] = report | ||
self.report_file = "" | ||
|
||
self.summary: Dict[str, Any] = {} | ||
self._save_dvc_exp: bool = save_dvc_exp | ||
self._step: Optional[int] = None | ||
self._metrics: Dict[str, Any] = {} | ||
self._images: Dict[str, Any] = {} | ||
self._plots: Dict[str, Any] = {} | ||
self._params: Dict[str, Any] = {} | ||
self._plots: Dict[str, Any] = {} | ||
|
||
self._init_paths() | ||
os.makedirs(self.dir, exist_ok=True) | ||
|
||
if self.report_mode in ("html", "md"): | ||
if not self.report_file: | ||
self.report_file = os.path.join(self.dir, f"report.{report}") | ||
out = Path(self.report_file).resolve() | ||
logger.info(f"Report file (if generated): {out}") | ||
self._report_mode: Optional[str] = report | ||
self._init_report() | ||
|
||
if self._resume: | ||
self._read_params() | ||
self._step = self.read_step() | ||
if self._step != 0: | ||
self._step += 1 | ||
logger.info(f"Resumed from step {self._step}") | ||
self._init_resume() | ||
else: | ||
self._cleanup() | ||
self._init_cleanup() | ||
|
||
self._baseline_rev: Optional[str] = None | ||
self._exp_name: Optional[str] = None | ||
self._inside_dvc_exp: bool = False | ||
self._dvc_repo = None | ||
self._init_dvc() | ||
|
||
self._studio_url: Optional[str] = None | ||
self._studio_token: Optional[str] = None | ||
self._latest_studio_step = self.step if resume else -1 | ||
if self.report_mode == "studio": | ||
from scmrepo.git import Git | ||
self._studio_events_to_skip: Set[str] = set() | ||
self._init_studio() | ||
|
||
self.rev = Git().get_rev() | ||
def _init_resume(self): | ||
self._read_params() | ||
self._step = self.read_step() | ||
if self._step != 0: | ||
self._step += 1 | ||
logger.debug(f"{self._step=}") | ||
|
||
if not post_to_studio(self, "start", logger): | ||
logger.warning( | ||
"`post_to_studio` `start` event failed. " | ||
"`studio` report cancelled." | ||
) | ||
self.report_mode = None | ||
|
||
def _cleanup(self): | ||
def _init_cleanup(self): | ||
for plot_type in PLOT_TYPES: | ||
shutil.rmtree( | ||
Path(self.plots_dir) / plot_type.subfolder, ignore_errors=True | ||
) | ||
|
||
for f in (self.metrics_file, self.report_file, self.params_file): | ||
if os.path.exists(f): | ||
if f and os.path.exists(f): | ||
os.remove(f) | ||
|
||
def _init_paths(self): | ||
os.makedirs(self.dir, exist_ok=True) | ||
def _init_dvc(self): | ||
self._dvc_repo = get_dvc_repo() | ||
if os.getenv(env.DVC_EXP_BASELINE_REV, None): | ||
# `dvc exp` execution | ||
self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "") | ||
self._exp_name = os.getenv(env.DVC_EXP_NAME, "") | ||
self._inside_dvc_exp = True | ||
elif self._save_dvc_exp: | ||
# `Python Only` execution | ||
# TODO: How to handle `dvc repro` execution? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dberenbaum what do you think we should do if DVCLive is used inside a dvc pipeline that has been executed with I think we should skip the experiment creation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, let's skip it. We may still want to call |
||
if self._dvc_repo is not None: | ||
self._baseline_rev = self._dvc_repo.scm.get_rev() | ||
self._exp_name = random_exp_name( | ||
self._dvc_repo, self._baseline_rev | ||
) | ||
make_dvcyaml(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not as sure that we should skip There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another follow up: pass an experiment name to use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And another: consider whether to include the dvclive user script/notebook in the tracked files. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do you mean to include it in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, whether to include it in |
||
|
||
def _init_studio(self): | ||
if not self._dvc_repo: | ||
logger.warning("`studio` report can't be used without a DVC Repo.") | ||
return | ||
|
||
self._studio_url = os.getenv(env.STUDIO_REPO_URL, None) | ||
self._studio_token = os.getenv(env.STUDIO_TOKEN, None) | ||
|
||
if self._studio_url and self._studio_token: | ||
if self._inside_dvc_exp: | ||
logger.debug( | ||
"Skipping `post_to_studio` `start` and `done` events." | ||
) | ||
self._studio_events_to_skip.add("start") | ||
self._studio_events_to_skip.add("done") | ||
elif not post_to_studio(self, "start", logger): | ||
logger.warning( | ||
"`post_to_studio` `start` event failed. " | ||
"`studio` report cancelled." | ||
) | ||
self._studio_events_to_skip.add("start") | ||
self._studio_events_to_skip.add("data") | ||
self._studio_events_to_skip.add("done") | ||
logger.debug("Skipping `studio` report.") | ||
|
||
def _init_report(self): | ||
if self._report_mode == "auto": | ||
if env2bool("CI") and matplotlib_installed(): | ||
self._report_mode = "md" | ||
else: | ||
self._report_mode = "html" | ||
elif self._report_mode not in {None, "html", "md"}: | ||
raise ValueError( | ||
"`report` can only be `None`, `auto`, `html` or `md`" | ||
) | ||
logger.debug(f"{self._report_mode=}") | ||
|
||
@property | ||
def dir(self): | ||
def dir(self) -> str: | ||
return self._dir | ||
|
||
@property | ||
def params_file(self): | ||
def params_file(self) -> str: | ||
return os.path.join(self.dir, "params.yaml") | ||
|
||
@property | ||
def metrics_file(self): | ||
def metrics_file(self) -> str: | ||
return os.path.join(self.dir, "metrics.json") | ||
|
||
@property | ||
def plots_dir(self): | ||
def dvc_file(self) -> str: | ||
return os.path.join(self.dir, "dvc.yaml") | ||
|
||
@property | ||
def plots_dir(self) -> str: | ||
return os.path.join(self.dir, "plots") | ||
|
||
@property | ||
def report_file(self) -> Optional[str]: | ||
if self._report_mode in ("html", "md"): | ||
return os.path.join(self.dir, f"report.{self._report_mode}") | ||
return None | ||
|
||
@property | ||
def step(self) -> int: | ||
return self._step or 0 | ||
|
@@ -227,29 +269,43 @@ def make_summary(self): | |
dump_json(self.summary, self.metrics_file, cls=NumpyEncoder) | ||
|
||
def make_report(self): | ||
if self.report_mode == "studio": | ||
if ( | ||
self._studio_url | ||
and self._studio_token | ||
and "data" not in self._studio_events_to_skip | ||
): | ||
if not post_to_studio(self, "data", logger): | ||
logger.warning( | ||
"`post_to_studio` `data` event failed." | ||
" Data will be resent on next call." | ||
) | ||
else: | ||
self._latest_studio_step = self.step | ||
elif self.report_mode is not None: | ||
|
||
if self._report_mode is not None: | ||
make_report(self) | ||
if self.report_mode == "html" and env2bool(env.DVCLIVE_OPEN): | ||
if self._report_mode == "html" and env2bool(env.DVCLIVE_OPEN): | ||
open_file_in_browser(self.report_file) | ||
|
||
def end(self): | ||
self.make_summary() | ||
if self.report_mode == "studio": | ||
if not self._ended: | ||
if self._studio_url and self._studio_token: | ||
if "done" not in self._studio_events_to_skip: | ||
if not post_to_studio(self, "done", logger): | ||
logger.warning("`post_to_studio` `done` event failed.") | ||
self._ended = True | ||
self._studio_events_to_skip.add("done") | ||
else: | ||
self.make_report() | ||
|
||
if ( | ||
self._dvc_repo is not None | ||
and not self._inside_dvc_exp | ||
and self._save_dvc_exp | ||
): | ||
self._dvc_repo.experiments.save( | ||
name=self._exp_name, include_untracked=self.dir | ||
) | ||
|
||
def make_checkpoint(self): | ||
if env2bool(env.DVC_CHECKPOINT): | ||
make_checkpoint() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dberenbaum named the option like this. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it makes sense for the
dvc.yaml
saving part and whether those should be coupled together, but I think it's fine for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I though your concerns applied to both the git ref and the
dvc.yaml
creation