Skip to content

Commit

Permalink
exp run: Separate exp param parsing from other commands (#5310)
Browse files Browse the repository at this point in the history
* Separate exp param parsing from other commands

* Update exp run params metavar

* Updated experiments run param argument name
  • Loading branch information
sjawhar authored Feb 16, 2021
1 parent ae8efde commit 66f24b8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
9 changes: 5 additions & 4 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def run(self):
queue=self.args.queue,
run_all=self.args.run_all,
jobs=self.args.jobs,
params=self.args.params,
params=self.args.set_param,
checkpoint_resume=self.args.checkpoint_resume,
reset=self.args.reset,
tmp_dir=self.args.tmp_dir,
Expand Down Expand Up @@ -1184,11 +1184,12 @@ def _add_run_common(parser):
metavar="<name>",
)
parser.add_argument(
"--params",
"-S",
"--set-param",
action="append",
default=[],
help="Use the specified param values when reproducing pipelines.",
metavar="[<filename>:]<params_list>",
help="Use the specified param value when reproducing pipelines.",
metavar="[<filename>:]<param_name>=<param_value>",
)
parser.add_argument(
"--queue",
Expand Down
4 changes: 2 additions & 2 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Iterable, Optional

from dvc.repo import locked
from dvc.utils.cli_parse import loads_params
from dvc.utils.cli_parse import loads_param_overrides

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,7 +33,7 @@ def run(
return repo.experiments.reproduce_queued(jobs=jobs)

if params:
params = loads_params(params)
params = loads_param_overrides(params)
return repo.experiments.reproduce_one(
targets=targets, params=params, tmp_dir=tmp_dir, **kwargs
)
40 changes: 22 additions & 18 deletions dvc/utils/cli_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,35 @@ def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]:
return [{path: params} for path, params in ret.items()]


def loads_params(path_params: Iterable[str],) -> Dict[str, Dict[str, Any]]:

def loads_param_overrides(
path_params: Iterable[str],
) -> Dict[str, Dict[str, Any]]:
"""Loads the content of params from the cli as Python object."""
from ruamel.yaml import YAMLError

from dvc.dependency.param import ParamsDependency
from dvc.exceptions import InvalidArgumentError

from .serialize import loads_yaml

normalized_params = parse_params(path_params)
ret: Dict[str, Dict[str, Any]] = defaultdict(dict)

for part in normalized_params:
assert part
(item,) = part.items()
path, param_keys = item

for param_str in param_keys:
try:
key, _, value = param_str.partition("=")
# interpret value strings using YAML rules
parsed = loads_yaml(value)
ret[path][key] = parsed
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid param/value pair '{param_str}'"
)
for path_param in path_params:
param_name, _, param_value = path_param.partition("=")
if not param_value:
raise InvalidArgumentError(
f"Must provide a value for parameter '{param_name}'"
)
path, _, param_name = param_name.partition(":")
if not param_name:
param_name = path
path = ParamsDependency.DEFAULT_PARAMS_FILE

try:
ret[path][param_name] = loads_yaml(param_value)
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid parameter value for '{param_name}': '{param_value}"
)

return ret
10 changes: 7 additions & 3 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, caplog):
"changes, expected",
[
[["foo=baz"], "{foo: baz, goo: {bag: 3}, lorem: false}"],
[["foo=baz,goo=bar"], "{foo: baz, goo: bar, lorem: false}"],
[["foo=baz", "goo=bar"], "{foo: baz, goo: bar, lorem: false}"],
[
["goo.bag=4"],
"{foo: [bar: 1, baz: 2], goo: {bag: 4}, lorem: false}",
Expand All @@ -114,7 +114,7 @@ def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, caplog):
"{foo: [bar: 1, baz: 3], goo: {bag: 3}, lorem: false}",
],
[
["foo[1]=- baz\n- goo"],
["foo[1]=[baz, goo]"],
"{foo: [bar: 1, [baz, goo]], goo: {bag: 3}, lorem: false}",
],
[
Expand Down Expand Up @@ -261,7 +261,11 @@ def test_update_py_params(tmp_dir, scm, dvc):

results = dvc.experiments.run(
stage.addressing,
params=["params.py:FLOAT=0.1,Train.seed=2121,Klass.a=222"],
params=[
"params.py:FLOAT=0.1",
"params.py:Train.seed=2121",
"params.py:Klass.a=222",
],
tmp_dir=True,
)
exp_a = first(results)
Expand Down

0 comments on commit 66f24b8

Please sign in to comment.