diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index c5cb5f29031f..fbd7bc897683 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -131,9 +131,19 @@ def tokenize_target(target): a list of parsed tokens extracted from the target string """ + # Regex to tokenize the "--target" value. It is split into five parts + # to match with: + # 1. target and option names e.g. llvm, -mattr=, -mcpu= + # 2. option values, all together, without quotes e.g. -mattr=+foo,+opt + # 3. option values, when single quotes are used e.g. -mattr='+foo, +opt' + # 4. option values, when double quotes are used e.g. -mattr="+foo ,+opt" + # 5. commas that separate different targets e.g. "my-target, llvm" target_pattern = ( r"(\-{0,2}[\w\-]+\=?" - r"(?:[\w\+\-\.]+(?:,[\w\+\-\.])*|[\'][\w\+\-,\s\.]+[\']|[\"][\w\+\-,\s\.]+[\"])*|,)" + r"(?:[\w\+\-\.]+(?:,[\w\+\-\.])*" + r"|[\'][\w\+\-,\s\.]+[\']" + r"|[\"][\w\+\-,\s\.]+[\"])*" + r"|,)" ) return re.findall(target_pattern, target) diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 23ea4f46b2ff..474649d8b1b3 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -275,22 +275,16 @@ def test_parse_multiple_target_with_opts(): assert "llvm" == targets[1]["name"] -def test_parse_multiple_separators_on_target(): - targets = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") - - assert len(targets) == 1 - assert "+v1.0x,+value,+bar" == targets[0]["opts"]["option1"] +def test_parse_quotes_and_separators_on_options(): + targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") + targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") + targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') + assert len(targets_no_quote) == 1 + assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] -def test_parse_single_quoted_multiple_separators_on_target(): - targets = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") - - assert len(targets) == 1 - assert "+v1.0x,+value" == targets[0]["opts"]["option1"] + assert len(targets_single_quote) == 1 + assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] - -def test_parse_double_quoted_multiple_separators_on_target(): - targets = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') - - assert len(targets) == 1 - assert "+v1.0x,+value" == targets[0]["opts"]["option1"] + assert len(targets_double_quote) == 1 + assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"]