From 93032226c833f0679f7b0e95128952cdfa497118 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Tue, 9 Feb 2021 11:12:33 -0700 Subject: [PATCH] refactor onnx importer to do more static imports by constant folding --- python/tvm/relay/frontend/common.py | 6 ++ python/tvm/relay/frontend/onnx.py | 162 ++++++++++++++++------------ 2 files changed, 99 insertions(+), 69 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 6323c63ab9b32..2db420a409924 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -491,6 +491,12 @@ def infer_type(node, mod=None): return ret +def fold_constant(node, mod=None): + if mod is None: + mod = IRModule.from_expr(node) + return _transform.FoldConstantExpr(node, mod) + + def infer_channels(inputs, transpose=False): """A hack for getting 'channels' or 'units' since caffe2 does not provide these attributes. We check the shape of weights provided to get the number. diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c423598a2ee76..c5729e39b96df 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -34,7 +34,7 @@ from .. import ty as _ty from .common import AttrCvt, Renamer -from .common import get_relay_op, new_var, infer_shape, infer_channels +from .common import get_relay_op, new_var, infer_shape, infer_channels, fold_constant from .common import infer_type, get_name @@ -364,7 +364,7 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d ), dtype="int64", ) - shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim]) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) # get input shape # set up integer constants @@ -545,9 +545,9 @@ class MatMul(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) # Need to check input shape as batch matmul must be supported. - a_shape = _op.shape_of(inputs[0]) + a_shape = shape_of(inputs[0]) a_rank = infer_shape(a_shape)[0] - b_shape = _op.shape_of(inputs[1]) + b_shape = shape_of(inputs[1]) b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: @@ -555,9 +555,13 @@ def _impl_v1(cls, inputs, attr, params): def flatten_to_3d(x, x_shape): ndims = infer_shape(x_shape)[0] newshape = _op.concatenate( - [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0 + [ + _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - 2], [ndims]), + ], + 0, ) - out = _op.reshape(x, newshape) + out = _op.reshape(x, fold_constant(newshape)) return out # Convert a and b into 3 dimensional tensors. @@ -598,7 +602,7 @@ def flatten_to_3d(x, x_shape): ], 0, ) - return _op.reshape(output, final_shape) + return _op.reshape(output, fold_constant(final_shape)) # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) @@ -646,7 +650,7 @@ def _impl_v11(cls, inputs, attr, params): multiplier = _op.concatenate( [_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0 ) - total_output_shape = multiplier * _op.shape_of(data, dtype="int64") + total_output_shape = multiplier * shape_of(data, dtype="int64") # Add extra dimensions from kernel size and stride mismatch total_output_shape += _op.concatenate( [_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0 @@ -792,11 +796,11 @@ def _impl_v2(cls, inputs, attr, params): def _impl_v11(cls, inputs, attr, params): pads = inputs[1] if len(inputs) == 3: - value = _op.take(inputs[2], _op.const(0)) + value = fold_constant(_op.take(inputs[2], _op.const(0))) else: value = 0 - pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1))) + pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1)))) pad_mode = attr.get("mode", b"constant").decode("utf-8") if not pad_mode in ["constant", "edge", "reflect"]: @@ -823,7 +827,7 @@ class Prelu(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) - input_shape = _op.shape_of(inputs[0]) + input_shape = shape_of(inputs[0]) alpha = _op.broadcast_to_like(inputs[1], inputs[0]) alpha = _op.reshape(alpha, [-1]) output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0) @@ -875,7 +879,6 @@ class DepthToSpace(OnnxOpConverter): @classmethod def _impl_v11(cls, inputs, attr, params): - block_size = int(attr["blocksize"]) mode = attr.get("mode", b"DCR").decode("utf-8") return _op.nn.depth_to_space(inputs[0], block_size, mode=mode) @@ -1015,8 +1018,9 @@ def _impl_v9(cls, inputs, attr, params): scales = params[inputs[1].name_hint].asnumpy() else: scales = inputs[1] - - if not isinstance(scales, _expr.Call): + if isinstance(scales, _expr.Constant): + scales = list(scales.data.asnumpy()) + if not isinstance(scales, _expr.Expr): assert scales[0] == 1.0 and scales[1] == 1.0 mode = attr.get("mode") @@ -1067,12 +1071,19 @@ def _impl_v9(cls, inputs, attr, params): return out +def shape_of(x, dtype="int64"): + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + return _expr.const([i for i in ttype.shape], dtype) + return _op.shape_of(x, "int64") + + class Shape(OnnxOpConverter): """Operator converter for Shape.""" @classmethod def _impl_v1(cls, inputs, attr, params): - return _op.shape_of(inputs[0], "int64") + return shape_of(inputs[0], "int64") class Cast(OnnxOpConverter): @@ -1182,7 +1193,7 @@ def _impl_v10(cls, inputs, attr, params): # Update the starts and ends according to axes if required. if axes is not None: - data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) + data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) starts = _op.scatter( _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype), axes, @@ -1201,7 +1212,9 @@ def _impl_v10(cls, inputs, attr, params): if steps is None: steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype) - return _op.strided_slice(inputs[0], starts, ends, steps) + return _op.strided_slice( + inputs[0], fold_constant(starts), fold_constant(ends), fold_constant(steps) + ) class Gather(OnnxOpConverter): @@ -1509,6 +1522,20 @@ def _impl_v9(cls, inputs, attr, params): return output +class Constant(OnnxOpConverter): + """Operator converter for ConstantOfShape.""" + + @classmethod + def _impl_v9(cls, inputs, attr, params): + if "value" in attr: + np_value = get_numpy(attr.pop("value")) + dtype = np_value.dtype.name + value = _expr.const(np_value, dtype) + return value + else: + raise "No Value in Constant" + + class Sign(OnnxOpConverter): """Operator converter for Sign.""" @@ -1569,12 +1596,14 @@ def _impl_v9(cls, inputs, attr, params): # to that shape. max_rank = max(ranks) max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank] - broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]]) + broadcast_shape = shape_of(inputs[max_rank_idxs[0]]) # If two or more inputs have the same rank, compute the broadcast # shape by taking the maximum value of each dimensions. if len(max_rank_idxs) > 1: for idx in max_rank_idxs: - broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx])) + broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx])) + + broadcast_shape = fold_constant(broadcast_shape) condition = _op.broadcast_to(inputs[0], broadcast_shape) x = _op.broadcast_to(inputs[1], broadcast_shape) @@ -1596,7 +1625,7 @@ class Expand(OnnxOpConverter): @classmethod def _impl_v8(cls, inputs, attr, params): dtype = infer_type(inputs[1]).checked_type.dtype - in_shape = _op.shape_of(inputs[0], dtype=dtype) + in_shape = shape_of(inputs[0], dtype=dtype) shape = inputs[1] # Currently 'op.broadcast_to' expect the rank of the given 'shape' @@ -1645,7 +1674,7 @@ def expand_shape(in_shape, shape): new_shape = _op.maximum(in_shape, shape) return new_shape - shape = expand_shape(in_shape, shape) + shape = fold_constant(expand_shape(in_shape, shape)) return _op.broadcast_to(inputs[0], shape=shape) @@ -1920,10 +1949,10 @@ def _impl_v10(cls, inputs, attr, params): ) scale = inputs[1] - size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + size = _op.cast(size, "int64") layout = "NCHW" # ONNX assumes NCHW layout - out_size = _op.strided_slice(size, [2], [4]) + out_size = fold_constant(_op.strided_slice(size, [2], [4])) return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric") @classmethod @@ -1947,7 +1976,8 @@ def _impl_v11(cls, inputs, attr, params): size = inputs[3] else: assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + size = _op.cast(size, "int64") coord_trans = attr.get("coordinate_transformation_mode") if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]: @@ -1961,7 +1991,7 @@ def _impl_v11(cls, inputs, attr, params): "Unsupported coordinate_transformation_mode: {}".format(coord_trans) ) layout = "NCHW" # ONNX assumes NCHW layout - out_size = _op.strided_slice(size, [2], [4]) + out_size = fold_constant(_op.strided_slice(size, [2], [4])) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) @@ -2224,7 +2254,7 @@ def body_fn(*loop_inputs): expand_scan = _op.expand_dims(new_scan, axis=0) # For non scalar outputs we need to broadcast the initial value. if rank > 0: - new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) + new_scan_shape = shape_of(new_scan, dtype=iter_dtype) scan_broadcast = _op.concatenate( [_op.reshape(loop_count, [1]), new_scan_shape], axis=0 ) @@ -2446,9 +2476,9 @@ def _first_body( # partially prepare ONNX output format by labeling batch_num, class_id nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1) - batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64")) + batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64")) batch_num = _op.expand_dims(batch_num, -1, 1) - class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64")) + class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64")) new_onnx_out = _op.concatenate( [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1 ) @@ -2548,7 +2578,7 @@ def _outer_body(i, B, C, onnx_out, nms_size_out, out): ) # Call the first loop, perform NMS - B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3) + B, C, S = _op.split(shape_of(scores, dtype="int64"), 3) init_count = _op.const(np.array([0]), dtype="int64") init_onnx_out = _op.const([1], dtype="int64") init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0)) @@ -2595,6 +2625,7 @@ def _get_convert_map(opset): "ThresholdedRelu": ThresholdedRelu.get_converter(opset), "ScaledTanh": ScaledTanh.get_converter(opset), "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), + "Constant": Constant.get_converter(opset), "ConstantOfShape": ConstantOfShape.get_converter(opset), # 'GivenTensorFill' "FC": AttrCvt("dense", ignores=["axis", "axis_w"]), @@ -2827,12 +2858,16 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): for init_tensor in graph.initializer: if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") - self._params[init_tensor.name] = self._parse_array(init_tensor) - self._nodes[init_tensor.name] = new_var( - init_tensor.name, - shape=self._params[init_tensor.name].shape, - dtype=self._params[init_tensor.name].dtype, - ) + if freeze_params: + array = self._parse_array(init_tensor) + self._nodes[init_tensor.name] = _expr.const(array) + else: + self._params[init_tensor.name] = self._parse_array(init_tensor) + self._nodes[init_tensor.name] = new_var( + init_tensor.name, + shape=self._params[init_tensor.name].shape, + dtype=self._params[init_tensor.name].dtype, + ) for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' @@ -2844,6 +2879,8 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): self._nodes[i_name] = new_var( i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype ) + elif i_name in self._nodes: + continue else: self._num_input += 1 if i_name in self._shape: @@ -2886,37 +2923,27 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): for i in node.input: if i != "": inputs[i] = self._nodes[self._renames.get(i, i)] - if op_name == "Constant": - t_proto = self._parse_attr(node.attribute)["value"] - self._num_param += 1 - # We should convert scalar integers to int32, to normalize. - array = self._parse_array(t_proto) - self._params[node.output[0]] = array - self._nodes[node.output[0]] = new_var( - node.output[0], shape=list(t_proto.dims), dtype=array.dtype - ) + i_name = self._parse_value_proto(node) + node_output = self._fix_outputs(op_name, node.output) + attr["tvm_custom"] = {} + attr["tvm_custom"]["name"] = i_name + attr["tvm_custom"]["num_outputs"] = len(node_output) + + op = self._convert_operator(op_name, inputs, attr, opset) + if not isinstance(op, _expr.TupleWrapper): + outputs_num = 1 else: - i_name = self._parse_value_proto(node) - node_output = self._fix_outputs(op_name, node.output) - attr["tvm_custom"] = {} - attr["tvm_custom"]["name"] = i_name - attr["tvm_custom"]["num_outputs"] = len(node_output) - - op = self._convert_operator(op_name, inputs, attr, opset) - if not isinstance(op, _expr.TupleWrapper): - outputs_num = 1 - else: - outputs_num = len(op) - assert ( - len(node_output) == outputs_num - ), "Number of output mismatch {} vs {} in {}.".format( - len(node_output), outputs_num, op_name - ) - if outputs_num == 1: - self._nodes[node_output[0]] = op - else: - for k, i in zip(list(node_output), range(len(node_output))): - self._nodes[k] = op[i] + outputs_num = len(op) + assert ( + len(node_output) == outputs_num + ), "Number of output mismatch {} vs {} in {}.".format( + len(node_output), outputs_num, op_name + ) + if outputs_num == 1: + self._nodes[node_output[0]] = op + else: + for k, i in zip(list(node_output), range(len(node_output))): + self._nodes[k] = op[i] # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] @@ -2934,9 +2961,6 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): self._inputs[i_name] = self._nodes[i_name] # Create a function from our output expression and all input variables. func = _function.Function([v for k, v in self._inputs.items()], outputs) - if freeze_params: - func, params = self.freeze(func, self._params) - return IRModule.from_expr(func), params return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto):