diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 864c3a9bddb4..e4ff27c6fcd8 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -333,6 +333,29 @@ def tracker_host_port_from_cli(rpc_tracker_str): return rpc_hostname, rpc_port +def parse_disabled_pass(input_string): + """Parse an input string for disabled passes + + Parameters + ---------- + input_string: str + Possibly comma-separated string with the names of disabled passes + + Returns + ------- + list: a list of disabled passes. + """ + if input_string is not None: + pass_list = input_string.split(",") + nf = [_ for _ in pass_list if tvm.get_global_func("relay._transform." + _, True) is None] + if len(nf) > 0: + raise argparse.ArgumentTypeError( + "Following passes are not registered within tvm: " + str(nf) + ) + return pass_list + return None + + def parse_shape_string(inputs_string): """Parse an input shape dictionary string to a usable dictionary. diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index b8450750f115..eff262c0efb7 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -95,6 +95,12 @@ def add_compile_parser(subparsers): type=common.parse_shape_string, default=None, ) + parser.add_argument( + "--disabled-pass", + help="disable specific passes, comma-separated list of pass names", + type=common.parse_disabled_pass, + default=None, + ) def drive_compile(args): @@ -121,6 +127,7 @@ def drive_compile(args): None, args.tuning_records, args.desired_layout, + args.disabled_pass, ) if dumps: @@ -138,6 +145,7 @@ def compile_model( target_host=None, tuning_records=None, alter_layout=None, + disabled_pass=None, ): """Compile a model from a supported framework into a TVM module. @@ -167,6 +175,10 @@ def compile_model( The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. + disabled_pass: str, optional + Comma-separated list of passes which needs to be disabled + during compilation + Returns ------- @@ -209,16 +221,20 @@ def compile_model( if use_autoscheduler: with auto_scheduler.ApplyHistoryBest(tuning_records): config["relay.backend.use_auto_scheduler"] = True - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext( + opt_level=3, config=config, disabled_pass=disabled_pass + ): logger.debug("building relay graph with autoscheduler") graph_module = relay.build(mod, target=target, params=params) else: with autotvm.apply_history_best(tuning_records): - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext( + opt_level=3, config=config, disabled_pass=disabled_pass + ): logger.debug("building relay graph with tuning records") graph_module = relay.build(mod, tvm_target, params=params) else: - with tvm.transform.PassContext(opt_level=3, config=config): + with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=disabled_pass): logger.debug("building relay graph (no tuning records provided)") graph_module = relay.build(mod, tvm_target, params=params)