From 96c67df2c3ec4a2cc8091d58dc9d394dcb5c1703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 1 Dec 2020 15:24:55 +0545 Subject: [PATCH] parametrization: convert bool to string as "true"/"false" Previously, Python's `str()` was being used that resulted in boolean transformed into "True"/"False". Fixes #4996 --- dvc/parsing/interpolate.py | 13 ++++++++++++- tests/func/test_stage_resolver.py | 6 +++++- tests/unit/test_context.py | 11 +++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index 97138c00af..7bfcf255d3 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -1,5 +1,6 @@ import re import typing +from functools import singledispatch from pyparsing import ( CharsNotIn, @@ -57,6 +58,16 @@ def format_and_raise_parse_error(exc): raise ParseError(_format_exc_msg(exc)) +@singledispatch +def to_str(obj): + return str(obj) + + +@to_str.register(bool) +def _(obj: bool): + return "true" if obj else "false" + + def _format_exc_msg(exc: ParseException): exc.loc += 2 # 2 because we append `${` at the start of expr below @@ -103,7 +114,7 @@ def str_interpolate(template: str, matches: "List[Match]", context: "Context"): raise ParseError( f"Cannot interpolate data of type '{type(value).__name__}'" ) - buf += template[index:start] + str(value) + buf += template[index:start] + to_str(value) index = end buf += template[index:] # regex already backtracks and avoids any `${` starting with diff --git a/tests/func/test_stage_resolver.py b/tests/func/test_stage_resolver.py index dee9d14731..d92c36799b 100644 --- a/tests/func/test_stage_resolver.py +++ b/tests/func/test_stage_resolver.py @@ -337,12 +337,16 @@ def test_set(tmp_dir, dvc, value): } } resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + if isinstance(value, bool): + stringified_value = "true" if value else "false" + else: + stringified_value = str(value) assert_stage_equal( resolver.resolve(), { "stages": { "build": { - "cmd": f"python script.py --thresh {value}", + "cmd": f"python script.py --thresh {stringified_value}", "always_changed": value, } } diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 924b8f986c..507a16c1dd 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -411,6 +411,17 @@ def test_resolve_resolves_dict_keys(): } +def test_resolve_resolves_boolean_value(): + d = {"enabled": True, "disabled": False} + context = Context(d) + + assert context.resolve_str("${enabled}") is True + assert context.resolve_str("${disabled}") is False + + assert context.resolve_str("--flag ${enabled}") == "--flag true" + assert context.resolve_str("--flag ${disabled}") == "--flag false" + + def test_merge_from_raises_if_file_not_exist(tmp_dir, dvc): context = Context(foo="bar") with pytest.raises(ParamsFileNotFound):