Skip to content

Commit

Permalink
[TVMC] --disable-pass option added to compile mode
Browse files Browse the repository at this point in the history
Added --disable-pass option to TVMC compile mode to disallow
certain supplied passes in PassContext for the compiler.

Change-Id: Iae1849d7b051ac9288509dc458a58788c865537a
  • Loading branch information
d-smirnov committed Apr 9, 2021
1 parent 461d06e commit d63f6b4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
23 changes: 23 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,29 @@ def tracker_host_port_from_cli(rpc_tracker_str):
return rpc_hostname, rpc_port


def parse_disabled_pass(input_string):
"""Parse an input string for disabled passes
Parameters
----------
input_string: str
Possibly comma-separated string with the names of disabled passes
Returns
-------
list: a list of disabled passes.
"""
if input_string is not None:
pass_list = input_string.split(",")
nf = [_ for _ in pass_list if tvm.get_global_func("relay._transform." + _, True) is None]
if len(nf) > 0:
raise argparse.ArgumentTypeError(
"Following passes are not registered within tvm: " + str(nf)
)
return pass_list
return None


def parse_shape_string(inputs_string):
"""Parse an input shape dictionary string to a usable dictionary.
Expand Down
22 changes: 19 additions & 3 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def add_compile_parser(subparsers):
type=common.parse_shape_string,
default=None,
)
parser.add_argument(
"--disabled-pass",
help="disable specific passes, comma-separated list of pass names",
type=common.parse_disabled_pass,
default=None,
)


def drive_compile(args):
Expand All @@ -121,6 +127,7 @@ def drive_compile(args):
None,
args.tuning_records,
args.desired_layout,
args.disabled_pass,
)

if dumps:
Expand All @@ -138,6 +145,7 @@ def compile_model(
target_host=None,
tuning_records=None,
alter_layout=None,
disabled_pass=None,
):
"""Compile a model from a supported framework into a TVM module.
Expand Down Expand Up @@ -167,6 +175,10 @@ def compile_model(
The layout to convert the graph to. Note, the convert layout
pass doesn't currently guarantee the whole of the graph will
be converted to the chosen layout.
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
Returns
-------
Expand Down Expand Up @@ -209,16 +221,20 @@ def compile_model(
if use_autoscheduler:
with auto_scheduler.ApplyHistoryBest(tuning_records):
config["relay.backend.use_auto_scheduler"] = True
with tvm.transform.PassContext(opt_level=3, config=config):
with tvm.transform.PassContext(
opt_level=3, config=config, disabled_pass=disabled_pass
):
logger.debug("building relay graph with autoscheduler")
graph_module = relay.build(mod, target=target, params=params)
else:
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(opt_level=3, config=config):
with tvm.transform.PassContext(
opt_level=3, config=config, disabled_pass=disabled_pass
):
logger.debug("building relay graph with tuning records")
graph_module = relay.build(mod, tvm_target, params=params)
else:
with tvm.transform.PassContext(opt_level=3, config=config):
with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=disabled_pass):
logger.debug("building relay graph (no tuning records provided)")
graph_module = relay.build(mod, tvm_target, params=params)

Expand Down

0 comments on commit d63f6b4

Please sign in to comment.