From 66f24b870844b6d1171f50baa67c3e546785c5d4 Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Tue, 16 Feb 2021 08:19:56 -0600 Subject: [PATCH] exp run: Separate exp param parsing from other commands (#5310) * Separate exp param parsing from other commands * Update exp run params metavar * Updated experiments run param argument name --- dvc/command/experiments.py | 9 ++--- dvc/repo/experiments/run.py | 4 +-- dvc/utils/cli_parse.py | 40 ++++++++++++---------- tests/func/experiments/test_experiments.py | 10 ++++-- 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 8e426fd14a..534c7c1526 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -521,7 +521,7 @@ def run(self): queue=self.args.queue, run_all=self.args.run_all, jobs=self.args.jobs, - params=self.args.params, + params=self.args.set_param, checkpoint_resume=self.args.checkpoint_resume, reset=self.args.reset, tmp_dir=self.args.tmp_dir, @@ -1184,11 +1184,12 @@ def _add_run_common(parser): metavar="", ) parser.add_argument( - "--params", + "-S", + "--set-param", action="append", default=[], - help="Use the specified param values when reproducing pipelines.", - metavar="[:]", + help="Use the specified param value when reproducing pipelines.", + metavar="[:]=", ) parser.add_argument( "--queue", diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index 3851a0a596..2659ea9659 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -2,7 +2,7 @@ from typing import Iterable, Optional from dvc.repo import locked -from dvc.utils.cli_parse import loads_params +from dvc.utils.cli_parse import loads_param_overrides logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ def run( return repo.experiments.reproduce_queued(jobs=jobs) if params: - params = loads_params(params) + params = loads_param_overrides(params) return repo.experiments.reproduce_one( targets=targets, params=params, tmp_dir=tmp_dir, **kwargs ) diff --git a/dvc/utils/cli_parse.py b/dvc/utils/cli_parse.py index 3bbd597f9d..c288548468 100644 --- a/dvc/utils/cli_parse.py +++ b/dvc/utils/cli_parse.py @@ -17,31 +17,35 @@ def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]: return [{path: params} for path, params in ret.items()] -def loads_params(path_params: Iterable[str],) -> Dict[str, Dict[str, Any]]: - +def loads_param_overrides( + path_params: Iterable[str], +) -> Dict[str, Dict[str, Any]]: """Loads the content of params from the cli as Python object.""" from ruamel.yaml import YAMLError + from dvc.dependency.param import ParamsDependency from dvc.exceptions import InvalidArgumentError from .serialize import loads_yaml - normalized_params = parse_params(path_params) ret: Dict[str, Dict[str, Any]] = defaultdict(dict) - for part in normalized_params: - assert part - (item,) = part.items() - path, param_keys = item - - for param_str in param_keys: - try: - key, _, value = param_str.partition("=") - # interpret value strings using YAML rules - parsed = loads_yaml(value) - ret[path][key] = parsed - except (ValueError, YAMLError): - raise InvalidArgumentError( - f"Invalid param/value pair '{param_str}'" - ) + for path_param in path_params: + param_name, _, param_value = path_param.partition("=") + if not param_value: + raise InvalidArgumentError( + f"Must provide a value for parameter '{param_name}'" + ) + path, _, param_name = param_name.partition(":") + if not param_name: + param_name = path + path = ParamsDependency.DEFAULT_PARAMS_FILE + + try: + ret[path][param_name] = loads_yaml(param_value) + except (ValueError, YAMLError): + raise InvalidArgumentError( + f"Invalid parameter value for '{param_name}': '{param_value}" + ) + return ret diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 0f81e7b2b9..3558351c60 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -103,7 +103,7 @@ def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, caplog): "changes, expected", [ [["foo=baz"], "{foo: baz, goo: {bag: 3}, lorem: false}"], - [["foo=baz,goo=bar"], "{foo: baz, goo: bar, lorem: false}"], + [["foo=baz", "goo=bar"], "{foo: baz, goo: bar, lorem: false}"], [ ["goo.bag=4"], "{foo: [bar: 1, baz: 2], goo: {bag: 4}, lorem: false}", @@ -114,7 +114,7 @@ def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, caplog): "{foo: [bar: 1, baz: 3], goo: {bag: 3}, lorem: false}", ], [ - ["foo[1]=- baz\n- goo"], + ["foo[1]=[baz, goo]"], "{foo: [bar: 1, [baz, goo]], goo: {bag: 3}, lorem: false}", ], [ @@ -261,7 +261,11 @@ def test_update_py_params(tmp_dir, scm, dvc): results = dvc.experiments.run( stage.addressing, - params=["params.py:FLOAT=0.1,Train.seed=2121,Klass.a=222"], + params=[ + "params.py:FLOAT=0.1", + "params.py:Train.seed=2121", + "params.py:Klass.a=222", + ], tmp_dir=True, ) exp_a = first(results)