Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create DVC experiment on live.end. #366

Merged
merged 2 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion src/dvclive/dvc.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/dvclive/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
STUDIO_ENDPOINT = "STUDIO_ENDPOINT"
STUDIO_REPO_URL = "STUDIO_REPO_URL"
STUDIO_TOKEN = "STUDIO_TOKEN" # nosec B105
DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV"
DVC_EXP_NAME = "DVC_EXP_NAME"
7 changes: 6 additions & 1 deletion src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ def __init__(
dir: Optional[str] = None, # noqa pylint: disable=redefined-builtin
resume: bool = False,
report: Optional[str] = "auto",
save_dvc_exp: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dberenbaum named the option like this. wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it makes sense for the dvc.yaml saving part and whether those should be coupled together, but I think it's fine for now.

Copy link
Contributor Author

@daavoo daavoo Dec 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though your concerns applied to both the git ref and the dvc.yaml creation

):

super().__init__()
self._prefix = prefix
self._live_init: Dict[str, Any] = {"resume": resume, "report": report}
self._live_init: Dict[str, Any] = {
"resume": resume,
"report": report,
"save_dvc_exp": save_dvc_exp,
}
if dir is not None:
self._live_init["dir"] = dir
self._experiment = experiment
Expand Down
178 changes: 117 additions & 61 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union

from ruamel.yaml.representer import RepresenterError

from . import env
from .dvc import make_checkpoint
from .dvc import get_dvc_repo, make_checkpoint, make_dvcyaml, random_exp_name
from .error import (
InvalidDataTypeError,
InvalidParameterTypeError,
Expand Down Expand Up @@ -40,96 +40,138 @@ def __init__(
dir: str = "dvclive", # noqa pylint: disable=redefined-builtin
resume: bool = False,
report: Optional[str] = "auto",
save_dvc_exp: bool = False,
):
self.summary: Dict[str, Any] = {}

self._dir: str = dir
self._resume: bool = resume or env2bool(env.DVCLIVE_RESUME)
self._ended: bool = False
self.studio_url = os.getenv(env.STUDIO_REPO_URL, None)
self.studio_token = os.getenv(env.STUDIO_TOKEN, None)
self.rev = None

if report == "auto":
if self.studio_url and self.studio_token:
report = "studio"
elif env2bool("CI") and matplotlib_installed():
report = "md"
else:
report = "html"
else:
if report not in {None, "html", "md"}:
raise ValueError(
"`report` can only be `None`, `auto`, `html` or `md`"
)

self.report_mode: Optional[str] = report
self.report_file = ""

self.summary: Dict[str, Any] = {}
self._save_dvc_exp: bool = save_dvc_exp
self._step: Optional[int] = None
self._metrics: Dict[str, Any] = {}
self._images: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._params: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}

self._init_paths()
os.makedirs(self.dir, exist_ok=True)

if self.report_mode in ("html", "md"):
if not self.report_file:
self.report_file = os.path.join(self.dir, f"report.{report}")
out = Path(self.report_file).resolve()
logger.info(f"Report file (if generated): {out}")
self._report_mode: Optional[str] = report
self._init_report()

if self._resume:
self._read_params()
self._step = self.read_step()
if self._step != 0:
self._step += 1
logger.info(f"Resumed from step {self._step}")
self._init_resume()
else:
self._cleanup()
self._init_cleanup()

self._baseline_rev: Optional[str] = None
self._exp_name: Optional[str] = None
self._inside_dvc_exp: bool = False
self._dvc_repo = None
self._init_dvc()

self._studio_url: Optional[str] = None
self._studio_token: Optional[str] = None
self._latest_studio_step = self.step if resume else -1
if self.report_mode == "studio":
from scmrepo.git import Git
self._studio_events_to_skip: Set[str] = set()
self._init_studio()

self.rev = Git().get_rev()
def _init_resume(self):
self._read_params()
self._step = self.read_step()
if self._step != 0:
self._step += 1
logger.debug(f"{self._step=}")

if not post_to_studio(self, "start", logger):
logger.warning(
"`post_to_studio` `start` event failed. "
"`studio` report cancelled."
)
self.report_mode = None

def _cleanup(self):
def _init_cleanup(self):
for plot_type in PLOT_TYPES:
shutil.rmtree(
Path(self.plots_dir) / plot_type.subfolder, ignore_errors=True
)

for f in (self.metrics_file, self.report_file, self.params_file):
if os.path.exists(f):
if f and os.path.exists(f):
os.remove(f)

def _init_paths(self):
os.makedirs(self.dir, exist_ok=True)
def _init_dvc(self):
self._dvc_repo = get_dvc_repo()
if os.getenv(env.DVC_EXP_BASELINE_REV, None):
# `dvc exp` execution
self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "")
self._exp_name = os.getenv(env.DVC_EXP_NAME, "")
self._inside_dvc_exp = True
elif self._save_dvc_exp:
# `Python Only` execution
# TODO: How to handle `dvc repro` execution?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dberenbaum what do you think we should do if DVCLive is used inside a dvc pipeline that has been executed with dvc repro?

I think we should skip the experiment creation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's skip it. We may still want to call make_dvcyaml, but we can skip that also for now if it simplifies things.

