From 425471d732caa15857000a7e295a2d3c9d0fc48f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 13 May 2021 17:48:48 -0700 Subject: [PATCH 01/59] Initial skeleton for fp16 pass. initial green gray and red lists move fp16 conversion to own fodler second pass example split up files a bit more cool nodes bro initial transofmr pass --- .../transform/fp16_conversion/__init__.py | 0 .../fp16_conversion/fp16_op_description.py | 47 ++++ .../transform/fp16_conversion/fp32_to_fp16.py | 233 ++++++++++++++++++ .../transform/fp16_conversion/graph_colors.py | 113 +++++++++ 4 files changed, 393 insertions(+) create mode 100644 python/tvm/relay/transform/fp16_conversion/__init__.py create mode 100644 python/tvm/relay/transform/fp16_conversion/fp16_op_description.py create mode 100644 python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py create mode 100644 python/tvm/relay/transform/fp16_conversion/graph_colors.py diff --git a/python/tvm/relay/transform/fp16_conversion/__init__.py b/python/tvm/relay/transform/fp16_conversion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py b/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py new file mode 100644 index 000000000000..6a3b3e56ae39 --- /dev/null +++ b/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py @@ -0,0 +1,47 @@ +from typing import * + +from tvm import relay +from tvm.relay.transform.fp16_conversion import graph_colors + +FP16OutDtype = NamedTuple("FP16OutDtype", [("accumulation_dtype", str), ("output_dtype", str)]) + + +class DefaultFP16TypeDefinition: + # These fp16 operations accumulate their results in a 32 bit buffer + DEFAULT_FP32_ACCUMULATION_LIST = [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", + ] + + # These fp16 operations return fp32 results. If an operation has + # an fp32 accumulator but is not in this list, it is assumed the accumulator + # is quantized to 16 bits before being used in other operations. + DEFAULT_FP32_OUTPUT_LIST = [] + + SUPPORTED_OPS = { + + } + + def __init__( + self, + fp32_accumulation_ops: List[str] = DEFAULT_FP32_ACCUMULATION_LIST, + fp32_output_ops: List[str] = DEFAULT_FP32_OUTPUT_LIST, + ): + self.fp32_accumulation_ops = set(graph_colors.create_op_list(fp32_accumulation_ops)) + self.fp32_output_ops = set(graph_colors.create_op_list(fp32_output_ops)) + + def __call__(self, call_node: relay.Call) -> FP16OutDtype: + accumulation_dtype = "float32" if call_node.op in self.fp32_accumulation_ops else "float16" + output_dtype = "float32" if call_node.op in self.fp32_output_ops else "float16" + return FP16OutDtype(accumulation_dtype, output_dtype) diff --git a/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py b/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py new file mode 100644 index 000000000000..9ffd560ab9d0 --- /dev/null +++ b/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay type recasting pass""" +from typing import * + +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.testing import resnet +from tvm.relay.transform import InferType +from tvm.relay.transform.fp16_conversion import fp16_op_description, graph_colors + + +class InitialGraphColorer(ExprVisitor): + """Color ops""" + + def __init__(self, color_function: Callable[[relay.Call], graph_colors.ConversionCategory]): + super().__init__() + self.color_function = color_function + self.result_map = {} + + def visit_call(self, call: relay.Call): + self.result_map[call] = self.color_function(call) + super().visit_call(call) + + +class PropagateColors(ExprVisitor): + """Propagate colors outward through gray colored nodes. + + A gray node becomes green if all it's inputs are fp16 or compile time constants (which can be cast at compile time). + Otherwise the node will become red. + """ + + def __init__( + self, + result_map: Dict[relay.Call, graph_colors.ConversionCategory], + output_dtype_function: Callable[[relay.Call], fp16_op_description.FP16OutDtype], + ): + super().__init__() + self.result_map = result_map.copy() + self.output_dtype_function = output_dtype_function + + def visit_call(self, call: relay.Call): + super().visit_call(call) + + if self.result_map[call] != graph_colors.ConversionCategory.GRAY: + return + + is_green = True + for arg in call.args: + is_green = is_green and self.is_fp16_compatible_arg(arg) + + self.result_map[call] = ( + graph_colors.ConversionCategory.GREEN + if is_green + else graph_colors.ConversionCategory.RED + ) + + def is_fp16_compatible_arg(self, arg: relay.Expr) -> bool: + """ + For vars and constants, assume can cast to fp16 always and have constant folding + """ + if isinstance(arg, relay.Var) or isinstance(arg, relay.Constant): + return True + elif isinstance(arg, relay.Call): + return ( + self.output_dtype_function(arg).output_dtype == "float16" + and self.result_map[arg] == graph_colors.ConversionCategory.GREEN + ) + elif isinstance(arg, relay.TupleGetItem): + return self.is_fp16_compatible_arg(arg.tuple_value) + # TODO: propogate through other control flow + else: + raise ValueError(f"Unknown node type {type(arg)} for args") + + +class RewriteBasedOnColors(relay.ExprMutator): + def __init__( + self, + result_map: Dict[relay.Call, graph_colors.ConversionCategory], + fp16_dtype_func: Callable[[relay.Call], fp16_op_description.FP16OutDtype], + ): + super().__init__() + self.result_map = result_map.copy() + self.fp16_dtype_func = fp16_dtype_func + + def visit_call(self, call): + if self.result_map[call] == graph_colors.ConversionCategory.GRAY: + raise ValueError("Rewriting encountered gray! Remember to run PropagateColors pass!") + elif self.result_map[call] == graph_colors.ConversionCategory.RED: + return super().visit_call(call) + + call_op = self.visit(call.op) + args = [self.visit(arg) for arg in call.args] + new_args = [] + for arg in args: + if isinstance(arg, relay.Var) or isinstance(arg, relay.Constant): + new_args.append(relay.cast(arg, "float16")) + elif isinstance(arg, relay.Call): + if self.result_map[arg] == graph_colors.ConversionCategory.GREEN: + arg = ( + arg + if self.fp16_dtype_func(arg).output_dtype == "float16" + else relay.cast(arg, "float16") + ) + else: + arg = relay.cast(arg, "float16") + new_args.append(arg) + else: + new_args.append(arg) + + # TODO: what do we do about operations without control over the accumulation dtype? + fp16_op_output = self.fp16_dtype_func(call) + + if call.attrs is not None and "out_dtype" in call.attrs.keys(): + new_attr_dict = {} + for attr in call.attrs.keys(): + attr_value = call.attrs[attr] + if isinstance(attr_value, tvm.ir.container.Array): + attr_value = tuple(attr_value) + new_attr_dict[str(attr)] = attr_value + new_attr_dict["out_dtype"] = fp16_op_output.accumulation_dtype + attr_type = str(call.attrs).split("(")[0] + new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) + else: + new_attrs = call.attrs + + # Inject proper arg types here based on fp16 op description func + output = relay.Call(call_op, new_args, new_attrs) + + if fp16_op_output.accumulation_dtype != fp16_op_output.output_dtype: + output = relay.cast(output, fp16_op_output.output_dtype) + + self.result_map[output] = self.result_map[call] + return output + + +class PrintVisitor(ExprVisitor): + def __init__(self, result_map: Dict[relay.Call, graph_colors.ConversionCategory]): + super().__init__() + self.result_map = result_map.copy() + + def visit_call(self, call): + super().visit_call(call) + + if call.checked_type == None: + raise ValueError( + "Warning! Could not infer type for f{call.op} operation. Did you run InferType pass?" + ) + + if isinstance(call.checked_type, tvm.ir.tensor_type.TensorType): + # Assume this refers to the output tensor + output_dtype = call.checked_type.dtype + elif isinstance(call.checked_type, tvm.ir.type.TupleType): + output_dtype = call.checked_type.fields[0].dtype + else: + raise ValueError(f"Unknown type {type(call.checked_type)}") + + print(f"Operation {call.op} output dtype {output_dtype}, color {self.result_map[call]}") + + if call.op == relay.op.get("nn.batch_norm"): + pass + elif call.op == relay.op.get("nn.conv2d"): + pass + elif call.op == relay.op.get("nn.relu"): + pass + elif call.op == relay.op.get("add"): + pass + elif call.op == relay.op.get("nn.global_avg_pool2d"): + pass + elif call.op == relay.op.get("nn.batch_flatten"): + pass + elif call.op == relay.op.get("nn.dense"): + pass + elif call.op == relay.op.get("nn.bias_add"): + pass + elif call.op == relay.op.get("nn.softmax"): + pass + else: + raise ValueError(f"Unknown call {call.op}") + + # print() + # import pdb + # pdb.set_trace() + # print(call) + + +if __name__ == "__main__": + c = resnet.get_net(1, 5, num_layers=18, image_shape=(1, 32, 32)) + + infer_type_pass = InferType() + + mod = tvm.IRModule.from_expr(c) + + out = infer_type_pass(mod) + relay_node_out = out["main"].body + + color_func = graph_colors.DefaultColorer() + colorer = InitialGraphColorer(color_func) + colorer.visit(relay_node_out) + + print("Initial color") + visitor = PrintVisitor(colorer.result_map) + visitor.visit(relay_node_out) + + fp16_op_descriptor = fp16_op_description.DefaultFP16TypeDefinition() + propagater = PropagateColors(colorer.result_map, fp16_op_descriptor) + propagater.visit_call(relay_node_out) + + print() + print("After propogate") + visitor = PrintVisitor(propagater.result_map) + visitor.visit(relay_node_out) + + rewriter = RewriteBasedOnColors(visitor.result_map, fp16_op_descriptor) + out = rewriter.visit_call(relay_node_out) + import pdb + + pdb.set_trace() diff --git a/python/tvm/relay/transform/fp16_conversion/graph_colors.py b/python/tvm/relay/transform/fp16_conversion/graph_colors.py new file mode 100644 index 000000000000..855e93e8b6a1 --- /dev/null +++ b/python/tvm/relay/transform/fp16_conversion/graph_colors.py @@ -0,0 +1,113 @@ +import enum +from typing import * + +import tvm +from tvm import relay + + +def create_op_list(op_list: List[str]) -> List[tvm.ir.Op]: + return [relay.op.get(op_name) for op_name in op_list] + + +class ConversionCategory(enum.Enum): + """ + Green: will cast to fp16 version of the op which takes in fp16 inputs + Gray: may cast after doing analysis + Red: will not cast to fp16 version + """ + + GREEN = "Green" + GRAY = "Gray" + RED = "Red" + + +class DefaultColorer: + # Default lists inspired from TF's classifications: + # https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h + # They might have a bias toward NVidia's Tensor Cores so be aware and modify lists per your hardware choice. + + # These should always be done in fp16 if possible + DEFAULT_GREEN_LIST = { + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + } + + # These can be done in fp16 or fp32 with no point in casting between + DEFAULT_GRAY_LIST = { + # These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + # Simple arithmetic + "add", + "nn.bias_add", + "nn.batch_norm", + # Simple activations + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + # Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + ## "nn.global_max_pool1d", # does not exist + "nn.global_max_pool2d", + ## "nn.global_max_pool3d", # does not exist + ## "nn.global_avg_pool1d", # does not exist + "nn.global_avg_pool2d", + ## "nn.global_avg_pool3d", # does not exist + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", + } + + # These should always be done in fp32 + DEFAULT_RED_LIST = { + # Activations with exponents or division + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + # Other + "nn.l2_normalize", + } + + def __init__( + self, + green_list: List[str] = DEFAULT_GREEN_LIST, + gray_list: List[str] = DEFAULT_GRAY_LIST, + red_list: List[str] = DEFAULT_RED_LIST, + ): + # Convert each list to entry + green_list = create_op_list(green_list) + gray_list = create_op_list(gray_list) + red_list = create_op_list(red_list) + + # Create lookup table mapping relay op -> color in grpah + self.lookup_table = {} + for op_list, val in [ + (green_list, ConversionCategory.GREEN), + (gray_list, ConversionCategory.GRAY), + (red_list, ConversionCategory.RED), + ]: + for op in op_list: + self.lookup_table[op] = val + + def __call__(self, call_node: relay.Call, ignore_missing: bool = False) -> ConversionCategory: + if call_node.op not in self.lookup_table: + if ignore_missing: + return ConversionCategory.RED + else: + raise ValueError(f"Unknown op {call_node.op}") + + return self.lookup_table[call_node.op] From 2bd53119fb03145a72b403878763694a8d999455 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 17 May 2021 11:49:45 -0400 Subject: [PATCH 02/59] Working python version of fp16 pass. fix topi conv2d not casting kernel to output type working resnet, but conv2d topi intrinsics need work tests for resnet add more tests, extend coverage for converter update tests, ensure red ops convert back to fp32 clean up code a bit simplify fp16 output dtype examination fix pass update tests initial coloring --- .../fp16_conversion/fp16_op_description.py | 48 +---- .../transform/fp16_conversion/fp32_to_fp16.py | 192 +++++++++--------- .../transform/fp16_conversion/graph_colors.py | 1 + python/tvm/topi/nn/conv2d.py | 10 +- src/relay/transforms/fp32_to_fp16.cc | 57 ++++++ src/relay/transforms/fp32_to_fp16.h | 116 +++++++++++ .../relay/test_fp32_to_fp16_transform.py | 70 +++++++ 7 files changed, 362 insertions(+), 132 deletions(-) create mode 100644 src/relay/transforms/fp32_to_fp16.cc create mode 100644 src/relay/transforms/fp32_to_fp16.h create mode 100644 tests/python/relay/test_fp32_to_fp16_transform.py diff --git a/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py b/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py index 6a3b3e56ae39..19020573e09f 100644 --- a/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py +++ b/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py @@ -3,45 +3,19 @@ from tvm import relay from tvm.relay.transform.fp16_conversion import graph_colors -FP16OutDtype = NamedTuple("FP16OutDtype", [("accumulation_dtype", str), ("output_dtype", str)]) +FP16OutDtype = NamedTuple("FP16OutDtype", [("accumulation_dtype", Optional[str]), ("output_dtype", str)]) class DefaultFP16TypeDefinition: - # These fp16 operations accumulate their results in a 32 bit buffer - DEFAULT_FP32_ACCUMULATION_LIST = [ - "nn.conv1d", - "nn.conv2d", - "nn.conv3d", - "nn.conv1d_transpose", - "nn.conv2d_transpose", - "nn.conv3d_transpose", - "nn.dense", - "nn.avg_pool1d", - "nn.avg_pool2d", - "nn.avg_pool3d", - "nn.adaptive_avg_pool1d", - "nn.adaptive_avg_pool2d", - "nn.adaptive_avg_pool3d", - ] - - # These fp16 operations return fp32 results. If an operation has - # an fp32 accumulator but is not in this list, it is assumed the accumulator - # is quantized to 16 bits before being used in other operations. - DEFAULT_FP32_OUTPUT_LIST = [] - - SUPPORTED_OPS = { - - } - - def __init__( - self, - fp32_accumulation_ops: List[str] = DEFAULT_FP32_ACCUMULATION_LIST, - fp32_output_ops: List[str] = DEFAULT_FP32_OUTPUT_LIST, - ): - self.fp32_accumulation_ops = set(graph_colors.create_op_list(fp32_accumulation_ops)) - self.fp32_output_ops = set(graph_colors.create_op_list(fp32_output_ops)) + """By default we assume every node wants the final output to be float16. + + If the relay node supports out_dtype we try to accumulate the fp32 before + casting beack. + """ def __call__(self, call_node: relay.Call) -> FP16OutDtype: - accumulation_dtype = "float32" if call_node.op in self.fp32_accumulation_ops else "float16" - output_dtype = "float32" if call_node.op in self.fp32_output_ops else "float16" - return FP16OutDtype(accumulation_dtype, output_dtype) + if call_node.attrs is not None and hasattr(call_node.attrs, "out_dtype"): + # Assume for now we always accumulate into fp32 if given the option + return FP16OutDtype(accumulation_dtype="float32", output_dtype="float16") + else: + return FP16OutDtype(accumulation_dtype=None, output_dtype="float16") diff --git a/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py b/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py index 9ffd560ab9d0..1ca91132630c 100644 --- a/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py +++ b/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py @@ -17,6 +17,7 @@ """Relay type recasting pass""" from typing import * +import numpy as np import tvm from tvm import relay from tvm.relay.expr_functor import ExprVisitor @@ -26,7 +27,11 @@ class InitialGraphColorer(ExprVisitor): - """Color ops""" + """Color relay Call operations green, gray or red via a given color function. + + These determine whether the operation will be operating on inputs in the fp16 or + fp32 space. + """ def __init__(self, color_function: Callable[[relay.Call], graph_colors.ConversionCategory]): super().__init__() @@ -56,34 +61,43 @@ def __init__( def visit_call(self, call: relay.Call): super().visit_call(call) - if self.result_map[call] != graph_colors.ConversionCategory.GRAY: return - is_green = True for arg in call.args: - is_green = is_green and self.is_fp16_compatible_arg(arg) + if not self.is_fp16_compatible_arg(arg): + self.result_map[call] = graph_colors.ConversionCategory.RED - self.result_map[call] = ( - graph_colors.ConversionCategory.GREEN - if is_green - else graph_colors.ConversionCategory.RED - ) + self.result_map[call] = graph_colors.ConversionCategory.GREEN def is_fp16_compatible_arg(self, arg: relay.Expr) -> bool: """ - For vars and constants, assume can cast to fp16 always and have constant folding + Returns whether the argument is either fp16 after conversion, or is a constant that + can be cast before runtime. + + For vars and constants, assume we can always safely cast to fp16. """ if isinstance(arg, relay.Var) or isinstance(arg, relay.Constant): return True elif isinstance(arg, relay.Call): + # A call result is fp16 if we will replace it with an fp16 operation + # (i.e. it is green) and the output will be fp16. return ( - self.output_dtype_function(arg).output_dtype == "float16" - and self.result_map[arg] == graph_colors.ConversionCategory.GREEN + self.result_map[arg] == graph_colors.ConversionCategory.GREEN + and self.output_dtype_function(arg).output_dtype == "float16" ) elif isinstance(arg, relay.TupleGetItem): return self.is_fp16_compatible_arg(arg.tuple_value) - # TODO: propogate through other control flow + elif isinstance(arg, relay.Tuple): + for ele in arg: + if not self.is_fp16_compatible_arg(ele): + return False + return True + elif isinstance(arg, relay.If): + return self.is_fp16_compatible_arg(arg.true_branch) and self.is_fp16_compatible_arg( + arg.false_branch + ) + # TODO: pass through other control flow else: raise ValueError(f"Unknown node type {type(arg)} for args") @@ -98,63 +112,85 @@ def __init__( self.result_map = result_map.copy() self.fp16_dtype_func = fp16_dtype_func - def visit_call(self, call): - if self.result_map[call] == graph_colors.ConversionCategory.GRAY: + def visit_let(self, let: relay.Let) -> relay.Expr: + raise ValueError("We don't support let bindings in this pass yet.") + + def visit_call(self, call: relay.Call) -> relay.Expr: + # Based on color, determine what dtype we want arguments of the Call to be + if self.result_map[call] == graph_colors.ConversionCategory.RED: + arg_cast_type = "float32" + elif self.result_map[call] == graph_colors.ConversionCategory.GREEN: + arg_cast_type = "float16" + elif self.result_map[call] == graph_colors.ConversionCategory.GRAY: raise ValueError("Rewriting encountered gray! Remember to run PropagateColors pass!") - elif self.result_map[call] == graph_colors.ConversionCategory.RED: - return super().visit_call(call) + else: + raise ValueError(f"Unknown coloring {self.result_map[call]}") call_op = self.visit(call.op) + + # Create new args and attrs taking into account the datatype + new_args = self.get_new_args(call, arg_cast_type) + new_attrs = self.get_new_attrs(call, arg_cast_type) + output = relay.Call(call_op, new_args, new_attrs) + self.result_map[output] = self.result_map[call] + + fp16_op_output = self.fp16_dtype_func(call) + if ( + fp16_op_output.accumulation_dtype is not None + and fp16_op_output.accumulation_dtype != fp16_op_output.output_dtype + ): + output = relay.cast(output, fp16_op_output.output_dtype) + self.result_map[output] = self.result_map[call] + + return output + + def get_new_args(self, call: relay.Call, arg_cast_type: str) -> List[relay.Expr]: args = [self.visit(arg) for arg in call.args] new_args = [] for arg in args: if isinstance(arg, relay.Var) or isinstance(arg, relay.Constant): - new_args.append(relay.cast(arg, "float16")) + # Assume all vars and consts are by default fp32 + new_args.append(relay.cast(arg, "float16") if arg_cast_type == "float16" else arg) elif isinstance(arg, relay.Call): - if self.result_map[arg] == graph_colors.ConversionCategory.GREEN: - arg = ( - arg - if self.fp16_dtype_func(arg).output_dtype == "float16" - else relay.cast(arg, "float16") - ) + if ( + self.result_map[arg] == graph_colors.ConversionCategory.GREEN + and self.fp16_dtype_func(arg).output_dtype == "float16" + ): + arg = arg if arg_cast_type == "float16" else relay.cast(arg, "float32") else: - arg = relay.cast(arg, "float16") + arg = relay.cast(arg, arg_cast_type) new_args.append(arg) else: new_args.append(arg) - - # TODO: what do we do about operations without control over the accumulation dtype? - fp16_op_output = self.fp16_dtype_func(call) - - if call.attrs is not None and "out_dtype" in call.attrs.keys(): + return new_args + + def get_new_attrs(self, call: relay.Call, arg_cast_type: str) -> tvm.ir.Node: + # Create a new attrs node which overwrites the output type if it's a field + if ( + call.attrs is not None + and "out_dtype" in call.attrs.keys() + and arg_cast_type == "float16" + ): new_attr_dict = {} for attr in call.attrs.keys(): attr_value = call.attrs[attr] if isinstance(attr_value, tvm.ir.container.Array): attr_value = tuple(attr_value) new_attr_dict[str(attr)] = attr_value - new_attr_dict["out_dtype"] = fp16_op_output.accumulation_dtype + new_attr_dict["out_dtype"] = self.fp16_dtype_func(call).accumulation_dtype attr_type = str(call.attrs).split("(")[0] - new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) - else: - new_attrs = call.attrs - - # Inject proper arg types here based on fp16 op description func - output = relay.Call(call_op, new_args, new_attrs) - - if fp16_op_output.accumulation_dtype != fp16_op_output.output_dtype: - output = relay.cast(output, fp16_op_output.output_dtype) - - self.result_map[output] = self.result_map[call] - return output + return tvm.ir.make_node(attr_type, **new_attr_dict) + return call.attrs class PrintVisitor(ExprVisitor): + """Used for debugging. Prints the name, original output type, and color of nodes.""" + def __init__(self, result_map: Dict[relay.Call, graph_colors.ConversionCategory]): super().__init__() self.result_map = result_map.copy() - def visit_call(self, call): + def visit_call(self, call: relay.Call): super().visit_call(call) if call.checked_type == None: @@ -172,62 +208,34 @@ def visit_call(self, call): print(f"Operation {call.op} output dtype {output_dtype}, color {self.result_map[call]}") - if call.op == relay.op.get("nn.batch_norm"): - pass - elif call.op == relay.op.get("nn.conv2d"): - pass - elif call.op == relay.op.get("nn.relu"): - pass - elif call.op == relay.op.get("add"): - pass - elif call.op == relay.op.get("nn.global_avg_pool2d"): - pass - elif call.op == relay.op.get("nn.batch_flatten"): - pass - elif call.op == relay.op.get("nn.dense"): - pass - elif call.op == relay.op.get("nn.bias_add"): - pass - elif call.op == relay.op.get("nn.softmax"): - pass - else: - raise ValueError(f"Unknown call {call.op}") - - # print() - # import pdb - # pdb.set_trace() - # print(call) - - -if __name__ == "__main__": - c = resnet.get_net(1, 5, num_layers=18, image_shape=(1, 32, 32)) - - infer_type_pass = InferType() - - mod = tvm.IRModule.from_expr(c) - out = infer_type_pass(mod) - relay_node_out = out["main"].body +def quantize_to_fp16(body: relay.Expr, debug: bool = False) -> relay.Expr: + if debug: + mod = tvm.ir.IRModule.from_expr(body) + infer_type_pass = InferType() + out = infer_type_pass(mod) + body = out["main"].body color_func = graph_colors.DefaultColorer() colorer = InitialGraphColorer(color_func) - colorer.visit(relay_node_out) + colorer.visit(body) - print("Initial color") - visitor = PrintVisitor(colorer.result_map) - visitor.visit(relay_node_out) + if debug: + print("Initial color") + visitor = PrintVisitor(colorer.result_map) + visitor.visit(body) fp16_op_descriptor = fp16_op_description.DefaultFP16TypeDefinition() propagater = PropagateColors(colorer.result_map, fp16_op_descriptor) - propagater.visit_call(relay_node_out) + propagater.visit_call(body) - print() - print("After propogate") - visitor = PrintVisitor(propagater.result_map) - visitor.visit(relay_node_out) + if debug: + print() + print("After propogate") + visitor = PrintVisitor(propagater.result_map) + visitor.visit(body) - rewriter = RewriteBasedOnColors(visitor.result_map, fp16_op_descriptor) - out = rewriter.visit_call(relay_node_out) - import pdb + rewriter = RewriteBasedOnColors(propagater.result_map, fp16_op_descriptor) + out = rewriter.visit_call(body) - pdb.set_trace() + return out diff --git a/python/tvm/relay/transform/fp16_conversion/graph_colors.py b/python/tvm/relay/transform/fp16_conversion/graph_colors.py index 855e93e8b6a1..38aaba6192f4 100644 --- a/python/tvm/relay/transform/fp16_conversion/graph_colors.py +++ b/python/tvm/relay/transform/fp16_conversion/graph_colors.py @@ -42,6 +42,7 @@ class DefaultColorer: # These ops add new data or change shape "nn.pad", "nn.batch_flatten", + "concatenate", # Simple arithmetic "add", "nn.bias_add", diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 130eb4b69844..3f72bdc4b667 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -18,13 +18,15 @@ # pylint: disable=unused-argument, redefined-builtin """Conv2D operators""" from __future__ import absolute_import as _abs + from collections import namedtuple + import tvm -from tvm import te, auto_scheduler +from tvm import auto_scheduler, te +from ..utils import get_const_int, get_const_tuple, simplify, tag from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify, get_const_tuple, get_const_int, tag from .winograd_util import winograd_transform_matrices # workload description of conv2d @@ -548,7 +550,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ow * WSTR + kw * dilation_w, idxmod(ic, ic_bn), ].astype(out_dtype) - * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block], + * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype( + out_dtype + ), axis=[ic, kh, kw], ), name="conv2d_NCHWc", diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc new file mode 100644 index 000000000000..360ab5219daa --- /dev/null +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -0,0 +1,57 @@ + +#include "fp32_to_fp16.h" + +#include +#include +#include + +namespace tvm { +namespace relay { + +using CallColorMap = std::unordered_map; + +class GraphColorer : private ExprVisitor { + using ColorFunc = std::function; + + private: + CallColorMap color_map; + ColorFunc func; + + void VisitExpr_(const CallNode* l) final { + // FP16ConversionCategory c = func(l); + color_map[l] = func(l); + ExprVisitor::VisitExpr_(l); + } + + public: + GraphColorer(ColorFunc func = DefaultColorer()) : func(func) {} + + CallColorMap result() { return color_map; } + void VisitExpr(const Expr& expr) { ExprVisitor::VisitExpr(expr); } +}; + +class ColorPrinter : private ExprVisitor { + private: + CallColorMap color_map; + + public: + explicit ColorPrinter(CallColorMap& color_map) : color_map(color_map) {} + explicit ColorPrinter() {} + void VisitExpr(const Expr& expr) { ExprVisitor::VisitExpr(expr); } + + void VisitExpr_(const CallNode* l) final { + ExprVisitor::VisitExpr_(l); + std::cout << l->op << " is " << conversion_category_strings[color_map[l]] << std::endl; + } +}; + +void PrintColors(const Expr& expr) { + GraphColorer initial_colorer = GraphColorer(); + initial_colorer.VisitExpr(expr); + CallColorMap color_map = initial_colorer.result(); + ColorPrinter(color_map).VisitExpr(expr); +} +TVM_REGISTER_GLOBAL("relay._transform.PrintColorsExpr").set_body_typed(PrintColors); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h new file mode 100644 index 000000000000..18198fec5c7c --- /dev/null +++ b/src/relay/transforms/fp32_to_fp16.h @@ -0,0 +1,116 @@ +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +enum FP16ConversionCategory { RED, GRAY, GREEN }; +std::unordered_map conversion_category_strings({{RED, "Red"}, + {GRAY, "Gray"}, + {GREEN, "Green"}}); + +using OpStringSet = std::unordered_set; + +// Default lists inspired from TF's classifications: +// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +// They might have a bias toward NVidia's Tensor Cores so be aware and modify lists per your +// hardware choice. +OpStringSet DEFAULT_GREEN_LIST({ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", +}); +OpStringSet DEFAULT_GRAY_LIST({ + // These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + "concatenate", + // Simple arithmetic + "add", + "nn.bias_add", + "nn.batch_norm", + // Simple activations + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + // Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + // "nn.global_max_pool1d", // does not exist + "nn.global_max_pool2d", + // "nn.global_max_pool3d", // does not exist + // "nn.global_avg_pool1d", // does not exist + "nn.global_avg_pool2d", + // "nn.global_avg_pool3d", // does not exist + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", +}); +OpStringSet DEFAULT_RED_LIST({ + // Activations with exponents or division + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + // Other + "nn.l2_normalize", +}); + +class DefaultColorer { + private: + std::unordered_map op_to_initial_color; + + public: + DefaultColorer(OpStringSet red_list = DEFAULT_RED_LIST, OpStringSet gray_list = DEFAULT_GRAY_LIST, + OpStringSet green_list = DEFAULT_GREEN_LIST) { + std::vector> lists_and_colors{ + {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}}; + + for (auto list_and_color : lists_and_colors) { + OpStringSet ops = list_and_color.first; + FP16ConversionCategory color = list_and_color.second; + for (std::string op_name : ops) { + op_to_initial_color.insert({{op_name, color}}); + } + } + } + + FP16ConversionCategory operator()(const tvm::relay::CallNode* call, bool ignore_missing = false) { + auto* op_node = (call->op).as(); + if (op_node == nullptr) { + throw std::invalid_argument("FP16 conversion only supports call nodes with op calls."); + } + + std::string op_name = op_node->name; + auto color = op_to_initial_color.find(op_name); + + if (color == op_to_initial_color.end()) { + if (ignore_missing) { + return RED; + } else { + throw std::invalid_argument("Op name " + op_name + " not in included lists!."); + } + } + + return color->second; + } +}; + +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py new file mode 100644 index 000000000000..b3f91eceba30 --- /dev/null +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -0,0 +1,70 @@ +from typing import * + +import numpy as np +import tvm +from numpy.lib.type_check import imag +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.testing import densenet, mobilenet, resnet, resnet_3d, squeezenet +from tvm.relay.transform import InferType +from tvm.relay.transform.fp16_conversion import fp32_to_fp16 + + +def run_module(mod, mod_params): + dev = tvm.device("llvm", 0) + intrp = relay.create_executor("debug", mod, device=dev, target="llvm") + # in_data = [tvm.nd.array(value) for value in in_data.values()] + return intrp.evaluate()(**mod_params).asnumpy() + + +def verify_fp32_fp16_output_close(mod, mod_params): + result_fp32 = run_module(mod, mod_params) + + fp16 = fp32_to_fp16.quantize_to_fp16(mod["main"].body) + fp16_mod = tvm.ir.IRModule.from_expr(fp16) + result_fp16 = run_module(fp16_mod, mod_params) + + # Ensure the results are close + np.testing.assert_allclose(result_fp32, result_fp16, rtol=1e-3) + +def test_resnet18(): + np.random.seed(4321) + mod, mod_params = resnet.get_workload(1, 5, num_layers=18, image_shape=(1, 32, 32)) + mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") + + verify_fp32_fp16_output_close(mod, mod_params) + + +def test_resnet18_3d(): + np.random.seed(3215) + mod, mod_params = resnet_3d.get_workload(1, 5, num_layers=18, image_shape=(1, 3, 32, 32)) + mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 3, 32, 32)).astype("float32") + + verify_fp32_fp16_output_close(mod, mod_params) + + +def test_mobilenet(): + np.random.seed(4615) + + mod, mod_params = mobilenet.get_workload(1, 5, image_shape=(1, 32, 32)) + mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") + + verify_fp32_fp16_output_close(mod, mod_params) + + +def test_densenet(): + np.random.seed(3222) + mod, mod_params = densenet.get_workload(classes=5, batch_size=1, image_shape=(1, 224, 224)) + mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 224, 224)).astype("float32") + + verify_fp32_fp16_output_close(mod, mod_params) + + +def test_squeezenet(): + np.random.seed(5628) + mod, mod_params = squeezenet.get_workload(1, 5, image_shape=(1, 32, 32)) + mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") + + verify_fp32_fp16_output_close(mod, mod_params) + +#def test_ From 9fda090311a6d2a058dbccef09edfa078ec66803 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 19 May 2021 16:07:13 -0400 Subject: [PATCH 03/59] Rewrite python passes in C++ inspect arg fields add propagate colors pass" private -> public inheritance" rewrite draft full transformation in c++ remove prints fp16 pass the proper wrapping insert extra cast to pass type checking fix previously broken test by removing cast in wrong scenario remove old python_files --- .../transform/fp16_conversion/__init__.py | 0 .../fp16_conversion/fp16_op_description.py | 21 -- .../transform/fp16_conversion/fp32_to_fp16.py | 241 ----------------- .../transform/fp16_conversion/graph_colors.py | 114 -------- python/tvm/relay/transform/transform.py | 14 +- src/relay/transforms/fp32_to_fp16.cc | 249 +++++++++++++++++- src/relay/transforms/fp32_to_fp16.h | 32 ++- .../relay/test_fp32_to_fp16_transform.py | 16 +- 8 files changed, 279 insertions(+), 408 deletions(-) delete mode 100644 python/tvm/relay/transform/fp16_conversion/__init__.py delete mode 100644 python/tvm/relay/transform/fp16_conversion/fp16_op_description.py delete mode 100644 python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py delete mode 100644 python/tvm/relay/transform/fp16_conversion/graph_colors.py diff --git a/python/tvm/relay/transform/fp16_conversion/__init__.py b/python/tvm/relay/transform/fp16_conversion/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py b/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py deleted file mode 100644 index 19020573e09f..000000000000 --- a/python/tvm/relay/transform/fp16_conversion/fp16_op_description.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import * - -from tvm import relay -from tvm.relay.transform.fp16_conversion import graph_colors - -FP16OutDtype = NamedTuple("FP16OutDtype", [("accumulation_dtype", Optional[str]), ("output_dtype", str)]) - - -class DefaultFP16TypeDefinition: - """By default we assume every node wants the final output to be float16. - - If the relay node supports out_dtype we try to accumulate the fp32 before - casting beack. - """ - - def __call__(self, call_node: relay.Call) -> FP16OutDtype: - if call_node.attrs is not None and hasattr(call_node.attrs, "out_dtype"): - # Assume for now we always accumulate into fp32 if given the option - return FP16OutDtype(accumulation_dtype="float32", output_dtype="float16") - else: - return FP16OutDtype(accumulation_dtype=None, output_dtype="float16") diff --git a/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py b/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py deleted file mode 100644 index 1ca91132630c..000000000000 --- a/python/tvm/relay/transform/fp16_conversion/fp32_to_fp16.py +++ /dev/null @@ -1,241 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relay type recasting pass""" -from typing import * - -import numpy as np -import tvm -from tvm import relay -from tvm.relay.expr_functor import ExprVisitor -from tvm.relay.testing import resnet -from tvm.relay.transform import InferType -from tvm.relay.transform.fp16_conversion import fp16_op_description, graph_colors - - -class InitialGraphColorer(ExprVisitor): - """Color relay Call operations green, gray or red via a given color function. - - These determine whether the operation will be operating on inputs in the fp16 or - fp32 space. - """ - - def __init__(self, color_function: Callable[[relay.Call], graph_colors.ConversionCategory]): - super().__init__() - self.color_function = color_function - self.result_map = {} - - def visit_call(self, call: relay.Call): - self.result_map[call] = self.color_function(call) - super().visit_call(call) - - -class PropagateColors(ExprVisitor): - """Propagate colors outward through gray colored nodes. - - A gray node becomes green if all it's inputs are fp16 or compile time constants (which can be cast at compile time). - Otherwise the node will become red. - """ - - def __init__( - self, - result_map: Dict[relay.Call, graph_colors.ConversionCategory], - output_dtype_function: Callable[[relay.Call], fp16_op_description.FP16OutDtype], - ): - super().__init__() - self.result_map = result_map.copy() - self.output_dtype_function = output_dtype_function - - def visit_call(self, call: relay.Call): - super().visit_call(call) - if self.result_map[call] != graph_colors.ConversionCategory.GRAY: - return - - for arg in call.args: - if not self.is_fp16_compatible_arg(arg): - self.result_map[call] = graph_colors.ConversionCategory.RED - - self.result_map[call] = graph_colors.ConversionCategory.GREEN - - def is_fp16_compatible_arg(self, arg: relay.Expr) -> bool: - """ - Returns whether the argument is either fp16 after conversion, or is a constant that - can be cast before runtime. - - For vars and constants, assume we can always safely cast to fp16. - """ - if isinstance(arg, relay.Var) or isinstance(arg, relay.Constant): - return True - elif isinstance(arg, relay.Call): - # A call result is fp16 if we will replace it with an fp16 operation - # (i.e. it is green) and the output will be fp16. - return ( - self.result_map[arg] == graph_colors.ConversionCategory.GREEN - and self.output_dtype_function(arg).output_dtype == "float16" - ) - elif isinstance(arg, relay.TupleGetItem): - return self.is_fp16_compatible_arg(arg.tuple_value) - elif isinstance(arg, relay.Tuple): - for ele in arg: - if not self.is_fp16_compatible_arg(ele): - return False - return True - elif isinstance(arg, relay.If): - return self.is_fp16_compatible_arg(arg.true_branch) and self.is_fp16_compatible_arg( - arg.false_branch - ) - # TODO: pass through other control flow - else: - raise ValueError(f"Unknown node type {type(arg)} for args") - - -class RewriteBasedOnColors(relay.ExprMutator): - def __init__( - self, - result_map: Dict[relay.Call, graph_colors.ConversionCategory], - fp16_dtype_func: Callable[[relay.Call], fp16_op_description.FP16OutDtype], - ): - super().__init__() - self.result_map = result_map.copy() - self.fp16_dtype_func = fp16_dtype_func - - def visit_let(self, let: relay.Let) -> relay.Expr: - raise ValueError("We don't support let bindings in this pass yet.") - - def visit_call(self, call: relay.Call) -> relay.Expr: - # Based on color, determine what dtype we want arguments of the Call to be - if self.result_map[call] == graph_colors.ConversionCategory.RED: - arg_cast_type = "float32" - elif self.result_map[call] == graph_colors.ConversionCategory.GREEN: - arg_cast_type = "float16" - elif self.result_map[call] == graph_colors.ConversionCategory.GRAY: - raise ValueError("Rewriting encountered gray! Remember to run PropagateColors pass!") - else: - raise ValueError(f"Unknown coloring {self.result_map[call]}") - - call_op = self.visit(call.op) - - # Create new args and attrs taking into account the datatype - new_args = self.get_new_args(call, arg_cast_type) - new_attrs = self.get_new_attrs(call, arg_cast_type) - output = relay.Call(call_op, new_args, new_attrs) - self.result_map[output] = self.result_map[call] - - fp16_op_output = self.fp16_dtype_func(call) - if ( - fp16_op_output.accumulation_dtype is not None - and fp16_op_output.accumulation_dtype != fp16_op_output.output_dtype - ): - output = relay.cast(output, fp16_op_output.output_dtype) - self.result_map[output] = self.result_map[call] - - return output - - def get_new_args(self, call: relay.Call, arg_cast_type: str) -> List[relay.Expr]: - args = [self.visit(arg) for arg in call.args] - new_args = [] - for arg in args: - if isinstance(arg, relay.Var) or isinstance(arg, relay.Constant): - # Assume all vars and consts are by default fp32 - new_args.append(relay.cast(arg, "float16") if arg_cast_type == "float16" else arg) - elif isinstance(arg, relay.Call): - if ( - self.result_map[arg] == graph_colors.ConversionCategory.GREEN - and self.fp16_dtype_func(arg).output_dtype == "float16" - ): - arg = arg if arg_cast_type == "float16" else relay.cast(arg, "float32") - else: - arg = relay.cast(arg, arg_cast_type) - new_args.append(arg) - else: - new_args.append(arg) - return new_args - - def get_new_attrs(self, call: relay.Call, arg_cast_type: str) -> tvm.ir.Node: - # Create a new attrs node which overwrites the output type if it's a field - if ( - call.attrs is not None - and "out_dtype" in call.attrs.keys() - and arg_cast_type == "float16" - ): - new_attr_dict = {} - for attr in call.attrs.keys(): - attr_value = call.attrs[attr] - if isinstance(attr_value, tvm.ir.container.Array): - attr_value = tuple(attr_value) - new_attr_dict[str(attr)] = attr_value - new_attr_dict["out_dtype"] = self.fp16_dtype_func(call).accumulation_dtype - attr_type = str(call.attrs).split("(")[0] - return tvm.ir.make_node(attr_type, **new_attr_dict) - return call.attrs - - -class PrintVisitor(ExprVisitor): - """Used for debugging. Prints the name, original output type, and color of nodes.""" - - def __init__(self, result_map: Dict[relay.Call, graph_colors.ConversionCategory]): - super().__init__() - self.result_map = result_map.copy() - - def visit_call(self, call: relay.Call): - super().visit_call(call) - - if call.checked_type == None: - raise ValueError( - "Warning! Could not infer type for f{call.op} operation. Did you run InferType pass?" - ) - - if isinstance(call.checked_type, tvm.ir.tensor_type.TensorType): - # Assume this refers to the output tensor - output_dtype = call.checked_type.dtype - elif isinstance(call.checked_type, tvm.ir.type.TupleType): - output_dtype = call.checked_type.fields[0].dtype - else: - raise ValueError(f"Unknown type {type(call.checked_type)}") - - print(f"Operation {call.op} output dtype {output_dtype}, color {self.result_map[call]}") - - -def quantize_to_fp16(body: relay.Expr, debug: bool = False) -> relay.Expr: - if debug: - mod = tvm.ir.IRModule.from_expr(body) - infer_type_pass = InferType() - out = infer_type_pass(mod) - body = out["main"].body - - color_func = graph_colors.DefaultColorer() - colorer = InitialGraphColorer(color_func) - colorer.visit(body) - - if debug: - print("Initial color") - visitor = PrintVisitor(colorer.result_map) - visitor.visit(body) - - fp16_op_descriptor = fp16_op_description.DefaultFP16TypeDefinition() - propagater = PropagateColors(colorer.result_map, fp16_op_descriptor) - propagater.visit_call(body) - - if debug: - print() - print("After propogate") - visitor = PrintVisitor(propagater.result_map) - visitor.visit(body) - - rewriter = RewriteBasedOnColors(propagater.result_map, fp16_op_descriptor) - out = rewriter.visit_call(body) - - return out diff --git a/python/tvm/relay/transform/fp16_conversion/graph_colors.py b/python/tvm/relay/transform/fp16_conversion/graph_colors.py deleted file mode 100644 index 38aaba6192f4..000000000000 --- a/python/tvm/relay/transform/fp16_conversion/graph_colors.py +++ /dev/null @@ -1,114 +0,0 @@ -import enum -from typing import * - -import tvm -from tvm import relay - - -def create_op_list(op_list: List[str]) -> List[tvm.ir.Op]: - return [relay.op.get(op_name) for op_name in op_list] - - -class ConversionCategory(enum.Enum): - """ - Green: will cast to fp16 version of the op which takes in fp16 inputs - Gray: may cast after doing analysis - Red: will not cast to fp16 version - """ - - GREEN = "Green" - GRAY = "Gray" - RED = "Red" - - -class DefaultColorer: - # Default lists inspired from TF's classifications: - # https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h - # They might have a bias toward NVidia's Tensor Cores so be aware and modify lists per your hardware choice. - - # These should always be done in fp16 if possible - DEFAULT_GREEN_LIST = { - "nn.conv1d", - "nn.conv2d", - "nn.conv3d", - "nn.conv1d_transpose", - "nn.conv2d_transpose", - "nn.conv3d_transpose", - "nn.dense", - } - - # These can be done in fp16 or fp32 with no point in casting between - DEFAULT_GRAY_LIST = { - # These ops add new data or change shape - "nn.pad", - "nn.batch_flatten", - "concatenate", - # Simple arithmetic - "add", - "nn.bias_add", - "nn.batch_norm", - # Simple activations - "nn.relu", - "nn.leaky_relu", - "nn.prelu", - "nn.dropout", - # Pooling operations - "nn.max_pool1d", - "nn.max_pool2d", - "nn.max_pool3d", - "nn.avg_pool1d", - "nn.avg_pool2d", - "nn.avg_pool3d", - ## "nn.global_max_pool1d", # does not exist - "nn.global_max_pool2d", - ## "nn.global_max_pool3d", # does not exist - ## "nn.global_avg_pool1d", # does not exist - "nn.global_avg_pool2d", - ## "nn.global_avg_pool3d", # does not exist - "nn.adaptive_max_pool1d", - "nn.adaptive_max_pool2d", - "nn.adaptive_max_pool3d", - "nn.adaptive_avg_pool1d", - "nn.adaptive_avg_pool2d", - "nn.adaptive_avg_pool3d", - } - - # These should always be done in fp32 - DEFAULT_RED_LIST = { - # Activations with exponents or division - "nn.cross_entropy", - "nn.cross_entropy_with_logits", - "nn.softmax", - # Other - "nn.l2_normalize", - } - - def __init__( - self, - green_list: List[str] = DEFAULT_GREEN_LIST, - gray_list: List[str] = DEFAULT_GRAY_LIST, - red_list: List[str] = DEFAULT_RED_LIST, - ): - # Convert each list to entry - green_list = create_op_list(green_list) - gray_list = create_op_list(gray_list) - red_list = create_op_list(red_list) - - # Create lookup table mapping relay op -> color in grpah - self.lookup_table = {} - for op_list, val in [ - (green_list, ConversionCategory.GREEN), - (gray_list, ConversionCategory.GRAY), - (red_list, ConversionCategory.RED), - ]: - for op in op_list: - self.lookup_table[op] = val - - def __call__(self, call_node: relay.Call, ignore_missing: bool = False) -> ConversionCategory: - if call_node.op not in self.lookup_table: - if ignore_missing: - return ConversionCategory.RED - else: - raise ValueError(f"Unknown op {call_node.op}") - - return self.lookup_table[call_node.op] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e045abab6c..40fb383d55ac 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -18,16 +18,15 @@ """ Relay pass transformation infrastructure. """ -import types -import inspect import functools +import inspect +import types import warnings import tvm.ir -from tvm import te +from tvm import relay, te from tvm.runtime import ndarray as _nd -from tvm import relay from . import _ffi_api @@ -1199,3 +1198,10 @@ def FakeQuantizationToInteger(): The registered SimplifyExpr pass. """ return _ffi_api.FakeQuantizationToInteger() + + +def RewriteFP16(debug=False): + """ + Cool stuff. TODO + """ + return _ffi_api.RewriteFP16(debug) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 360ab5219daa..b2eee8eba972 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -1,18 +1,21 @@ #include "fp32_to_fp16.h" +#include #include #include #include +#include "pattern_utils.h" + namespace tvm { namespace relay { using CallColorMap = std::unordered_map; +using ColorFunc = std::function; +using OutputDtypeFunc = std::function; -class GraphColorer : private ExprVisitor { - using ColorFunc = std::function; - +class GraphColorer : public ExprVisitor { private: CallColorMap color_map; ColorFunc func; @@ -24,20 +27,197 @@ class GraphColorer : private ExprVisitor { } public: - GraphColorer(ColorFunc func = DefaultColorer()) : func(func) {} + GraphColorer(ColorFunc func = DefaultFP16Colorer()) : func(func) {} CallColorMap result() { return color_map; } - void VisitExpr(const Expr& expr) { ExprVisitor::VisitExpr(expr); } }; -class ColorPrinter : private ExprVisitor { +class PropagateColors : public ExprVisitor { + private: + CallColorMap color_map; + OutputDtypeFunc func; + + /* TODO: replace *Nodes with managed References e.g. CallNode* -> Call*/ + void VisitExpr_(const CallNode* l) final { + ExprVisitor::VisitExpr_(l); + auto result = color_map.find(l); + if (result == color_map.end()) { + throw std::invalid_argument("Unknown node not in initial color map!"); + } + FP16ConversionCategory color = result->second; + if (color != GRAY) return; + + for (Expr arg : l->args) { + if (!is_fp16_compatible_arg(arg)) { + color_map[l] = RED; + return; + } + } + + color_map[l] = GREEN; + } + + bool is_fp16_compatible_arg(Expr arg) { + if (arg->IsInstance() || arg->IsInstance()) { + return true; + } else if (const CallNode* call = arg.as()) { + auto result = color_map.find(call); + if (result == color_map.end()) { + throw std::invalid_argument("Unknown node not in initial color map!"); + } + FP16ConversionCategory color = result->second; + return color == GREEN && func(call).output_dtype == DataType::Float(16); + } else if (const TupleGetItemNode* tuple_get_item = arg.as()) { + return is_fp16_compatible_arg(tuple_get_item->tuple); + } else if (const TupleNode* tuple = arg.as()) { + for (Expr exp : tuple->fields) { + if (!is_fp16_compatible_arg(exp)) { + return false; + } + } + return true; + } else { + throw std::invalid_argument("Unknown node type " + arg->GetTypeKey()); + } + + return true; + } + + public: + PropagateColors(CallColorMap initial_color_map, OutputDtypeFunc func = DefaultFP16OpDefinition()) + : color_map(initial_color_map), func(func) {} + CallColorMap result() { return color_map; } +}; + +class RewriteBasedOnColors : public ExprMutator { + private: + CallColorMap color_map; + OutputDtypeFunc output_func; + + Array get_new_args(const CallNode* call, DataType arg_cast_datatype) { + Array ret; + for (Expr arg : call->args) { + arg = VisitExpr(arg); + Expr new_arg; + if (arg->IsInstance() || arg->IsInstance()) { + // Assume every var and const node is by default fp32, so cast if we are not casting to that + new_arg = arg_cast_datatype != DataType::Float(32) ? Cast(arg, arg_cast_datatype) : arg; + } else if (const CallNode* arg_call = arg.as()) { + auto entry = color_map.find(arg_call); + if (entry == color_map.end()) { + throw std::invalid_argument("Found element not in color map!"); + } + FP16ConversionCategory color = entry->second; + + // Cast result of a call, if we are going to rewrite it + if (color == GREEN) { + new_arg = output_func(arg_call).output_dtype != arg_cast_datatype + ? Cast(arg, arg_cast_datatype) + : arg; + } else { + // Was RED, assume fp32 output so cast to type + new_arg = arg_cast_datatype != DataType::Float(32) ? Cast(arg, arg_cast_datatype) : arg; + } + } else { + // Else assume it's a composite type composed of cast elements + new_arg = arg; + } + + ret.push_back(new_arg); + } + + return ret; + } + + Attrs get_new_attrs(const CallNode* call, DataType accumulation_dtype) { + Attrs new_attrs = Attrs(call->attrs); + if (new_attrs.get() != nullptr) { + // TODO: Figure out a better way to do this + if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + modify_output_dtype(attrs, accumulation_dtype); + } + } + + return new_attrs; + } + + template + void modify_output_dtype(const T* attrs, DataType accumulation_dtype) { + // Helper template to modify relevant attributes with out_dtype type. + // TODO: think about a better way to do this + T* mutable_attrs = const_cast(attrs); + mutable_attrs->out_dtype = accumulation_dtype; + } + + public: + RewriteBasedOnColors(CallColorMap color_map, + OutputDtypeFunc output_func = DefaultFP16OpDefinition()) + : color_map(color_map), output_func(output_func) {} + Expr VisitExpr_(const LetNode* op) final { + throw std::invalid_argument("Let nodes not supported for FP16 for now."); + } + + Expr VisitExpr_(const CallNode* call) final { + auto result = color_map.find(call); + if (result == color_map.end()) throw std::invalid_argument("Found element not in color map!"); + FP16ConversionCategory color = result->second; + + if (color == GRAY) { + throw std::invalid_argument( + "Had gray colored node during rewrite! Make sure other passes color all nodes!"); + } + + Expr new_op = Mutate(call->op); + FP16OpDType output_dtypes = output_func(call); + + // Create new node, ensure inputs are all fp32 if red, fp16 if green. + // For sttrs we may overwrite the accumulation dtype field "output_dtype" + // TODO: extend to bfloat types + DataType arg_cast_dtype = color == GREEN ? DataType::Float(16) : DataType::Float(32); + + Array new_args = get_new_args(call, arg_cast_dtype); + Attrs new_attrs = get_new_attrs(call, output_dtypes.accumulation_dtype); + Expr output = Call(new_op, new_args, new_attrs, call->type_args, call->span); + color_map[output.as()] = color_map[call]; + + if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { + output = Cast(output, output_dtypes.output_dtype); + color_map[output.as()] = color_map[call]; + } + + return output; + }; +}; + +class ColorPrinter : public ExprVisitor { private: CallColorMap color_map; public: explicit ColorPrinter(CallColorMap& color_map) : color_map(color_map) {} explicit ColorPrinter() {} - void VisitExpr(const Expr& expr) { ExprVisitor::VisitExpr(expr); } void VisitExpr_(const CallNode* l) final { ExprVisitor::VisitExpr_(l); @@ -45,13 +225,60 @@ class ColorPrinter : private ExprVisitor { } }; -void PrintColors(const Expr& expr) { +Expr RewriteFp16Graph(const Expr& expr, bool debug) { + // Do an initial coloring based on each operation GraphColorer initial_colorer = GraphColorer(); initial_colorer.VisitExpr(expr); - CallColorMap color_map = initial_colorer.result(); - ColorPrinter(color_map).VisitExpr(expr); + CallColorMap color_map_initial = initial_colorer.result(); + + if (debug) { + std::cout << "Initial color map:" << std::endl; + ColorPrinter(color_map_initial).VisitExpr(expr); + std::cout << std::endl; + } + + // Propagate colors so gray nodes in adjacent green regions are green + // and those in red regions are red. + PropagateColors propagate_colorer = PropagateColors(color_map_initial); + propagate_colorer.VisitExpr(expr); + CallColorMap color_map_final = propagate_colorer.result(); + + if (debug) { + std::cout << "Propagate color map:" << std::endl; + ColorPrinter(color_map_final).VisitExpr(expr); + } + + // Replace all green nodes with fp16 versions of the ops, inserting casts along way. + RewriteBasedOnColors rewriter = RewriteBasedOnColors(color_map_final); + + // TODO: think about removing extraneous casts which can sometimes be added + // (Usually interactions with non-Call nodes like Tuples) + + // Insert an extraneous cast to FP32 to match old module output + Expr result = rewriter.Mutate(expr); + + // Insert an extra FP32 cast to match the old FP32 output type. + // TODO: look into how to change the type annotation + if (const FunctionNode* func = result.as()) { + const_cast(func)->body = Cast(func->body, DataType::Float(32)); + } + + return result; } -TVM_REGISTER_GLOBAL("relay._transform.PrintColorsExpr").set_body_typed(PrintColors); + +namespace transform { + +Pass RewriteFP16(bool debug) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(RewriteFp16Graph(f, debug)); + }; + return CreateFunctionPass(pass_func, 10, "RewriteFp16", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.RewriteFP16").set_body_typed(RewriteFP16); + +} // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 18198fec5c7c..3f0d6432cf89 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -9,10 +9,14 @@ namespace tvm { namespace relay { +struct FP16OpDType { + DataType accumulation_dtype; + DataType output_dtype; +}; + enum FP16ConversionCategory { RED, GRAY, GREEN }; -std::unordered_map conversion_category_strings({{RED, "Red"}, - {GRAY, "Gray"}, - {GREEN, "Green"}}); +std::unordered_map conversion_category_strings( + {{RED, "Red"}, {GRAY, "Gray"}, {GREEN, "Green"}}); using OpStringSet = std::unordered_set; @@ -72,13 +76,14 @@ OpStringSet DEFAULT_RED_LIST({ "nn.l2_normalize", }); -class DefaultColorer { +class DefaultFP16Colorer { private: std::unordered_map op_to_initial_color; public: - DefaultColorer(OpStringSet red_list = DEFAULT_RED_LIST, OpStringSet gray_list = DEFAULT_GRAY_LIST, - OpStringSet green_list = DEFAULT_GREEN_LIST) { + DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST, + OpStringSet gray_list = DEFAULT_GRAY_LIST, + OpStringSet green_list = DEFAULT_GREEN_LIST) { std::vector> lists_and_colors{ {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}}; @@ -91,7 +96,7 @@ class DefaultColorer { } } - FP16ConversionCategory operator()(const tvm::relay::CallNode* call, bool ignore_missing = false) { + FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = false) { auto* op_node = (call->op).as(); if (op_node == nullptr) { throw std::invalid_argument("FP16 conversion only supports call nodes with op calls."); @@ -112,5 +117,18 @@ class DefaultColorer { } }; +class DefaultFP16OpDefinition { + public: + FP16OpDType operator()(const CallNode* call) { + if (call->attrs != NullValue()) { + Array fields = call->attrs->ListFieldInfo(); + for (AttrFieldInfo field_info : fields) { + if (field_info->name == "out_dtype") return {DataType::Float(32), DataType::Float(16)}; + } + } + return {DataType::Float(16), DataType::Float(16)}; + } +}; + } // namespace relay } // namespace tvm \ No newline at end of file diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index b3f91eceba30..0c5150874af8 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -2,31 +2,26 @@ import numpy as np import tvm -from numpy.lib.type_check import imag from tvm import relay -from tvm.relay.expr_functor import ExprVisitor from tvm.relay.testing import densenet, mobilenet, resnet, resnet_3d, squeezenet -from tvm.relay.transform import InferType -from tvm.relay.transform.fp16_conversion import fp32_to_fp16 +from tvm.relay.transform import RewriteFP16 def run_module(mod, mod_params): dev = tvm.device("llvm", 0) intrp = relay.create_executor("debug", mod, device=dev, target="llvm") - # in_data = [tvm.nd.array(value) for value in in_data.values()] return intrp.evaluate()(**mod_params).asnumpy() def verify_fp32_fp16_output_close(mod, mod_params): + fp16_mod = RewriteFP16()(mod) + result_fp16 = run_module(fp16_mod, mod_params) result_fp32 = run_module(mod, mod_params) - fp16 = fp32_to_fp16.quantize_to_fp16(mod["main"].body) - fp16_mod = tvm.ir.IRModule.from_expr(fp16) - result_fp16 = run_module(fp16_mod, mod_params) - # Ensure the results are close np.testing.assert_allclose(result_fp32, result_fp16, rtol=1e-3) + def test_resnet18(): np.random.seed(4321) mod, mod_params = resnet.get_workload(1, 5, num_layers=18, image_shape=(1, 32, 32)) @@ -67,4 +62,5 @@ def test_squeezenet(): verify_fp32_fp16_output_close(mod, mod_params) -#def test_ + +# def test_ From 4903a31a50fc84501c8b1991d95887824f1a71c9 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 25 May 2021 16:29:34 -0400 Subject: [PATCH 04/59] Extend support to things besides CallNodes. E.g. tuples and lets fp32 invalidate typing instead of cast adding basic tests skeleton code out Stash work -- casting based on checked types working let statements add more ops, handle functions more generally add multiply, fix broken case support TupleNodes properly, move hash function for datatypes into data_type.h" update simple let test with structural expectation cleanup p1 remove old file --- include/tvm/runtime/data_type.h | 15 ++ src/relay/transforms/fp32_to_fp16.cc | 173 ++++++++++++---- src/relay/transforms/fp32_to_fp16.h | 58 +++--- .../relay/test_fp32_to_fp16_transform.py | 188 +++++++++++++++++- 4 files changed, 366 insertions(+), 68 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index b4fdcbff58b4..3b767547357b 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -389,4 +389,19 @@ inline DLDataType String2DLDataType(std::string s) { using DataType = runtime::DataType; } // namespace tvm + +namespace std { +template <> +struct hash { + inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; } + std::size_t operator()(tvm::DataType const& dtype) const { + int a = dtype.code(); + int b = dtype.bits(); + int c = dtype.lanes(); + int d = cantor_pairing_function(a, b); + return cantor_pairing_function(c, d); + } +}; +} // namespace std + #endif // TVM_RUNTIME_DATA_TYPE_H_ diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index b2eee8eba972..90c5cf2e2351 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -11,7 +11,19 @@ namespace tvm { namespace relay { +struct pair_hash { + template + std::size_t operator()(const std::pair& pair) const { + auto h1 = std::hash()(pair.first); + auto h2 = std::hash()(pair.second); + + return h1 ^ h2; + } +}; + +// A map of call nodes to their fp16 conversion type using CallColorMap = std::unordered_map; +using CachedCastNodes = std::unordered_map, Expr, pair_hash>; using ColorFunc = std::function; using OutputDtypeFunc = std::function; @@ -37,17 +49,29 @@ class PropagateColors : public ExprVisitor { CallColorMap color_map; OutputDtypeFunc func; - /* TODO: replace *Nodes with managed References e.g. CallNode* -> Call*/ void VisitExpr_(const CallNode* l) final { ExprVisitor::VisitExpr_(l); auto result = color_map.find(l); if (result == color_map.end()) { - throw std::invalid_argument("Unknown node not in initial color map!"); + LOG(FATAL) << "Unknown node not in initial color map!"; } FP16ConversionCategory color = result->second; - if (color != GRAY) return; + + // Red and Green colored nodes are final, we only care about gray nodes + if (color != GRAY) { + ExprVisitor::VisitExpr_(l); + return; + }; + + // Make sure to visit everything, paying attention to args + this->VisitSpan(l->span); + this->VisitExpr(l->op); + for (auto ty_arg : l->type_args) { + this->VisitType(ty_arg); + } for (Expr arg : l->args) { + this->VisitExpr(arg); if (!is_fp16_compatible_arg(arg)) { color_map[l] = RED; return; @@ -58,12 +82,14 @@ class PropagateColors : public ExprVisitor { } bool is_fp16_compatible_arg(Expr arg) { + // TODO: examine for correctness -- think about conditions and rethink this method + // Whether this is an argument which can/should be casted to fp16 if (arg->IsInstance() || arg->IsInstance()) { return true; } else if (const CallNode* call = arg.as()) { auto result = color_map.find(call); if (result == color_map.end()) { - throw std::invalid_argument("Unknown node not in initial color map!"); + LOG(FATAL) << "Unknown node not in initial color map!"; } FP16ConversionCategory color = result->second; return color == GREEN && func(call).output_dtype == DataType::Float(16); @@ -77,7 +103,7 @@ class PropagateColors : public ExprVisitor { } return true; } else { - throw std::invalid_argument("Unknown node type " + arg->GetTypeKey()); + LOG(FATAL) << "Unknown node not in initial color map!"; } return true; @@ -93,37 +119,72 @@ class RewriteBasedOnColors : public ExprMutator { private: CallColorMap color_map; OutputDtypeFunc output_func; + CachedCastNodes cached_cast_nodes; + + Expr GetTypedExpr(const Expr& expr) { + // Returns typed version of the expression whose type can be gotten using checked_type(). + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } + } + + Expr cached_cast(Expr expr, DataType expr_dtype, DataType wanted_dtype) { + // Cast tensor to the wanted datatype, returning a cached version if it's already been done. + + // If this is not a floating point type, do not cast. E.g. it might be an integer + if (!expr_dtype.is_float()) { + return expr; + } + + const ExprNode* expr_node = expr.as(); + if (!expr_node) { + LOG(FATAL) << "None expression node found in cast: " << expr; + } + + auto search = cached_cast_nodes.find({expr_node, wanted_dtype}); + if (search != cached_cast_nodes.end()) { + // Use cached result + return search->second; + } + + Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype); + cached_cast_nodes[{expr_node, wanted_dtype}] = result; + + // Reverse the cache result too, e.g. if we want to reverse the cast point to original node + const ExprNode* new_expr_node = result.as(); + cached_cast_nodes[{new_expr_node, expr_dtype}] = expr; + return result; + } + + Expr arg_cast_helper(Expr expr, Type t, DataType wanted_dtype) { + // Helper for casting arguments to call_nodes handling all relevant cases. + if (const TensorTypeNode* tensor_type = t.as()) { + return cached_cast(expr, tensor_type->dtype, wanted_dtype); + } else if (const TupleTypeNode* tuple_type = t.as()) { + Array new_expr; + for (int i = 0; i < (tuple_type->fields).size(); i++) { + Expr tuple_expr_element = GetField(expr, i); + Type tuple_expr_element_dtype = (tuple_type->fields)[i]; + new_expr.push_back( + arg_cast_helper(tuple_expr_element, tuple_expr_element_dtype, wanted_dtype)); + } + return Tuple(new_expr); + } else { + LOG(FATAL) << "Unsupported type " << t << " we don't know how to cast for arguments!"; + return expr; + } + } Array get_new_args(const CallNode* call, DataType arg_cast_datatype) { Array ret; for (Expr arg : call->args) { arg = VisitExpr(arg); - Expr new_arg; - if (arg->IsInstance() || arg->IsInstance()) { - // Assume every var and const node is by default fp32, so cast if we are not casting to that - new_arg = arg_cast_datatype != DataType::Float(32) ? Cast(arg, arg_cast_datatype) : arg; - } else if (const CallNode* arg_call = arg.as()) { - auto entry = color_map.find(arg_call); - if (entry == color_map.end()) { - throw std::invalid_argument("Found element not in color map!"); - } - FP16ConversionCategory color = entry->second; - - // Cast result of a call, if we are going to rewrite it - if (color == GREEN) { - new_arg = output_func(arg_call).output_dtype != arg_cast_datatype - ? Cast(arg, arg_cast_datatype) - : arg; - } else { - // Was RED, assume fp32 output so cast to type - new_arg = arg_cast_datatype != DataType::Float(32) ? Cast(arg, arg_cast_datatype) : arg; - } - } else { - // Else assume it's a composite type composed of cast elements - new_arg = arg; - } - - ret.push_back(new_arg); + Type arg_type = GetTypedExpr(arg)->checked_type(); + ret.push_back(arg_cast_helper(arg, arg_type, arg_cast_datatype)); } return ret; @@ -158,6 +219,10 @@ class RewriteBasedOnColors : public ExprMutator { } else if (auto attrs = new_attrs.as()) { modify_output_dtype(attrs, accumulation_dtype); } + + if (auto attrs = new_attrs.as()) { + modify_dtype(attrs, accumulation_dtype); + } } return new_attrs; @@ -166,27 +231,51 @@ class RewriteBasedOnColors : public ExprMutator { template void modify_output_dtype(const T* attrs, DataType accumulation_dtype) { // Helper template to modify relevant attributes with out_dtype type. + // These represent accumulation dtypes for some operations e.g. + // conv2d might take in fp16 and give a fp32 result. // TODO: think about a better way to do this T* mutable_attrs = const_cast(attrs); mutable_attrs->out_dtype = accumulation_dtype; } + template + void modify_dtype(const T* attrs, DataType accumulation_dtype) { + // Helper template to modify relevant attributes with dtype type. + // This determines the output dtype for some ops. For example + // zeros creates a tensor of zeros of the specified dtype. + // TODO: think about a better way to do this + T* mutable_attrs = const_cast(attrs); + mutable_attrs->dtype = accumulation_dtype; + } + public: RewriteBasedOnColors(CallColorMap color_map, OutputDtypeFunc output_func = DefaultFP16OpDefinition()) : color_map(color_map), output_func(output_func) {} + Expr VisitExpr_(const LetNode* op) final { - throw std::invalid_argument("Let nodes not supported for FP16 for now."); + // First convert as much as the bound computation to FP16 as possible + Expr value = this->Mutate(op->value); + + // Then rewrite the var type and associated expression + Var var = Downcast(this->Mutate(op->var)); + VarNode* mutable_var = const_cast((op->var).as()); + mutable_var->type_annotation = GetTypedExpr(value)->checked_type(); + mutable_var->checked_type_ = mutable_var->type_annotation; + + // Mutate body last as it may depend on previous results + Expr body = this->Mutate(op->body); + + return Let(var, value, body, op->span); } Expr VisitExpr_(const CallNode* call) final { auto result = color_map.find(call); - if (result == color_map.end()) throw std::invalid_argument("Found element not in color map!"); + if (result == color_map.end()) LOG(FATAL) << "Unknown node not in initial color map!"; FP16ConversionCategory color = result->second; if (color == GRAY) { - throw std::invalid_argument( - "Had gray colored node during rewrite! Make sure other passes color all nodes!"); + LOG(FATAL) << "Had gray colored node during rewrite! Make sure other passes color all nodes!"; } Expr new_op = Mutate(call->op); @@ -200,15 +289,22 @@ class RewriteBasedOnColors : public ExprMutator { Array new_args = get_new_args(call, arg_cast_dtype); Attrs new_attrs = get_new_attrs(call, output_dtypes.accumulation_dtype); Expr output = Call(new_op, new_args, new_attrs, call->type_args, call->span); - color_map[output.as()] = color_map[call]; + color_map[output.as()] = color_map[call]; if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { output = Cast(output, output_dtypes.output_dtype); color_map[output.as()] = color_map[call]; } + // TODO: remember to visit everything! Def missing arg dtypes rn return output; }; + + Expr VisitExpr_(const FunctionNode* func) final { + // Erase the ret_type annotation and let the pass recalculate + const_cast(func)->ret_type = Type(nullptr); + return ExprMutator::VisitExpr_(func); + } }; class ColorPrinter : public ExprVisitor { @@ -257,10 +353,9 @@ Expr RewriteFp16Graph(const Expr& expr, bool debug) { // Insert an extraneous cast to FP32 to match old module output Expr result = rewriter.Mutate(expr); - // Insert an extra FP32 cast to match the old FP32 output type. - // TODO: look into how to change the type annotation + // Old type annotations may no longer be accurate so rewrite if (const FunctionNode* func = result.as()) { - const_cast(func)->body = Cast(func->body, DataType::Float(32)); + const_cast(func)->ret_type = Type(nullptr); } return result; diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 3f0d6432cf89..3511c944f574 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -22,8 +23,7 @@ using OpStringSet = std::unordered_set; // Default lists inspired from TF's classifications: // github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h -// They might have a bias toward NVidia's Tensor Cores so be aware and modify lists per your -// hardware choice. +// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. OpStringSet DEFAULT_GREEN_LIST({ "nn.conv1d", "nn.conv2d", @@ -38,8 +38,13 @@ OpStringSet DEFAULT_GRAY_LIST({ "nn.pad", "nn.batch_flatten", "concatenate", + "zeros", + "split", // Simple arithmetic "add", + "subtract", + "multiply", + "divide", "nn.bias_add", "nn.batch_norm", // Simple activations @@ -47,6 +52,9 @@ OpStringSet DEFAULT_GRAY_LIST({ "nn.leaky_relu", "nn.prelu", "nn.dropout", + // Complicated activations which saturate in a narrow range + "sigmoid", + "tanh", // Pooling operations "nn.max_pool1d", "nn.max_pool2d", @@ -54,12 +62,12 @@ OpStringSet DEFAULT_GRAY_LIST({ "nn.avg_pool1d", "nn.avg_pool2d", "nn.avg_pool3d", - // "nn.global_max_pool1d", // does not exist + // "nn.global_max_pool1d", // does not exist yet "nn.global_max_pool2d", - // "nn.global_max_pool3d", // does not exist - // "nn.global_avg_pool1d", // does not exist + // "nn.global_max_pool3d", // does not exist yet + // "nn.global_avg_pool1d", // does not exist yet "nn.global_avg_pool2d", - // "nn.global_avg_pool3d", // does not exist + // "nn.global_avg_pool3d", // does not exist yet "nn.adaptive_max_pool1d", "nn.adaptive_max_pool2d", "nn.adaptive_max_pool3d", @@ -68,11 +76,11 @@ OpStringSet DEFAULT_GRAY_LIST({ "nn.adaptive_avg_pool3d", }); OpStringSet DEFAULT_RED_LIST({ - // Activations with exponents or division + // In general if |f(x)| >> |x| for some expected inputs to the op then put it here. + // Activations with exponents or dividing by small numbers "nn.cross_entropy", "nn.cross_entropy_with_logits", "nn.softmax", - // Other "nn.l2_normalize", }); @@ -96,24 +104,28 @@ class DefaultFP16Colorer { } } - FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = false) { - auto* op_node = (call->op).as(); - if (op_node == nullptr) { - throw std::invalid_argument("FP16 conversion only supports call nodes with op calls."); - } - - std::string op_name = op_node->name; - auto color = op_to_initial_color.find(op_name); + FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) { + if (auto* op_node = (call->op).as()) { + std::string op_name = op_node->name; + auto color = op_to_initial_color.find(op_name); - if (color == op_to_initial_color.end()) { - if (ignore_missing) { - return RED; - } else { - throw std::invalid_argument("Op name " + op_name + " not in included lists!."); + if (color == op_to_initial_color.end()) { + if (ignore_missing) { + LOG(WARNING) << "Op name " + op_name + " not in included in fp16 conversion lists!."; + return RED; + } else { + LOG(FATAL) << "Op name " + op_name + " not in included in fp16 lists!."; + } } - } - return color->second; + return color->second; + } else if (auto* func_node = (call->op).as()) { + // Make RED to avoid messing with function types which are complicated, fold in other pass + return RED; + } else { + LOG(FATAL) << "FP16 conversion only supports call nodes with op calls got " << call->op; + return RED; + } } }; diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index 0c5150874af8..742d9380dae2 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -1,10 +1,13 @@ +from collections import defaultdict from typing import * import numpy as np import tvm from tvm import relay -from tvm.relay.testing import densenet, mobilenet, resnet, resnet_3d, squeezenet +from tvm.relay.op.tensor import exp +from tvm.relay.testing import densenet, lstm, mobilenet, resnet, resnet_3d, squeezenet from tvm.relay.transform import RewriteFP16 +from tvm.relay.transform.transform import AnnotateSpans, InferType def run_module(mod, mod_params): @@ -13,13 +16,17 @@ def run_module(mod, mod_params): return intrp.evaluate()(**mod_params).asnumpy() -def verify_fp32_fp16_output_close(mod, mod_params): +def verify_fp32_fp16_output_close(mod, mod_params, rtol=1e-3, atol=0): + mod = InferType()(mod) + mod = AnnotateSpans()(mod) + result_fp32 = run_module(mod, mod_params) fp16_mod = RewriteFP16()(mod) result_fp16 = run_module(fp16_mod, mod_params) - result_fp32 = run_module(mod, mod_params) # Ensure the results are close - np.testing.assert_allclose(result_fp32, result_fp16, rtol=1e-3) + np.testing.assert_allclose(result_fp32, result_fp16, rtol=rtol, atol=atol) + + return fp16_mod def test_resnet18(): @@ -59,8 +66,177 @@ def test_squeezenet(): np.random.seed(5628) mod, mod_params = squeezenet.get_workload(1, 5, image_shape=(1, 32, 32)) mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") - verify_fp32_fp16_output_close(mod, mod_params) -# def test_ +def test_lstm(): + np.random.seed(5628) + mod, mod_params = lstm.get_workload(5, 3) + + # This is an unrolled lstm so each data should be the previous results but whatever. + # We jsut want to use this to test more complicated let statements + nested funcs + mod_params["data"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") + mod_params["data1"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") + mod_params["data2"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") + mod_params["data3"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") + mod_params["data4"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") + + verify_fp32_fp16_output_close(mod, mod_params, rtol=0.01, atol=0.01) + + +def test_convert_single_conv(): + """Conv is a green listed operation meaning it will always use fp16 workload. + + By default it accumulates to fp32 and outputs fp16. + """ + np.random.seed(208) + + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_do_not_convert_softmax(): + """Softmax is a red listed operation and therefore should never be fp16.""" + np.random.seed(209) + shape = [1, 2, 3] + a = relay.var("a", shape=shape) + b = relay.nn.softmax(a) + mod = tvm.IRModule.from_expr(b) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "a": np.random.uniform(-1, 1, size=shape).astype("float32"), + } + output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0) + assert tvm.ir.structural_equal(mod, output_mod) + + +def test_green_gray_propagates_simple(): + """Conv is a green listed operation, while addition is gray. + + When adjacent + """ + np.random.seed(210) + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + conv = conv + conv + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + conv_expr = relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_red_gray_propagates_simple(): + """Conv is a green listed operation, while addition is gray. + + When adjacent + """ + np.random.seed(211) + shape = [1, 2, 3] + a = relay.var("a", shape=shape) + b = relay.nn.softmax(a) + c = b + b + mod = tvm.IRModule.from_expr(c) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "a": np.random.uniform(-1, 1, size=shape).astype("float32"), + } + output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0.0) + + assert tvm.ir.structural_equal(mod, output_mod) + + +def test_let_statement_simple(): + np.random.seed(211) + var1 = relay.var("var1", shape=[1, 20]) + var2 = relay.var("var2", shape=[1, 20]) + + data = relay.var("data", shape=[1, 20]) + weight = relay.var("weight", shape=[20, 20]) + + r1 = var1 + var1 + + r2 = var2 + var2 + let2 = relay.Let(var2, relay.nn.dense(r1, weight, units=20), r2) + let1 = relay.Let(var1, relay.nn.dense(data, weight, units=20), let2) + + mod = tvm.IRModule.from_expr(let1) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), + } + output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Construct expected structure + var1 = relay.var("var1", shape=[1, 20], dtype="float16") + var2 = relay.var("var2", shape=[1, 20], dtype="float16") + data = relay.cast(relay.var("data", shape=[1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + r1 = var1 + var1 + r2 = var2 + var2 + let2 = relay.Let( + var2, + relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"), + r2, + ) + let1 = relay.Let( + var1, + relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"), + let2, + ) + expected_mod = tvm.IRModule.from_expr(let1) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, output_mod) From 41ac568d2c2c71f3135bd2fb09fe0b8c23c66fae Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 3 Jun 2021 16:23:11 -0700 Subject: [PATCH 05/59] Rewrite how and when casting is done by checking types directly. add support for GPT2, BERT add some more comments new single pass version formatting make a lot of things const references clean up tests more cleanup more comments final comment add newline --- python/tvm/relay/transform/transform.py | 16 +- src/relay/transforms/fp32_to_fp16.cc | 460 ++++++++---------- src/relay/transforms/fp32_to_fp16.h | 80 ++- .../relay/test_fp32_to_fp16_transform.py | 200 +++++--- 4 files changed, 414 insertions(+), 342 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 40fb383d55ac..7dffd27ad499 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1167,7 +1167,7 @@ def AnnotateSpans(): Returns ------- ret : tvm.transform.Pass - The regsistered AnnotateSpans pass. + The registered AnnotateSpans pass. """ return _ffi_api.AnnotateSpans() @@ -1200,8 +1200,16 @@ def FakeQuantizationToInteger(): return _ffi_api.FakeQuantizationToInteger() -def RewriteFP16(debug=False): +def RewriteFP16(): """ - Cool stuff. TODO + Rewrite an FP32 relay graph into an FP16 version. Note this does mutate + the original graph putting it in a bad state potentially. + + TODO: don't mutate the original graph. + + Returns + ------- + ret : tvm.transform.Pass + The registered RewriteFP16 pass. """ - return _ffi_api.RewriteFP16(debug) + return _ffi_api.RewriteFP16() diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 90c5cf2e2351..20f23da54f12 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -11,129 +11,129 @@ namespace tvm { namespace relay { +// A callable which hashes std::pair struct pair_hash { template std::size_t operator()(const std::pair& pair) const { auto h1 = std::hash()(pair.first); auto h2 = std::hash()(pair.second); - return h1 ^ h2; + return h1 ^ (h2 << 1); } }; -// A map of call nodes to their fp16 conversion type -using CallColorMap = std::unordered_map; +// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype using CachedCastNodes = std::unordered_map, Expr, pair_hash>; -using ColorFunc = std::function; -using OutputDtypeFunc = std::function; -class GraphColorer : public ExprVisitor { - private: - CallColorMap color_map; - ColorFunc func; - - void VisitExpr_(const CallNode* l) final { - // FP16ConversionCategory c = func(l); - color_map[l] = func(l); - ExprVisitor::VisitExpr_(l); - } - - public: - GraphColorer(ColorFunc func = DefaultFP16Colorer()) : func(func) {} +// A function which maps CallNodes to their initial conversion color +using ColorFunc = std::function; - CallColorMap result() { return color_map; } -}; +// A function which maps green CallNodes to wanted accumulation and output dtypes +using OutputDtypeFunc = std::function; -class PropagateColors : public ExprVisitor { +class FP16GraphCreator : public ExprMutator { private: - CallColorMap color_map; - OutputDtypeFunc func; - - void VisitExpr_(const CallNode* l) final { - ExprVisitor::VisitExpr_(l); - auto result = color_map.find(l); - if (result == color_map.end()) { - LOG(FATAL) << "Unknown node not in initial color map!"; - } - FP16ConversionCategory color = result->second; - - // Red and Green colored nodes are final, we only care about gray nodes - if (color != GRAY) { - ExprVisitor::VisitExpr_(l); - return; - }; - - // Make sure to visit everything, paying attention to args - this->VisitSpan(l->span); - this->VisitExpr(l->op); - for (auto ty_arg : l->type_args) { - this->VisitType(ty_arg); - } + CachedCastNodes cast_nodes_cache; + const ColorFunc colorer; + const OutputDtypeFunc output_dtype_func; - for (Expr arg : l->args) { - this->VisitExpr(arg); - if (!is_fp16_compatible_arg(arg)) { - color_map[l] = RED; - return; + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { + /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ + Attrs new_attrs = Attrs(call->attrs); + if (new_attrs.get() != nullptr) { + // TODO: Figure out a better way to do this + // modify output_dtype attributes (accumulation dtypes for ops) + if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); } - } - - color_map[l] = GREEN; - } - bool is_fp16_compatible_arg(Expr arg) { - // TODO: examine for correctness -- think about conditions and rethink this method - // Whether this is an argument which can/should be casted to fp16 - if (arg->IsInstance() || arg->IsInstance()) { - return true; - } else if (const CallNode* call = arg.as()) { - auto result = color_map.find(call); - if (result == color_map.end()) { - LOG(FATAL) << "Unknown node not in initial color map!"; - } - FP16ConversionCategory color = result->second; - return color == GREEN && func(call).output_dtype == DataType::Float(16); - } else if (const TupleGetItemNode* tuple_get_item = arg.as()) { - return is_fp16_compatible_arg(tuple_get_item->tuple); - } else if (const TupleNode* tuple = arg.as()) { - for (Expr exp : tuple->fields) { - if (!is_fp16_compatible_arg(exp)) { - return false; - } + // modify dtype attributes (creating new tensors of type dtype) + if (auto attrs = new_attrs.as()) { + ModifyAttrsDType(attrs, accumulation_dtype); } - return true; - } else { - LOG(FATAL) << "Unknown node not in initial color map!"; } - return true; + return new_attrs; } - public: - PropagateColors(CallColorMap initial_color_map, OutputDtypeFunc func = DefaultFP16OpDefinition()) - : color_map(initial_color_map), func(func) {} - CallColorMap result() { return color_map; } -}; + template + void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with out_dtype type. + These represent accumulation dtypes for some operations e.g. + conv2d might take in fp16 and give a fp32 result. + Attrs is const because we get it as a const. + TODO: think about a better way to do this + */ + T* mutable_attrs = const_cast(attrs); + mutable_attrs->out_dtype = accumulation_dtype; + } -class RewriteBasedOnColors : public ExprMutator { - private: - CallColorMap color_map; - OutputDtypeFunc output_func; - CachedCastNodes cached_cast_nodes; + template + void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with dtype type. + This determines the output dtype for some ops. For example + zeros creates a tensor of zeros of the specified dtype. + Attrs is const because we get it as a const. + TODO: think about a better way to do this + */ + T* mutable_attrs = const_cast(attrs); + mutable_attrs->dtype = accumulation_dtype; + } - Expr GetTypedExpr(const Expr& expr) { - // Returns typed version of the expression whose type can be gotten using checked_type(). + Type GetType(const Expr& expr) const { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); if (expr.as()) { - return mod->Lookup("main"); + return mod->Lookup("main")->checked_type(); } else { - return mod->Lookup("main").as()->body; + return mod->Lookup("main").as()->body->checked_type(); } } - Expr cached_cast(Expr expr, DataType expr_dtype, DataType wanted_dtype) { - // Cast tensor to the wanted datatype, returning a cached version if it's already been done. + bool IsFP16Type(const Type& t, bool ignore_non_float = false) const { + /* Returns whether t is a type with only fp16 elements. + If ignore_non_float, then ignore non-floating types. + */ + if (const TensorTypeNode* tensor_type = t.as()) { + return (!ignore_non_float || (tensor_type->dtype).is_float()) && + tensor_type->dtype == DataType::Float(16); + } else if (const TupleTypeNode* tuple_type = t.as()) { + for (Type t : tuple_type->fields) { + if (!IsFP16Type(t, ignore_non_float)) return false; + } + return true; + } else { + LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle"; + return false; + } + } + + Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) { + /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */ // If this is not a floating point type, do not cast. E.g. it might be an integer if (!expr_dtype.is_float()) { @@ -142,231 +142,159 @@ class RewriteBasedOnColors : public ExprMutator { const ExprNode* expr_node = expr.as(); if (!expr_node) { - LOG(FATAL) << "None expression node found in cast: " << expr; + LOG(FATAL) << "Non-expression node found in cast: " << expr; } - auto search = cached_cast_nodes.find({expr_node, wanted_dtype}); - if (search != cached_cast_nodes.end()) { - // Use cached result + // Use cached result if possible. + auto search = cast_nodes_cache.find({expr_node, wanted_dtype}); + if (search != cast_nodes_cache.end()) { return search->second; } Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype); - cached_cast_nodes[{expr_node, wanted_dtype}] = result; + cast_nodes_cache[{expr_node, wanted_dtype}] = result; - // Reverse the cache result too, e.g. if we want to reverse the cast point to original node + // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node const ExprNode* new_expr_node = result.as(); - cached_cast_nodes[{new_expr_node, expr_dtype}] = expr; + cast_nodes_cache[{new_expr_node, expr_dtype}] = expr; return result; } - Expr arg_cast_helper(Expr expr, Type t, DataType wanted_dtype) { - // Helper for casting arguments to call_nodes handling all relevant cases. - if (const TensorTypeNode* tensor_type = t.as()) { - return cached_cast(expr, tensor_type->dtype, wanted_dtype); - } else if (const TupleTypeNode* tuple_type = t.as()) { + Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) { + /* Helper for casting arguments to call_nodes handling all relevant cases. */ + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return CachedCast(expr, tensor_type->dtype, wanted_dtype); + } else if (const TupleTypeNode* tuple_type = expr_type.as()) { Array new_expr; + bool all_same = true; for (int i = 0; i < (tuple_type->fields).size(); i++) { - Expr tuple_expr_element = GetField(expr, i); - Type tuple_expr_element_dtype = (tuple_type->fields)[i]; - new_expr.push_back( - arg_cast_helper(tuple_expr_element, tuple_expr_element_dtype, wanted_dtype)); + Expr tuple_element = GetField(expr, i); + Type tuple_element_dtype = (tuple_type->fields)[i]; + Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype); + new_expr.push_back(casted_element); + all_same &= casted_element.same_as(tuple_element); } - return Tuple(new_expr); + return all_same ? expr : Tuple(new_expr); } else { - LOG(FATAL) << "Unsupported type " << t << " we don't know how to cast for arguments!"; + LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!"; return expr; } } - Array get_new_args(const CallNode* call, DataType arg_cast_datatype) { - Array ret; - for (Expr arg : call->args) { - arg = VisitExpr(arg); - Type arg_type = GetTypedExpr(arg)->checked_type(); - ret.push_back(arg_cast_helper(arg, arg_type, arg_cast_datatype)); + std::pair, Array> CastAllArgs(const Array& cur_args, + const Array& cur_arg_types, + const DataType& wanted_dtype) { + Array new_args; + Array new_arg_types; + for (size_t i = 0; i < cur_args.size(); i++) { + Expr cur_arg = cur_args[i]; + Type cur_arg_type = cur_arg_types[i]; + Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype); + Type new_arg_type = GetType(new_arg); + new_args.push_back(new_arg); + new_arg_types.push_back(new_arg_type); } - - return ret; + return {new_args, new_arg_types}; } - Attrs get_new_attrs(const CallNode* call, DataType accumulation_dtype) { - Attrs new_attrs = Attrs(call->attrs); - if (new_attrs.get() != nullptr) { - // TODO: Figure out a better way to do this - if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - modify_output_dtype(attrs, accumulation_dtype); + public: + explicit FP16GraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func) + : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {} + + Expr VisitExpr_(const CallNode* call_node) { + FP16ConversionCategory initial_color = colorer(call_node); + auto new_op = this->Mutate(call_node->op); + + // Mutate arguments to FP16 form first if possible and keep track of whether all floating point + // tensors are in FP16 form already. This is useful for propagating color. + Array new_args; + Array new_arg_types; + bool all_args_fp16_compatible = true; + for (Expr arg : call_node->args) { + Expr new_arg = this->Mutate(arg); + Type new_arg_type = GetType(new_arg); + new_args.push_back(new_arg); + new_arg_types.push_back(new_arg_type); + + if (all_args_fp16_compatible) { + // We can cast Vars and Constants to the right types so don't care about the types. + bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance() || + arg->IsInstance(); + all_args_fp16_compatible &= is_fp16_compatible; } + } - if (auto attrs = new_attrs.as()) { - modify_dtype(attrs, accumulation_dtype); - } + // Determine the final color. + FP16ConversionCategory final_color; + if (initial_color == GRAY) { + final_color = all_args_fp16_compatible ? GREEN : RED; + } else { + final_color = initial_color; } - return new_attrs; - } + // Create the new arguments to the call. + DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32); + auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes); - template - void modify_output_dtype(const T* attrs, DataType accumulation_dtype) { - // Helper template to modify relevant attributes with out_dtype type. - // These represent accumulation dtypes for some operations e.g. - // conv2d might take in fp16 and give a fp32 result. - // TODO: think about a better way to do this - T* mutable_attrs = const_cast(attrs); - mutable_attrs->out_dtype = accumulation_dtype; - } + Array call_args = call_args_and_types.first; + Array call_arg_types; + if (call_node->op.as()) { + // Function Nodes don't store type info in the Call, it should be a [] + call_arg_types = call_node->type_args; + } else { + call_arg_types = call_args_and_types.second; + } - template - void modify_dtype(const T* attrs, DataType accumulation_dtype) { - // Helper template to modify relevant attributes with dtype type. - // This determines the output dtype for some ops. For example - // zeros creates a tensor of zeros of the specified dtype. - // TODO: think about a better way to do this - T* mutable_attrs = const_cast(attrs); - mutable_attrs->dtype = accumulation_dtype; + // Finally create the new attributes. + if (final_color == GREEN) { + FP16OpDType output_dtypes = output_dtype_func(call_node); + + Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype); + Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span); + if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { + output = CastArg(output, GetType(output), output_dtypes.output_dtype); + } + return output; + } else { + return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span); + } } - public: - RewriteBasedOnColors(CallColorMap color_map, - OutputDtypeFunc output_func = DefaultFP16OpDefinition()) - : color_map(color_map), output_func(output_func) {} + Expr VisitExpr_(const FunctionNode* func) final { + // Erase the ret_type annotation and let the normal pass recalculate + const_cast(func)->ret_type = Type(nullptr); + return ExprMutator::VisitExpr_(func); + } Expr VisitExpr_(const LetNode* op) final { - // First convert as much as the bound computation to FP16 as possible + // First convert as much of the bound computation to FP16 as possible Expr value = this->Mutate(op->value); // Then rewrite the var type and associated expression Var var = Downcast(this->Mutate(op->var)); VarNode* mutable_var = const_cast((op->var).as()); - mutable_var->type_annotation = GetTypedExpr(value)->checked_type(); + mutable_var->type_annotation = GetType(value); mutable_var->checked_type_ = mutable_var->type_annotation; // Mutate body last as it may depend on previous results Expr body = this->Mutate(op->body); - return Let(var, value, body, op->span); } - - Expr VisitExpr_(const CallNode* call) final { - auto result = color_map.find(call); - if (result == color_map.end()) LOG(FATAL) << "Unknown node not in initial color map!"; - FP16ConversionCategory color = result->second; - - if (color == GRAY) { - LOG(FATAL) << "Had gray colored node during rewrite! Make sure other passes color all nodes!"; - } - - Expr new_op = Mutate(call->op); - FP16OpDType output_dtypes = output_func(call); - - // Create new node, ensure inputs are all fp32 if red, fp16 if green. - // For sttrs we may overwrite the accumulation dtype field "output_dtype" - // TODO: extend to bfloat types - DataType arg_cast_dtype = color == GREEN ? DataType::Float(16) : DataType::Float(32); - - Array new_args = get_new_args(call, arg_cast_dtype); - Attrs new_attrs = get_new_attrs(call, output_dtypes.accumulation_dtype); - Expr output = Call(new_op, new_args, new_attrs, call->type_args, call->span); - - color_map[output.as()] = color_map[call]; - if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { - output = Cast(output, output_dtypes.output_dtype); - color_map[output.as()] = color_map[call]; - } - - // TODO: remember to visit everything! Def missing arg dtypes rn - return output; - }; - - Expr VisitExpr_(const FunctionNode* func) final { - // Erase the ret_type annotation and let the pass recalculate - const_cast(func)->ret_type = Type(nullptr); - return ExprMutator::VisitExpr_(func); - } }; -class ColorPrinter : public ExprVisitor { - private: - CallColorMap color_map; - - public: - explicit ColorPrinter(CallColorMap& color_map) : color_map(color_map) {} - explicit ColorPrinter() {} - - void VisitExpr_(const CallNode* l) final { - ExprVisitor::VisitExpr_(l); - std::cout << l->op << " is " << conversion_category_strings[color_map[l]] << std::endl; - } -}; - -Expr RewriteFp16Graph(const Expr& expr, bool debug) { - // Do an initial coloring based on each operation - GraphColorer initial_colorer = GraphColorer(); - initial_colorer.VisitExpr(expr); - CallColorMap color_map_initial = initial_colorer.result(); - - if (debug) { - std::cout << "Initial color map:" << std::endl; - ColorPrinter(color_map_initial).VisitExpr(expr); - std::cout << std::endl; - } - - // Propagate colors so gray nodes in adjacent green regions are green - // and those in red regions are red. - PropagateColors propagate_colorer = PropagateColors(color_map_initial); - propagate_colorer.VisitExpr(expr); - CallColorMap color_map_final = propagate_colorer.result(); - - if (debug) { - std::cout << "Propagate color map:" << std::endl; - ColorPrinter(color_map_final).VisitExpr(expr); - } - - // Replace all green nodes with fp16 versions of the ops, inserting casts along way. - RewriteBasedOnColors rewriter = RewriteBasedOnColors(color_map_final); - - // TODO: think about removing extraneous casts which can sometimes be added - // (Usually interactions with non-Call nodes like Tuples) - - // Insert an extraneous cast to FP32 to match old module output - Expr result = rewriter.Mutate(expr); - - // Old type annotations may no longer be accurate so rewrite - if (const FunctionNode* func = result.as()) { - const_cast(func)->ret_type = Type(nullptr); - } - - return result; +Expr RewriteFp16Graph(const Expr& expr, const ColorFunc& colorer, + const OutputDtypeFunc& output_dtype_func) { + FP16GraphCreator converter = FP16GraphCreator(colorer, output_dtype_func); + return converter.Mutate(expr); } namespace transform { -Pass RewriteFP16(bool debug) { +Pass RewriteFP16() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(RewriteFp16Graph(f, debug)); + return Downcast( + RewriteFp16Graph(f, DefaultFP16Colorer(), DefaultFP16OpDefinition())); }; return CreateFunctionPass(pass_func, 10, "RewriteFp16", {}); } diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 3511c944f574..2b9ca73c60cb 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -15,6 +15,9 @@ struct FP16OpDType { DataType output_dtype; }; +// GREEN colored ops should always be done in FP16 due to the speed and memory savings +// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast. +// RED colored ops should not be done in FP16 due to numerical reasons. enum FP16ConversionCategory { RED, GRAY, GREEN }; std::unordered_map conversion_category_strings( {{RED, "Red"}, {GRAY, "Gray"}, {GREEN, "Green"}}); @@ -32,7 +35,10 @@ OpStringSet DEFAULT_GREEN_LIST({ "nn.conv2d_transpose", "nn.conv3d_transpose", "nn.dense", + "nn.batch_matmul", }); +// TODO make a list of ops which don't care about the types of tensors coming in for stuff like +// "where" and "strided_slice" OpStringSet DEFAULT_GRAY_LIST({ // These ops add new data or change shape "nn.pad", @@ -40,6 +46,32 @@ OpStringSet DEFAULT_GRAY_LIST({ "concatenate", "zeros", "split", + "squeeze", + "transpose", + "expand_dims", + "reshape", + "dyn.reshape", + "broadcast_to_like", + "dyn.broadcast_to", + "strided_slice", + "dyn.strided_slice", + "take", + "argwhere", + "where", + "tile", + "dyn.tile", + "scatter", + "full", + "dyn.full", + // Comparison + "less", + "greater", + "less_equal", + "greater_equal", + // By definition copy and cast will become green or red based on inputs + "copy", + "cast", + "cast_like", // Simple arithmetic "add", "subtract", @@ -47,7 +79,15 @@ OpStringSet DEFAULT_GRAY_LIST({ "divide", "nn.bias_add", "nn.batch_norm", + "sum", + "mean", + "sqrt", + "shape_of", // Simple activations + "max", + "min", + "maximum", + "minimum", "nn.relu", "nn.leaky_relu", "nn.prelu", @@ -76,15 +116,23 @@ OpStringSet DEFAULT_GRAY_LIST({ "nn.adaptive_avg_pool3d", }); OpStringSet DEFAULT_RED_LIST({ - // In general if |f(x)| >> |x| for some expected inputs to the op then put it here. - // Activations with exponents or dividing by small numbers + // In general if |f(x)| >> |x| for expected inputs then put the op here. + "exp", + "power", "nn.cross_entropy", "nn.cross_entropy_with_logits", "nn.softmax", "nn.l2_normalize", + // Error function doesn't seem to be able to be lowered into fp16 version in llvm. + // Move to gray list when it does. + "erf", }); class DefaultFP16Colorer { + /* The default class to initially color ops for conversion using lists. + + Creates a callable which given a CallNode* returns the node's color. + */ private: std::unordered_map op_to_initial_color; @@ -111,36 +159,54 @@ class DefaultFP16Colorer { if (color == op_to_initial_color.end()) { if (ignore_missing) { - LOG(WARNING) << "Op name " + op_name + " not in included in fp16 conversion lists!."; + LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!."; return RED; } else { - LOG(FATAL) << "Op name " + op_name + " not in included in fp16 lists!."; + LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!."; } } return color->second; } else if (auto* func_node = (call->op).as()) { - // Make RED to avoid messing with function types which are complicated, fold in other pass + // TODO: make RED to avoid messing with function signatures. For now keep this simple return RED; } else { - LOG(FATAL) << "FP16 conversion only supports call nodes with op calls got " << call->op; + LOG(FATAL) << "FP16 conversion only supports call nodes with OpNodes or Functions got " + << call->op; return RED; } } }; class DefaultFP16OpDefinition { + /* The default class which determines the accumulation and + + Note this is actually kind of hard! Not every op fits neatly into the dichotomy of + returning a floating point type. In the future try using type relations to keep things better. + */ public: FP16OpDType operator()(const CallNode* call) { + // TODO: remove when batch_matmul handles accumulation dtypes well. + // Batched matmul has inconsistent support for mixed precision operations. + // Many schedules ignore the out_dtype attribute which leads to errors when + // input types do not match the out_dtype. Therefore, accumulate to fp16 if green. + if (auto op_node = call->op.as()) { + if (op_node->name == "nn.batch_matmul") { + return {DataType::Float(16), DataType::Float(16)}; + } + } + + // We assume the "out_dtype" field is always an accumulation dtype specification. if (call->attrs != NullValue()) { Array fields = call->attrs->ListFieldInfo(); for (AttrFieldInfo field_info : fields) { if (field_info->name == "out_dtype") return {DataType::Float(32), DataType::Float(16)}; } } + return {DataType::Float(16), DataType::Float(16)}; } }; } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index 742d9380dae2..4390803652c0 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -1,85 +1,56 @@ -from collections import defaultdict -from typing import * +from typing import Any, Dict, List import numpy as np import tvm from tvm import relay -from tvm.relay.op.tensor import exp -from tvm.relay.testing import densenet, lstm, mobilenet, resnet, resnet_3d, squeezenet +from tvm.relay.testing import lstm from tvm.relay.transform import RewriteFP16 -from tvm.relay.transform.transform import AnnotateSpans, InferType +from tvm.relay.transform.transform import InferType -def run_module(mod, mod_params): +def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: dev = tvm.device("llvm", 0) intrp = relay.create_executor("debug", mod, device=dev, target="llvm") - return intrp.evaluate()(**mod_params).asnumpy() - - -def verify_fp32_fp16_output_close(mod, mod_params, rtol=1e-3, atol=0): + result = intrp.evaluate()(**mod_params) + if isinstance(result, tvm.runtime.container.ADT): + result = [r.asnumpy() for r in result] + return result + else: + return [result.asnumpy()] + + +def verify_fp32_fp16_output_close( + mod: tvm.runtime.Module, mod_params: Dict[str, Any], rtol: float = 1e-3, atol: float = 0 +) -> tvm.runtime.Module: + # TODO: add InferType to list of required pass before this one mod = InferType()(mod) - mod = AnnotateSpans()(mod) result_fp32 = run_module(mod, mod_params) fp16_mod = RewriteFP16()(mod) result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close - np.testing.assert_allclose(result_fp32, result_fp16, rtol=rtol, atol=atol) + for fp32, fp16 in zip(result_fp32, result_fp16): + np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) return fp16_mod -def test_resnet18(): - np.random.seed(4321) - mod, mod_params = resnet.get_workload(1, 5, num_layers=18, image_shape=(1, 32, 32)) - mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") - - verify_fp32_fp16_output_close(mod, mod_params) - - -def test_resnet18_3d(): - np.random.seed(3215) - mod, mod_params = resnet_3d.get_workload(1, 5, num_layers=18, image_shape=(1, 3, 32, 32)) - mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 3, 32, 32)).astype("float32") - - verify_fp32_fp16_output_close(mod, mod_params) - - -def test_mobilenet(): - np.random.seed(4615) - - mod, mod_params = mobilenet.get_workload(1, 5, image_shape=(1, 32, 32)) - mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") - - verify_fp32_fp16_output_close(mod, mod_params) - - -def test_densenet(): - np.random.seed(3222) - mod, mod_params = densenet.get_workload(classes=5, batch_size=1, image_shape=(1, 224, 224)) - mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 224, 224)).astype("float32") - - verify_fp32_fp16_output_close(mod, mod_params) - - -def test_squeezenet(): - np.random.seed(5628) - mod, mod_params = squeezenet.get_workload(1, 5, image_shape=(1, 32, 32)) - mod_params["data"] = np.random.uniform(-10, 10, (1, 1, 32, 32)).astype("float32") - verify_fp32_fp16_output_close(mod, mod_params) - - def test_lstm(): + """A small stress test on a single unrolled lstm unit. + + Has internal functions and let statements the pass must work on. + """ np.random.seed(5628) - mod, mod_params = lstm.get_workload(5, 3) + units = 3 + iterations = 5 + mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) - # This is an unrolled lstm so each data should be the previous results but whatever. - # We jsut want to use this to test more complicated let statements + nested funcs - mod_params["data"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") - mod_params["data1"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") - mod_params["data2"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") - mod_params["data3"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") - mod_params["data4"] = np.random.uniform(-10, 10, (1, 3)).astype("float32") + # This is an unrolled lstm so each data should be the previous results but + # we don't care, we just want to stress test things. + for i in range(iterations): + mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform( + -10, 10, (1, units) + ).astype("float32") verify_fp32_fp16_output_close(mod, mod_params, rtol=0.01, atol=0.01) @@ -123,6 +94,56 @@ def test_convert_single_conv(): assert tvm.ir.structural_equal(fp16_mod, expected_mod) +def test_convert_conv_bn(): + """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" + np.random.seed(208) + + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + + bn_shape = [5] + gamma = relay.var("gamma", shape=bn_shape) + beta = relay.var("beta", shape=bn_shape) + moving_mean = relay.var("moving_mean", shape=bn_shape) + moving_var = relay.var("moving_var", shape=bn_shape) + bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) + mod = tvm.IRModule.from_expr(bn[0]) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + "gamma": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "beta": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + } + fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + # Creating expected module + data = relay.cast(relay.var("data", shape=data_shape), "float16") + weight = relay.cast(relay.var("weight", shape=weight_shape), "float16") + conv = relay.cast( + relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"), + "float16", + ) + + bn_shape = [5] + gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16") + beta = relay.cast(relay.var("beta", shape=bn_shape), "float16") + moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16") + moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16") + bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) + + expected_mod = tvm.IRModule.from_expr(bn[0]) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + def test_do_not_convert_softmax(): """Softmax is a red listed operation and therefore should never be fp16.""" np.random.seed(209) @@ -142,7 +163,7 @@ def test_do_not_convert_softmax(): def test_green_gray_propagates_simple(): """Conv is a green listed operation, while addition is gray. - When adjacent + As Conv outputs fp16 the add should be done in fp16. """ np.random.seed(210) data_shape = (1, 3, 32, 32) @@ -178,10 +199,7 @@ def test_green_gray_propagates_simple(): def test_red_gray_propagates_simple(): - """Conv is a green listed operation, while addition is gray. - - When adjacent - """ + """Everything after a softmax should be in FP32 (exception green colored ops)""" np.random.seed(211) shape = [1, 2, 3] a = relay.var("a", shape=shape) @@ -199,6 +217,10 @@ def test_red_gray_propagates_simple(): def test_let_statement_simple(): + """A 'simple' let statement example. + + Noticable is the mutation of the bound variable types. + """ np.random.seed(211) var1 = relay.var("var1", shape=[1, 20]) var2 = relay.var("var2", shape=[1, 20]) @@ -240,3 +262,51 @@ def test_let_statement_simple(): expected_mod = InferType()(expected_mod) assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_where_simple(): + data = relay.var("data", shape=[1, 20]) + weight = relay.var("weight", shape=[20, 20]) + a = relay.nn.dense(data, weight, units=20) + b = relay.where(data, a, a) + mod = tvm.IRModule.from_expr(b) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), + } + + output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16") + b = relay.where(data, a, a) + expected_mod = tvm.IRModule.from_expr(b) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_batch_matmul_simple(): + """Batch matmul is a special case where we try to accumulate to fp16. + + This is due to the fact heterogenous accumulation dtypes does not work + on all platforms at the moment. + """ + data = relay.var("data", shape=[1, 1, 20]) + weight = relay.var("weight", shape=[1, 20, 20]) + a = relay.nn.batch_matmul(data, weight) + mod = tvm.IRModule.from_expr(a) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"), + } + output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01) + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16") + a = relay.nn.batch_matmul(data, weight, out_dtype="float16") + expected_mod = tvm.IRModule.from_expr(a) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) From bde1c5817795c14f363c0eb08e3d03bba490ca08 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 4 Jun 2021 17:53:08 -0700 Subject: [PATCH 06/59] linting and formatting --- src/relay/transforms/fp32_to_fp16.cc | 31 ++++++++++++++++++--- src/relay/transforms/fp32_to_fp16.h | 40 +++++++++++++++++++++------- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 20f23da54f12..5567e706474d 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -1,4 +1,27 @@ - +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file fp32_to_fp16.cc + * \brief Rewrite a graph into an fp16 form. + */ #include "fp32_to_fp16.h" #include @@ -6,6 +29,8 @@ #include #include +#include + #include "pattern_utils.h" namespace tvm { @@ -41,7 +66,7 @@ class FP16GraphCreator : public ExprMutator { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ Attrs new_attrs = Attrs(call->attrs); if (new_attrs.get() != nullptr) { - // TODO: Figure out a better way to do this + // TODO(AndrewZhaoLuo): Figure out a better way to do this // modify output_dtype attributes (accumulation dtypes for ops) if (auto attrs = new_attrs.as()) { ModifyAttrsOutputDType(attrs, accumulation_dtype); @@ -85,7 +110,6 @@ class FP16GraphCreator : public ExprMutator { These represent accumulation dtypes for some operations e.g. conv2d might take in fp16 and give a fp32 result. Attrs is const because we get it as a const. - TODO: think about a better way to do this */ T* mutable_attrs = const_cast(attrs); mutable_attrs->out_dtype = accumulation_dtype; @@ -98,7 +122,6 @@ class FP16GraphCreator : public ExprMutator { This determines the output dtype for some ops. For example zeros creates a tensor of zeros of the specified dtype. Attrs is const because we get it as a const. - TODO: think about a better way to do this */ T* mutable_attrs = const_cast(attrs); mutable_attrs->dtype = accumulation_dtype; diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 2b9ca73c60cb..032ff175bcdb 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -1,3 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file fp32_to_fp16.h + * \brief Utilities and common types used for FP32->FP16 pass. + */ +#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ +#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ + #include #include #include @@ -5,6 +31,7 @@ #include #include #include +#include #include namespace tvm { @@ -37,8 +64,6 @@ OpStringSet DEFAULT_GREEN_LIST({ "nn.dense", "nn.batch_matmul", }); -// TODO make a list of ops which don't care about the types of tensors coming in for stuff like -// "where" and "strided_slice" OpStringSet DEFAULT_GRAY_LIST({ // These ops add new data or change shape "nn.pad", @@ -168,7 +193,7 @@ class DefaultFP16Colorer { return color->second; } else if (auto* func_node = (call->op).as()) { - // TODO: make RED to avoid messing with function signatures. For now keep this simple + // Make RED to avoid messing with function headers. return RED; } else { LOG(FATAL) << "FP16 conversion only supports call nodes with OpNodes or Functions got " @@ -179,14 +204,10 @@ class DefaultFP16Colorer { }; class DefaultFP16OpDefinition { - /* The default class which determines the accumulation and - - Note this is actually kind of hard! Not every op fits neatly into the dichotomy of - returning a floating point type. In the future try using type relations to keep things better. - */ + /* The default callable for determining accumulation_dtypes for ops. */ public: FP16OpDType operator()(const CallNode* call) { - // TODO: remove when batch_matmul handles accumulation dtypes well. + // TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. // Batched matmul has inconsistent support for mixed precision operations. // Many schedules ignore the out_dtype attribute which leads to errors when // input types do not match the out_dtype. Therefore, accumulate to fp16 if green. @@ -210,3 +231,4 @@ class DefaultFP16OpDefinition { } // namespace relay } // namespace tvm +#endif // TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ From 2101e6ed66f0f3a3bcaf994fd414d35451eadb42 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 4 Jun 2021 17:56:12 -0700 Subject: [PATCH 07/59] add AST header --- .../python/relay/test_fp32_to_fp16_transform.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index 4390803652c0..49a9b41cac19 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for testing FP32 -> FP16 pass""" from typing import Any, Dict, List import numpy as np From 8e82c40ce95730a161755ba20141c81639ab0bed Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 4 Jun 2021 18:32:13 -0700 Subject: [PATCH 08/59] remove todo --- tests/python/relay/test_fp32_to_fp16_transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index 49a9b41cac19..57888e28937d 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -39,7 +39,6 @@ def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: def verify_fp32_fp16_output_close( mod: tvm.runtime.Module, mod_params: Dict[str, Any], rtol: float = 1e-3, atol: float = 0 ) -> tvm.runtime.Module: - # TODO: add InferType to list of required pass before this one mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) fp16_mod = RewriteFP16()(mod) From 399121bfc95e92f14a54f4cd819cf7415d385518 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Sat, 5 Jun 2021 00:29:02 -0700 Subject: [PATCH 09/59] lint errors2 --- src/relay/transforms/fp32_to_fp16.cc | 2 +- src/relay/transforms/fp32_to_fp16.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 5567e706474d..ec3cb751fea5 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -190,7 +190,7 @@ class FP16GraphCreator : public ExprMutator { } else if (const TupleTypeNode* tuple_type = expr_type.as()) { Array new_expr; bool all_same = true; - for (int i = 0; i < (tuple_type->fields).size(); i++) { + for (size_t i = 0; i < (tuple_type->fields).size(); i++) { Expr tuple_element = GetField(expr, i); Type tuple_element_dtype = (tuple_type->fields)[i]; Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype); diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 032ff175bcdb..2bd9eb2739af 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -192,7 +192,7 @@ class DefaultFP16Colorer { } return color->second; - } else if (auto* func_node = (call->op).as()) { + } else if ((call->op).as()) { // Make RED to avoid messing with function headers. return RED; } else { From c8f7428f08fda7426dae1b67aee34dc7416031bc Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Sat, 5 Jun 2021 01:39:30 -0700 Subject: [PATCH 10/59] remove i386 incompatible features --- src/relay/transforms/fp32_to_fp16.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 2bd9eb2739af..447f89c702dd 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -46,8 +46,6 @@ struct FP16OpDType { // GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast. // RED colored ops should not be done in FP16 due to numerical reasons. enum FP16ConversionCategory { RED, GRAY, GREEN }; -std::unordered_map conversion_category_strings( - {{RED, "Red"}, {GRAY, "Gray"}, {GREEN, "Green"}}); using OpStringSet = std::unordered_set; From 42b0c041fd38c8aaedff6660a32ff4037b2a2c3b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Sat, 5 Jun 2021 23:03:08 -0700 Subject: [PATCH 11/59] Trigger CI again From 65b8d6c0c56903621232168b0ef2c036e2159870 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Sat, 5 Jun 2021 23:07:03 -0700 Subject: [PATCH 12/59] set seed --- tests/python/frontend/mxnet/test_forward.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 362a9b623d25..5ff11e823bf2 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -14,22 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np import operator +import random +import numpy as np +import pytest import tvm -from tvm import te +import tvm.testing +from tvm import relay, te from tvm.contrib import graph_executor -from tvm import relay -import mxnet as mx +import model_zoo +import mxnet as mx from mxnet import gluon from mxnet.gluon.model_zoo import vision -import random -import pytest -import model_zoo - -import tvm.testing def verify_mxnet_frontend_impl( @@ -1217,6 +1215,7 @@ def verify(shape, axis=1, fix_gamma=False): @tvm.testing.uses_gpu def test_forward_instance_norm(): + np.random.seed(90) def verify(shape, axis=1, epsilon=1e-5): x = np.random.uniform(size=shape).astype("float32") gamma = np.random.uniform(size=(shape[axis])).astype("float32") From 8860b1c9cda02b31026289524c4c015024cf11f7 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Sat, 5 Jun 2021 23:11:05 -0700 Subject: [PATCH 13/59] lint --- tests/python/frontend/mxnet/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 5ff11e823bf2..68641e8a611c 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1216,6 +1216,7 @@ def verify(shape, axis=1, fix_gamma=False): @tvm.testing.uses_gpu def test_forward_instance_norm(): np.random.seed(90) + def verify(shape, axis=1, epsilon=1e-5): x = np.random.uniform(size=shape).astype("float32") gamma = np.random.uniform(size=(shape[axis])).astype("float32") From b3b877675d0bcd448026ecfc01336bde588be139 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 7 Jun 2021 12:43:56 -0700 Subject: [PATCH 14/59] address animesh's initial comments --- python/tvm/relay/transform/transform.py | 12 +++++++----- src/relay/transforms/fp32_to_fp16.cc | 16 ++++++++-------- src/relay/transforms/fp32_to_fp16.h | 4 ++-- .../python/relay/test_fp32_to_fp16_transform.py | 4 ++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7dffd27ad499..26d62fe37947 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1200,16 +1200,18 @@ def FakeQuantizationToInteger(): return _ffi_api.FakeQuantizationToInteger() -def RewriteFP16(): +def AMPRewrite(): """ - Rewrite an FP32 relay graph into an FP16 version. Note this does mutate - the original graph putting it in a bad state potentially. + Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version + where as many operations as possible are in FP16. - TODO: don't mutate the original graph. + Note this does mutate the original graph putting it in a bad state potentially. + + TODO(AndrewZhaoLuo): don't mutate the original graph. Returns ------- ret : tvm.transform.Pass The registered RewriteFP16 pass. """ - return _ffi_api.RewriteFP16() + return _ffi_api.AMPRewrite() diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index ec3cb751fea5..5059b30c3621 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -56,7 +56,7 @@ using ColorFunc = std::function; // A function which maps green CallNodes to wanted accumulation and output dtypes using OutputDtypeFunc = std::function; -class FP16GraphCreator : public ExprMutator { +class AmpGraphCreator : public ExprMutator { private: CachedCastNodes cast_nodes_cache; const ColorFunc colorer; @@ -221,7 +221,7 @@ class FP16GraphCreator : public ExprMutator { } public: - explicit FP16GraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func) + explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func) : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {} Expr VisitExpr_(const CallNode* call_node) { @@ -305,24 +305,24 @@ class FP16GraphCreator : public ExprMutator { } }; -Expr RewriteFp16Graph(const Expr& expr, const ColorFunc& colorer, - const OutputDtypeFunc& output_dtype_func) { - FP16GraphCreator converter = FP16GraphCreator(colorer, output_dtype_func); +Expr AMPRewriteGraph(const Expr& expr, const ColorFunc& colorer, + const OutputDtypeFunc& output_dtype_func) { + AmpGraphCreator converter = AmpGraphCreator(colorer, output_dtype_func); return converter.Mutate(expr); } namespace transform { -Pass RewriteFP16() { +Pass AMPRewrite() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast( - RewriteFp16Graph(f, DefaultFP16Colorer(), DefaultFP16OpDefinition())); + AMPRewriteGraph(f, DefaultFP16Colorer(), DefaultFP16OpDefinition())); }; return CreateFunctionPass(pass_func, 10, "RewriteFp16", {}); } -TVM_REGISTER_GLOBAL("relay._transform.RewriteFP16").set_body_typed(RewriteFP16); +TVM_REGISTER_GLOBAL("relay._transform.AMPRewrite").set_body_typed(AMPRewrite); } // namespace transform diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 447f89c702dd..dafdab5dd94c 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -182,10 +182,10 @@ class DefaultFP16Colorer { if (color == op_to_initial_color.end()) { if (ignore_missing) { - LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!."; + LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!"; return RED; } else { - LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!."; + LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!"; } } diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index 57888e28937d..da4509a57ea5 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from tvm.relay.testing import lstm -from tvm.relay.transform import RewriteFP16 +from tvm.relay.transform import AMPRewrite from tvm.relay.transform.transform import InferType @@ -41,7 +41,7 @@ def verify_fp32_fp16_output_close( ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) - fp16_mod = RewriteFP16()(mod) + fp16_mod = AMPRewrite()(mod) result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close From 479124b25fbfb40e478975a921d27aec1c347432 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 7 Jun 2021 18:15:06 -0700 Subject: [PATCH 15/59] mutate attributes only if they were originally floats --- src/relay/transforms/fp32_to_fp16.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 5059b30c3621..29f278936755 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -112,7 +112,7 @@ class AmpGraphCreator : public ExprMutator { Attrs is const because we get it as a const. */ T* mutable_attrs = const_cast(attrs); - mutable_attrs->out_dtype = accumulation_dtype; + if ((mutable_attrs->out_dtype).is_float()) mutable_attrs->out_dtype = accumulation_dtype; } template @@ -124,12 +124,13 @@ class AmpGraphCreator : public ExprMutator { Attrs is const because we get it as a const. */ T* mutable_attrs = const_cast(attrs); - mutable_attrs->dtype = accumulation_dtype; + if ((mutable_attrs->dtype).is_float()) mutable_attrs->dtype = accumulation_dtype; } Type GetType(const Expr& expr) const { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); + if (expr.as()) { return mod->Lookup("main")->checked_type(); } else { @@ -261,6 +262,7 @@ class AmpGraphCreator : public ExprMutator { Array call_args = call_args_and_types.first; Array call_arg_types; + if (call_node->op.as()) { // Function Nodes don't store type info in the Call, it should be a [] call_arg_types = call_node->type_args; From 22ae9e77acb9fdcc6bd16cf24002aaa54ffd5212 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 8 Jun 2021 13:43:10 -0700 Subject: [PATCH 16/59] initial comments from matthew --- src/relay/transforms/fp32_to_fp16.cc | 2 +- tests/python/frontend/mxnet/test_forward.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 29f278936755..af1809b94cfe 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -43,7 +43,7 @@ struct pair_hash { auto h1 = std::hash()(pair.first); auto h2 = std::hash()(pair.second); - return h1 ^ (h2 << 1); + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)) } }; diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 68641e8a611c..a6c3d6efec56 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1215,8 +1215,6 @@ def verify(shape, axis=1, fix_gamma=False): @tvm.testing.uses_gpu def test_forward_instance_norm(): - np.random.seed(90) - def verify(shape, axis=1, epsilon=1e-5): x = np.random.uniform(size=shape).astype("float32") gamma = np.random.uniform(size=(shape[axis])).astype("float32") @@ -1231,7 +1229,9 @@ def verify(shape, axis=1, epsilon=1e-5): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) op_res = intrp.evaluate()(x, gamma, beta) - tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=2e-5, atol=1e-5 + ) verify((2, 3, 4, 5)) verify((32, 64, 80, 64)) From d95684855c24ba277bc686dda42002708e0874ed Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 8 Jun 2021 17:24:40 -0700 Subject: [PATCH 17/59] add comment on hashing strat --- src/relay/transforms/fp32_to_fp16.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index af1809b94cfe..ff6f02282b32 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -43,6 +43,7 @@ struct pair_hash { auto h1 = std::hash()(pair.first); auto h2 = std::hash()(pair.second); + // Use boost's combine_hash strategy return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)) } }; From cb39e0f98e09937946bf26694fb2effc95b774aa Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 8 Jun 2021 21:07:14 -0700 Subject: [PATCH 18/59] add missing ; --- src/relay/transforms/fp32_to_fp16.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index ff6f02282b32..3f8f88c49b80 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -44,7 +44,7 @@ struct pair_hash { auto h2 = std::hash()(pair.second); // Use boost's combine_hash strategy - return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)) + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); } }; From a00fd8bc5de54fc92b335a16ea21261c54a4f546 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 10:38:24 -0700 Subject: [PATCH 19/59] edge case when mutating attrs --- src/relay/transforms/fp32_to_fp16.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 3f8f88c49b80..2e860c6123aa 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -113,7 +113,9 @@ class AmpGraphCreator : public ExprMutator { Attrs is const because we get it as a const. */ T* mutable_attrs = const_cast(attrs); - if ((mutable_attrs->out_dtype).is_float()) mutable_attrs->out_dtype = accumulation_dtype; + + DataType cur_type = (mutable_attrs->out_dtype); + if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype; } template @@ -125,7 +127,9 @@ class AmpGraphCreator : public ExprMutator { Attrs is const because we get it as a const. */ T* mutable_attrs = const_cast(attrs); - if ((mutable_attrs->dtype).is_float()) mutable_attrs->dtype = accumulation_dtype; + DataType cur_type = (mutable_attrs->dtype); + + if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype; } Type GetType(const Expr& expr) const { From e25c40c0ed4a5206063ed433569489793c357473 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 12:34:13 -0700 Subject: [PATCH 20/59] Cody's easy to address comments --- src/relay/transforms/fp32_to_fp16.cc | 6 +++--- tests/python/relay/test_fp32_to_fp16_transform.py | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 2e860c6123aa..5f10841d074e 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -245,7 +245,7 @@ class AmpGraphCreator : public ExprMutator { new_args.push_back(new_arg); new_arg_types.push_back(new_arg_type); - if (all_args_fp16_compatible) { + if (initial_color == GRAY && all_args_fp16_compatible) { // We can cast Vars and Constants to the right types so don't care about the types. bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance() || arg->IsInstance(); @@ -285,9 +285,9 @@ class AmpGraphCreator : public ExprMutator { output = CastArg(output, GetType(output), output_dtypes.output_dtype); } return output; - } else { - return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span); } + + return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span); } Expr VisitExpr_(const FunctionNode* func) final { diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index da4509a57ea5..c62d77b7a0e6 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -18,6 +18,7 @@ from typing import Any, Dict, List import numpy as np +import pytest import tvm from tvm import relay from tvm.relay.testing import lstm @@ -326,3 +327,7 @@ def test_batch_matmul_simple(): expected_mod = tvm.IRModule.from_expr(a) expected_mod = InferType()(expected_mod) assert tvm.ir.structural_equal(expected_mod, output_mod) + + +if __name__ == "__main__": + pytest.main([__file__]) From 70436f537e132f592766290d114d5ecb0e9fce0b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 13:05:16 -0700 Subject: [PATCH 21/59] add test to show green-red casting works --- .../relay/test_fp32_to_fp16_transform.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index c62d77b7a0e6..d0fbc368c2fa 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -215,6 +215,63 @@ def test_green_gray_propagates_simple(): assert tvm.ir.structural_equal(fp16_mod, expected_mod) +def test_green_red_not_use_extraneous_cast(): + """Conv. is a green listed operation, while softmax is red. + + Conv. also by default accumulates to fp32 but outputs fp16. + + We want to avoid a situation where we have extraneous casts. + E.g. because softmax wants to operate on FP32 we might have + + conv (FP32) -> cast (FP16) -> cast (FP32) -> softmax (FP32) + + To get around this internally when we cast in the pass we cache + the output nodes and the reverse of the cast back to the original + node. For example casting the `conv (FP32)` to FP16 would produce: + + `conv (FP32) -> cast (FP16)` + + As the outputs. Now anytime we try to cast the `conv (FP32)` node + to FP16 it would return the cached result instead of a new cast node: + + `conv (FP32) -> cast (FP16)` + + Furthermore, if we try to cast the `cast (FP16)` node back to FP32 it + would just return + + `conv (FP32)`. + + This test makes sure this behavior occurs. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + result = relay.nn.softmax(conv) + mod = tvm.IRModule.from_expr(result) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + # Construct expected structure + conv = relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ) + result = relay.nn.softmax(conv) + expected_mod = tvm.IRModule.from_expr(result) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, fp16_mod) + + def test_red_gray_propagates_simple(): """Everything after a softmax should be in FP32 (exception green colored ops)""" np.random.seed(211) From 2c78317088ee71ad55a4922d3cd46f5d75310e25 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 13:07:35 -0700 Subject: [PATCH 22/59] remove np.random seed from each test --- tests/python/relay/test_fp32_to_fp16_transform.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index d0fbc368c2fa..be01ef5c80bc 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -57,7 +57,6 @@ def test_lstm(): Has internal functions and let statements the pass must work on. """ - np.random.seed(5628) units = 3 iterations = 5 mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) @@ -77,8 +76,6 @@ def test_convert_single_conv(): By default it accumulates to fp32 and outputs fp16. """ - np.random.seed(208) - data_shape = (1, 3, 32, 32) weight_shape = (5, 3, 3, 3) data = relay.var("data", shape=data_shape, dtype="float32") @@ -113,8 +110,6 @@ def test_convert_single_conv(): def test_convert_conv_bn(): """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" - np.random.seed(208) - data_shape = (1, 3, 32, 32) weight_shape = (5, 3, 3, 3) data = relay.var("data", shape=data_shape, dtype="float32") @@ -163,7 +158,6 @@ def test_convert_conv_bn(): def test_do_not_convert_softmax(): """Softmax is a red listed operation and therefore should never be fp16.""" - np.random.seed(209) shape = [1, 2, 3] a = relay.var("a", shape=shape) b = relay.nn.softmax(a) @@ -182,7 +176,6 @@ def test_green_gray_propagates_simple(): As Conv outputs fp16 the add should be done in fp16. """ - np.random.seed(210) data_shape = (1, 3, 32, 32) weight_shape = (5, 3, 3, 3) data = relay.var("data", shape=data_shape, dtype="float32") @@ -274,7 +267,6 @@ def test_green_red_not_use_extraneous_cast(): def test_red_gray_propagates_simple(): """Everything after a softmax should be in FP32 (exception green colored ops)""" - np.random.seed(211) shape = [1, 2, 3] a = relay.var("a", shape=shape) b = relay.nn.softmax(a) @@ -293,9 +285,8 @@ def test_red_gray_propagates_simple(): def test_let_statement_simple(): """A 'simple' let statement example. - Noticable is the mutation of the bound variable types. + Noticeable is the mutation of the bound variable types. """ - np.random.seed(211) var1 = relay.var("var1", shape=[1, 20]) var2 = relay.var("var2", shape=[1, 20]) From 44b9782f391c5ddb2f6a4213d4fa5ec2e2a19888 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 13:55:16 -0700 Subject: [PATCH 23/59] remove as many references to fp16 types in favor of generic mixed types --- src/relay/transforms/fp32_to_fp16.cc | 59 +++++++++++++++---------- src/relay/transforms/fp32_to_fp16.h | 65 ++++++++++++++++------------ 2 files changed, 73 insertions(+), 51 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 5f10841d074e..b9eb439806a3 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -52,16 +52,17 @@ struct pair_hash { using CachedCastNodes = std::unordered_map, Expr, pair_hash>; // A function which maps CallNodes to their initial conversion color -using ColorFunc = std::function; +using ColorFunc = std::function; // A function which maps green CallNodes to wanted accumulation and output dtypes -using OutputDtypeFunc = std::function; +using OutputDtypeFunc = std::function; class AmpGraphCreator : public ExprMutator { private: CachedCastNodes cast_nodes_cache; const ColorFunc colorer; const OutputDtypeFunc output_dtype_func; + const DataType mixed_precision_type; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -143,16 +144,16 @@ class AmpGraphCreator : public ExprMutator { } } - bool IsFP16Type(const Type& t, bool ignore_non_float = false) const { - /* Returns whether t is a type with only fp16 elements. + bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const { + /* Returns whether t is a type with only target mixed precision type elements. If ignore_non_float, then ignore non-floating types. */ if (const TensorTypeNode* tensor_type = t.as()) { return (!ignore_non_float || (tensor_type->dtype).is_float()) && - tensor_type->dtype == DataType::Float(16); + tensor_type->dtype == mixed_precision_type; } else if (const TupleTypeNode* tuple_type = t.as()) { for (Type t : tuple_type->fields) { - if (!IsFP16Type(t, ignore_non_float)) return false; + if (!IsMixedPrecisionType(t, ignore_non_float)) return false; } return true; } else { @@ -227,42 +228,52 @@ class AmpGraphCreator : public ExprMutator { } public: - explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func) - : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {} + explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func, + DataType mixed_precision_type = DataType::Float(16)) + : ExprMutator(), + colorer(colorer), + output_dtype_func(output_dtype_func), + mixed_precision_type(mixed_precision_type) { + if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) + LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat 16 got " + << mixed_precision_type; + } Expr VisitExpr_(const CallNode* call_node) { - FP16ConversionCategory initial_color = colorer(call_node); + MixedTypeConversionCategory initial_color = colorer(call_node); auto new_op = this->Mutate(call_node->op); - // Mutate arguments to FP16 form first if possible and keep track of whether all floating point - // tensors are in FP16 form already. This is useful for propagating color. + // Mutate arguments to reduced precision form first if possible and keep track of + // whether all floating point tensors are in reduced precision form already. This is + // useful for propagating conversion conditions. Array new_args; Array new_arg_types; - bool all_args_fp16_compatible = true; + bool all_args_mixed_type_compatible = true; for (Expr arg : call_node->args) { Expr new_arg = this->Mutate(arg); Type new_arg_type = GetType(new_arg); new_args.push_back(new_arg); new_arg_types.push_back(new_arg_type); - if (initial_color == GRAY && all_args_fp16_compatible) { + if (initial_color == GRAY && all_args_mixed_type_compatible) { // We can cast Vars and Constants to the right types so don't care about the types. - bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance() || - arg->IsInstance(); - all_args_fp16_compatible &= is_fp16_compatible; + bool is_mixed_type_compatible = IsMixedPrecisionType(new_arg_type, true) || + arg->IsInstance() || + arg->IsInstance(); + all_args_mixed_type_compatible &= is_mixed_type_compatible; } } // Determine the final color. - FP16ConversionCategory final_color; + MixedTypeConversionCategory final_color; if (initial_color == GRAY) { - final_color = all_args_fp16_compatible ? GREEN : RED; + final_color = all_args_mixed_type_compatible ? GREEN : RED; } else { final_color = initial_color; } // Create the new arguments to the call. - DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32); + DataType wanted_arg_dtypes = final_color == GREEN ? mixed_precision_type : DataType::Float(32); auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes); Array call_args = call_args_and_types.first; @@ -277,7 +288,7 @@ class AmpGraphCreator : public ExprMutator { // Finally create the new attributes. if (final_color == GREEN) { - FP16OpDType output_dtypes = output_dtype_func(call_node); + MixedPrecisionOpOutDType output_dtypes = output_dtype_func(call_node); Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype); Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span); @@ -297,7 +308,7 @@ class AmpGraphCreator : public ExprMutator { } Expr VisitExpr_(const LetNode* op) final { - // First convert as much of the bound computation to FP16 as possible + // First convert as much of the bound computation to lower precision as possible Expr value = this->Mutate(op->value); // Then rewrite the var type and associated expression @@ -323,10 +334,10 @@ namespace transform { Pass AMPRewrite() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast( - AMPRewriteGraph(f, DefaultFP16Colorer(), DefaultFP16OpDefinition())); + return Downcast(AMPRewriteGraph(f, DefaultMixedPrecisionColorer(), + DefaultMixedPrecisionOpDefinition())); }; - return CreateFunctionPass(pass_func, 10, "RewriteFp16", {}); + return CreateFunctionPass(pass_func, 10, "AMPRewrite", {}); } TVM_REGISTER_GLOBAL("relay._transform.AMPRewrite").set_body_typed(AMPRewrite); diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index dafdab5dd94c..0d59fd82865f 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -37,15 +37,15 @@ namespace tvm { namespace relay { -struct FP16OpDType { +struct MixedPrecisionOpOutDType { DataType accumulation_dtype; DataType output_dtype; }; -// GREEN colored ops should always be done in FP16 due to the speed and memory savings -// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast. -// RED colored ops should not be done in FP16 due to numerical reasons. -enum FP16ConversionCategory { RED, GRAY, GREEN }; +// GREEN colored ops should always be done in lower precision due to the speed and memory savings +// GRAY colored ops can be done in lower precision but don't have speedups to justify a dedicated +// cast. RED colored ops should not be done in lower precision due to numerical reasons. +enum MixedTypeConversionCategory { RED, GRAY, GREEN }; using OpStringSet = std::unordered_set; @@ -151,42 +151,40 @@ OpStringSet DEFAULT_RED_LIST({ "erf", }); -class DefaultFP16Colorer { +class DefaultMixedPrecisionColorer { /* The default class to initially color ops for conversion using lists. + Default lists are for NVidia Tensor Cores and FP16. Creates a callable which given a CallNode* returns the node's color. */ private: - std::unordered_map op_to_initial_color; + std::unordered_map op_to_initial_color; public: - DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST, - OpStringSet gray_list = DEFAULT_GRAY_LIST, - OpStringSet green_list = DEFAULT_GREEN_LIST) { - std::vector> lists_and_colors{ + DefaultMixedPrecisionColorer(OpStringSet red_list = DEFAULT_RED_LIST, + OpStringSet gray_list = DEFAULT_GRAY_LIST, + OpStringSet green_list = DEFAULT_GREEN_LIST) { + std::vector> lists_and_colors{ {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}}; for (auto list_and_color : lists_and_colors) { OpStringSet ops = list_and_color.first; - FP16ConversionCategory color = list_and_color.second; + MixedTypeConversionCategory color = list_and_color.second; for (std::string op_name : ops) { op_to_initial_color.insert({{op_name, color}}); } } } - FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) { + MixedTypeConversionCategory operator()(const CallNode* call, bool ignore_missing = true) { if (auto* op_node = (call->op).as()) { std::string op_name = op_node->name; auto color = op_to_initial_color.find(op_name); if (color == op_to_initial_color.end()) { - if (ignore_missing) { - LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!"; - return RED; - } else { - LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!"; - } + (ignore_missing ? LOG(WARNING) : LOG(FATAL)) + << "Op name " << op_name << " not in included in conversion lists!"; + return RED; } return color->second; @@ -194,24 +192,36 @@ class DefaultFP16Colorer { // Make RED to avoid messing with function headers. return RED; } else { - LOG(FATAL) << "FP16 conversion only supports call nodes with OpNodes or Functions got " + LOG(FATAL) << "Conversion only supports call nodes with OpNodes or Functions got " << call->op; return RED; } } }; -class DefaultFP16OpDefinition { - /* The default callable for determining accumulation_dtypes for ops. */ +class DefaultMixedPrecisionOpDefinition { + /* The default callable for determining accumulation_dtypes for ops. + + Assumes accumulatable operations accumulate to one type and outputs are + all of the same type.*/ + + const DataType default_output_dtype; + const DataType default_accumulation_dtype; + public: - FP16OpDType operator()(const CallNode* call) { + DefaultMixedPrecisionOpDefinition(DataType default_output_dtype = DataType::Float(16), + DataType default_accumulation_dtype = DataType::Float(32)) + : default_output_dtype(default_output_dtype), + default_accumulation_dtype(default_accumulation_dtype) {} + + MixedPrecisionOpOutDType operator()(const CallNode* call) { // TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. // Batched matmul has inconsistent support for mixed precision operations. // Many schedules ignore the out_dtype attribute which leads to errors when - // input types do not match the out_dtype. Therefore, accumulate to fp16 if green. + // input types do not match the out_dtype. Therefore, accumulate to output_dtype if green. if (auto op_node = call->op.as()) { if (op_node->name == "nn.batch_matmul") { - return {DataType::Float(16), DataType::Float(16)}; + return {default_output_dtype, default_output_dtype}; } } @@ -219,11 +229,12 @@ class DefaultFP16OpDefinition { if (call->attrs != NullValue()) { Array fields = call->attrs->ListFieldInfo(); for (AttrFieldInfo field_info : fields) { - if (field_info->name == "out_dtype") return {DataType::Float(32), DataType::Float(16)}; + if (field_info->name == "out_dtype") + return {default_accumulation_dtype, default_output_dtype}; } } - return {DataType::Float(16), DataType::Float(16)}; + return {default_output_dtype, default_output_dtype}; } }; From 4911d4fefa0aa56de72c75ee84582fd42dd8e885 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 14:13:21 -0700 Subject: [PATCH 24/59] rename RED, GREEN, GRAY to MIXED_PRECISION_ALLOW, etc. --- src/relay/transforms/fp32_to_fp16.cc | 22 +++++++------- src/relay/transforms/fp32_to_fp16.h | 43 ++++++++++++++++------------ 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index b9eb439806a3..a3ed87fd7294 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -54,7 +54,7 @@ using CachedCastNodes = std::unordered_map, // A function which maps CallNodes to their initial conversion color using ColorFunc = std::function; -// A function which maps green CallNodes to wanted accumulation and output dtypes +// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes using OutputDtypeFunc = std::function; class AmpGraphCreator : public ExprMutator { @@ -240,7 +240,7 @@ class AmpGraphCreator : public ExprMutator { } Expr VisitExpr_(const CallNode* call_node) { - MixedTypeConversionCategory initial_color = colorer(call_node); + MixedTypeConversionCategory initial_category = colorer(call_node); auto new_op = this->Mutate(call_node->op); // Mutate arguments to reduced precision form first if possible and keep track of @@ -255,7 +255,7 @@ class AmpGraphCreator : public ExprMutator { new_args.push_back(new_arg); new_arg_types.push_back(new_arg_type); - if (initial_color == GRAY && all_args_mixed_type_compatible) { + if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) { // We can cast Vars and Constants to the right types so don't care about the types. bool is_mixed_type_compatible = IsMixedPrecisionType(new_arg_type, true) || arg->IsInstance() || @@ -264,16 +264,18 @@ class AmpGraphCreator : public ExprMutator { } } - // Determine the final color. - MixedTypeConversionCategory final_color; - if (initial_color == GRAY) { - final_color = all_args_mixed_type_compatible ? GREEN : RED; + // Determine the final category for conversion. + MixedTypeConversionCategory final_category; + if (initial_category == MIXED_PRECISION_FOLLOW) { + final_category = + all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; } else { - final_color = initial_color; + final_category = initial_category; } // Create the new arguments to the call. - DataType wanted_arg_dtypes = final_color == GREEN ? mixed_precision_type : DataType::Float(32); + DataType wanted_arg_dtypes = + final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32); auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes); Array call_args = call_args_and_types.first; @@ -287,7 +289,7 @@ class AmpGraphCreator : public ExprMutator { } // Finally create the new attributes. - if (final_color == GREEN) { + if (final_category == MIXED_PRECISION_ALWAYS) { MixedPrecisionOpOutDType output_dtypes = output_dtype_func(call_node); Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype); diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/fp32_to_fp16.h index 0d59fd82865f..3f4ccdc8d05f 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/fp32_to_fp16.h @@ -42,17 +42,22 @@ struct MixedPrecisionOpOutDType { DataType output_dtype; }; -// GREEN colored ops should always be done in lower precision due to the speed and memory savings -// GRAY colored ops can be done in lower precision but don't have speedups to justify a dedicated -// cast. RED colored ops should not be done in lower precision due to numerical reasons. -enum MixedTypeConversionCategory { RED, GRAY, GREEN }; +// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +// numerical reasons. +enum MixedTypeConversionCategory { + MIXED_PRECISION_ALWAYS, + MIXED_PRECISION_FOLLOW, + MIXED_PRECISION_NEVER +}; using OpStringSet = std::unordered_set; // Default lists inspired from TF's classifications: // github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h // They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. -OpStringSet DEFAULT_GREEN_LIST({ +OpStringSet DEFAULT_ALWAYS_LIST({ "nn.conv1d", "nn.conv2d", "nn.conv3d", @@ -62,7 +67,7 @@ OpStringSet DEFAULT_GREEN_LIST({ "nn.dense", "nn.batch_matmul", }); -OpStringSet DEFAULT_GRAY_LIST({ +OpStringSet DEFAULT_FOLLOW_LIST({ // These ops add new data or change shape "nn.pad", "nn.batch_flatten", @@ -91,7 +96,7 @@ OpStringSet DEFAULT_GRAY_LIST({ "greater", "less_equal", "greater_equal", - // By definition copy and cast will become green or red based on inputs + // By definition copy and cast will depend on inputs for output. "copy", "cast", "cast_like", @@ -138,7 +143,7 @@ OpStringSet DEFAULT_GRAY_LIST({ "nn.adaptive_avg_pool2d", "nn.adaptive_avg_pool3d", }); -OpStringSet DEFAULT_RED_LIST({ +OpStringSet DEFAULT_NEVER_LIST({ // In general if |f(x)| >> |x| for expected inputs then put the op here. "exp", "power", @@ -147,7 +152,7 @@ OpStringSet DEFAULT_RED_LIST({ "nn.softmax", "nn.l2_normalize", // Error function doesn't seem to be able to be lowered into fp16 version in llvm. - // Move to gray list when it does. + // Move to follow list when it does. "erf", }); @@ -161,11 +166,13 @@ class DefaultMixedPrecisionColorer { std::unordered_map op_to_initial_color; public: - DefaultMixedPrecisionColorer(OpStringSet red_list = DEFAULT_RED_LIST, - OpStringSet gray_list = DEFAULT_GRAY_LIST, - OpStringSet green_list = DEFAULT_GREEN_LIST) { + DefaultMixedPrecisionColorer(OpStringSet never_list = DEFAULT_NEVER_LIST, + OpStringSet follow_list = DEFAULT_FOLLOW_LIST, + OpStringSet always_list = DEFAULT_ALWAYS_LIST) { std::vector> lists_and_colors{ - {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}}; + {never_list, MIXED_PRECISION_NEVER}, + {follow_list, MIXED_PRECISION_FOLLOW}, + {always_list, MIXED_PRECISION_ALWAYS}}; for (auto list_and_color : lists_and_colors) { OpStringSet ops = list_and_color.first; @@ -184,17 +191,17 @@ class DefaultMixedPrecisionColorer { if (color == op_to_initial_color.end()) { (ignore_missing ? LOG(WARNING) : LOG(FATAL)) << "Op name " << op_name << " not in included in conversion lists!"; - return RED; + return MIXED_PRECISION_NEVER; } return color->second; } else if ((call->op).as()) { - // Make RED to avoid messing with function headers. - return RED; + // Make MIXED_PRECISION_NEVER to avoid messing with function headers. + return MIXED_PRECISION_NEVER; } else { LOG(FATAL) << "Conversion only supports call nodes with OpNodes or Functions got " << call->op; - return RED; + return MIXED_PRECISION_NEVER; } } }; @@ -218,7 +225,7 @@ class DefaultMixedPrecisionOpDefinition { // TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. // Batched matmul has inconsistent support for mixed precision operations. // Many schedules ignore the out_dtype attribute which leads to errors when - // input types do not match the out_dtype. Therefore, accumulate to output_dtype if green. + // input types do not match the out_dtype. Therefore, accumulate to output_dtype. if (auto op_node = call->op.as()) { if (op_node->name == "nn.batch_matmul") { return {default_output_dtype, default_output_dtype}; From 47c2cf8682c894ecbfcd1548fe49b9e562572c47 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 15:40:23 -0700 Subject: [PATCH 25/59] skeleton for supporting arbitrary mixed types --- python/tvm/relay/transform/transform.py | 4 +- src/relay/transforms/fp32_to_fp16.cc | 19 +++++--- .../relay/test_fp32_to_fp16_transform.py | 48 +++++++++++++++++-- 3 files changed, 58 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 26d62fe37947..7f81ac341541 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1200,7 +1200,7 @@ def FakeQuantizationToInteger(): return _ffi_api.FakeQuantizationToInteger() -def AMPRewrite(): +def AMPRewrite(mixed_precision_type="float16"): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in FP16. @@ -1214,4 +1214,4 @@ def AMPRewrite(): ret : tvm.transform.Pass The registered RewriteFP16 pass. """ - return _ffi_api.AMPRewrite() + return _ffi_api.AMPRewrite(mixed_precision_type) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index a3ed87fd7294..71e16b8b3ea1 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -235,7 +235,7 @@ class AmpGraphCreator : public ExprMutator { output_dtype_func(output_dtype_func), mixed_precision_type(mixed_precision_type) { if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) - LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat 16 got " + LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got " << mixed_precision_type; } @@ -276,6 +276,7 @@ class AmpGraphCreator : public ExprMutator { // Create the new arguments to the call. DataType wanted_arg_dtypes = final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32); + auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes); Array call_args = call_args_and_types.first; @@ -326,18 +327,22 @@ class AmpGraphCreator : public ExprMutator { }; Expr AMPRewriteGraph(const Expr& expr, const ColorFunc& colorer, - const OutputDtypeFunc& output_dtype_func) { - AmpGraphCreator converter = AmpGraphCreator(colorer, output_dtype_func); - return converter.Mutate(expr); + const OutputDtypeFunc& output_dtype_func, + const DataType& mixed_precision_type) { + AmpGraphCreator converter = AmpGraphCreator(colorer, output_dtype_func, mixed_precision_type); + auto result = converter.Mutate(expr); + return result; } namespace transform { -Pass AMPRewrite() { +Pass AMPRewrite(DataType mixed_precision_type) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(AMPRewriteGraph(f, DefaultMixedPrecisionColorer(), - DefaultMixedPrecisionOpDefinition())); + return Downcast(AMPRewriteGraph( + f, DefaultMixedPrecisionColorer(), + DefaultMixedPrecisionOpDefinition(mixed_precision_type, DataType::Float(32)), + mixed_precision_type)); }; return CreateFunctionPass(pass_func, 10, "AMPRewrite", {}); } diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index be01ef5c80bc..39bea90c4325 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -22,8 +22,7 @@ import tvm from tvm import relay from tvm.relay.testing import lstm -from tvm.relay.transform import AMPRewrite -from tvm.relay.transform.transform import InferType +from tvm.relay.transform import AMPRewrite, InferType def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: @@ -38,11 +37,15 @@ def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: def verify_fp32_fp16_output_close( - mod: tvm.runtime.Module, mod_params: Dict[str, Any], rtol: float = 1e-3, atol: float = 0 + mod: tvm.runtime.Module, + mod_params: Dict[str, Any], + mixed_precision_dtype="float16", + rtol: float = 1e-3, + atol: float = 0, ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) - fp16_mod = AMPRewrite()(mod) + fp16_mod = AMPRewrite(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close @@ -108,6 +111,43 @@ def test_convert_single_conv(): assert tvm.ir.structural_equal(fp16_mod, expected_mod) +def test_convert_single_conv_bfloat16(): + """Stuff""" + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + + fp16_mod = verify_fp32_fp16_output_close( + mod, mod_params, mixed_precision_dtype="bfloat16", atol=0.01, rtol=1e-3 + ) + + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + def test_convert_conv_bn(): """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" data_shape = (1, 3, 32, 32) From 239dbfbcff4da5681edd38132df0ad69d84e1b6d Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 9 Jun 2021 19:57:01 -0700 Subject: [PATCH 26/59] cool tests --- .../relay/test_fp32_to_fp16_transform.py | 76 ++++++++----------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index 39bea90c4325..cfbc64ca9cf5 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -36,7 +36,7 @@ def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: return [result.asnumpy()] -def verify_fp32_fp16_output_close( +def verify_mixed_precision_output_close( mod: tvm.runtime.Module, mod_params: Dict[str, Any], mixed_precision_dtype="float16", @@ -71,48 +71,37 @@ def test_lstm(): -10, 10, (1, units) ).astype("float32") - verify_fp32_fp16_output_close(mod, mod_params, rtol=0.01, atol=0.01) + verify_mixed_precision_output_close(mod, mod_params, rtol=0.01, atol=0.01) -def test_convert_single_conv(): - """Conv is a green listed operation meaning it will always use fp16 workload. +def test_lstm_float64(): + """Tests if can handle other mixed precision types. - By default it accumulates to fp32 and outputs fp16. + As a toy example show can convert graph to float64 and have it run. + + It doesn't really make sense to do it, this just shows you can. """ - data_shape = (1, 3, 32, 32) - weight_shape = (5, 3, 3, 3) - data = relay.var("data", shape=data_shape, dtype="float32") - weight = relay.var("weight", shape=weight_shape, dtype="float32") - conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") - mod = tvm.IRModule.from_expr(conv) - mod = tvm.relay.transform.InferType()(mod) + units = 3 + iterations = 5 + mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) - mod_params = { - "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), - "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), - } - fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + # This is an unrolled lstm so each data should be the previous results but + # we don't care, we just want to stress test things. + for i in range(iterations): + mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform( + -10, 10, (1, units) + ).astype("float32") - expected_mod = tvm.IRModule.from_expr( - relay.cast( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float16", - ) + verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype="float64", rtol=0.01, atol=0.01 ) - expected_mod = tvm.relay.transform.InferType()(expected_mod) - assert not tvm.ir.structural_equal(fp16_mod, mod) - assert tvm.ir.structural_equal(fp16_mod, expected_mod) +def test_convert_single_conv(): + """Conv is a green listed operation meaning it will always use fp16 workload. -def test_convert_single_conv_bfloat16(): - """Stuff""" + By default it accumulates to fp32 and outputs fp16. + """ data_shape = (1, 3, 32, 32) weight_shape = (5, 3, 3, 3) data = relay.var("data", shape=data_shape, dtype="float32") @@ -125,10 +114,7 @@ def test_convert_single_conv_bfloat16(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - - fp16_mod = verify_fp32_fp16_output_close( - mod, mod_params, mixed_precision_dtype="bfloat16", atol=0.01, rtol=1e-3 - ) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) expected_mod = tvm.IRModule.from_expr( relay.cast( @@ -173,7 +159,7 @@ def test_convert_conv_bn(): "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), } - fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) # Creating expected module data = relay.cast(relay.var("data", shape=data_shape), "float16") @@ -207,7 +193,7 @@ def test_do_not_convert_softmax(): mod_params = { "a": np.random.uniform(-1, 1, size=shape).astype("float32"), } - output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) assert tvm.ir.structural_equal(mod, output_mod) @@ -229,7 +215,7 @@ def test_green_gray_propagates_simple(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) conv_expr = relay.cast( relay.nn.conv2d( @@ -288,7 +274,7 @@ def test_green_red_not_use_extraneous_cast(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) # Construct expected structure conv = relay.nn.conv2d( @@ -317,7 +303,7 @@ def test_red_gray_propagates_simple(): mod_params = { "a": np.random.uniform(-1, 1, size=shape).astype("float32"), } - output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0.0) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0.0) assert tvm.ir.structural_equal(mod, output_mod) @@ -344,7 +330,7 @@ def test_let_statement_simple(): "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), } - output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) # Construct expected structure var1 = relay.var("var1", shape=[1, 20], dtype="float16") @@ -380,7 +366,7 @@ def test_where_simple(): "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), } - output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) # Create expected module data = relay.cast(relay.var("data", shape=[1, 20]), "float16") @@ -407,7 +393,7 @@ def test_batch_matmul_simple(): "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"), "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"), } - output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) # Create expected module data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16") weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16") From 33e286f3dfc03b7fa36e5099fd1bfb4be7b30498 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 10 Jun 2021 11:04:08 -0700 Subject: [PATCH 27/59] Using MixedModeMutator --- src/relay/transforms/fp32_to_fp16.cc | 58 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 71e16b8b3ea1..77338aeadf7b 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -57,7 +57,7 @@ using ColorFunc = std::function; // A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes using OutputDtypeFunc = std::function; -class AmpGraphCreator : public ExprMutator { +class AmpGraphCreator : public MixedModeMutator { private: CachedCastNodes cast_nodes_cache; const ColorFunc colorer; @@ -228,9 +228,11 @@ class AmpGraphCreator : public ExprMutator { } public: + using MixedModeMutator::VisitExpr_; + explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func, DataType mixed_precision_type = DataType::Float(16)) - : ExprMutator(), + : MixedModeMutator(), colorer(colorer), output_dtype_func(output_dtype_func), mixed_precision_type(mixed_precision_type) { @@ -239,32 +241,32 @@ class AmpGraphCreator : public ExprMutator { << mixed_precision_type; } - Expr VisitExpr_(const CallNode* call_node) { - MixedTypeConversionCategory initial_category = colorer(call_node); - auto new_op = this->Mutate(call_node->op); + Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + MixedTypeConversionCategory initial_category = colorer(pre_call_node); + const CallNode* post_call_node = post.as(); + if (!post_call_node) { + LOG(FATAL) << "Expected a CallNode for the rewrite got " << post; + } - // Mutate arguments to reduced precision form first if possible and keep track of - // whether all floating point tensors are in reduced precision form already. This is - // useful for propagating conversion conditions. - Array new_args; - Array new_arg_types; + Expr cur_op = post_call_node->op; + + // First check if all the new mutated args are in lower precision form + Array cur_arg_types; bool all_args_mixed_type_compatible = true; - for (Expr arg : call_node->args) { - Expr new_arg = this->Mutate(arg); - Type new_arg_type = GetType(new_arg); - new_args.push_back(new_arg); - new_arg_types.push_back(new_arg_type); + for (Expr arg : post_call_node->args) { + Type cur_arg_type = GetType(arg); + cur_arg_types.push_back(cur_arg_type); if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) { // We can cast Vars and Constants to the right types so don't care about the types. - bool is_mixed_type_compatible = IsMixedPrecisionType(new_arg_type, true) || + bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) || arg->IsInstance() || arg->IsInstance(); all_args_mixed_type_compatible &= is_mixed_type_compatible; } } - // Determine the final category for conversion. + // Determine the final category we want for conversion MixedTypeConversionCategory final_category; if (initial_category == MIXED_PRECISION_FOLLOW) { final_category = @@ -276,32 +278,30 @@ class AmpGraphCreator : public ExprMutator { // Create the new arguments to the call. DataType wanted_arg_dtypes = final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32); + auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); + Array new_args = call_args_and_types.first; + Array new_arg_types; - auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes); - - Array call_args = call_args_and_types.first; - Array call_arg_types; - - if (call_node->op.as()) { + if (pre_call_node->op.as()) { // Function Nodes don't store type info in the Call, it should be a [] - call_arg_types = call_node->type_args; + new_arg_types = pre_call_node->type_args; } else { - call_arg_types = call_args_and_types.second; + new_arg_types = call_args_and_types.second; } // Finally create the new attributes. if (final_category == MIXED_PRECISION_ALWAYS) { - MixedPrecisionOpOutDType output_dtypes = output_dtype_func(call_node); + MixedPrecisionOpOutDType output_dtypes = output_dtype_func(pre_call_node); - Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype); - Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span); + Attrs new_attrs = GetNewAttrs(pre_call_node, output_dtypes.accumulation_dtype); + Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { output = CastArg(output, GetType(output), output_dtypes.output_dtype); } return output; } - return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span); + return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); } Expr VisitExpr_(const FunctionNode* func) final { From 418f873a6ae46f19ff9aedf28c2c43b901a93cc1 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 10 Jun 2021 12:09:11 -0700 Subject: [PATCH 28/59] rename things ToMixedPrecision --- python/tvm/relay/transform/transform.py | 4 ++-- src/relay/transforms/fp32_to_fp16.cc | 12 ++++++------ tests/python/relay/test_fp32_to_fp16_transform.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7f81ac341541..7767cb898ef1 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1200,7 +1200,7 @@ def FakeQuantizationToInteger(): return _ffi_api.FakeQuantizationToInteger() -def AMPRewrite(mixed_precision_type="float16"): +def ToMixedPrecision(mixed_precision_type="float16"): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in FP16. @@ -1214,4 +1214,4 @@ def AMPRewrite(mixed_precision_type="float16"): ret : tvm.transform.Pass The registered RewriteFP16 pass. """ - return _ffi_api.AMPRewrite(mixed_precision_type) + return _ffi_api.ToMixedPrecision(mixed_precision_type) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/fp32_to_fp16.cc index 77338aeadf7b..fe37d4a57f07 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/fp32_to_fp16.cc @@ -57,7 +57,7 @@ using ColorFunc = std::function; // A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes using OutputDtypeFunc = std::function; -class AmpGraphCreator : public MixedModeMutator { +class AMPGraphCreator : public MixedModeMutator { private: CachedCastNodes cast_nodes_cache; const ColorFunc colorer; @@ -230,7 +230,7 @@ class AmpGraphCreator : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func, + explicit AMPGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func, DataType mixed_precision_type = DataType::Float(16)) : MixedModeMutator(), colorer(colorer), @@ -329,14 +329,14 @@ class AmpGraphCreator : public MixedModeMutator { Expr AMPRewriteGraph(const Expr& expr, const ColorFunc& colorer, const OutputDtypeFunc& output_dtype_func, const DataType& mixed_precision_type) { - AmpGraphCreator converter = AmpGraphCreator(colorer, output_dtype_func, mixed_precision_type); + AMPGraphCreator converter = AMPGraphCreator(colorer, output_dtype_func, mixed_precision_type); auto result = converter.Mutate(expr); return result; } namespace transform { -Pass AMPRewrite(DataType mixed_precision_type) { +Pass ToMixedPrecision(DataType mixed_precision_type) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(AMPRewriteGraph( @@ -344,10 +344,10 @@ Pass AMPRewrite(DataType mixed_precision_type) { DefaultMixedPrecisionOpDefinition(mixed_precision_type, DataType::Float(32)), mixed_precision_type)); }; - return CreateFunctionPass(pass_func, 10, "AMPRewrite", {}); + return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {}); } -TVM_REGISTER_GLOBAL("relay._transform.AMPRewrite").set_body_typed(AMPRewrite); +TVM_REGISTER_GLOBAL("relay._transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); } // namespace transform diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_fp32_to_fp16_transform.py index cfbc64ca9cf5..94f54ded55e4 100644 --- a/tests/python/relay/test_fp32_to_fp16_transform.py +++ b/tests/python/relay/test_fp32_to_fp16_transform.py @@ -22,7 +22,7 @@ import tvm from tvm import relay from tvm.relay.testing import lstm -from tvm.relay.transform import AMPRewrite, InferType +from tvm.relay.transform import InferType, ToMixedPrecision def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: @@ -45,7 +45,7 @@ def verify_mixed_precision_output_close( ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) - fp16_mod = AMPRewrite(mixed_precision_dtype)(mod) + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close From 7d62fe1e9814461ccec088e7693323e7a0415819 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 10 Jun 2021 12:17:52 -0700 Subject: [PATCH 29/59] rename passes to amp.cc --- src/relay/transforms/{fp32_to_fp16.cc => amp.cc} | 6 +++--- src/relay/transforms/{fp32_to_fp16.h => amp.h} | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) rename src/relay/transforms/{fp32_to_fp16.cc => amp.cc} (99%) rename src/relay/transforms/{fp32_to_fp16.h => amp.h} (97%) diff --git a/src/relay/transforms/fp32_to_fp16.cc b/src/relay/transforms/amp.cc similarity index 99% rename from src/relay/transforms/fp32_to_fp16.cc rename to src/relay/transforms/amp.cc index fe37d4a57f07..6cc9e241b610 100644 --- a/src/relay/transforms/fp32_to_fp16.cc +++ b/src/relay/transforms/amp.cc @@ -19,10 +19,10 @@ /*! * - * \file fp32_to_fp16.cc - * \brief Rewrite a graph into an fp16 form. + * \file amp.cc + * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form. */ -#include "fp32_to_fp16.h" +#include "amp.h" #include #include diff --git a/src/relay/transforms/fp32_to_fp16.h b/src/relay/transforms/amp.h similarity index 97% rename from src/relay/transforms/fp32_to_fp16.h rename to src/relay/transforms/amp.h index 3f4ccdc8d05f..ae3c81facdea 100644 --- a/src/relay/transforms/fp32_to_fp16.h +++ b/src/relay/transforms/amp.h @@ -18,11 +18,11 @@ */ /*! - * \file fp32_to_fp16.h - * \brief Utilities and common types used for FP32->FP16 pass. + * \file amp.h + * \brief Utilities and common types used for automatic mixed precision pass. */ -#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ -#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ +#ifndef TVM_RELAY_TRANSFORMS_AMP_H_ +#define TVM_RELAY_TRANSFORMS_AMP_H_ #include #include @@ -247,4 +247,4 @@ class DefaultMixedPrecisionOpDefinition { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ +#endif // TVM_RELAY_TRANSFORMS_AMP_H_ From b4ebd066239227c797b7217749291cc1c0d3bdcf Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 10 Jun 2021 12:25:51 -0700 Subject: [PATCH 30/59] rename tests to match transform --- ...{test_fp32_to_fp16_transform.py => test_to_mixed_precision.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/relay/{test_fp32_to_fp16_transform.py => test_to_mixed_precision.py} (100%) diff --git a/tests/python/relay/test_fp32_to_fp16_transform.py b/tests/python/relay/test_to_mixed_precision.py similarity index 100% rename from tests/python/relay/test_fp32_to_fp16_transform.py rename to tests/python/relay/test_to_mixed_precision.py From 8968cdab22fbce732c42105b13b8087608e1433a Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 10 Jun 2021 12:26:47 -0700 Subject: [PATCH 31/59] clean up typos --- tests/python/relay/test_to_mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 94f54ded55e4..79fe605956bb 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Unit tests for testing FP32 -> FP16 pass""" +"""Unit tests for testing AMP pass""" from typing import Any, Dict, List import numpy as np From 180b556840ae26cb31dd7a64efb2e52178c3ec0f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 10 Jun 2021 12:38:35 -0700 Subject: [PATCH 32/59] rename even better to_mixed_precision --- .../{amp.cc => to_mixed_precision.cc} | 21 ++++++++++--------- .../{amp.h => to_mixed_precision.h} | 8 +++---- 2 files changed, 15 insertions(+), 14 deletions(-) rename src/relay/transforms/{amp.cc => to_mixed_precision.cc} (95%) rename src/relay/transforms/{amp.h => to_mixed_precision.h} (97%) diff --git a/src/relay/transforms/amp.cc b/src/relay/transforms/to_mixed_precision.cc similarity index 95% rename from src/relay/transforms/amp.cc rename to src/relay/transforms/to_mixed_precision.cc index 6cc9e241b610..56add99db006 100644 --- a/src/relay/transforms/amp.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -19,10 +19,10 @@ /*! * - * \file amp.cc + * \file to_mixed_precision.cc * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form. */ -#include "amp.h" +#include "to_mixed_precision.h" #include #include @@ -57,7 +57,7 @@ using ColorFunc = std::function; // A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes using OutputDtypeFunc = std::function; -class AMPGraphCreator : public MixedModeMutator { +class MixedPrecisionPass : public MixedModeMutator { private: CachedCastNodes cast_nodes_cache; const ColorFunc colorer; @@ -230,8 +230,8 @@ class AMPGraphCreator : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit AMPGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func, - DataType mixed_precision_type = DataType::Float(16)) + explicit MixedPrecisionPass(ColorFunc colorer, OutputDtypeFunc output_dtype_func, + DataType mixed_precision_type = DataType::Float(16)) : MixedModeMutator(), colorer(colorer), output_dtype_func(output_dtype_func), @@ -326,10 +326,11 @@ class AMPGraphCreator : public MixedModeMutator { } }; -Expr AMPRewriteGraph(const Expr& expr, const ColorFunc& colorer, - const OutputDtypeFunc& output_dtype_func, - const DataType& mixed_precision_type) { - AMPGraphCreator converter = AMPGraphCreator(colorer, output_dtype_func, mixed_precision_type); +Expr ToMixedPrecision(const Expr& expr, const ColorFunc& colorer, + const OutputDtypeFunc& output_dtype_func, + const DataType& mixed_precision_type) { + MixedPrecisionPass converter = + MixedPrecisionPass(colorer, output_dtype_func, mixed_precision_type); auto result = converter.Mutate(expr); return result; } @@ -339,7 +340,7 @@ namespace transform { Pass ToMixedPrecision(DataType mixed_precision_type) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(AMPRewriteGraph( + return Downcast(ToMixedPrecision( f, DefaultMixedPrecisionColorer(), DefaultMixedPrecisionOpDefinition(mixed_precision_type, DataType::Float(32)), mixed_precision_type)); diff --git a/src/relay/transforms/amp.h b/src/relay/transforms/to_mixed_precision.h similarity index 97% rename from src/relay/transforms/amp.h rename to src/relay/transforms/to_mixed_precision.h index ae3c81facdea..364a304e8252 100644 --- a/src/relay/transforms/amp.h +++ b/src/relay/transforms/to_mixed_precision.h @@ -18,11 +18,11 @@ */ /*! - * \file amp.h + * \file to_mixed_precision.h * \brief Utilities and common types used for automatic mixed precision pass. */ -#ifndef TVM_RELAY_TRANSFORMS_AMP_H_ -#define TVM_RELAY_TRANSFORMS_AMP_H_ +#ifndef TVM_RELAY_TRANSFORMS_TO_MIXED_PRECISION_H_ +#define TVM_RELAY_TRANSFORMS_TO_MIXED_PRECISION_H_ #include #include @@ -247,4 +247,4 @@ class DefaultMixedPrecisionOpDefinition { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TRANSFORMS_AMP_H_ +#endif // TVM_RELAY_TRANSFORMS_TO_MIXED_PRECISION_H_ From 528ef7be724c684ed10c9883b03b25e50dc8d53e Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 14 Jun 2021 10:28:46 -0700 Subject: [PATCH 33/59] don't insert into cache when dtypes equal --- src/relay/transforms/to_mixed_precision.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 56add99db006..22b50b41971d 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -170,6 +170,10 @@ class MixedPrecisionPass : public MixedModeMutator { return expr; } + if (expr_dtype == wanted_dtype) { + return expr; + } + const ExprNode* expr_node = expr.as(); if (!expr_node) { LOG(FATAL) << "Non-expression node found in cast: " << expr; @@ -181,7 +185,7 @@ class MixedPrecisionPass : public MixedModeMutator { return search->second; } - Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype); + Expr result = Cast(expr, wanted_dtype); cast_nodes_cache[{expr_node, wanted_dtype}] = result; // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node From 5ca14629398c6c8e6616910076f2a7d5564217fd Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 14 Jun 2021 16:43:36 -0700 Subject: [PATCH 34/59] new python interface for registering ops --- python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/op.py | 12 +- python/tvm/relay/transform/mixed_precision.py | 169 ++++++++++++ python/tvm/relay/transform/transform.py | 6 +- src/relay/transforms/to_mixed_precision.cc | 104 ++++++-- src/relay/transforms/to_mixed_precision.h | 250 ------------------ tests/python/relay/test_to_mixed_precision.py | 6 +- 7 files changed, 262 insertions(+), 286 deletions(-) create mode 100644 python/tvm/relay/transform/mixed_precision.py delete mode 100644 src/relay/transforms/to_mixed_precision.h diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 4c693fe64ee0..2e509a111c4a 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -29,6 +29,7 @@ debug, register_external_compiler, register_fake_quantization_to_integer, + register_mixed_precision_conversion, ) from . import strategy diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index ccf011819a97..a5045a0bfca4 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -18,10 +18,11 @@ """The base node types for the Relay language.""" import tvm._ffi import tvm.ir -from tvm.driver import lower, build -from tvm.target import get_native_generic_func, GenericFunc -from tvm.runtime import Object import tvm.ir._ffi_api +from tvm.driver import build, lower +from tvm.runtime import Object +from tvm.target import GenericFunc, get_native_generic_func + from . import _make @@ -457,6 +458,11 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) +def register_mixed_precision_conversion(op_name, func=None, level=10): + """TODO""" + return tvm.ir.register_op_attr(op_name, "FTVMMixedPrecisionConversionType", func, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py new file mode 100644 index 000000000000..42cfbf098a1b --- /dev/null +++ b/python/tvm/relay/transform/mixed_precision.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TODO""" +import tvm +from tvm import relay + +from ..op import register_mixed_precision_conversion + +# Conversion types +MIXED_PRECISION_ALWAYS = 0 +MIXED_PRECISION_FOLLOW = 1 +MIXED_PRECISION_NEVER = 2 + + +# Functions for FTVMMixedPrecisionConversionType which +# Take in CallNodes and a DType and returns a conversion type, +# an accumulation dtype, and an output_dtype. +def get_generic_dtypes(call_node, mixed_precision_type): + # TODO: examine attributes + if hasattr(call_node.attrs, "out_dtype"): + return ["float32", mixed_precision_type] + + return [mixed_precision_type, mixed_precision_type] + + +def generic_always_op(call_node, mixed_precision_type): + return [MIXED_PRECISION_ALWAYS] + get_generic_dtypes(call_node, mixed_precision_type) + + +def generic_follow_op(call_node, mixed_precision_type): + return [MIXED_PRECISION_FOLLOW] + get_generic_dtypes(call_node, mixed_precision_type) + + +def generic_never_op(call_node, mixed_precision_type): + return [MIXED_PRECISION_NEVER] + get_generic_dtypes(call_node, mixed_precision_type) + + +# Default lists inspired from TF's classifications: +# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. +DEFAULT_ALWAYS_LIST = [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + # "nn.batch_matmul", # Handled by a special case +] +DEFAULT_FOLLOW_LIST = [ + # These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + "concatenate", + "zeros", + "split", + "squeeze", + "transpose", + "expand_dims", + "reshape", + "dyn.reshape", + "broadcast_to_like", + "dyn.broadcast_to", + "strided_slice", + "dyn.strided_slice", + "take", + "argwhere", + "where", + "tile", + "dyn.tile", + "scatter", + "full", + "dyn.full", + # Comparison + "less", + "greater", + "less_equal", + "greater_equal", + # By definition copy and cast will depend on inputs for output. + "copy", + "cast", + "cast_like", + # Simple arithmetic + "add", + "subtract", + "multiply", + "divide", + "nn.bias_add", + "nn.batch_norm", + "sum", + "mean", + "sqrt", + "shape_of", + # Simple activations + "max", + "min", + "maximum", + "minimum", + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + # Complicated activations which saturate in a narrow range + "sigmoid", + "tanh", + # Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + # "nn.global_max_pool1d", # does not exist yet + "nn.global_max_pool2d", + # "nn.global_max_pool3d", # does not exist yet + # "nn.global_avg_pool1d", # does not exist yet + "nn.global_avg_pool2d", + # "nn.global_avg_pool3d", # does not exist yet + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", +] +DEFAULT_NEVER_LIST = [ + # In general if |f(x)| >> |x| for expected inputs then put the op here. + "exp", + "power", + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + "nn.l2_normalize", + # Error function doesn't seem to be able to be lowered into fp16 version in llvm. + # Move to follow list when it does. + "erf", +] + + +def register_default_mixed_precision_attributes(): + for list_of_ops, func in zip( + [DEFAULT_ALWAYS_LIST, DEFAULT_FOLLOW_LIST, DEFAULT_NEVER_LIST], + [generic_always_op, generic_follow_op, generic_never_op], + ): + for op_name in list_of_ops: + register_mixed_precision_conversion(op_name, func=func) + + @register_mixed_precision_conversion("nn.batch_matmul") + def nn_batch_matmul(call_node, mixed_precision_type): + # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. + # Batched matmul has inconsistent support for mixed precision operations. + # Many schedules ignore the out_dtype attribute which leads to errors when + # input types do not match the out_dtype. Therefore, accumulate to output_dtype. + return [MIXED_PRECISION_ALWAYS, "float16", "float16"] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7767cb898ef1..c2e2aaf439b0 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1200,7 +1200,9 @@ def FakeQuantizationToInteger(): return _ffi_api.FakeQuantizationToInteger() -def ToMixedPrecision(mixed_precision_type="float16"): +def ToMixedPrecision( + mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True +): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in FP16. @@ -1214,4 +1216,4 @@ def ToMixedPrecision(mixed_precision_type="float16"): ret : tvm.transform.Pass The registered RewriteFP16 pass. """ - return _ffi_api.ToMixedPrecision(mixed_precision_type) + return _ffi_api.ToMixedPrecision(mixed_precision_type, ignore_missing_ops, warn_missing_ops) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 22b50b41971d..23b59d463ad3 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -20,9 +20,9 @@ /*! * * \file to_mixed_precision.cc - * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form. + * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16. + * */ -#include "to_mixed_precision.h" #include #include @@ -48,22 +48,40 @@ struct pair_hash { } }; +// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +// numerical reasons. +enum MixedTypeConversionCategory : int { + MIXED_PRECISION_ALWAYS = 0, + MIXED_PRECISION_FOLLOW = 1, + MIXED_PRECISION_NEVER = 2 +}; + // A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype using CachedCastNodes = std::unordered_map, Expr, pair_hash>; -// A function which maps CallNodes to their initial conversion color -using ColorFunc = std::function; - -// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes -using OutputDtypeFunc = std::function; +// Return array is of type : [MixedTypeConversionCategory (int), String, String] +// The fields are : [ConversionCategory, accumulation_datatype, output_datatype] +// Call is a call node, DataType is the mixed precision type +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( + const Call& call_node, const std::string& target_dtype_str)>; class MixedPrecisionPass : public MixedModeMutator { private: CachedCastNodes cast_nodes_cache; - const ColorFunc colorer; - const OutputDtypeFunc output_dtype_func; + + // The target datatype we want to convert to e.g. FP16 const DataType mixed_precision_type; + // If false, throws a fatal error if an op which is not registered with a + // FTVMMixedPrecisionConversionType is encountered. + bool ignore_missing_ops; + + // If true, emits a warning if an op which is not registered with a + // FTVMMixedPrecisionConversionType is encountered. + bool warn_missing_ops; + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ Attrs new_attrs = Attrs(call->attrs); @@ -234,19 +252,18 @@ class MixedPrecisionPass : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit MixedPrecisionPass(ColorFunc colorer, OutputDtypeFunc output_dtype_func, - DataType mixed_precision_type = DataType::Float(16)) + explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16), + bool ignore_missing_ops = true, bool warn_missing_ops = true) : MixedModeMutator(), - colorer(colorer), - output_dtype_func(output_dtype_func), - mixed_precision_type(mixed_precision_type) { + mixed_precision_type(mixed_precision_type), + ignore_missing_ops(ignore_missing_ops), + warn_missing_ops(warn_missing_ops) { if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got " << mixed_precision_type; } Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { - MixedTypeConversionCategory initial_category = colorer(pre_call_node); const CallNode* post_call_node = post.as(); if (!post_call_node) { LOG(FATAL) << "Expected a CallNode for the rewrite got " << post; @@ -254,6 +271,39 @@ class MixedPrecisionPass : public MixedModeMutator { Expr cur_op = post_call_node->op; + // Results are: conversion category (int), accumulation dtype (str), output dtype (str) + MixedTypeConversionCategory initial_category; + DataType accumulation_dtype, output_dtype; + if (cur_op.as()) { + // Avoid messing with functions to avoid changing signature + initial_category = MIXED_PRECISION_NEVER; + accumulation_dtype = DataType::Float(32); + output_dtype = DataType::Float(32); + } else if (cur_op.as()) { + static auto attr_map = + Op::GetAttrMap("FTVMMixedPrecisionConversionType"); + Op op = Downcast(cur_op); + if (attr_map.count(op)) { + FTVMMixedPrecisionConversionType func = attr_map[op]; + Array op_descriptor = + func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type)); + + int64_t op_conversion_type = Downcast(op_descriptor[0])->value; + initial_category = static_cast(op_conversion_type); + accumulation_dtype = DataType(String2DLDataType(Downcast(op_descriptor[1]))); + output_dtype = DataType(String2DLDataType(Downcast(op_descriptor[2]))); + } else { + if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!"; + if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!"; + + initial_category = MIXED_PRECISION_FOLLOW; + accumulation_dtype = DataType::Float(16); + output_dtype = DataType::Float(16); + } + } else { + LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op; + } + // First check if all the new mutated args are in lower precision form Array cur_arg_types; bool all_args_mixed_type_compatible = true; @@ -295,12 +345,10 @@ class MixedPrecisionPass : public MixedModeMutator { // Finally create the new attributes. if (final_category == MIXED_PRECISION_ALWAYS) { - MixedPrecisionOpOutDType output_dtypes = output_dtype_func(pre_call_node); - - Attrs new_attrs = GetNewAttrs(pre_call_node, output_dtypes.accumulation_dtype); + Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype); Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); - if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { - output = CastArg(output, GetType(output), output_dtypes.output_dtype); + if (accumulation_dtype != output_dtype) { + output = CastArg(output, GetType(output), output_dtype); } return output; } @@ -330,24 +378,22 @@ class MixedPrecisionPass : public MixedModeMutator { } }; -Expr ToMixedPrecision(const Expr& expr, const ColorFunc& colorer, - const OutputDtypeFunc& output_dtype_func, - const DataType& mixed_precision_type) { +Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, + bool ignore_missing_ops, bool warn_missing_ops) { MixedPrecisionPass converter = - MixedPrecisionPass(colorer, output_dtype_func, mixed_precision_type); + MixedPrecisionPass(mixed_precision_type, ignore_missing_ops, warn_missing_ops); auto result = converter.Mutate(expr); return result; } namespace transform { -Pass ToMixedPrecision(DataType mixed_precision_type) { +Pass ToMixedPrecision(DataType mixed_precision_type, bool ignore_missing_ops, + bool warn_missing_ops) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToMixedPrecision( - f, DefaultMixedPrecisionColorer(), - DefaultMixedPrecisionOpDefinition(mixed_precision_type, DataType::Float(32)), - mixed_precision_type)); + return Downcast( + ToMixedPrecision(f, mixed_precision_type, ignore_missing_ops, warn_missing_ops)); }; return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {}); } diff --git a/src/relay/transforms/to_mixed_precision.h b/src/relay/transforms/to_mixed_precision.h deleted file mode 100644 index 364a304e8252..000000000000 --- a/src/relay/transforms/to_mixed_precision.h +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file to_mixed_precision.h - * \brief Utilities and common types used for automatic mixed precision pass. - */ -#ifndef TVM_RELAY_TRANSFORMS_TO_MIXED_PRECISION_H_ -#define TVM_RELAY_TRANSFORMS_TO_MIXED_PRECISION_H_ - -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -struct MixedPrecisionOpOutDType { - DataType accumulation_dtype; - DataType output_dtype; -}; - -// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory -// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to -// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to -// numerical reasons. -enum MixedTypeConversionCategory { - MIXED_PRECISION_ALWAYS, - MIXED_PRECISION_FOLLOW, - MIXED_PRECISION_NEVER -}; - -using OpStringSet = std::unordered_set; - -// Default lists inspired from TF's classifications: -// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h -// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. -OpStringSet DEFAULT_ALWAYS_LIST({ - "nn.conv1d", - "nn.conv2d", - "nn.conv3d", - "nn.conv1d_transpose", - "nn.conv2d_transpose", - "nn.conv3d_transpose", - "nn.dense", - "nn.batch_matmul", -}); -OpStringSet DEFAULT_FOLLOW_LIST({ - // These ops add new data or change shape - "nn.pad", - "nn.batch_flatten", - "concatenate", - "zeros", - "split", - "squeeze", - "transpose", - "expand_dims", - "reshape", - "dyn.reshape", - "broadcast_to_like", - "dyn.broadcast_to", - "strided_slice", - "dyn.strided_slice", - "take", - "argwhere", - "where", - "tile", - "dyn.tile", - "scatter", - "full", - "dyn.full", - // Comparison - "less", - "greater", - "less_equal", - "greater_equal", - // By definition copy and cast will depend on inputs for output. - "copy", - "cast", - "cast_like", - // Simple arithmetic - "add", - "subtract", - "multiply", - "divide", - "nn.bias_add", - "nn.batch_norm", - "sum", - "mean", - "sqrt", - "shape_of", - // Simple activations - "max", - "min", - "maximum", - "minimum", - "nn.relu", - "nn.leaky_relu", - "nn.prelu", - "nn.dropout", - // Complicated activations which saturate in a narrow range - "sigmoid", - "tanh", - // Pooling operations - "nn.max_pool1d", - "nn.max_pool2d", - "nn.max_pool3d", - "nn.avg_pool1d", - "nn.avg_pool2d", - "nn.avg_pool3d", - // "nn.global_max_pool1d", // does not exist yet - "nn.global_max_pool2d", - // "nn.global_max_pool3d", // does not exist yet - // "nn.global_avg_pool1d", // does not exist yet - "nn.global_avg_pool2d", - // "nn.global_avg_pool3d", // does not exist yet - "nn.adaptive_max_pool1d", - "nn.adaptive_max_pool2d", - "nn.adaptive_max_pool3d", - "nn.adaptive_avg_pool1d", - "nn.adaptive_avg_pool2d", - "nn.adaptive_avg_pool3d", -}); -OpStringSet DEFAULT_NEVER_LIST({ - // In general if |f(x)| >> |x| for expected inputs then put the op here. - "exp", - "power", - "nn.cross_entropy", - "nn.cross_entropy_with_logits", - "nn.softmax", - "nn.l2_normalize", - // Error function doesn't seem to be able to be lowered into fp16 version in llvm. - // Move to follow list when it does. - "erf", -}); - -class DefaultMixedPrecisionColorer { - /* The default class to initially color ops for conversion using lists. - Default lists are for NVidia Tensor Cores and FP16. - - Creates a callable which given a CallNode* returns the node's color. - */ - private: - std::unordered_map op_to_initial_color; - - public: - DefaultMixedPrecisionColorer(OpStringSet never_list = DEFAULT_NEVER_LIST, - OpStringSet follow_list = DEFAULT_FOLLOW_LIST, - OpStringSet always_list = DEFAULT_ALWAYS_LIST) { - std::vector> lists_and_colors{ - {never_list, MIXED_PRECISION_NEVER}, - {follow_list, MIXED_PRECISION_FOLLOW}, - {always_list, MIXED_PRECISION_ALWAYS}}; - - for (auto list_and_color : lists_and_colors) { - OpStringSet ops = list_and_color.first; - MixedTypeConversionCategory color = list_and_color.second; - for (std::string op_name : ops) { - op_to_initial_color.insert({{op_name, color}}); - } - } - } - - MixedTypeConversionCategory operator()(const CallNode* call, bool ignore_missing = true) { - if (auto* op_node = (call->op).as()) { - std::string op_name = op_node->name; - auto color = op_to_initial_color.find(op_name); - - if (color == op_to_initial_color.end()) { - (ignore_missing ? LOG(WARNING) : LOG(FATAL)) - << "Op name " << op_name << " not in included in conversion lists!"; - return MIXED_PRECISION_NEVER; - } - - return color->second; - } else if ((call->op).as()) { - // Make MIXED_PRECISION_NEVER to avoid messing with function headers. - return MIXED_PRECISION_NEVER; - } else { - LOG(FATAL) << "Conversion only supports call nodes with OpNodes or Functions got " - << call->op; - return MIXED_PRECISION_NEVER; - } - } -}; - -class DefaultMixedPrecisionOpDefinition { - /* The default callable for determining accumulation_dtypes for ops. - - Assumes accumulatable operations accumulate to one type and outputs are - all of the same type.*/ - - const DataType default_output_dtype; - const DataType default_accumulation_dtype; - - public: - DefaultMixedPrecisionOpDefinition(DataType default_output_dtype = DataType::Float(16), - DataType default_accumulation_dtype = DataType::Float(32)) - : default_output_dtype(default_output_dtype), - default_accumulation_dtype(default_accumulation_dtype) {} - - MixedPrecisionOpOutDType operator()(const CallNode* call) { - // TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. - // Batched matmul has inconsistent support for mixed precision operations. - // Many schedules ignore the out_dtype attribute which leads to errors when - // input types do not match the out_dtype. Therefore, accumulate to output_dtype. - if (auto op_node = call->op.as()) { - if (op_node->name == "nn.batch_matmul") { - return {default_output_dtype, default_output_dtype}; - } - } - - // We assume the "out_dtype" field is always an accumulation dtype specification. - if (call->attrs != NullValue()) { - Array fields = call->attrs->ListFieldInfo(); - for (AttrFieldInfo field_info : fields) { - if (field_info->name == "out_dtype") - return {default_accumulation_dtype, default_output_dtype}; - } - } - - return {default_output_dtype, default_output_dtype}; - } -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_TRANSFORMS_TO_MIXED_PRECISION_H_ diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 79fe605956bb..632df870f2ba 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -22,7 +22,9 @@ import tvm from tvm import relay from tvm.relay.testing import lstm -from tvm.relay.transform import InferType, ToMixedPrecision +from tvm.relay.transform import InferType, ToMixedPrecision, mixed_precision + +mixed_precision.register_default_mixed_precision_attributes() def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: @@ -43,11 +45,11 @@ def verify_mixed_precision_output_close( rtol: float = 1e-3, atol: float = 0, ) -> tvm.runtime.Module: + mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) - # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) From 9e77cff27f88e02a84dd089007efd802a8fd6272 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 00:17:53 -0700 Subject: [PATCH 35/59] cleaner registering ops --- python/tvm/relay/transform/mixed_precision.py | 82 +++++++++---------- tests/python/relay/test_to_mixed_precision.py | 2 - 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 42cfbf098a1b..d05f5030ecaa 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -15,40 +15,13 @@ # specific language governing permissions and limitations # under the License. """TODO""" -import tvm -from tvm import relay - -from ..op import register_mixed_precision_conversion +from tvm.relay.op import register_mixed_precision_conversion # Conversion types MIXED_PRECISION_ALWAYS = 0 MIXED_PRECISION_FOLLOW = 1 MIXED_PRECISION_NEVER = 2 - -# Functions for FTVMMixedPrecisionConversionType which -# Take in CallNodes and a DType and returns a conversion type, -# an accumulation dtype, and an output_dtype. -def get_generic_dtypes(call_node, mixed_precision_type): - # TODO: examine attributes - if hasattr(call_node.attrs, "out_dtype"): - return ["float32", mixed_precision_type] - - return [mixed_precision_type, mixed_precision_type] - - -def generic_always_op(call_node, mixed_precision_type): - return [MIXED_PRECISION_ALWAYS] + get_generic_dtypes(call_node, mixed_precision_type) - - -def generic_follow_op(call_node, mixed_precision_type): - return [MIXED_PRECISION_FOLLOW] + get_generic_dtypes(call_node, mixed_precision_type) - - -def generic_never_op(call_node, mixed_precision_type): - return [MIXED_PRECISION_NEVER] + get_generic_dtypes(call_node, mixed_precision_type) - - # Default lists inspired from TF's classifications: # github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h # They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. @@ -152,18 +125,45 @@ def generic_never_op(call_node, mixed_precision_type): ] -def register_default_mixed_precision_attributes(): - for list_of_ops, func in zip( - [DEFAULT_ALWAYS_LIST, DEFAULT_FOLLOW_LIST, DEFAULT_NEVER_LIST], - [generic_always_op, generic_follow_op, generic_never_op], - ): - for op_name in list_of_ops: +# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType +def register_func_to_op_list(list_ops=[]): + def decorator(func): + for op_name in list_ops: register_mixed_precision_conversion(op_name, func=func) - @register_mixed_precision_conversion("nn.batch_matmul") - def nn_batch_matmul(call_node, mixed_precision_type): - # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. - # Batched matmul has inconsistent support for mixed precision operations. - # Many schedules ignore the out_dtype attribute which leads to errors when - # input types do not match the out_dtype. Therefore, accumulate to output_dtype. - return [MIXED_PRECISION_ALWAYS, "float16", "float16"] + return decorator + + +# Functions for FTVMMixedPrecisionConversionType which +# Take in CallNodes and a DType and returns a conversion type, +# an accumulation dtype, and an output_dtype. +def get_generic_dtypes(call_node, mixed_precision_type): + # TODO: examine attributes + if hasattr(call_node.attrs, "out_dtype"): + return ["float32", mixed_precision_type] + + return [mixed_precision_type, mixed_precision_type] + + +@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) +def generic_always_op(call_node, mixed_precision_type): + return [MIXED_PRECISION_ALWAYS] + get_generic_dtypes(call_node, mixed_precision_type) + + +@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) +def generic_follow_op(call_node, mixed_precision_type): + return [MIXED_PRECISION_FOLLOW] + get_generic_dtypes(call_node, mixed_precision_type) + + +@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) +def generic_never_op(call_node, mixed_precision_type): + return [MIXED_PRECISION_NEVER] + get_generic_dtypes(call_node, mixed_precision_type) + + +@register_mixed_precision_conversion("nn.batch_matmul") +def nn_batch_matmul(call_node, mixed_precision_type): + # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. + # Batched matmul has inconsistent support for mixed precision operations. + # Many schedules ignore the out_dtype attribute which leads to errors when + # input types do not match the out_dtype. Therefore, accumulate to output_dtype. + return [MIXED_PRECISION_ALWAYS, "float16", "float16"] diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 632df870f2ba..8db7303a6764 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -24,8 +24,6 @@ from tvm.relay.testing import lstm from tvm.relay.transform import InferType, ToMixedPrecision, mixed_precision -mixed_precision.register_default_mixed_precision_attributes() - def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: dev = tvm.device("llvm", 0) From e691e4f90699c55fa3c55c37e9eeafabe5342e8b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 10:41:42 -0700 Subject: [PATCH 36/59] add fp64 structural test --- tests/python/relay/test_to_mixed_precision.py | 43 ++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 8db7303a6764..caccd52d60c2 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Unit tests for testing AMP pass""" +"""Unit tests for testing ToMixedPrecision pass""" from typing import Any, Dict, List import numpy as np @@ -79,7 +79,8 @@ def test_lstm_float64(): As a toy example show can convert graph to float64 and have it run. - It doesn't really make sense to do it, this just shows you can. + It doesn't really make sense to do it, this just shows we can change + the target mixed_precision_dtype. """ units = 3 iterations = 5 @@ -134,6 +135,44 @@ def test_convert_single_conv(): assert tvm.ir.structural_equal(fp16_mod, expected_mod) +def test_convert_single_conv_fp64(): + """As above but checks choosing a mixed_precision_type other than FP16 works""" + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype="float64", atol=0.01, rtol=1e-3 + ) + + # Note we still accumulate to FP32 by default, a user would need to overwrite default + # behavior to make this make more sense. + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float64"), + relay.cast(weight, "float64"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float64", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + def test_convert_conv_bn(): """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" data_shape = (1, 3, 32, 32) From 37200fdd732db14047fb1558ada933c5a2744b63 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 10:42:23 -0700 Subject: [PATCH 37/59] clean up and comments --- python/tvm/relay/op/op.py | 20 ++++++++++- python/tvm/relay/transform/mixed_precision.py | 35 +++++++++++-------- python/tvm/relay/transform/transform.py | 2 +- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index a5045a0bfca4..f6cd6fb53e0e 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -459,7 +459,25 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10): def register_mixed_precision_conversion(op_name, func=None, level=10): - """TODO""" + """Register mixed precision conversion function for an op + + Given an op the function should return information on how the value should be + converted. Specifically the function should take a call node and the target + mixed precision datatype (e.g. FP16) and return the conversion category + (see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation + and output datatype of the oepration. + + Parameters + ---------- + op_name : str + The name of the operator + + func: function (expr: Expr, map: Map) -> new_expr: Expr + The function for translating the op into affine space and integer operators + + level : int + The priority level + """ return tvm.ir.register_op_attr(op_name, "FTVMMixedPrecisionConversionType", func, level) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index d05f5030ecaa..c580eecec791 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -14,10 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TODO""" +"""Default behavior for ops in mixed_precision pass. Import this file to use.""" +from typing import List + +from tvm import relay from tvm.relay.op import register_mixed_precision_conversion -# Conversion types +# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +# numerical reasons. MIXED_PRECISION_ALWAYS = 0 MIXED_PRECISION_FOLLOW = 1 MIXED_PRECISION_NEVER = 2 @@ -134,34 +140,35 @@ def decorator(func): return decorator -# Functions for FTVMMixedPrecisionConversionType which -# Take in CallNodes and a DType and returns a conversion type, -# an accumulation dtype, and an output_dtype. -def get_generic_dtypes(call_node, mixed_precision_type): - # TODO: examine attributes +def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: + # Assume support accumulation dtypes <---> has out_dtype attr if hasattr(call_node.attrs, "out_dtype"): return ["float32", mixed_precision_type] + # [accumulation_dtype, output_dtype] for the operations return [mixed_precision_type, mixed_precision_type] +# Functions for FTVMMixedPrecisionConversionType which +# Take in CallNodes and a DType and returns a conversion type, +# an accumulation dtype, and an output_dtype. @register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) -def generic_always_op(call_node, mixed_precision_type): - return [MIXED_PRECISION_ALWAYS] + get_generic_dtypes(call_node, mixed_precision_type) +def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) -def generic_follow_op(call_node, mixed_precision_type): - return [MIXED_PRECISION_FOLLOW] + get_generic_dtypes(call_node, mixed_precision_type) +def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) -def generic_never_op(call_node, mixed_precision_type): - return [MIXED_PRECISION_NEVER] + get_generic_dtypes(call_node, mixed_precision_type) +def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_mixed_precision_conversion("nn.batch_matmul") -def nn_batch_matmul(call_node, mixed_precision_type): +def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List: # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. # Batched matmul has inconsistent support for mixed precision operations. # Many schedules ignore the out_dtype attribute which leads to errors when diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c2e2aaf439b0..5ff4da0531f3 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1205,7 +1205,7 @@ def ToMixedPrecision( ): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version - where as many operations as possible are in FP16. + where as many operations as possible are in the target mixed_precision_type. Note this does mutate the original graph putting it in a bad state potentially. From 4c93545c73051bcefc07fc324c513dfd3c44f9b3 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 11:08:26 -0700 Subject: [PATCH 38/59] make copy of attributes --- src/relay/transforms/to_mixed_precision.cc | 83 +++++++++++----------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 23b59d463ad3..ba296825c402 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -84,71 +84,71 @@ class MixedPrecisionPass : public MixedModeMutator { Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ - Attrs new_attrs = Attrs(call->attrs); - if (new_attrs.get() != nullptr) { + Attrs cur_attrs = call->attrs; + if (cur_attrs.get() != nullptr) { // TODO(AndrewZhaoLuo): Figure out a better way to do this // modify output_dtype attributes (accumulation dtypes for ops) - if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); - } else if (auto attrs = new_attrs.as()) { - ModifyAttrsOutputDType(attrs, accumulation_dtype); + if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); } // modify dtype attributes (creating new tensors of type dtype) - if (auto attrs = new_attrs.as()) { - ModifyAttrsDType(attrs, accumulation_dtype); + if (auto attrs = cur_attrs.as()) { + return ModifyAttrsDType(attrs, accumulation_dtype); } } - return new_attrs; + return cur_attrs; } template - void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { + Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { /* Helper template to modify relevant attributes with out_dtype type. These represent accumulation dtypes for some operations e.g. conv2d might take in fp16 and give a fp32 result. Attrs is const because we get it as a const. */ - T* mutable_attrs = const_cast(attrs); - - DataType cur_type = (mutable_attrs->out_dtype); - if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype; + DataType cur_type = (attrs->out_dtype); + ObjectPtr new_attrs = make_object(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype; + return Attrs(new_attrs); } template - void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { + Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { /* Helper template to modify relevant attributes with dtype type. This determines the output dtype for some ops. For example zeros creates a tensor of zeros of the specified dtype. Attrs is const because we get it as a const. */ - T* mutable_attrs = const_cast(attrs); - DataType cur_type = (mutable_attrs->dtype); - - if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype; + DataType cur_type = (attrs->dtype); + ObjectPtr new_attrs = make_object(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype; + return Attrs(new_attrs); } Type GetType(const Expr& expr) const { @@ -271,7 +271,8 @@ class MixedPrecisionPass : public MixedModeMutator { Expr cur_op = post_call_node->op; - // Results are: conversion category (int), accumulation dtype (str), output dtype (str) + // Get info on the operation being called: + // conversion category (int), accumulation dtype (str), output dtype (str) MixedTypeConversionCategory initial_category; DataType accumulation_dtype, output_dtype; if (cur_op.as()) { @@ -284,6 +285,7 @@ class MixedPrecisionPass : public MixedModeMutator { Op::GetAttrMap("FTVMMixedPrecisionConversionType"); Op op = Downcast(cur_op); if (attr_map.count(op)) { + // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; Array op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type)); @@ -296,6 +298,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!"; if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!"; + // If not registered, by default assume is a generic FOLLOW operation. initial_category = MIXED_PRECISION_FOLLOW; accumulation_dtype = DataType::Float(16); output_dtype = DataType::Float(16); From 6aa727d1877bcfac1ad4325317f9c5c387c79656 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 11:13:13 -0700 Subject: [PATCH 39/59] asf header --- python/tvm/relay/transform/mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index c580eecec791..5f7b359b9ead 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -6,7 +6,7 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http:#www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an From 173801bbac7dd0926c2f7530b975656af352798e Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 11:21:51 -0700 Subject: [PATCH 40/59] pylint --- python/tvm/relay/transform/mixed_precision.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 5f7b359b9ead..2839a121ea6b 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=line-too-long,unused-argument """Default behavior for ops in mixed_precision pass. Import this file to use.""" from typing import List @@ -132,7 +133,7 @@ # Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType -def register_func_to_op_list(list_ops=[]): +def register_func_to_op_list(list_ops): def decorator(func): for op_name in list_ops: register_mixed_precision_conversion(op_name, func=func) From f4da2df9c6bb214a2aebcffae40e536c1f8ea0b0 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 13:39:33 -0700 Subject: [PATCH 41/59] remove TODO which is solved --- python/tvm/relay/transform/transform.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 5ff4da0531f3..c44d858fd3cb 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1207,10 +1207,6 @@ def ToMixedPrecision( Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in the target mixed_precision_type. - Note this does mutate the original graph putting it in a bad state potentially. - - TODO(AndrewZhaoLuo): don't mutate the original graph. - Returns ------- ret : tvm.transform.Pass From 7698920985211ba4ad3d3642ecbc824086dac62c Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 15 Jun 2021 15:01:38 -0700 Subject: [PATCH 42/59] Apply nits from code review (comaniac) Co-authored-by: Cody Yu --- python/tvm/relay/op/op.py | 2 +- python/tvm/relay/transform/transform.py | 2 +- src/relay/transforms/to_mixed_precision.cc | 23 ++++++++-------------- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index f6cd6fb53e0e..a955281fdd3e 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -465,7 +465,7 @@ def register_mixed_precision_conversion(op_name, func=None, level=10): converted. Specifically the function should take a call node and the target mixed precision datatype (e.g. FP16) and return the conversion category (see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation - and output datatype of the oepration. + and output datatype of the operation. Parameters ---------- diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c44d858fd3cb..f1e5812fd2c0 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1210,6 +1210,6 @@ def ToMixedPrecision( Returns ------- ret : tvm.transform.Pass - The registered RewriteFP16 pass. + The registered pass. """ return _ffi_api.ToMixedPrecision(mixed_precision_type, ignore_missing_ops, warn_missing_ops) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index ba296825c402..6de6aba2bc18 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -71,7 +71,7 @@ class MixedPrecisionPass : public MixedModeMutator { private: CachedCastNodes cast_nodes_cache; - // The target datatype we want to convert to e.g. FP16 + /*! \brief The target datatype we want to convert to e.g. FP16 */ const DataType mixed_precision_type; // If false, throws a fatal error if an op which is not registered with a @@ -193,9 +193,7 @@ class MixedPrecisionPass : public MixedModeMutator { } const ExprNode* expr_node = expr.as(); - if (!expr_node) { - LOG(FATAL) << "Non-expression node found in cast: " << expr; - } + CHECK(expr_node) << "Non-expression node found in cast: " << expr; // Use cached result if possible. auto search = cast_nodes_cache.find({expr_node, wanted_dtype}); @@ -227,10 +225,8 @@ class MixedPrecisionPass : public MixedModeMutator { all_same &= casted_element.same_as(tuple_element); } return all_same ? expr : Tuple(new_expr); - } else { - LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!"; - return expr; } + CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!"; } std::pair, Array> CastAllArgs(const Array& cur_args, @@ -258,16 +254,15 @@ class MixedPrecisionPass : public MixedModeMutator { mixed_precision_type(mixed_precision_type), ignore_missing_ops(ignore_missing_ops), warn_missing_ops(warn_missing_ops) { - if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) - LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got " + if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) { + LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type; + } } Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { const CallNode* post_call_node = post.as(); - if (!post_call_node) { - LOG(FATAL) << "Expected a CallNode for the rewrite got " << post; - } + CHECK(post_call_node) << "Expected a CallNode, but got " << post; Expr cur_op = post_call_node->op; @@ -324,12 +319,10 @@ class MixedPrecisionPass : public MixedModeMutator { } // Determine the final category we want for conversion - MixedTypeConversionCategory final_category; + MixedTypeConversionCategory final_category = initial_category; if (initial_category == MIXED_PRECISION_FOLLOW) { final_category = all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; - } else { - final_category = initial_category; } // Create the new arguments to the call. From 177f9c4ea186722ece9797a52375f67dff1293da Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 15:02:40 -0700 Subject: [PATCH 43/59] change cast_node_cache --> cast_node_cache_ --- src/relay/transforms/to_mixed_precision.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 6de6aba2bc18..a156099affd1 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -69,7 +69,7 @@ using FTVMMixedPrecisionConversionType = runtime::TypedPackedFuncsecond; } Expr result = Cast(expr, wanted_dtype); - cast_nodes_cache[{expr_node, wanted_dtype}] = result; + cast_nodes_cache_[{expr_node, wanted_dtype}] = result; // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node const ExprNode* new_expr_node = result.as(); - cast_nodes_cache[{new_expr_node, expr_dtype}] = expr; + cast_nodes_cache_[{new_expr_node, expr_dtype}] = expr; return result; } From 8ddabdaf98b8ecfaa4be8b303ad8ae26f7f62ec3 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 15:11:22 -0700 Subject: [PATCH 44/59] add check for returned vals --- src/relay/transforms/to_mixed_precision.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index a156099affd1..a22722c13740 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -227,6 +227,7 @@ class MixedPrecisionPass : public MixedModeMutator { return all_same ? expr : Tuple(new_expr); } CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!"; + return expr; } std::pair, Array> CastAllArgs(const Array& cur_args, @@ -284,13 +285,17 @@ class MixedPrecisionPass : public MixedModeMutator { FTVMMixedPrecisionConversionType func = attr_map[op]; Array op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type)); + ICHECK(op_descriptor.size() == 3) + << "got the wrong number of returned arguments (expected 3) from " + "FTVMMixedPrecisionConversionType for " + << AsText(op, false); int64_t op_conversion_type = Downcast(op_descriptor[0])->value; initial_category = static_cast(op_conversion_type); accumulation_dtype = DataType(String2DLDataType(Downcast(op_descriptor[1]))); output_dtype = DataType(String2DLDataType(Downcast(op_descriptor[2]))); } else { - if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!"; + ICHECK(ignore_missing_ops) << "Op " << op->name << " not in conversion lists!"; if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!"; // If not registered, by default assume is a generic FOLLOW operation. From 78b5b31b7421a31d4853bda419de4efdde20003c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 15:12:12 -0700 Subject: [PATCH 45/59] better error msg --- src/relay/transforms/to_mixed_precision.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index a22722c13740..e342972be80c 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -286,9 +286,8 @@ class MixedPrecisionPass : public MixedModeMutator { Array op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type)); ICHECK(op_descriptor.size() == 3) - << "got the wrong number of returned arguments (expected 3) from " - "FTVMMixedPrecisionConversionType for " - << AsText(op, false); + << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() + << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false); int64_t op_conversion_type = Downcast(op_descriptor[0])->value; initial_category = static_cast(op_conversion_type); From 54d7c3db207fa958e360900ae53293b49500b016 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 15:16:30 -0700 Subject: [PATCH 46/59] docstring for pass in python --- python/tvm/relay/transform/transform.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index f1e5812fd2c0..ada1d6727a3b 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1207,6 +1207,21 @@ def ToMixedPrecision( Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in the target mixed_precision_type. + Parameters + ---------- + mixed_precision_type: str + The target datatype to transform operations in the graph to use. + + ignore_missing_ops: bool + If false, throws an error if an op not registered with + FTVMFakeQuantizationToInteger is encountered during the pass. + + warn_missing_ops: bool + If true, emits a warning if an op not registered with + FTVMFakeQuantizationToInteger is encountered during the pass. + By default, such ops will assume to be of conversion category + MIXED_PRECISION_FOLLOW. + Returns ------- ret : tvm.transform.Pass From 33312243ea79c7129cb33303bc8bca7ea3030203 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 15:17:35 -0700 Subject: [PATCH 47/59] fix default behavior to be proper --- src/relay/transforms/to_mixed_precision.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index e342972be80c..e0d53ebf99b2 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -299,8 +299,8 @@ class MixedPrecisionPass : public MixedModeMutator { // If not registered, by default assume is a generic FOLLOW operation. initial_category = MIXED_PRECISION_FOLLOW; - accumulation_dtype = DataType::Float(16); - output_dtype = DataType::Float(16); + accumulation_dtype = mixed_precision_type; + output_dtype = mixed_precision_type; } } else { LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op; From c781bf23d2dd03533911c4f62282ae65251e2abc Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 16:08:12 -0700 Subject: [PATCH 48/59] better error reporting via single flag --- python/tvm/relay/transform/transform.py | 22 ++++---- src/relay/transforms/to_mixed_precision.cc | 60 +++++++++++++--------- 2 files changed, 46 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index ada1d6727a3b..c3c0cc1b78c5 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1200,9 +1200,7 @@ def FakeQuantizationToInteger(): return _ffi_api.FakeQuantizationToInteger() -def ToMixedPrecision( - mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True -): +def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in the target mixed_precision_type. @@ -1212,19 +1210,17 @@ def ToMixedPrecision( mixed_precision_type: str The target datatype to transform operations in the graph to use. - ignore_missing_ops: bool - If false, throws an error if an op not registered with - FTVMFakeQuantizationToInteger is encountered during the pass. - - warn_missing_ops: bool - If true, emits a warning if an op not registered with - FTVMFakeQuantizationToInteger is encountered during the pass. - By default, such ops will assume to be of conversion category - MIXED_PRECISION_FOLLOW. + missing_op_mode: int + Determines how to handle ops not registered with FTVMMixedPrecisionConversionType + 0: Does not allow any missing ops. Will throw errors when encountering any. + 1: Allow missing ops but throw warnings. + 2: Allow missing ops and silently ignore them. Returns ------- ret : tvm.transform.Pass The registered pass. """ - return _ffi_api.ToMixedPrecision(mixed_precision_type, ignore_missing_ops, warn_missing_ops) + if missing_op_mode < 0 or missing_op_mode > 2: + raise ValueError("Missing op mode is either 0, 1, or 2") + return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index e0d53ebf99b2..c35418466e52 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -74,13 +74,9 @@ class MixedPrecisionPass : public MixedModeMutator { /*! \brief The target datatype we want to convert to e.g. FP16 */ const DataType mixed_precision_type; - // If false, throws a fatal error if an op which is not registered with a - // FTVMMixedPrecisionConversionType is encountered. - bool ignore_missing_ops; - - // If true, emits a warning if an op which is not registered with a - // FTVMMixedPrecisionConversionType is encountered. - bool warn_missing_ops; + // Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were + // encountered. Used for emitting warnings on missing ops in the pass. + std::unordered_map missing_ops; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -249,12 +245,8 @@ class MixedPrecisionPass : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16), - bool ignore_missing_ops = true, bool warn_missing_ops = true) - : MixedModeMutator(), - mixed_precision_type(mixed_precision_type), - ignore_missing_ops(ignore_missing_ops), - warn_missing_ops(warn_missing_ops) { + explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16)) + : MixedModeMutator(), mixed_precision_type(mixed_precision_type) { if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type; @@ -294,8 +286,7 @@ class MixedPrecisionPass : public MixedModeMutator { accumulation_dtype = DataType(String2DLDataType(Downcast(op_descriptor[1]))); output_dtype = DataType(String2DLDataType(Downcast(op_descriptor[2]))); } else { - ICHECK(ignore_missing_ops) << "Op " << op->name << " not in conversion lists!"; - if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!"; + missing_ops[op->name] += 1; // If not registered, by default assume is a generic FOLLOW operation. initial_category = MIXED_PRECISION_FOLLOW; @@ -376,24 +367,47 @@ class MixedPrecisionPass : public MixedModeMutator { Expr body = this->Mutate(op->body); return Let(var, value, body, op->span); } + + // To access map of ops not registered for error reporting + friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, + int missing_op_mode); }; -Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, - bool ignore_missing_ops, bool warn_missing_ops) { - MixedPrecisionPass converter = - MixedPrecisionPass(mixed_precision_type, ignore_missing_ops, warn_missing_ops); +Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) { + /* + missing_op_mode: + + 0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any. + 1: Allow missing ops but throw warnings. + 2: Allow missing ops and silently ignore them. + */ + ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) + << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; + + MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type); auto result = converter.Mutate(expr); + + for (auto it = converter.missing_ops.begin(); + missing_op_mode != 2 && it != converter.missing_ops.end(); it++) { + std::string op_name = it->first; + int appear_count = it->second; + + LOG(WARNING) << "Op \"" << op_name << "\" not registered " + << "FTVMMixedPrecisionConversionType appears " << appear_count << " in graph."; + } + + if (converter.missing_ops.size() != 0 && missing_op_mode == 0) { + CHECK(0) << "Missing ops were found, please fix!"; + } return result; } namespace transform { -Pass ToMixedPrecision(DataType mixed_precision_type, bool ignore_missing_ops, - bool warn_missing_ops) { +Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast( - ToMixedPrecision(f, mixed_precision_type, ignore_missing_ops, warn_missing_ops)); + return Downcast(ToMixedPrecision(f, mixed_precision_type, missing_op_mode)); }; return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {}); } From b513fee81c15ffac3259d138e26eb65941d3b58f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 15 Jun 2021 16:09:25 -0700 Subject: [PATCH 49/59] priority to 0 --- src/relay/transforms/to_mixed_precision.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index c35418466e52..4bf8586fd99d 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -409,7 +409,7 @@ Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { [=](Function f, IRModule m, PassContext pc) { return Downcast(ToMixedPrecision(f, mixed_precision_type, missing_op_mode)); }; - return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {}); + return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } TVM_REGISTER_GLOBAL("relay._transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); From 4fea97895267037c3bbbe54d167bebc54dfd5d69 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 16 Jun 2021 13:28:46 -0700 Subject: [PATCH 50/59] address more nits --- src/relay/transforms/to_mixed_precision.cc | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 4bf8586fd99d..c485ab6bab0e 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -69,14 +69,16 @@ using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc missing_ops; + /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were + * encountered. Used for emitting warnings on missing ops in the pass. + */ + std::unordered_map missing_ops_; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -286,7 +288,7 @@ class MixedPrecisionPass : public MixedModeMutator { accumulation_dtype = DataType(String2DLDataType(Downcast(op_descriptor[1]))); output_dtype = DataType(String2DLDataType(Downcast(op_descriptor[2]))); } else { - missing_ops[op->name] += 1; + missing_ops_[op->name] += 1; // If not registered, by default assume is a generic FOLLOW operation. initial_category = MIXED_PRECISION_FOLLOW; @@ -387,17 +389,18 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type); auto result = converter.Mutate(expr); - for (auto it = converter.missing_ops.begin(); - missing_op_mode != 2 && it != converter.missing_ops.end(); it++) { + for (auto it = converter.missing_ops_.begin(); + missing_op_mode != 2 && it != converter.missing_ops_.end(); it++) { std::string op_name = it->first; int appear_count = it->second; LOG(WARNING) << "Op \"" << op_name << "\" not registered " - << "FTVMMixedPrecisionConversionType appears " << appear_count << " in graph."; + << "FTVMMixedPrecisionConversionType appears " << appear_count + << " times in graph."; } - if (converter.missing_ops.size() != 0 && missing_op_mode == 0) { - CHECK(0) << "Missing ops were found, please fix!"; + if (converter.missing_ops_.size() != 0 && missing_op_mode == 0) { + CHECK(0) << "Missing ops were found!"; } return result; } From 25d8a1de44793460422b6ab95b573032e9796c88 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 16 Jun 2021 13:30:02 -0700 Subject: [PATCH 51/59] fix story telling slightly --- python/tvm/relay/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c3c0cc1b78c5..fa7f4c4db644 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1213,7 +1213,7 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): missing_op_mode: int Determines how to handle ops not registered with FTVMMixedPrecisionConversionType 0: Does not allow any missing ops. Will throw errors when encountering any. - 1: Allow missing ops but throw warnings. + 1: Allow missing ops but emit warnings. 2: Allow missing ops and silently ignore them. Returns From a063994ab2337ec9388caab96c9554305ff7c95c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 16 Jun 2021 13:32:53 -0700 Subject: [PATCH 52/59] restart From 22841f1c88a67caf42ec1bbb18b0d8587137051a Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 00:09:36 -0700 Subject: [PATCH 53/59] correct docstring --- python/tvm/relay/op/op.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index a955281fdd3e..0d90a5cdeafa 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -465,15 +465,18 @@ def register_mixed_precision_conversion(op_name, func=None, level=10): converted. Specifically the function should take a call node and the target mixed precision datatype (e.g. FP16) and return the conversion category (see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation - and output datatype of the operation. + and output datatype of the operation in the mixed precision dtype space. Parameters ---------- op_name : str The name of the operator - func: function (expr: Expr, map: Map) -> new_expr: Expr - The function for translating the op into affine space and integer operators + func: function (call_node: relay.Call, target_dtype: string) + -> [conversion category, accumulation dtype, output dtype]: [int, string, string] + A function which given a call_node and target_dtype (e.g. FP16) returns the + conversion category and associated accumulation/output of the operation + when transformed into the mixed precision dtype space. level : int The priority level From 7a933a5fe900e3b63667039700c844aeff5c3382 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 00:19:35 -0700 Subject: [PATCH 54/59] change class fields to have _ at end --- src/relay/transforms/to_mixed_precision.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index c485ab6bab0e..9c573e05ff19 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -73,7 +73,7 @@ class MixedPrecisionPass : public MixedModeMutator { CachedCastNodes cast_nodes_cache_; /*! \brief The target datatype we want to convert to e.g. FP16 */ - const DataType mixed_precision_type; + const DataType mixed_precision_type_; /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were * encountered. Used for emitting warnings on missing ops in the pass. @@ -166,7 +166,7 @@ class MixedPrecisionPass : public MixedModeMutator { */ if (const TensorTypeNode* tensor_type = t.as()) { return (!ignore_non_float || (tensor_type->dtype).is_float()) && - tensor_type->dtype == mixed_precision_type; + tensor_type->dtype == mixed_precision_type_; } else if (const TupleTypeNode* tuple_type = t.as()) { for (Type t : tuple_type->fields) { if (!IsMixedPrecisionType(t, ignore_non_float)) return false; @@ -248,10 +248,10 @@ class MixedPrecisionPass : public MixedModeMutator { using MixedModeMutator::VisitExpr_; explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16)) - : MixedModeMutator(), mixed_precision_type(mixed_precision_type) { - if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) { + : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) { + if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " - << mixed_precision_type; + << mixed_precision_type_; } } @@ -278,7 +278,7 @@ class MixedPrecisionPass : public MixedModeMutator { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; Array op_descriptor = - func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type)); + func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false); @@ -292,8 +292,8 @@ class MixedPrecisionPass : public MixedModeMutator { // If not registered, by default assume is a generic FOLLOW operation. initial_category = MIXED_PRECISION_FOLLOW; - accumulation_dtype = mixed_precision_type; - output_dtype = mixed_precision_type; + accumulation_dtype = mixed_precision_type_; + output_dtype = mixed_precision_type_; } } else { LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op; @@ -324,7 +324,7 @@ class MixedPrecisionPass : public MixedModeMutator { // Create the new arguments to the call. DataType wanted_arg_dtypes = - final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32); + final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32); auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); Array new_args = call_args_and_types.first; Array new_arg_types; From a1dbb683a0783539f04b97b1623670b1922669e6 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 00:33:23 -0700 Subject: [PATCH 55/59] add class docstring --- src/relay/transforms/to_mixed_precision.cc | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 9c573e05ff19..2d923500c385 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -67,6 +67,32 @@ using CachedCastNodes = std::unordered_map, using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( const Call& call_node, const std::string& target_dtype_str)>; +/*! \brief This class transforms the given relay module into a version where + * as many operations as possible operate in the target mixed precision dtype. + * + * Input : A Relay module with operations registered with FTVMMixedPrecisionConversionType + * functions. These describe when and how the operations will be transformed + * into the target precision dtype. + * + * Output : A Relay module with some operations transformed according to the below + * methodology. + * + * Methodology : + * 1) Each relay Op is either of conversion category ALWAYS, FOLLOW, NEVER + * defined by the associated FTVMMixedPrecisionConversionType function. + * If an operation is not registered, it by default is assumed to be + * FOLLOW. + * 2) ALWAYS operations always convert the input floating point args into + * the target mixed precision dtype. FOLLOW Ops will convert the input + * floating point args back into FP32 unless all floating point args + * are in the target mixed precision dtypes. NEVER ops will always cast + * inputs back into FP32. + * 3) Each ALWAYS Op, and FOLLOW Op with mixed precision dtype arguments + * also have an associated accumulation_dtype and output_dtype which + * describe whether a larger dtype is used to accumulate the results + * of the operation. The output_dtype meanwhile describes the dtype + * most Ops should use from this accumulator. + */ class MixedPrecisionPass : public MixedModeMutator { private: /*! \brief A cache of nodes + target dtype to a cast version of the node with target dtype. */ From 97fbd897ee349caf17d6f735e7dca37f54700018 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 12:01:28 -0700 Subject: [PATCH 56/59] add comment on accumulation dtype hack --- python/tvm/relay/transform/mixed_precision.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 2839a121ea6b..7dae6bf8fcff 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=line-too-long,unused-argument """Default behavior for ops in mixed_precision pass. Import this file to use.""" -from typing import List +from typing import Callable, List from tvm import relay from tvm.relay.op import register_mixed_precision_conversion @@ -133,7 +133,7 @@ # Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType -def register_func_to_op_list(list_ops): +def register_func_to_op_list(list_ops: List): def decorator(func): for op_name in list_ops: register_mixed_precision_conversion(op_name, func=func) @@ -142,7 +142,11 @@ def decorator(func): def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: - # Assume support accumulation dtypes <---> has out_dtype attr + # Assume support accumulation dtypes <---> has out_dtype attr. + # This is because there is no better way right now to tell which ops support accumulating + # at different data types. + # Some discussion here about making this better is here: + # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo if hasattr(call_node.attrs, "out_dtype"): return ["float32", mixed_precision_type] From 64408eef19f9c95b5bd1a26b752bb24b883d1142 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 15:10:29 -0700 Subject: [PATCH 57/59] ADT warnings --- src/relay/transforms/to_mixed_precision.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 2d923500c385..85286ff4e603 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -178,7 +178,6 @@ class MixedPrecisionPass : public MixedModeMutator { Type GetType(const Expr& expr) const { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); - if (expr.as()) { return mod->Lookup("main")->checked_type(); } else { @@ -287,6 +286,12 @@ class MixedPrecisionPass : public MixedModeMutator { Expr cur_op = post_call_node->op; + // Relay's algebraic data types are not supported yet. + ICHECK(!cur_op.as() // used to declare functions for recursion + && !cur_op.as() // constructing ADT types + && !cur_op.as()) // used for calling recursive functions + << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass."; + // Get info on the operation being called: // conversion category (int), accumulation dtype (str), output dtype (str) MixedTypeConversionCategory initial_category; From 98e9cea2ddd1bd1a68916fdb79fd6bee13bfbaab Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 15:29:18 -0700 Subject: [PATCH 58/59] add todo --- src/relay/transforms/to_mixed_precision.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 85286ff4e603..ae10c937ff1c 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -286,6 +286,7 @@ class MixedPrecisionPass : public MixedModeMutator { Expr cur_op = post_call_node->op; + // TODO(AndrewZhaoLuo): Support ADTs // Relay's algebraic data types are not supported yet. ICHECK(!cur_op.as() // used to declare functions for recursion && !cur_op.as() // constructing ADT types From 2634182ae259c59547ff2afa4d1889be9800ddfa Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Jun 2021 20:32:19 -0700 Subject: [PATCH 59/59] fix linter --- python/tvm/relay/transform/mixed_precision.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 7dae6bf8fcff..6aa3ac09cfee 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=line-too-long,unused-argument """Default behavior for ops in mixed_precision pass. Import this file to use.""" -from typing import Callable, List +from typing import List from tvm import relay from tvm.relay.op import register_mixed_precision_conversion @@ -142,6 +142,20 @@ def decorator(func): def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: + """A function which returns output dtypes in a way which works for most ops. + + Parameters + --------- + call_node: relay.Call + The call node containing the op. + mixed_precision_type: str + The target type to run the operation in. + Returns + ------- + output_dtypes : [str, str] + A list of two strings. The first represents the datatype used for accumulation + in the operation. The second represents the actual output datatype. + """ # Assume support accumulation dtypes <---> has out_dtype attr. # This is because there is no better way right now to tell which ops support accumulating # at different data types.