diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index 97138c00af..7bfcf255d3 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -1,5 +1,6 @@ import re import typing +from functools import singledispatch from pyparsing import ( CharsNotIn, @@ -57,6 +58,16 @@ def format_and_raise_parse_error(exc): raise ParseError(_format_exc_msg(exc)) +@singledispatch +def to_str(obj): + return str(obj) + + +@to_str.register(bool) +def _(obj: bool): + return "true" if obj else "false" + + def _format_exc_msg(exc: ParseException): exc.loc += 2 # 2 because we append `${` at the start of expr below @@ -103,7 +114,7 @@ def str_interpolate(template: str, matches: "List[Match]", context: "Context"): raise ParseError( f"Cannot interpolate data of type '{type(value).__name__}'" ) - buf += template[index:start] + str(value) + buf += template[index:start] + to_str(value) index = end buf += template[index:] # regex already backtracks and avoids any `${` starting with diff --git a/tests/func/test_stage_resolver.py b/tests/func/test_stage_resolver.py index dee9d14731..d92c36799b 100644 --- a/tests/func/test_stage_resolver.py +++ b/tests/func/test_stage_resolver.py @@ -337,12 +337,16 @@ def test_set(tmp_dir, dvc, value): } } resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + if isinstance(value, bool): + stringified_value = "true" if value else "false" + else: + stringified_value = str(value) assert_stage_equal( resolver.resolve(), { "stages": { "build": { - "cmd": f"python script.py --thresh {value}", + "cmd": f"python script.py --thresh {stringified_value}", "always_changed": value, } } diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 924b8f986c..507a16c1dd 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -411,6 +411,17 @@ def test_resolve_resolves_dict_keys(): } +def test_resolve_resolves_boolean_value(): + d = {"enabled": True, "disabled": False} + context = Context(d) + + assert context.resolve_str("${enabled}") is True + assert context.resolve_str("${disabled}") is False + + assert context.resolve_str("--flag ${enabled}") == "--flag true" + assert context.resolve_str("--flag ${disabled}") == "--flag false" + + def test_merge_from_raises_if_file_not_exist(tmp_dir, dvc): context = Context(foo="bar") with pytest.raises(ParamsFileNotFound):