From 670c5231d10672256532a20765e8d5e8abf6f31e Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 23 Feb 2023 00:04:51 +0530 Subject: [PATCH 1/5] [TVMC][TRANSFORMS] ToMixedPrecision transform support with custom options enabled Adds new command line options --mixed-precision --mixed-precision-ops --mixed-precision-input --mixed-precision-output and --desired-layout-ops This PR also enhances the python interface by replacing alter_layout to transform_args. transform_args is a dict with all tranform related options including existing desired_layout or alter_layout option. --- python/tvm/driver/tvmc/autotuner.py | 47 ++--- python/tvm/driver/tvmc/compiler.py | 23 +-- python/tvm/driver/tvmc/transform.py | 196 ++++++++++++++++++++- tests/python/driver/tvmc/test_compiler.py | 12 +- tests/python/driver/tvmc/test_transform.py | 56 +++++- 5 files changed, 276 insertions(+), 58 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index b7766efb4796..26836cd20bbf 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -39,7 +39,7 @@ from .model import TVMCModel from .target import target_from_cli, generate_target_args, reconstruct_target_args from .shape_parser import parse_shape_string -from .transform import convert_graph_layout +from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms # pylint: disable=invalid-name @@ -127,12 +127,7 @@ def add_tune_parser(subparsers, _, json_params): metavar="PATH", help="path to an auto-tuning log file by AutoTVM.", ) - parser.add_argument( - "--desired-layout", - choices=["NCHW", "NHWC"], - default=None, - help="change the data layout of the whole graph", - ) + generate_transform_args(parser) parser.add_argument( "--enable-autoscheduler", help="enable tuning the graph through the AutoScheduler tuner", @@ -269,6 +264,8 @@ def drive_tune(args): rpc_hostname = None rpc_port = None + transform_args = parse_graph_transform_args(args) + tune_model( tvmc_model, args.target, @@ -283,7 +280,7 @@ def drive_tune(args): tuner=args.tuner, min_repeat_ms=args.min_repeat_ms, early_stopping=args.early_stopping, - desired_layout=args.desired_layout, + transform_args=transform_args, timeout=args.timeout, repeat=args.repeat, number=args.number, @@ -309,7 +306,7 @@ def tune_model( tuner: str = "xgb", min_repeat_ms: Optional[int] = None, early_stopping: Optional[int] = None, - desired_layout: Optional[str] = None, + transform_args: Optional[Dict[str, Any]] = None, timeout: int = 10, repeat: int = 1, number: int = 10, @@ -354,10 +351,8 @@ def tune_model( Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets. early_stopping : int, optional When specified, stop tuning after this number of trials if results aren't improving. - desired_layout : str, optional - Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph - will have their layout set to this format. Tasks will then be tuned using this - specified layout. + transform_args: dict, optional + Graph transformation arguments that are applied to the relay module. timeout : int, optional, If a kernel trial lasts longer than this duration in seconds, it will be considered a failure. @@ -453,7 +448,7 @@ def tune_model( mod=mod, params=params, target=target, - alter_layout=desired_layout, + transform_args=transform_args, hardware_params=hardware_params, include_simple_tasks=include_simple_tasks, ) @@ -475,7 +470,7 @@ def tune_model( mod=mod, params=params, target=target, - alter_layout=desired_layout, + transform_args=transform_args, ) # In autotvm, trials is specified per task. We can convert the per-model input @@ -504,7 +499,7 @@ def autotvm_get_tuning_tasks( params: Dict[str, tvm.nd.NDArray], target: str, target_host: Optional[str] = None, - alter_layout: Optional[str] = None, + transform_args: Optional[Dict[str, Any]] = None, ): """Get the autotvm tuning tasks for a given relay module. @@ -518,10 +513,8 @@ def autotvm_get_tuning_tasks( The compilation target. target_host : str, optional The compilation target for the host. - alter_layout : str, optional - The layout to convert the graph to. Note, the convert layout - pass doesn't currently guarantee the whole of the graph will - be converted to the chosen layout. + transform_args: dict, optional + Graph transformation arguments that are applied to the relay module. Returns ------- @@ -530,8 +523,7 @@ def autotvm_get_tuning_tasks( """ target, target_host = Target.canon_target_and_host(target, target_host) - if alter_layout: - mod = convert_graph_layout(mod, alter_layout) + mod = apply_graph_transforms(mod, transform_args) tasks = autotvm.task.extract_from_program( mod["main"], @@ -547,7 +539,7 @@ def autoscheduler_get_tuning_tasks( params: Dict[str, tvm.nd.NDArray], target: str, target_host: Optional[str] = None, - alter_layout: Optional[str] = None, + transform_args: Optional[Dict[str, Any]] = None, hardware_params: Optional[HardwareParams] = None, include_simple_tasks: bool = False, ): @@ -563,10 +555,8 @@ def autoscheduler_get_tuning_tasks( The compilation target. target_host : str, optional The compilation target for the host. - alter_layout : str, optional - The layout to convert the graph to. Note, the convert layout - pass doesn't currently guarantee the whole of the graph will - be converted to the chosen layout. + transform_args: dict, optional + Graph transformation arguments that are applied to the relay module. hardware_params : Optional[HardwareParams] Hardware parameters used for the search tasks @@ -579,8 +569,7 @@ def autoscheduler_get_tuning_tasks( """ target, target_host = Target.canon_target_and_host(target, target_host) - if alter_layout: - mod = convert_graph_layout(mod, alter_layout) + mod = apply_graph_transforms(mod, transform_args) # Extract the tasks tasks, task_weights = auto_scheduler.extract_tasks( diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index eec80820cdb1..f960687d286f 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -37,7 +37,7 @@ from .target import target_from_cli, generate_target_args, reconstruct_target_args from .pass_config import parse_configs from .pass_list import parse_pass_list_str -from .transform import convert_graph_layout +from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms from .shape_parser import parse_shape_string from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate @@ -61,12 +61,7 @@ def add_compile_parser(subparsers, _, json_params): default="", help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'.", ) - parser.add_argument( - "--desired-layout", - choices=["NCHW", "NHWC"], - default=None, - help="change the data layout of the whole graph.", - ) + generate_transform_args(parser) parser.add_argument( "--dump-code", metavar="FORMAT", @@ -177,6 +172,7 @@ def drive_compile(args): additional_targets = reconstruct_target_args(args) workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets) + transform_args = parse_graph_transform_args(args) compile_model( tvmc_model, @@ -191,7 +187,7 @@ def drive_compile(args): output_format=args.output_format, dump_code=dump_code, target_host=None, - desired_layout=args.desired_layout, + transform_args=transform_args, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, mod_name=args.module_name, @@ -217,7 +213,7 @@ def compile_model( output_format: str = "so", dump_code: Optional[List[str]] = None, target_host: Optional[str] = None, - desired_layout: Optional[str] = None, + transform_args: Optional[Dict[str, Any]] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, @@ -260,10 +256,8 @@ def compile_model( target_host : str, optional The target of the host machine if host-side code needs to be generated. - desired_layout: str, optional - The layout to convert the graph to. Note, the convert layout - pass doesn't currently guarantee the whole of the graph will - be converted to the chosen layout. + transform_args: dict, optional + Graph transformation arguments that are applied to the relay module. disabled_pass: str, optional Comma-separated list of passes which needs to be disabled during compilation @@ -310,8 +304,7 @@ def compile_model( disabled_pass=disabled_pass, instruments=instruments, ): - if desired_layout: - mod = convert_graph_layout(mod, desired_layout) + mod = apply_graph_transforms(mod, transform_args) for partition_function, opts in zip(partition_functions, partition_opts): mod = partition_function(mod, params, mod_name=mod_name, **opts) diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py index 8527c48b6b04..80799695d99e 100644 --- a/python/tvm/driver/tvmc/transform.py +++ b/python/tvm/driver/tvmc/transform.py @@ -13,6 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language +# pylint: disable=unused-argument """ TVMC Graph Transforms """ @@ -20,8 +21,89 @@ from tvm import relay, transform from tvm.driver.tvmc import TVMCException +# ToMixedPrecision +ACC_DTYPE = "float32" -def convert_graph_layout(mod, desired_layout): + +def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): + global ACC_DTYPE + return [ + relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, + ACC_DTYPE, + mixed_precision_type, + ] + + +class MixedPrecision(object): + """Temporarily changes attr of ops to enable required precision.""" + + def __init__(self, ops): + """Saves the required info for RAII pattern usage. + + Parameters + ---------- + ops : list + list of operators + """ + self.older_attr = {} + self.ops = ops + self.attr_key = "FTVMMixedPrecisionConversionType" + + def __enter__(self): + for op_name in self.ops: + op = relay.op.get(op_name) + self.older_attr[op_name] = op.get_attr(self.attr_key) + op.reset_attr(self.attr_key) + op.set_attr(self.attr_key, mixed_precision_rule) + return self + + def __exit__(self, ptype, value, trace): + for op_name in self.ops: + op = relay.op.get(op_name) + op.reset_attr(self.attr_key) + if self.older_attr[op_name]: + op.set_attr(self.attr_key, self.older_attr[op_name]) + + +def convert_to_mixed_precision( + mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16" +): + """Converts the operator datatypes + + Parameters + ---------- + mod : tvm.IRModule + The relay module to convert. + ops : str + List of operators to be precision converted. + input_type: str + Input precision to be used. + output_type: str + Output or accumulation precision to be used. + + Returns + ------- + mod : tvm.IRModule + The converted module. + """ + + global ACC_DTYPE + ACC_DTYPE = out_type + + with MixedPrecision(ops.split(",")): + seq = transform.Sequential( + [relay.transform.InferType(), relay.transform.ToMixedPrecision()] + ) + with transform.PassContext( + config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, opt_level=3 + ): + try: + return seq(mod) + except Exception as err: + raise TVMCException("Error converting mixed precision : {0}".format(str(err))) + + +def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose,qnn.conv2d"): """Alter the layout of the input graph. Parameters @@ -30,6 +112,8 @@ def convert_graph_layout(mod, desired_layout): The relay module to convert. desired_layout : str The layout to convert to. + ops : str + List of operators to be layout converted. Returns ------- @@ -37,13 +121,7 @@ def convert_graph_layout(mod, desired_layout): The converted module. """ - # Assume for the time being that graphs only have - # conv2d as heavily-sensitive operators. - desired_layouts = { - "nn.conv2d": [desired_layout, "default"], - "nn.conv2d_transpose": [desired_layout, "default"], - "qnn.conv2d": [desired_layout, "default"], - } + desired_layouts = {op: [desired_layout, "default"] for op in ops.split(",")} # Convert the layout of the graph where possible. seq = transform.Sequential( @@ -58,3 +136,105 @@ def convert_graph_layout(mod, desired_layout): return seq(mod) except Exception as err: raise TVMCException("Error converting layout to {0}: {1}".format(desired_layout, str(err))) + + +def apply_graph_transforms(mod, args): + """Alter the layout of the input graph. + + Parameters + ---------- + mod : tvm.IRModule + The relay module to convert. + args : dict + The transform arguments. + + Returns + ------- + mod : tvm.IRModule + The converted module. + """ + if not args: + return mod + + # AlterLayout + if args.get("desired_layout", False): + mod = convert_graph_layout(mod, args["desired_layout"]) + + # ToMixedPrecision + if args.get("mixed_precision", False): + mod = convert_to_mixed_precision( + mod, + args.get("mixed_precision_ops", "nn.conv2d,nn.dense"), + args.get("mixed_precision_input", "float16"), + args.get("mixed_precision_output", "float16"), + ) + return mod + + +def parse_graph_transform_args(args): + """Parse incoming options for graph transform arguments. + + Parameters + ---------- + args: argparse.Namespace + Arguments from command line parser. + + Returns + ------- + transform_args : dict + Graph transform arguments + """ + + args_dict = vars(args) + + transform_args = [ + "desired_layout", + "desired_layout_ops", + "mixed_precision", + "mixed_precision_ops", + "mixed_precision_input", + "mixed_precision_output", + ] + transform_args = {key: args_dict.get(key, None) for key in transform_args} + return transform_args + + +def generate_transform_args(parser): + """Add graph transform related args""" + + # AlterLayout + parser.add_argument( + "--desired-layout", + choices=["NCHW", "NHWC"], + default=None, + help="Change the data layout of the whole graph.", + ) + parser.add_argument( + "--desired-layout-ops", + default="nn.conv2d,nn.conv2d_transpose,qnn.conv2d", + help="List of operators to be layout converted.", + ) + + # ToMixedPrecision + parser.add_argument( + "--mixed-precision", + help="Enable mixed precision conversion", + action="store_true", + ) + parser.add_argument( + "--mixed-precision-ops", + default="nn.conv2d,nn.dense", + help="List of operators to be converted to mixed precision", + ) + parser.add_argument( + "--mixed-precision-input", + choices=["float16", "float32"], + default="float16", + help="Input precision type", + ) + parser.add_argument( + "--mixed-precision-output", + choices=["float16", "float32"], + default="float16", + help="Output or accumulator precision type", + ) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 3a3f297729fd..c086860be62f 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -69,7 +69,11 @@ def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): pytest.importorskip("tflite") tvmc_model = tvmc.load(model, shape_dict=shape_dict) tvmc_package = tvmc.compile( - tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", use_vm=use_vm + tvmc_model, + target="llvm", + dump_code="ll", + transform_args={"desired_layout": "NCHW"}, + use_vm=use_vm, ) dumps_path = tvmc_package.package_path + ".ll" verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) @@ -286,7 +290,9 @@ def test_cross_compile_options_aarch64_onnx_module(onnx_resnet50): def verify_compile_paddle_module(model, shape_dict=None): pytest.importorskip("paddle") tvmc_model = tvmc.load(model, "paddle", shape_dict=shape_dict) - tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW") + tvmc_package = tvmc.compile( + tvmc_model, target="llvm", dump_code="ll", transform_args={"desired_layout": "NCHW"} + ) dumps_path = tvmc_package.package_path + ".ll" # check for output types @@ -368,7 +374,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): tvmc_package = tvmc.compile( tvmc_model, target="opencl -host=llvm", - desired_layout="NCHW", + transform_args={"desired_layout": "NCHW"}, dump_code="asm", ) dumps_path = tvmc_package.package_path + ".asm" diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py index 98bd3b5f98a3..fb7d8b633566 100644 --- a/tests/python/driver/tvmc/test_transform.py +++ b/tests/python/driver/tvmc/test_transform.py @@ -20,7 +20,7 @@ import tvm from tvm import relay from tvm.ir.instrument import pass_instrument -from tvm.driver.tvmc.transform import convert_graph_layout +from tvm.driver.tvmc.transform import apply_graph_transforms def test_layout_transform_fold_constant(relay_conv2d): @@ -39,7 +39,7 @@ def run_after_pass(self, _, info): pass_names = CollectPassNames() with tvm.transform.PassContext(opt_level=3, instruments=[pass_names]): - convert_graph_layout(relay_conv2d, desired_layout) + apply_graph_transforms(relay_conv2d, {"desired_layout": desired_layout}) names = pass_names.names assert "ConvertLayout" in names @@ -59,7 +59,7 @@ def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch): monkeypatch.setattr(relay.transform, "ConvertLayout", mock_convert_layout) with tvm.transform.PassContext(opt_level=3): - convert_graph_layout(relay_conv2d, desired_layout) + apply_graph_transforms(relay_conv2d, {"desired_layout": desired_layout}) mock_convert_layout.assert_called_once_with( { @@ -70,5 +70,55 @@ def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch): ) +def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d, monkeypatch): + """ + Check the mixed precision arugments which are expected when + mixed precision arguments are provided. + """ + mock_mixed_precision = MagicMock() + mock_mixed_precision.return_value = tvm.driver.tvmc.transform.MixedPrecision([]) + monkeypatch.setattr(tvm.driver.tvmc.transform, "MixedPrecision", mock_mixed_precision) + + with tvm.transform.PassContext(opt_level=3): + apply_graph_transforms( + relay_conv2d, + { + "mixed_precision": True, + }, + ) + mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"]) + + apply_graph_transforms( + relay_conv2d, + { + "mixed_precision": True, + "mixed_precision_ops": "nn.conv2d", + }, + ) + mock_mixed_precision.assert_called_with(["nn.conv2d"]) + + apply_graph_transforms( + relay_conv2d, + { + "mixed_precision": True, + "mixed_precision_ops": "nn.conv2d,nn.dense", + "mixed_precision_input": "float16", + "mixed_precision_output": "float16", + }, + ) + mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"]) + + apply_graph_transforms( + relay_conv2d, + { + "mixed_precision": True, + "mixed_precision_ops": "nn.conv2d,nn.dense", + "mixed_precision_input": "float16", + "mixed_precision_output": "float32", + }, + ) + mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"]) + + if __name__ == "__main__": tvm.testing.main() From e55faaf4c21f603b581c585244d921615e144636 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 2 Mar 2023 11:57:18 +0530 Subject: [PATCH 2/5] * review comments. --- python/tvm/driver/tvmc/autotuner.py | 28 ++++++-- python/tvm/driver/tvmc/compiler.py | 27 ++++++-- python/tvm/driver/tvmc/transform.py | 81 ++++++++++++---------- tests/python/driver/tvmc/test_compiler.py | 8 +-- tests/python/driver/tvmc/test_transform.py | 35 +++------- 5 files changed, 102 insertions(+), 77 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 26836cd20bbf..bed829ef6b29 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """ Provides support to auto-tuning networks using AutoTVM. """ @@ -280,7 +281,6 @@ def drive_tune(args): tuner=args.tuner, min_repeat_ms=args.min_repeat_ms, early_stopping=args.early_stopping, - transform_args=transform_args, timeout=args.timeout, repeat=args.repeat, number=args.number, @@ -289,6 +289,7 @@ def drive_tune(args): include_simple_tasks=args.include_simple_tasks, log_estimated_latency=args.log_estimated_latency, additional_target_options=reconstruct_target_args(args), + **transform_args, ) @@ -306,7 +307,6 @@ def tune_model( tuner: str = "xgb", min_repeat_ms: Optional[int] = None, early_stopping: Optional[int] = None, - transform_args: Optional[Dict[str, Any]] = None, timeout: int = 10, repeat: int = 1, number: int = 10, @@ -315,6 +315,12 @@ def tune_model( include_simple_tasks: bool = False, log_estimated_latency: bool = False, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, + desired_layout: Optional[str] = None, + desired_layout_ops: Optional[List[str]] = None, + mixed_precision: bool = False, + mixed_precision_ops: Optional[List[str]] = None, + mixed_precision_calculation_type: Optional[str] = None, + mixed_precision_acc_type: Optional[str] = None, ): """Use tuning to automatically optimize the functions in a model. @@ -351,8 +357,6 @@ def tune_model( Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets. early_stopping : int, optional When specified, stop tuning after this number of trials if results aren't improving. - transform_args: dict, optional - Graph transformation arguments that are applied to the relay module. timeout : int, optional, If a kernel trial lasts longer than this duration in seconds, it will be considered a failure. @@ -371,12 +375,28 @@ def tune_model( If using the autoscheduler, write the estimated latency at each step of tuning to file. additional_target_options: Optional[Dict[str, Dict[str, Any]]] Additional target options in a dictionary to combine with initial Target arguments + desired_layout: str, optional + Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph + will have their layout set to this format. Tasks will then be tuned using this + specified layout. + desired_layout_ops: list[str], optional + The list of operators to be transformed with desired layout. + mixed_precision: bool + To enable mixed precision transformation. + mixed_precision_ops: list[str], optional + The list of operators to be converted to mixed precision. + mixed_precision_calculation_type: str + The calculation dtype to be used while mixed precision. + mixed_precision_acc_type: str + The accumulation data type to be used while mixed precision. + Returns ------- tuning_records : str The path to the produced tuning log file. """ + transform_args = parse_graph_transform_args(locals()) target, extra_targets = target_from_cli(target, additional_target_options) target, target_host = Target.canon_target_and_host(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index f960687d286f..2d392226c505 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """ Provides support to compile networks both AOT and JIT. """ @@ -187,7 +188,6 @@ def drive_compile(args): output_format=args.output_format, dump_code=dump_code, target_host=None, - transform_args=transform_args, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, mod_name=args.module_name, @@ -195,6 +195,7 @@ def drive_compile(args): workspace_pools=( workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets) ), + **transform_args, ) return 0 @@ -213,7 +214,6 @@ def compile_model( output_format: str = "so", dump_code: Optional[List[str]] = None, target_host: Optional[str] = None, - transform_args: Optional[Dict[str, Any]] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, @@ -221,6 +221,12 @@ def compile_model( mod_name: Optional[str] = "default", workspace_pools: Optional[WorkspaceMemoryPools] = None, instruments: Optional[Sequence[PassInstrument]] = None, + desired_layout: Optional[str] = None, + desired_layout_ops: Optional[List[str]] = None, + mixed_precision: bool = False, + mixed_precision_ops: Optional[List[str]] = None, + mixed_precision_calculation_type: Optional[str] = None, + mixed_precision_acc_type: Optional[str] = None, ): """Compile a model from a supported framework into a TVM module. @@ -256,8 +262,6 @@ def compile_model( target_host : str, optional The target of the host machine if host-side code needs to be generated. - transform_args: dict, optional - Graph transformation arguments that are applied to the relay module. disabled_pass: str, optional Comma-separated list of passes which needs to be disabled during compilation @@ -275,6 +279,20 @@ def compile_model( compilation. instruments: Optional[Sequence[PassInstrument]] The list of pass instrument implementations. + desired_layout: str, optional + Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph + will have their layout set to this format. Tasks will then be tuned using this + specified layout. + desired_layout_ops: list[str], optional + The list of operators to be transformed with desired layout. + mixed_precision: bool + To enable mixed precision transformation. + mixed_precision_ops: list[str], optional + The list of operators to be converted to mixed precision. + mixed_precision_calculation_type: str + The calculation dtype to be used while mixed precision. + mixed_precision_acc_type: str + The accumulation data type to be used while mixed precision. Returns ------- @@ -304,6 +322,7 @@ def compile_model( disabled_pass=disabled_pass, instruments=instruments, ): + transform_args = parse_graph_transform_args(locals()) mod = apply_graph_transforms(mod, transform_args) for partition_function, opts in zip(partition_functions, partition_opts): diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py index 80799695d99e..fab6c5a146d0 100644 --- a/python/tvm/driver/tvmc/transform.py +++ b/python/tvm/driver/tvmc/transform.py @@ -21,32 +21,34 @@ from tvm import relay, transform from tvm.driver.tvmc import TVMCException -# ToMixedPrecision -ACC_DTYPE = "float32" +def generate_mixed_precision_rule(acc_dtype): + def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): + return [ + relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, + acc_dtype, + mixed_precision_type, + ] -def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): - global ACC_DTYPE - return [ - relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, - ACC_DTYPE, - mixed_precision_type, - ] + return _mixed_precision_rule class MixedPrecision(object): """Temporarily changes attr of ops to enable required precision.""" - def __init__(self, ops): + def __init__(self, ops, acc_type): """Saves the required info for RAII pattern usage. Parameters ---------- ops : list list of operators + acc_type: str + Output or accumulation precision to be used. """ self.older_attr = {} self.ops = ops + self.acc_type = acc_type self.attr_key = "FTVMMixedPrecisionConversionType" def __enter__(self): @@ -54,7 +56,7 @@ def __enter__(self): op = relay.op.get(op_name) self.older_attr[op_name] = op.get_attr(self.attr_key) op.reset_attr(self.attr_key) - op.set_attr(self.attr_key, mixed_precision_rule) + op.set_attr(self.attr_key, generate_mixed_precision_rule(self.acc_type)) return self def __exit__(self, ptype, value, trace): @@ -65,20 +67,18 @@ def __exit__(self, ptype, value, trace): op.set_attr(self.attr_key, self.older_attr[op_name]) -def convert_to_mixed_precision( - mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16" -): +def convert_to_mixed_precision(mod, ops=None, calculation_type="float16", acc_type="float16"): """Converts the operator datatypes Parameters ---------- mod : tvm.IRModule The relay module to convert. - ops : str + ops : list List of operators to be precision converted. - input_type: str + calculation_type: str Input precision to be used. - output_type: str + acc_type: str Output or accumulation precision to be used. Returns @@ -87,10 +87,10 @@ def convert_to_mixed_precision( The converted module. """ - global ACC_DTYPE - ACC_DTYPE = out_type + if ops is None: + ops = ["nn.conv2d", "nn.dense"] - with MixedPrecision(ops.split(",")): + with MixedPrecision(ops, acc_type): seq = transform.Sequential( [relay.transform.InferType(), relay.transform.ToMixedPrecision()] ) @@ -103,7 +103,7 @@ def convert_to_mixed_precision( raise TVMCException("Error converting mixed precision : {0}".format(str(err))) -def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose,qnn.conv2d"): +def convert_graph_layout(mod, desired_layout, ops=None): """Alter the layout of the input graph. Parameters @@ -112,7 +112,7 @@ def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose The relay module to convert. desired_layout : str The layout to convert to. - ops : str + ops : list List of operators to be layout converted. Returns @@ -120,8 +120,10 @@ def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose mod : tvm.IRModule The converted module. """ + if ops is None: + ops = ["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"] - desired_layouts = {op: [desired_layout, "default"] for op in ops.split(",")} + desired_layouts = {op: [desired_layout, "default"] for op in ops} # Convert the layout of the graph where possible. seq = transform.Sequential( @@ -164,9 +166,9 @@ def apply_graph_transforms(mod, args): if args.get("mixed_precision", False): mod = convert_to_mixed_precision( mod, - args.get("mixed_precision_ops", "nn.conv2d,nn.dense"), - args.get("mixed_precision_input", "float16"), - args.get("mixed_precision_output", "float16"), + args.get("mixed_precision_ops"), + args.get("mixed_precision_calculation_type"), + args.get("mixed_precision_acc_type"), ) return mod @@ -176,8 +178,8 @@ def parse_graph_transform_args(args): Parameters ---------- - args: argparse.Namespace - Arguments from command line parser. + args: argparse.Namespace or dict + Arguments. Returns ------- @@ -185,17 +187,18 @@ def parse_graph_transform_args(args): Graph transform arguments """ - args_dict = vars(args) + if not isinstance(args, dict): + args = vars(args) transform_args = [ "desired_layout", "desired_layout_ops", "mixed_precision", "mixed_precision_ops", - "mixed_precision_input", - "mixed_precision_output", + "mixed_precision_calculation_type", + "mixed_precision_acc_type", ] - transform_args = {key: args_dict.get(key, None) for key in transform_args} + transform_args = {key: args.get(key, None) for key in transform_args} return transform_args @@ -211,7 +214,8 @@ def generate_transform_args(parser): ) parser.add_argument( "--desired-layout-ops", - default="nn.conv2d,nn.conv2d_transpose,qnn.conv2d", + default=["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"], + nargs="+", help="List of operators to be layout converted.", ) @@ -223,18 +227,19 @@ def generate_transform_args(parser): ) parser.add_argument( "--mixed-precision-ops", - default="nn.conv2d,nn.dense", + default=["nn.conv2d", "nn.dense"], + nargs="+", help="List of operators to be converted to mixed precision", ) parser.add_argument( - "--mixed-precision-input", + "--mixed-precision-calculation-type", choices=["float16", "float32"], default="float16", - help="Input precision type", + help="Calculation precision type", ) parser.add_argument( - "--mixed-precision-output", + "--mixed-precision-acc-type", choices=["float16", "float32"], default="float16", - help="Output or accumulator precision type", + help="Accumulator precision type", ) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index c086860be62f..6653491a4ab4 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -72,7 +72,7 @@ def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): tvmc_model, target="llvm", dump_code="ll", - transform_args={"desired_layout": "NCHW"}, + desired_layout="NCHW", use_vm=use_vm, ) dumps_path = tvmc_package.package_path + ".ll" @@ -290,9 +290,7 @@ def test_cross_compile_options_aarch64_onnx_module(onnx_resnet50): def verify_compile_paddle_module(model, shape_dict=None): pytest.importorskip("paddle") tvmc_model = tvmc.load(model, "paddle", shape_dict=shape_dict) - tvmc_package = tvmc.compile( - tvmc_model, target="llvm", dump_code="ll", transform_args={"desired_layout": "NCHW"} - ) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW") dumps_path = tvmc_package.package_path + ".ll" # check for output types @@ -374,7 +372,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): tvmc_package = tvmc.compile( tvmc_model, target="opencl -host=llvm", - transform_args={"desired_layout": "NCHW"}, + desired_layout="NCHW", dump_code="asm", ) dumps_path = tvmc_package.package_path + ".asm" diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py index fb7d8b633566..1c13b2e9f815 100644 --- a/tests/python/driver/tvmc/test_transform.py +++ b/tests/python/driver/tvmc/test_transform.py @@ -76,7 +76,7 @@ def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d, monkeypatch mixed precision arguments are provided. """ mock_mixed_precision = MagicMock() - mock_mixed_precision.return_value = tvm.driver.tvmc.transform.MixedPrecision([]) + mock_mixed_precision.return_value = tvm.driver.tvmc.transform.MixedPrecision([], "") monkeypatch.setattr(tvm.driver.tvmc.transform, "MixedPrecision", mock_mixed_precision) with tvm.transform.PassContext(opt_level=3): @@ -84,40 +84,23 @@ def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d, monkeypatch relay_conv2d, { "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float16", }, ) - mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"]) + mock_mixed_precision.assert_called_with(["nn.conv2d"], "float16") apply_graph_transforms( relay_conv2d, { "mixed_precision": True, - "mixed_precision_ops": "nn.conv2d", + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float32", }, ) - mock_mixed_precision.assert_called_with(["nn.conv2d"]) - - apply_graph_transforms( - relay_conv2d, - { - "mixed_precision": True, - "mixed_precision_ops": "nn.conv2d,nn.dense", - "mixed_precision_input": "float16", - "mixed_precision_output": "float16", - }, - ) - mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"]) - - apply_graph_transforms( - relay_conv2d, - { - "mixed_precision": True, - "mixed_precision_ops": "nn.conv2d,nn.dense", - "mixed_precision_input": "float16", - "mixed_precision_output": "float32", - }, - ) - mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"]) + mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"], "float32") if __name__ == "__main__": From 804d5d0b7431adebdd898b9f1f8d122bcb226850 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Fri, 3 Mar 2023 15:21:36 +0530 Subject: [PATCH 3/5] * review comments --- tests/python/driver/tvmc/test_transform.py | 64 +++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py index 1c13b2e9f815..72c7cda6ff1a 100644 --- a/tests/python/driver/tvmc/test_transform.py +++ b/tests/python/driver/tvmc/test_transform.py @@ -19,6 +19,8 @@ import tvm from tvm import relay +from tvm.relay import testing +from tvm.relay.expr_functor import ExprMutator from tvm.ir.instrument import pass_instrument from tvm.driver.tvmc.transform import apply_graph_transforms @@ -70,7 +72,7 @@ def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch): ) -def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d, monkeypatch): +def test_layout_transform_to_mixed_precision_pass_args_mock(relay_conv2d, monkeypatch): """ Check the mixed precision arugments which are expected when mixed precision arguments are provided. @@ -103,5 +105,65 @@ def test_layout_transform_to_mixed_precision_pass_args(relay_conv2d, monkeypatch mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"], "float32") +def test_layout_transform_to_mixed_precision_pass_args_graph(): + """ + Check the mixed precision arugments application with in a graph. + """ + + mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") + + class CheckOpMutator(ExprMutator): + """Inspect Ops According to expected types.""" + + def __init__(self, calculation_type, acc_type, op): + self.calculation_type = calculation_type + self.acc_type = acc_type + self.op = op + self.is_expected = True + super().__init__() + + def visit_call(self, call): + visit = super().visit(call.args[0]) + if call.op == relay.op.get(self.op): + if self.is_expected: + self.is_expected = ( + call.checked_type.dtype == self.acc_type + or call.args[0].checked_type.dtype == self.calculation_type + ) + return call + + def check(self, func): + self.visit(func) + return self.is_expected + + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float16", + }, + ) + ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"]) + assert ret + ret = CheckOpMutator("float16", "float16", "nn.dense").check(mod["main"]) + assert ret + + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float32", + }, + ) + ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"]) + assert ret + ret = CheckOpMutator("float16", "float32", "nn.dense").check(mod["main"]) + assert ret + + if __name__ == "__main__": tvm.testing.main() From 499c86e801f66f26fd39f1d7ae4dfe4386b4ec85 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Sat, 4 Mar 2023 09:51:28 +0530 Subject: [PATCH 4/5] * review --- python/tvm/driver/tvmc/transform.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py index fab6c5a146d0..2b34ba11b49f 100644 --- a/python/tvm/driver/tvmc/transform.py +++ b/python/tvm/driver/tvmc/transform.py @@ -92,7 +92,7 @@ def convert_to_mixed_precision(mod, ops=None, calculation_type="float16", acc_ty with MixedPrecision(ops, acc_type): seq = transform.Sequential( - [relay.transform.InferType(), relay.transform.ToMixedPrecision()] + [relay.transform.InferType(), relay.transform.ToMixedPrecision(calculation_type)] ) with transform.PassContext( config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, opt_level=3 @@ -160,7 +160,9 @@ def apply_graph_transforms(mod, args): # AlterLayout if args.get("desired_layout", False): - mod = convert_graph_layout(mod, args["desired_layout"]) + mod = convert_graph_layout( + mod, args["desired_layout"], args.get("desired_layout_ops", None) + ) # ToMixedPrecision if args.get("mixed_precision", False): From b0379b4d052e8ac94d9fd7819d16c87c8198ea77 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Mon, 6 Mar 2023 14:49:11 +0530 Subject: [PATCH 5/5] * review comments --- python/tvm/driver/tvmc/compiler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 2d392226c505..97eeeb98c76d 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -286,13 +286,14 @@ def compile_model( desired_layout_ops: list[str], optional The list of operators to be transformed with desired layout. mixed_precision: bool - To enable mixed precision transformation. + To enable mixed precision transformation. Disabled by default. mixed_precision_ops: list[str], optional The list of operators to be converted to mixed precision. + Set to ["nn.conv2d", "nn.dense"] by default mixed_precision_calculation_type: str - The calculation dtype to be used while mixed precision. + The calculation dtype to be used while mixed precision. Set to "float16" by default. mixed_precision_acc_type: str - The accumulation data type to be used while mixed precision. + The accumulation data type to be used while mixed precision. Set to "float16" by default. Returns -------