Skip to content

Commit ab23bcd

Browse files
authored
experiments: support dvc repro --exp command line params (#4331)
* experiments: support passing params values instead of reading from user's workspace * repro: add --params option for experiments * support toml params * update invalid param error message * update tests
1 parent 4759517 commit ab23bcd

File tree

6 files changed

+131
-16
lines changed

6 files changed

+131
-16
lines changed

dvc/command/repro.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def run(self):
4444
queue=self.args.queue,
4545
run_all=self.args.run_all,
4646
jobs=self.args.jobs,
47+
params=self.args.params,
4748
)
4849

4950
if len(stages) == 0:
@@ -177,6 +178,13 @@ def add_parser(subparsers, parent_parser):
177178
default=False,
178179
help=argparse.SUPPRESS,
179180
)
181+
repro_parser.add_argument(
182+
"--params",
183+
action="append",
184+
default=[],
185+
help="Declare parameter values for an experiment.",
186+
metavar="[<filename>:]<params_list>",
187+
)
180188
repro_parser.add_argument(
181189
"--queue", action="store_true", default=False, help=argparse.SUPPRESS
182190
)

dvc/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ def __init__(self, path):
194194
)
195195

196196

197+
class TOMLFileCorruptedError(DvcException):
198+
def __init__(self, path):
199+
path = relpath(path)
200+
super().__init__(
201+
f"unable to read: '{path}', TOML file structure is corrupted"
202+
)
203+
204+
197205
class RecursiveAddingWhileUsingFilename(DvcException):
198206
def __init__(self):
199207
super().__init__(

dvc/repo/experiments/__init__.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import os
33
import re
44
import tempfile
5+
from collections import defaultdict
6+
from collections.abc import Mapping
57
from concurrent.futures import ProcessPoolExecutor, as_completed
68
from contextlib import contextmanager
79
from typing import Iterable, Optional
810

911
from funcy import cached_property
1012

1113
from dvc.exceptions import DvcException
14+
from dvc.path_info import PathInfo
1215
from dvc.repo.experiments.executor import ExperimentExecutor, LocalExecutor
1316
from dvc.scm.git import Git
1417
from dvc.stage.serialize import to_lockfile
@@ -139,21 +142,39 @@ def _scm_checkout(self, rev):
139142
logger.debug("Checking out experiment commit '%s'", rev)
140143
self.scm.checkout(rev)
141144

142-
def _stash_exp(self, *args, **kwargs):
145+
def _stash_exp(self, *args, params: Optional[dict] = None, **kwargs):
143146
"""Stash changes from the current (parent) workspace as an experiment.
147+
148+
Args:
149+
params: Optional dictionary of parameter values to be used.
150+
Values take priority over any parameters specified in the
151+
user's workspace.
144152
"""
145153
rev = self.scm.get_rev()
154+
155+
# patch user's workspace into experiments clone
146156
tmp = tempfile.NamedTemporaryFile(delete=False).name
147157
try:
148158
self.repo.scm.repo.git.diff(patch=True, output=tmp)
149159
if os.path.getsize(tmp):
150160
logger.debug("Patching experiment workspace")
151161
self.scm.repo.git.apply(tmp)
152-
else:
162+
elif not params:
163+
# experiment matches original baseline
153164
raise UnchangedExperimentError(rev)
154165
finally:
155166
remove(tmp)
167+
168+
# update experiment params from command line
169+
if params:
170+
self._update_params(params)
171+
172+
# save additional repro command line arguments
156173
self._pack_args(*args, **kwargs)
174+
175+
# save experiment as a stash commit w/message containing baseline rev
176+
# (stash commits are merge commits and do not contain a parent commit
177+
# SHA)
157178
msg = f"{self.STASH_MSG_PREFIX}{rev}"
158179
self.scm.repo.git.stash("push", "-m", msg)
159180
return self.scm.resolve_rev("stash@{0}")
@@ -166,6 +187,36 @@ def _unpack_args(self, tree=None):
166187
args_file = os.path.join(self.exp_dvc.tmp_dir, self.PACKED_ARGS_FILE)
167188
return ExperimentExecutor.unpack_repro_args(args_file, tree=tree)
168189

190+
def _update_params(self, params: dict):
191+
"""Update experiment params files with the specified values."""
192+
from dvc.utils.toml import dump_toml, parse_toml_for_update
193+
from dvc.utils.yaml import dump_yaml, parse_yaml_for_update
194+
195+
logger.debug("Using experiment params '%s'", params)
196+
197+
# recursive dict update
198+
def _update(dict_, other):
199+
for key, value in other.items():
200+
if isinstance(value, Mapping):
201+
dict_[key] = _update(dict_.get(key, {}), value)
202+
else:
203+
dict_[key] = value
204+
return dict_
205+
206+
loaders = defaultdict(lambda: parse_yaml_for_update)
207+
loaders.update({".toml": parse_toml_for_update})
208+
dumpers = defaultdict(lambda: dump_yaml)
209+
dumpers.update({".toml": dump_toml})
210+
211+
for params_fname in params:
212+
path = PathInfo(self.exp_dvc.root_dir) / params_fname
213+
with self.exp_dvc.tree.open(path, "r") as fobj:
214+
text = fobj.read()
215+
suffix = path.suffix.lower()
216+
data = loaders[suffix](text, path)
217+
_update(data, params[params_fname])
218+
dumpers[suffix](path, data)
219+
169220
def _commit(self, exp_hash, check_exists=True, branch=True):
170221
"""Commit stages as an experiment and return the commit SHA."""
171222
if not self.scm.is_dirty():
@@ -207,23 +258,19 @@ def reproduce_queued(self, **kwargs):
207258
)
208259
return results
209260

