diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 48e18fb6b6ad..033522d0e81a 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -415,3 +415,103 @@ def parse_shape_string(inputs_string): shape_dict[name] = shape return shape_dict + + +def get_pass_config_value(name, value, config_type): + """Get a PassContext configuration value, based on its config data type. + + Parameters + ---------- + name: str + config identifier name. + value: str + value assigned to the config, provided via command line. + config_type: str + data type defined to the config, as string. + + Returns + ------- + parsed_value: bool, int or str + a representation of the input value, converted to the type + specified by config_type. + """ + + if config_type == "IntImm": + # "Bool" configurations in the PassContext are recognized as + # IntImm, so deal with this case here + mapping_values = { + "false": False, + "true": True, + } + + if value.isdigit(): + parsed_value = int(value) + else: + # if not an int, accept only values on the mapping table, case insensitive + parsed_value = mapping_values.get(value.lower(), None) + + if parsed_value is None: + raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ") + + if config_type == "runtime.String": + parsed_value = value + + return parsed_value + + +def parse_configs(input_configs): + """Parse configuration values set via command line. + + Parameters + ---------- + input_configs: list of str + list of configurations provided via command line. + + Returns + ------- + pass_context_configs: dict + a dict containing key-value configs to be used in the PassContext. + """ + if not input_configs: + return {} + + all_configs = tvm.ir.transform.PassContext.list_configs() + supported_config_types = ("IntImm", "runtime.String") + supported_configs = [ + name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types + ] + + pass_context_configs = {} + + for config in input_configs: + if not config: + raise TVMCException( + f"Invalid format for configuration '{config}', use =" + ) + + # Each config is expected to be provided as "name=value" + try: + name, value = config.split("=") + name = name.strip() + value = value.strip() + except ValueError: + raise TVMCException( + f"Invalid format for configuration '{config}', use =" + ) + + if name not in all_configs: + raise TVMCException( + f"Configuration '{name}' is not defined in TVM. " + f"These are the existing configurations: {', '.join(all_configs)}" + ) + + if name not in supported_configs: + raise TVMCException( + f"Configuration '{name}' uses a data type not supported by TVMC. " + f"The following configurations are supported: {', '.join(supported_configs)}" + ) + + parsed_value = get_pass_config_value(name, value, all_configs[name]["type"]) + pass_context_configs[name] = parsed_value + + return pass_context_configs diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 071474a31594..2240eaaa3f07 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -83,6 +83,14 @@ def add_compile_parser(subparsers): help="output format. Use 'so' for shared object or 'mlf' for Model Library Format " "(only for µTVM targets). Defaults to 'so'.", ) + parser.add_argument( + "--pass-config", + action="append", + metavar=("name=value"), + help="configurations to be used at compile time. This option can be provided multiple " + "times, each one to set one configuration value, " + "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", + ) parser.add_argument( "--target", help="compilation targets as comma separated string, inline JSON or path to a JSON file.", @@ -145,6 +153,7 @@ def drive_compile(args): target_host=None, desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, + pass_context_configs=args.pass_config, ) return 0 @@ -162,6 +171,7 @@ def compile_model( target_host: Optional[str] = None, desired_layout: Optional[str] = None, disabled_pass: Optional[str] = None, + pass_context_configs: Optional[str] = None, ): """Compile a model from a supported framework into a TVM module. @@ -202,6 +212,9 @@ def compile_model( disabled_pass: str, optional Comma-separated list of passes which needs to be disabled during compilation + pass_context_configs: str, optional + String containing a set of configurations to be passed to the + PassContext. Returns @@ -212,7 +225,7 @@ def compile_model( """ mod, params = tvmc_model.mod, tvmc_model.params - config = {} + config = common.parse_configs(pass_context_configs) if desired_layout: mod = common.convert_graph_layout(mod, desired_layout) diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 476fac5da1b9..cb6b82a32937 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -19,6 +19,7 @@ import pytest import tvm +from tvm.contrib.target.vitis_ai import vitis_ai_available from tvm.driver import tvmc from tvm.driver.tvmc.common import TVMCException @@ -306,3 +307,53 @@ def test_parse_quotes_and_separators_on_options(): assert len(targets_double_quote) == 1 assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] + + +def test_config_invalid_format(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) + + +def test_config_missing_from_tvm(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) + + +def test_config_unsupported_tvmc_config(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) + + +def test_config_empty(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs([""]) + + +def test_config_valid_config_bool(): + configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) + + assert len(configs) == 1 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == True + + +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_config_valid_multiple_configs(): + configs = tvmc.common.parse_configs( + [ + "relay.backend.use_auto_scheduler=false", + "tir.detect_global_barrier=10", + "relay.ext.vitis_ai.options.build_dir=mystring", + ] + ) + + assert len(configs) == 3 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == False + assert "tir.detect_global_barrier" in configs.keys() + assert configs["tir.detect_global_barrier"] == 10 + assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() + assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"