if self._dvc_repo is not None:
self._baseline_rev = self._dvc_repo.scm.get_rev()
self._exp_name = random_exp_name(
self._dvc_repo, self._baseline_rev
)
make_dvcyaml(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not as sure that we should skip make_dvcyaml, but we can follow up on this separate from this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another follow up: pass an experiment name to use.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another: consider whether to include the dvclive user script/notebook in the tracked files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another: consider whether to include the dvclive user script/notebook in the tracked files.

Do you mean to include it in the include_tracked list passed to experiments save or other thing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, whether to include it in include_tracked. I'm not sure it's needed, so it was more of a discussion point than a request. There's something clean about the default being to only save the live dir.


def _init_studio(self):
if not self._dvc_repo:
logger.warning("`studio` report can't be used without a DVC Repo.")
return

self._studio_url = os.getenv(env.STUDIO_REPO_URL, None)
self._studio_token = os.getenv(env.STUDIO_TOKEN, None)

if self._studio_url and self._studio_token:
if self._inside_dvc_exp:
logger.debug(
"Skipping `post_to_studio` `start` and `done` events."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("done")
elif not post_to_studio(self, "start", logger):
logger.warning(
"`post_to_studio` `start` event failed. "
"`studio` report cancelled."
)
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("data")
self._studio_events_to_skip.add("done")
logger.debug("Skipping `studio` report.")

def _init_report(self):
if self._report_mode == "auto":
if env2bool("CI") and matplotlib_installed():
self._report_mode = "md"
else:
self._report_mode = "html"
elif self._report_mode not in {None, "html", "md"}:
raise ValueError(
"`report` can only be `None`, `auto`, `html` or `md`"
)
logger.debug(f"{self._report_mode=}")

@property
def dir(self):
def dir(self) -> str:
return self._dir

@property
def params_file(self):
def params_file(self) -> str:
return os.path.join(self.dir, "params.yaml")

@property
def metrics_file(self):
def metrics_file(self) -> str:
return os.path.join(self.dir, "metrics.json")

@property
def plots_dir(self):
def dvc_file(self) -> str:
return os.path.join(self.dir, "dvc.yaml")

@property
def plots_dir(self) -> str:
return os.path.join(self.dir, "plots")

@property
def report_file(self) -> Optional[str]:
if self._report_mode in ("html", "md"):
return os.path.join(self.dir, f"report.{self._report_mode}")
return None

@property
def step(self) -> int:
return self._step or 0
Expand Down Expand Up @@ -227,29 +269,43 @@ def make_summary(self):
dump_json(self.summary, self.metrics_file, cls=NumpyEncoder)

def make_report(self):
if self.report_mode == "studio":
if (
self._studio_url
and self._studio_token
and "data" not in self._studio_events_to_skip
):
if not post_to_studio(self, "data", logger):
logger.warning(
"`post_to_studio` `data` event failed."
" Data will be resent on next call."
)
else:
self._latest_studio_step = self.step
elif self.report_mode is not None:

if self._report_mode is not None:
make_report(self)
if self.report_mode == "html" and env2bool(env.DVCLIVE_OPEN):
if self._report_mode == "html" and env2bool(env.DVCLIVE_OPEN):
open_file_in_browser(self.report_file)

def end(self):
self.make_summary()
if self.report_mode == "studio":
if not self._ended:
if self._studio_url and self._studio_token:
if "done" not in self._studio_events_to_skip:
if not post_to_studio(self, "done", logger):
logger.warning("`post_to_studio` `done` event failed.")
self._ended = True
self._studio_events_to_skip.add("done")
else:
self.make_report()

if (
self._dvc_repo is not None
and not self._inside_dvc_exp
and self._save_dvc_exp
):
self._dvc_repo.experiments.save(
name=self._exp_name, include_untracked=self.dir
)

def make_checkpoint(self):
if env2bool(env.DVC_CHECKPOINT):
make_checkpoint()
Expand Down
6 changes: 3 additions & 3 deletions src/dvclive/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def make_report(live: "Live"):
get_plot_renderers(plots_path / SKLearnPlot.subfolder, live)
)

if live.report_mode == "html":
if live._report_mode == "html":
render_html(renderers, live.report_file, refresh_seconds=5)
elif live.report_mode == "md":
elif live._report_mode == "md":
render_markdown(renderers, live.report_file)
else:
raise ValueError(f"Invalid `mode` {live.report_mode}.")
raise ValueError(f"Invalid `mode` {live._report_mode}.")
7 changes: 4 additions & 3 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access
from os import getenv

from dvclive.env import STUDIO_ENDPOINT
Expand Down Expand Up @@ -46,8 +47,8 @@ def post_to_studio(live, event_type, logger) -> bool:

data = {
"type": event_type,
"repo_url": live.studio_url,
"rev": live.rev,
"repo_url": live._studio_url,
"rev": live._baseline_rev,
"client": "dvclive",
}

Expand All @@ -65,7 +66,7 @@ def post_to_studio(live, event_type, logger) -> bool:
json=data,
headers={
"Content-type": "application/json",
"Authorization": f"token {live.studio_token}",
"Authorization": f"token {live._studio_token}",
},
timeout=5,
)
Expand Down
Loading