210-
def new(self, *args, workspace=True, **kwargs):
261+
def new(self, *args, **kwargs):
211262
"""Create a new experiment.
212263
213264
Experiment will be reproduced and checked out into the user's
214265
workspace.
215266
"""
216267
rev = self.repo.scm.get_rev()
217268
self._scm_checkout(rev)
218-
if workspace:
219-
try:
220-
stash_rev = self._stash_exp(*args, **kwargs)
221-
except UnchangedExperimentError as exc:
222-
logger.info("Reproducing existing experiment '%s'.", rev[:7])
223-
raise exc
224-
else:
225-
# configure params via command line here
226-
pass
269+
try:
270+
stash_rev = self._stash_exp(*args, **kwargs)
271+
except UnchangedExperimentError as exc:
272+
logger.info("Reproducing existing experiment '%s'.", rev[:7])
273+
raise exc
227274
logger.debug(
228275
"Stashed experiment '%s' for future execution.", stash_rev[:7]
229276
)
@@ -365,8 +412,10 @@ def checkout_exp(self, rev):
365412
tmp = tempfile.NamedTemporaryFile(delete=False).name
366413
self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp)
367414

368-
logger.debug("Stashing workspace changes.")
369-
self.repo.scm.repo.git.stash("push")
415+
dirty = self.repo.scm.is_dirty()
416+
if dirty:
417+
logger.debug("Stashing workspace changes.")
418+
self.repo.scm.repo.git.stash("push")
370419

371420
try:
372421
if os.path.getsize(tmp):
@@ -379,7 +428,8 @@ def checkout_exp(self, rev):
379428
raise DvcException("failed to apply experiment changes.")
380429
finally:
381430
remove(tmp)
382-
self._unstash_workspace()
431+
if dirty:
432+
self._unstash_workspace()
383433

384434
if need_checkout:
385435
dvc_checkout(self.repo)

dvc/repo/reproduce.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def reproduce(
6060
recursive=False,
6161
pipeline=False,
6262
all_pipelines=False,
63-
**kwargs
63+
**kwargs,
6464
):
6565
from dvc.utils import parse_target
6666

@@ -71,6 +71,7 @@ def reproduce(
7171
)
7272

7373
experiment = kwargs.pop("experiment", False)
74+
params = _parse_params(kwargs.pop("params", []))
7475
queue = kwargs.pop("queue", False)
7576
run_all = kwargs.pop("run_all", False)
7677
jobs = kwargs.pop("jobs", 1)
@@ -81,6 +82,7 @@ def reproduce(
8182
target=target,
8283
recursive=recursive,
8384
all_pipelines=all_pipelines,
85+
params=params,
8486
queue=queue,
8587
run_all=run_all,
8688
jobs=jobs,
@@ -116,6 +118,31 @@ def reproduce(
116118
return _reproduce_stages(active_graph, targets, **kwargs)
117119

118120

121+
def _parse_params(path_params):
122+
from flatten_json import unflatten
123+
from yaml import safe_load, YAMLError
124+
from dvc.dependency.param import ParamsDependency
125+
126+
ret = {}
127+
for path_param in path_params:
128+
path, _, params_str = path_param.rpartition(":")
129+
# remove empty strings from params, on condition such as `-p "file1:"`
130+
params = {}
131+
for param_str in filter(bool, params_str.split(",")):
132+
try:
133+
# interpret value strings using YAML rules
134+
key, value = param_str.split("=")
135+
params[key] = safe_load(value)
136+
except (ValueError, YAMLError):
137+
raise InvalidArgumentError(
138+
f"Invalid param/value pair '{param_str}'"
139+
)
140+
if not path:
141+
path = ParamsDependency.DEFAULT_PARAMS_FILE
142+
ret[path] = unflatten(params, ".")
143+
return ret
144+
145+
119146
def _reproduce_experiments(repo, run_all=False, jobs=1, **kwargs):
120147
if run_all:
121148
return repo.experiments.reproduce_queued(jobs=jobs)

dvc/utils/toml.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import toml
2+
3+
from dvc.exceptions import TOMLFileCorruptedError
4+
5+
6+
def parse_toml_for_update(text, path):
7+
"""Parses text into Python structure.
8+
9+
NOTE: Python toml package does not currently use ordered dicts, so
10+
keys may be re-ordered between load/dump, but this function will at
11+
least preserve comments.
12+
"""
13+
try:
14+
return toml.loads(text, decoder=toml.TomlPreserveCommentDecoder())
15+
except toml.TomlDecodeError as exc:
16+
raise TOMLFileCorruptedError(path) from exc
17+
18+
19+
def dump_toml(path, data):
20+
with open(path, "w", encoding="utf-8") as fobj:
21+
toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder())

tests/unit/command/test_repro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"recursive": False,
1616
"force_downstream": False,
1717
"experiment": False,
18+
"params": [],
1819
"queue": False,
1920
"run_all": False,
2021
"jobs": None,

0 commit comments

Comments
 (0)