From b28e7e201f9d959be4df95325789638bb9c8e782 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 4 Aug 2020 18:26:46 -0700 Subject: [PATCH 01/17] Refactor ONNX frontend to be dynamic Make OneHot dynamic Support BatchMatMul with dynamically shaped inputs fix dynamic broadcast Add null checks to broadcast_to rel functions fail more isolated broadcast_to test use StructuralEqual instead of pointer comparisions in dynamic_to_static pass add an optional weight freeze argument to onnx importer convert onnx resize to dynamic op add dynamic expand to onnx importer add a shape_func for power fix BERTSquad, lint handle onnx graph initializer parameters more intelligently --- include/tvm/relay/transform.h | 11 + include/tvm/topi/broadcast.h | 11 +- python/tvm/relay/frontend/onnx.py | 535 +++++++----------- python/tvm/relay/op/_tensor.py | 1 + python/tvm/relay/op/nn/_nn.py | 17 + python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/strategy/x86.py | 15 +- python/tvm/topi/cuda/batch_matmul.py | 2 +- python/tvm/topi/nn/batch_matmul.py | 25 +- python/tvm/topi/x86/batch_matmul.py | 2 +- src/relay/backend/build_module.cc | 1 + src/relay/op/dyn/tensor/transform.cc | 18 +- src/relay/op/nn/nn.cc | 29 +- src/relay/op/nn/nn.h | 8 +- src/relay/op/tensor/transform.cc | 14 +- src/relay/transforms/dynamic_to_static.cc | 6 +- tests/python/frontend/onnx/test_forward.py | 114 ++-- .../relay/dyn/test_dynamic_op_level10.py | 78 ++- tests/python/relay/test_op_level10.py | 26 + 19 files changed, 454 insertions(+), 461 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index de2bcc4f4318..493c3e027b12 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -208,6 +208,17 @@ TVM_DLL Pass SimplifyInference(); */ TVM_DLL Pass FastMath(); +/*! + * \brief Find Dynamic ops and make them static + * + * Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces + * them with static ops and re-performs type inference and constant folding. The pass repeats + * istself until the graph stops changing or we run too many iterations. + * + * \return The pass. + */ +TVM_DLL Pass DynamicToStatic(); + /*! * \brief Infer the type of an expression. * diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 8fabaaee14f9..d03ddc93b4c0 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -54,14 +54,19 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); CHECK_EQ(output_shape.size(), bh.common_shape.size()); + Array oshape; for (size_t i = 0; i < output_shape.size(); ++i) { - CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); + if (output_shape[i].as() == nullptr) { + oshape.push_back(output_shape[i]); + } else { + CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); + oshape.push_back(bh.common_shape[i]); + } } auto l = [&](tvm::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; - return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, name, tag); + return tvm::te::compute(oshape, l, name, tag); } #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ea39010df066..a663da046d63 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -38,23 +38,13 @@ from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels -from .common import infer_type, get_name -from .common import infer_value as _infer_value -from .common import infer_value_simulated as _infer_value_simulated +from .common import infer_type, get_name, infer_value, infer_value_simulated __all__ = ['from_onnx'] -g = None - -def infer_value(input_val, params, mod=None): - return g.infer_value(input_val, params, mod) - -def infer_value_simulated(input_val, params): - return g.infer_value_simulated(input_val, params) class onnx_input(): """ Dual purpose list or dictionary access object.""" - def __init__(self): self.input_keys = [] self.input_dict = {} @@ -107,8 +97,7 @@ def get_numpy(tensor_proto): try: from onnx.numpy_helper import to_array except ImportError as e: - raise ImportError( - "Unable to import onnx which is required {}".format(e)) + raise ImportError("Unable to import onnx which is required {}".format(e)) return to_array(tensor_proto) @@ -136,8 +125,7 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise tvm.error.OpAttributeInvalid( - 'Number of pads must be either 2 or 4.') + raise tvm.error.OpAttributeInvalid('Number of pads must be either 2 or 4.') return pads @@ -192,7 +180,6 @@ def _dim_check(attrs): class OnnxOpConverter(object): """ A helper class for holding onnx op converters. """ - @classmethod def get_converter(cls, opset): """ Get converter matches given opset. @@ -207,17 +194,13 @@ def get_converter(cls, opset): converter, which should be `_impl_vx`. Number x is the biggest number smaller than or equal to opset belongs to all support versions. """ - versions = [ - int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d - ] + versions = [int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d] versions = sorted(versions + [opset]) - version = versions[ - max([i for i, v in enumerate(versions) if v == opset]) - 1] + version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1] if hasattr(cls, '_impl_v{}'.format(version)): return getattr(cls, '_impl_v{}'.format(version)) - raise NotImplementedError( - 'opset version {} of {} not implemented'.format( - version, cls.__name__)) + raise NotImplementedError('opset version {} of {} not implemented'.format( + version, cls.__name__)) class Unary(OnnxOpConverter): @@ -240,8 +223,7 @@ class Elemwise(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format( - cls.name, len(inputs)) + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) op_name = cls.name conv_ops = ["conv2d", "conv2d_transpose"] if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops): @@ -286,14 +268,13 @@ def _impl_v1(cls, inputs, attr, params): else: attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2)) - return AttrCvt( - op_name=dimension_picker(cls.name), - transforms={ - 'kernel_shape': 'pool_size', - 'pads': ('padding', 0) - }, - ignores=['dilations', 'storage_order'], - custom_check=dimension_constraint())(inputs, attr, params) + return AttrCvt(op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', 0) + }, + ignores=['dilations', 'storage_order'], + custom_check=dimension_constraint())(inputs, attr, params) class Absolute(Unary): @@ -317,21 +298,18 @@ class AveragePool(Pool): class BatchNorm(OnnxOpConverter): """ Operator converter for BatchNorm. """ - @classmethod def _impl_v1(cls, inputs, attr, params): # TODO(zhreshold): 'spatial' is not properly handled here. - out = AttrCvt( - op_name='batch_norm', - ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr, - params) + out = AttrCvt(op_name='batch_norm', + ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr, + params) return out[0] class InstanceNorm(OnnxOpConverter): """ Operator converter for BatchNorm. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name='instance_norm')(inputs, attr, params) @@ -340,7 +318,6 @@ def _impl_v1(cls, inputs, attr, params): class Conv(OnnxOpConverter): """ Operator converter for Conv. """ - @classmethod def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. @@ -379,15 +356,14 @@ def _impl_v1(cls, inputs, attr, params): if sym_pad: attr['pads'] = padding[0::2] - out = AttrCvt( - op_name=dimension_picker('conv'), - transforms={ - 'kernel_shape': 'kernel_size', - 'dilations': ('dilation', 1), - 'pads': ('padding', 0), - 'group': ('groups', 1) - }, - custom_check=dimension_constraint())(inputs[:2], attr, params) + out = AttrCvt(op_name=dimension_picker('conv'), + transforms={ + 'kernel_shape': 'kernel_size', + 'dilations': ('dilation', 1), + 'pads': ('padding', 0), + 'group': ('groups', 1) + }, + custom_check=dimension_constraint())(inputs[:2], attr, params) use_bias = len(inputs) == 3 if use_bias: @@ -428,15 +404,14 @@ def _impl_v1(cls, inputs, attr, params): raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) attr.pop('auto_pad') - out = AttrCvt( - op_name=dimension_picker('conv', '_transpose'), - transforms={ - 'kernel_shape': 'kernel_size', - 'dilations': ('dilation', (0, 0)), - 'pads': ('padding', (0, 0), revert_caffe2_pad) - }, - disables=['output_shape'], - custom_check=dimension_constraint())(inputs[:2], attr, params) + out = AttrCvt(op_name=dimension_picker('conv', '_transpose'), + transforms={ + 'kernel_shape': 'kernel_size', + 'dilations': ('dilation', (0, 0)), + 'pads': ('padding', (0, 0), revert_caffe2_pad) + }, + disables=['output_shape'], + custom_check=dimension_constraint())(inputs[:2], attr, params) use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) @@ -452,7 +427,6 @@ class Div(Elemwise): class Elu(OnnxOpConverter): """ Operator converter for Elu. """ - @classmethod def _impl_v1(cls, inputs, attr, params): alpha = float(attr.get('alpha', 1.0)) @@ -463,11 +437,9 @@ def _impl_v1(cls, inputs, attr, params): class Gemm(OnnxOpConverter): """ Operator converter for Gemm. """ - @classmethod def _impl_v1(cls, inputs, attr, params): - assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format( - len(inputs)) + assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs)) # Y = alpha * A * B + beta * C alpha = float(attr.get('alpha', 1.0)) beta = float(attr.get('beta', 1.0)) @@ -495,30 +467,43 @@ def _impl_v1(cls, inputs, attr, params): class MatMul(OnnxOpConverter): """ Operator converter for MatMul. """ - @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) # Need to check input shape as batch matmul must be supported. - a_shape = infer_shape(inputs[0]) + a_shape = _op.shape_of(inputs[0]) # When performing a batch matmul, we need to properly handle N-dim shapes. - if len(a_shape) > 2: - b_shape = infer_shape(inputs[1]) + if infer_shape(a_shape)[0] > 2: + b_shape = _op.shape_of(inputs[1]) + + def flatten_to_3d(x, x_shape): + ndims = infer_shape(x_shape)[0] + newshape = _op.concatenate( + [_expr.const([-1]), + _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0) + out = _op.reshape(x, newshape) + return out + # Convert a and b into 3 dimensional tensors. - a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]]) - b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]]) + a = flatten_to_3d(inputs[0], a_shape) + b = flatten_to_3d(inputs[1], b_shape) # Broadcast b to match batch size of a - new_b_shape = list(infer_shape(b)) - new_a_shape = infer_shape(a) - if new_a_shape[0] > new_b_shape[0]: - new_b_shape[0] = new_a_shape[0] - b = _op.broadcast_to(b, new_b_shape) + new_b_shape = _op.concatenate([ + _op.strided_slice(_op.shape_of(a), [0], [1]), + _op.strided_slice(_op.shape_of(b), [1], [3]) + ], 0) + b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) # Reshape output to original dimensions. - return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + final_shape = _op.concatenate([ + _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), + _op.strided_slice(b_shape, [infer_shape(b_shape)[0] - 1], + [infer_shape(b_shape)[0]]) + ], 0) + return _op.reshape(output, final_shape) # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) @@ -527,7 +512,6 @@ def _impl_v1(cls, inputs, attr, params): class Mod(OnnxOpConverter): """ Operator converter for Mod. """ - @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs)) @@ -548,6 +532,7 @@ class MaxPool(Pool): """ name = 'max_pool' + class LpPool(OnnxOpConverter): """ A helper class for lppool op converters. """ @@ -609,29 +594,28 @@ class Mul(Elemwise): class Pad(OnnxOpConverter): """ Operator converter for Pad. """ - @classmethod def _impl_v1(cls, inputs, attr, params): pad_width = [] pads = attr.pop('paddings') dims = int(len(pads) / 2) for i in range(dims): - pad_width.append((pads[i], pads[i+dims])) + pad_width.append((pads[i], pads[i + dims])) attr['pad_width'] = pad_width pad_mode = attr.get('mode', b'constant').decode('utf-8') if pad_mode in ['constant', 'edge', 'reflect']: attr['pad_mode'] = pad_mode attr.pop('mode', None) else: - raise tvm.error.OpAttributeInvalid( - 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') + raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt( _op.nn.pad, transforms={ 'value': 'pad_value', }, - )(inputs, attr, params) + )(inputs, attr, params) @classmethod def _impl_v2(cls, inputs, attr, params): @@ -639,22 +623,22 @@ def _impl_v2(cls, inputs, attr, params): pads = attr.pop('pads') dims = int(len(pads) / 2) for i in range(dims): - pad_width.append((pads[i], pads[i+dims])) + pad_width.append((pads[i], pads[i + dims])) attr['pad_width'] = pad_width pad_mode = attr.get('mode', b'constant').decode('utf-8') if pad_mode in ['constant', 'edge', 'reflect']: attr['pad_mode'] = pad_mode attr.pop('mode', None) else: - raise tvm.error.OpAttributeInvalid( - 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') + raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt( 'pad', transforms={ 'value': 'pad_value', }, - )(inputs, attr, params) + )(inputs, attr, params) @classmethod def _impl_v11(cls, inputs, attr, params): @@ -667,25 +651,22 @@ def _impl_v11(cls, inputs, attr, params): attr["pad_value"] = value dims = int(len(pads) / 2) for i in range(dims): - pad_width.append((pads[i], pads[i+dims])) + pad_width.append((pads[i], pads[i + dims])) attr['pad_width'] = pad_width pad_mode = attr.get('mode', b'constant').decode('utf-8') if pad_mode in ['constant', 'edge', 'reflect']: attr['pad_mode'] = pad_mode attr.pop('mode', None) else: - raise tvm.error.OpAttributeInvalid( - 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') + raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt('pad')(inputs[:1], attr, params) - - class ParametricSoftPlus(OnnxOpConverter): """ Operator converter for ParametricSoftPlus. """ - @classmethod def _impl_v1(cls, inputs, attr, params): alpha = _expr.const(float(attr.get('alpha', 1.0))) @@ -696,13 +677,12 @@ def _impl_v1(cls, inputs, attr, params): class Prelu(OnnxOpConverter): """ Operator converter for Prelu. """ - @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) alpha_shape = infer_shape(inputs[1]) if len(alpha_shape) != 1: - alpha = _op.reshape(inputs[1], (-1,)) + alpha = _op.reshape(inputs[1], (-1, )) else: alpha = inputs[1] return _op.nn.prelu(inputs[0], alpha) @@ -711,7 +691,6 @@ def _impl_v1(cls, inputs, attr, params): class Reciprocal(OnnxOpConverter): """ Operator converter for Reciprocal. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return _expr.const(1.0) / inputs[0] @@ -720,7 +699,6 @@ def _impl_v1(cls, inputs, attr, params): class Flatten(OnnxOpConverter): """ Operator converter for Flatten. """ - @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 1) @@ -736,7 +714,6 @@ def _impl_v1(cls, inputs, attr, params): class Reshape(OnnxOpConverter): """ Operator converter for Reshape. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return _op.reshape(inputs[0], attr['shape']) @@ -748,17 +725,13 @@ def _impl_v5(cls, inputs, attr, params): shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32")) out = _op.reshape(inputs[0], shape) else: - data, shape = inputs - static_shape = infer_value_simulated(shape, params) - out = _op.reshape(data, newshape=tuple( - static_shape.asnumpy().astype('int32'))) + out = _op.reshape(*inputs) return out class DepthToSpace(OnnxOpConverter): """ Operator converter for DepthToSpace. """ - @classmethod def _impl_v11(cls, inputs, attr, params): @@ -770,7 +743,6 @@ def _impl_v11(cls, inputs, attr, params): class SpaceToDepth(OnnxOpConverter): """ Operator converter for SpaceToDepth. """ - @classmethod def _impl_v1(cls, inputs, attr, params): @@ -781,15 +753,14 @@ def _impl_v1(cls, inputs, attr, params): class Concat(OnnxOpConverter): """ Operator converter for Concat. """ - @classmethod def _impl_v1(cls, inputs, args, params): - return AttrCvt(op_name='concatenate')((inputs,), args) + return AttrCvt(op_name='concatenate')((inputs, ), args) + class Scale(OnnxOpConverter): """ Operator converter for Scale. """ - @classmethod def _impl_v1(cls, inputs, attr, params): scale = float(attr.get('scale', 1.0)) @@ -799,20 +770,18 @@ def _impl_v1(cls, inputs, attr, params): class Selu(OnnxOpConverter): """ Operator converter for Selu. """ - @classmethod def _impl_v1(cls, inputs, attr, params): alpha = float(attr.get('alpha', 1.6732)) gamma = float(attr.get('gamma', 1.0507)) - return _expr.const(gamma) * (_expr.const(-alpha) * - _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + - _op.nn.relu(inputs[0])) + return _expr.const(gamma) * ( + _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0])) class ScaledTanh(OnnxOpConverter): """ Operator converter for ScaledTanh. """ - @classmethod def _impl_v1(cls, inputs, attr, params): alpha = float(attr.get('alpha', 1.0)) @@ -823,7 +792,6 @@ def _impl_v1(cls, inputs, attr, params): class SoftPlus(OnnxOpConverter): """ Operator converter for SoftPlus. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return _op.log(_op.exp(inputs[0]) + _expr.const(1.)) @@ -832,7 +800,6 @@ def _impl_v1(cls, inputs, attr, params): class Softsign(OnnxOpConverter): """ Operator converter for Softsign. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params)) @@ -847,7 +814,6 @@ class Sub(Elemwise): class Sum(OnnxOpConverter): """ Operator converter for Sum. """ - @classmethod def _impl_v1(cls, inputs, attr, params): # Onnx Sum Operator @@ -860,7 +826,6 @@ def _impl_v1(cls, inputs, attr, params): class Affine(OnnxOpConverter): """ Operator converter for Affine transformation. """ - @classmethod def _impl_v1(cls, inputs, attr, params): alpha = _expr.const(attr.get('alpha', 1.0)) @@ -871,7 +836,6 @@ def _impl_v1(cls, inputs, attr, params): class ThresholdedRelu(OnnxOpConverter): """ Operator converter for ThresholdedRelu. """ - @classmethod def _impl_v1(cls, inputs, attr, params): alpha = float(attr.get('alpha', 1.0)) @@ -881,7 +845,6 @@ def _impl_v1(cls, inputs, attr, params): def _broadcast_constraint(): - def _broadcast_check(attrs): if attrs.get('axis', None): return False @@ -891,7 +854,6 @@ def _broadcast_check(attrs): def _fully_connected(opset): - def _impl(inputs, attr, params): # get number of channels channels = infer_channels(inputs[1], params) @@ -904,7 +866,6 @@ def _impl(inputs, attr, params): class Upsample(OnnxOpConverter): """ Operator converter for Upsample (nearest mode). """ - @classmethod def _impl_v9(cls, inputs, attr, params): scales = attr.get('scales') @@ -927,9 +888,7 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scale_h': scales[-2], - 'scale_w': scales[-1], - 'method': method} + attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method} if dims == 5: assert len(scales) == 5 attr['scale_d'] = scales[-3] @@ -945,18 +904,18 @@ def _impl_v9(cls, inputs, attr, params): op_name = 'upsampling' return AttrCvt(op_name)(inputs, attr) + class Shape(OnnxOpConverter): """ Operator converter for Shape. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return _op.shape_of(inputs[0], "int64") + class Cast(OnnxOpConverter): """ Operator converter for Cast. """ - @classmethod def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) @@ -967,15 +926,13 @@ def _impl_v5(cls, inputs, attr, params): from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']]) except ImportError as e: - raise ImportError( - "Unable to import onnx.mapping which is required {}".format(e)) + raise ImportError("Unable to import onnx.mapping which is required {}".format(e)) return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) class Unsqueeze(OnnxOpConverter): """ Operator converter for Unsqueeze. """ - @classmethod def _impl_v1(cls, inputs, attr, params): for axes in attr['axes']: @@ -986,7 +943,6 @@ def _impl_v1(cls, inputs, attr, params): class Split(OnnxOpConverter): """ Operator converter for Split. """ - @classmethod def _impl_v1(cls, inputs, attr, params): splits = attr.get('split', False) @@ -998,17 +954,13 @@ def _impl_v1(cls, inputs, attr, params): attr['indices_or_sections'].append(index) # When splits isnt specified divide evenly over axis. else: - in_shape = infer_shape(inputs[0]) - attr['indices_or_sections'] = in_shape[attr['axis']] - return AttrCvt( - 'split', - ignores=['split'])(inputs, attr, params) + attr['indices_or_sections'] = attr['tvm_custom']['num_outputs'] + return AttrCvt('split', ignores=['split'])(inputs, attr, params) class Slice(OnnxOpConverter): """ Operator converter for Slice. """ - @classmethod def _common(cls, starts, ends, axes): new_axes = [] @@ -1030,16 +982,16 @@ def _common(cls, starts, ends, axes): @classmethod def _impl_v1(cls, inputs, attr, params): if isinstance(attr['starts'], int): - attr['starts'] = (attr['starts'],) - attr['ends'] = (attr['ends'],) + attr['starts'] = (attr['starts'], ) + attr['ends'] = (attr['ends'], ) try: # Update the starts and ends according to axes if required. if isinstance(attr['axes'], int): - attr['axes'] = (attr['axes'],) + attr['axes'] = (attr['axes'], ) if (max(attr['axes']) + 1) != len(attr['axes']): - new_starts, new_ends, new_axes = cls._common( - attr['starts'], attr['ends'], attr['axes']) + new_starts, new_ends, new_axes = cls._common(attr['starts'], attr['ends'], + attr['axes']) attr['axes'] = new_axes attr['starts'] = new_starts attr['ends'] = new_ends @@ -1054,19 +1006,21 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v10(cls, inputs, attr, params): - attrs = {'starts' : inputs[1], 'ends' : inputs[2]} + attrs = {'starts': inputs[1], 'ends': inputs[2]} if len(inputs) >= 4: attrs['axes'] = inputs[3] - attrs = {k : (v, get_name(v)) for (k, v) in attrs.items()} - attrs = {k : params[v[1]].asnumpy() if v[1] in params else - infer_value_simulated(v[0], params).asnumpy() - for (k, v) in attrs.items()} + attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()} + attrs = { + k: params[v[1]].asnumpy() + if v[1] in params else infer_value_simulated(v[0], params).asnumpy() + for (k, v) in attrs.items() + } # Update the starts and ends according to axes if required. if 'axes' in attrs: if max(attrs['axes'] + 1) != len(attrs['axes']): - new_starts, new_ends, _ = cls._common( - attrs['starts'], attrs['ends'], attrs['axes']) + new_starts, new_ends, _ = cls._common(attrs['starts'], attrs['ends'], + attrs['axes']) attrs['starts'] = new_starts attrs['ends'] = new_ends return _op.strided_slice(inputs[0], @@ -1080,8 +1034,7 @@ class Gather(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 0) - return AttrCvt('take', - extras={'axis': axis})(inputs, {}) + return AttrCvt('take', extras={'axis': axis})(inputs, {}) class GatherND(OnnxOpConverter): @@ -1095,7 +1048,6 @@ def _impl_v1(cls, inputs, attr, params): class Scatter(OnnxOpConverter): """ Operator converter for Scatter. """ - @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 0) @@ -1134,6 +1086,7 @@ def _impl_v1(cls, inputs, attr, params): attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias} return AttrCvt('lrn')(inputs, attr) + class Maximum(OnnxOpConverter): """ Operator converter for Maximum. """ @@ -1146,6 +1099,7 @@ def _impl_v1(cls, inputs, attr, params): _max = AttrCvt('maximum')([_max, inputs[i]], {}) return _max + class Minimum(OnnxOpConverter): """ Operator converter for Minimum. """ @@ -1158,6 +1112,7 @@ def _impl_v1(cls, inputs, attr, params): _min = AttrCvt('minimum')([_min, inputs[i]], {}) return _min + class Mean(OnnxOpConverter): """ Operator converter for Mean. """ @@ -1169,6 +1124,7 @@ def _impl_v1(cls, inputs, attr, params): concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0) return _op.mean(concat, axis=0, keepdims=False) + class HardSigmoid(OnnxOpConverter): """ Operator converter for HardSigmoid. """ @@ -1180,10 +1136,12 @@ def _impl_v1(cls, inputs, attr, params): attr = {'a_min': 0, 'a_max': 1} return AttrCvt('clip')([transformX], attr) + class Reduce(OnnxOpConverter): """ Operator converter for reduce ops. """ name = '' + @classmethod def _impl_v1(cls, inputs, attr, params): if 'axes' in attr: @@ -1194,31 +1152,37 @@ def _impl_v1(cls, inputs, attr, params): attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} return AttrCvt(cls.name)(inputs, attr) + class ReduceMax(Reduce): """ Operator converter for ReduceMax. """ name = 'max' + class ReduceMin(Reduce): """ Operator converter for ReduceMin. """ name = 'min' + class ReduceSum(Reduce): """ Operator converter for ReduceSum. """ name = 'sum' + class ReduceMean(Reduce): """ Operator converter for ReduceMean. """ name = 'mean' + class ReduceProd(Reduce): """ Operator converter for ReduceProd. """ name = 'prod' + class ReduceLogSumExp(Reduce): """ Operator converter for ReduceLogSumExp. """ @@ -1300,6 +1264,7 @@ def _impl_v1(cls, inputs, attr, params): attr = {'axis': axis, 'keepdims': keepdims} return AttrCvt('argmax')(inputs, attr) + class ArgMin(OnnxOpConverter): """ Operator converter for ArgMin. """ @@ -1310,6 +1275,7 @@ def _impl_v1(cls, inputs, attr, params): attr = {'axis': axis, 'keepdims': keepdims} return AttrCvt('argmin')(inputs, attr) + class Softmax(OnnxOpConverter): """ Operator converter for Softmax. """ @@ -1329,21 +1295,13 @@ def _impl_v9(cls, inputs, attr, params): # Extract relay one_hot inputs. indices, depth, values = inputs # Split onnx on off values into two separate expressions. - off_value, on_value = _op.take( - values, _op.const(0)), _op.take(values, _op.const(1)) + off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) # Extract the datatype of the output from on_value. dtype = infer_type(on_value).checked_type.dtype - # Convert depth into an integer. - depth = int(infer_value(depth, params).asnumpy()[0]) # set default value when axis is not set in the model if 'axis' not in attr: attr['axis'] = -1 - return _op.one_hot(indices, - on_value, - off_value, - depth, - int(attr['axis']), - dtype=dtype) + return _op.one_hot(indices, on_value, off_value, depth, int(attr['axis']), dtype=dtype) class ConstantOfShape(OnnxOpConverter): @@ -1358,9 +1316,7 @@ def _impl_v9(cls, inputs, attr, params): else: value = _expr.const(0) dtype = 'float32' - static_shape = infer_value_simulated(inputs[0], params) - output = _op.full( - value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype) + output = _op.full(value, inputs[0], dtype=dtype) return output @@ -1371,6 +1327,7 @@ class Sign(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _op.sign(inputs[0]) + class Equal(Elemwise): """ Operator converter for Equal. """ @@ -1406,9 +1363,8 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v6(cls, inputs, attr, params): - reps = tuple(infer_value_simulated( - inputs[1], params).asnumpy().astype('int32')) - return _op.tile(inputs[0], reps) + return _op.tile(inputs[0], inputs[1]) + class Erf(OnnxOpConverter): """Operator converter for Erf @@ -1417,6 +1373,7 @@ class Erf(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _op.erf(inputs[0]) + class Where(OnnxOpConverter): """Operator converter for Where """ @@ -1451,6 +1408,7 @@ def _impl_v9(cls, inputs, attr, params): inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape) return _op.where(inputs[0], inputs[1], inputs[2]) + class Or(Elemwise): """ Operator converter for Or. """ @@ -1464,11 +1422,8 @@ class Expand(OnnxOpConverter): """ @classmethod def _impl_v8(cls, inputs, attr, params): - in_shape = np.array(infer_shape(inputs[0])).astype('int32') - if get_name(inputs[1]) in params: - shape = params[inputs[1].name_hint].asnumpy().astype('int32') - else: - shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32') + in_shape = _op.shape_of(inputs[0]) + shape = inputs[1] # Currently 'op.broadcast_to' expect the rank of the given 'shape' # (the 2nd input) is always higher than that of the given 'input' (the 1st input) @@ -1483,34 +1438,27 @@ def expand_shape(in_shape, shape): intput. Also it replaces the extent of the shape with the corresponding extent of the intput when it is 1. """ - - # here we flip the shapes because this can be more simply written - # when the innermost dimension is located at the index 0. - in_shape = np.flip(in_shape, axis=0) - shape = np.flip(shape, axis=0) - - if in_shape.size < shape.size: - for i in range(shape.size): - if i < in_shape.size and in_shape[i] > shape[i]: - shape[i] = in_shape[i] - else: - for i in range(in_shape.size): - if i >= shape.size: - np.append(shape, in_shape[i]) - elif shape[i] == 1: - shape[i] = in_shape[i] - - new_shape = np.flip(shape, axis=0) + in_dims = infer_shape(in_shape)[0] + new_dims = infer_shape(shape)[0] + if in_dims < new_dims: + in_shape = _op.concatenate([_expr.const([ + 1, + ] * (new_dims - in_dims)), in_shape], + axis=0) + elif new_dims > in_dims: + shape = _op.concatenate([_expr.const([ + 1, + ] * (in_dims - new_dims)), shape], axis=0) + new_shape = _op.maximum(in_shape, shape) return new_shape shape = expand_shape(in_shape, shape) - return _op.broadcast_to(inputs[0], shape=tuple(shape)) + return _op.broadcast_to(inputs[0], shape=shape) class RNN(OnnxOpConverter): """ Operator converter for RNNs such as LSTM and GRU. """ - @classmethod def _activation_helper(cls, activation, alpha, beta): convert_map = _get_convert_map(1) @@ -1546,7 +1494,6 @@ def _activation_needs_beta(cls, activation): class LSTM(RNN): """Operator converter for LSTM """ - @classmethod def _impl_v7(cls, inputs, attr, params): # Unpack inputs, note that if optional and not provided then value will be None. @@ -1596,8 +1543,7 @@ def _impl_v7(cls, inputs, attr, params): if 'activations' in attr: activations = attr['activations'] if len(activations) != 3: - raise NotImplementedError( - "LSTM assumes 3 activation functions are provided") + raise NotImplementedError("LSTM assumes 3 activation functions are provided") alpha_loc = 0 alphas = attr.get('activation_alpha', []) if isinstance(alphas, float): @@ -1611,12 +1557,10 @@ def _impl_v7(cls, inputs, attr, params): alpha = None beta = None activation = activations[i] - if cls._activation_needs_alpha( - activation) and len(alphas) > alpha_loc: + if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: alpha = alphas[alpha_loc] alpha_loc += 1 - if cls._activation_needs_beta( - activation) and len(betas) > beta_loc: + if cls._activation_needs_beta(activation) and len(betas) > beta_loc: beta = betas[beta_loc] beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) @@ -1663,7 +1607,6 @@ def _impl_v7(cls, inputs, attr, params): class GRU(RNN): """Operator convert for GRU """ - @classmethod def _impl_v7(cls, inputs, attr, params): # Unpack inputs, note that if optional and not provided then value will be None. @@ -1704,8 +1647,7 @@ def _impl_v7(cls, inputs, attr, params): if 'activations' in attr: activations = attr['activations'] if len(activations) != 2: - raise NotImplementedError( - "GRU assumes 2 activation functions are provided") + raise NotImplementedError("GRU assumes 2 activation functions are provided") alpha_loc = 0 alphas = attr.get('activation_alpha', []) if isinstance(alphas, float): @@ -1719,12 +1661,10 @@ def _impl_v7(cls, inputs, attr, params): alpha = None beta = None activation = activations[i] - if cls._activation_needs_alpha( - activation) and len(alphas) > alpha_loc: + if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: alpha = alphas[alpha_loc] alpha_loc += 1 - if cls._activation_needs_beta( - activation) and len(betas) > beta_loc: + if cls._activation_needs_beta(activation) and len(betas) > beta_loc: beta = betas[beta_loc] beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) @@ -1785,14 +1725,16 @@ def _impl_v11(cls, inputs, attr, params): raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)) - in_size = np.array(infer_shape(inputs[0])) - scale = infer_value_simulated(inputs[2], params).asnumpy() + scale = inputs[2] + scale_shape = infer_shape(scale) if len(inputs) == 4: - assert len(scale) == 0, "One of scale or size should be passed, not both." - size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32) + assert len(scale_shape) == 0 or scale_shape[ + 0] == 0, "One of scale or size should be passed, not both." + size = inputs[3] else: - assert len(scale) != 0, "One of scale or size should be passed." - size = (in_size * scale).astype(np.int32) + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(_op.shape_of(inputs[0]), + infer_type(scale).type_annotation.dtype) * scale coord_trans = attr.get('coordinate_transformation_mode') if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: @@ -1805,7 +1747,7 @@ def _impl_v11(cls, inputs, attr, params): raise tvm.error.OpAttributeInvalid( 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) layout = "NCHW" # ONNX assumes NCHW layout - out_size = (size[2], size[3]) + out_size = _op.strided_slice(size, [2], [4]) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) @@ -1822,6 +1764,7 @@ def _impl_v9(cls, inputs, attr, params): output = _op.cast(output, "int64") return _op.transpose(output, axes=(1, 0)) + class TopK(OnnxOpConverter): """Operator converter for TopK """ @@ -1835,9 +1778,7 @@ def _impl_v1(cls, inputs, attr, params): if largest == 0: raise ValueError("TVM only supports finding TopK largest elements") - K = int(infer_value(inputs[1], params).asnumpy()[0]) - - return _op.topk(inputs[0], k=K, axis=axis) + return _op.topk(inputs[0], inputs[1], axis=axis) class MaxRoiPool(OnnxOpConverter): @@ -1875,12 +1816,12 @@ def _impl_v1(cls, inputs, attr, params): spatial_scale = attr.get("spatial_scale", 1.0) batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1) - batch_indices = _op.cast( - batch_indices, infer_type(rois).type_annotation.dtype) + batch_indices = _op.cast(batch_indices, infer_type(rois).type_annotation.dtype) rois = _op.concatenate([batch_indices, rois], 1) - return _vision.roi_align(x, rois, [output_height, output_width], - spatial_scale, sampling_ratio) + return _vision.roi_align(x, rois, [output_height, output_width], spatial_scale, + sampling_ratio) + class Clip(OnnxOpConverter): """Operator converter for Clip. @@ -2069,7 +2010,8 @@ def _get_convert_map(opset): 'NonZero': NonZero.get_converter(opset), } -class GraphProto(ExprFunctor): + +class GraphProto(): """A helper class for handling Relay expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -2081,113 +2023,25 @@ class GraphProto(ExprFunctor): dtype : str or dict of str to str The input types to the graph """ - def __init__(self, shape, dtype): self._nodes = {} self._params = {} + self._inputs = {} self._renames = {} self._num_input = 0 self._num_param = 0 self._shape = shape if shape else {} self._dtype = dtype - #For infering Values - self._tmp_params = {} - self._infer_simulated = True - self._mod = None - super(GraphProto, self).__init__() - - def infer_value(self, input_val, params, mod=None): - self._tmp_params = params - self._infer_simulated = False - self._mod = mod - return self.visit(input_val).data - - def infer_value_simulated(self, input_val, params): - self._tmp_params = params - self._infer_simulated = True - return self.visit(input_val).data - - def infer(self, expr): - if self._infer_simulated: - out = _infer_value_simulated(expr, self._tmp_params) - else: - out = _infer_value(expr, self._tmp_params) - return _expr.const(out.asnumpy()) - - def visit_function(self, fn): - new_params = [self.visit(x) for x in fn.params] - new_body = self.visit(fn.body) - return self.infer(Function( - list(new_params), - new_body, - fn.ret_type, - fn.type_params, - fn.attrs)) - - def visit_let(self, let): - newvar = self.visit(let.var) - newval = self.visit(let.value) - newbody = self.visit(let.body) - return self.infer(Let(newvar, newval, newbody)) - - def visit_call(self, call): - new_fn = self.visit(call.op) - new_args = [self.visit(arg) for arg in call.args] - call = Call(new_fn, new_args, call.attrs) - if new_fn == _op.get("nn.batch_norm"): - return call - return self.infer(call) - - def visit_var(self, var): - return self.infer(var) - - def visit_global_id(self, global_var): - return self.infer(global_var) - - def visit_if(self, ite): - return self.infer(If( - self.visit(ite.cond), - self.visit(ite.true_branch), - self.visit(ite.false_branch))) - - def visit_tuple(self, tup): - return Tuple([self.visit(field) for field in tup.fields]) - - def visit_tuple_getitem(self, op): - tuple_value = self.visit(op.tuple_value) - if not tuple_value.same_as(op.tuple_value): - return self.infer(TupleGetItem(tuple_value, op.index)) - return self.infer(op) - - def visit_global_var(self, gvar): - return self.infer(gvar) - - def visit_op(self, op): - return op - - def visit_constant(self, const): - return const - - def visit_constructor(self, con): - return con - - def visit_match(self, m): - return self.infer(Match( - self.visit(m.data), - [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], - complete=m.complete)) - - def visit_ref_create(self, r): - return RefCreate(self.visit(r.value)) - - def visit_ref_write(self, r): - return RefWrite(self.visit(r.ref), self.visit(r.value)) - - def visit_ref_read(self, r): - return RefRead(self.visit(r.ref)) - - def from_onnx(self, graph, opset): + def freeze(self, func, params): + bind_map = {} + for name in params.keys(): + bind_map[self._nodes[name]] = _expr.const(params[name]) + body = _expr.bind(func.body, bind_map) + fn = _function.Function(analysis.free_vars(body), body) + return fn, {} + + def from_onnx(self, graph, opset, freeze_params=False): """Construct Relay expression from ONNX graph. Onnx graph is a python protobuf object. @@ -2243,6 +2097,7 @@ def from_onnx(self, graph, opset): else: dtype = d_type self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) + self._inputs[i_name] = self._nodes[i_name] # get list of unsupported ops convert_map = _get_convert_map(opset) unsupported_ops = set() @@ -2271,17 +2126,17 @@ def from_onnx(self, graph, opset): # We should convert scalar integers to int32, to normalize. array = self._parse_array(t_proto) self._params[node.output[0]] = array - self._nodes[node.output[0]] = new_var( - node.output[0], - shape=list(t_proto.dims), - dtype=array.dtype) + self._nodes[node.output[0]] = new_var(node.output[0], + shape=list(t_proto.dims), + dtype=array.dtype) else: i_name = self._parse_value_proto(node) + node_output = self._fix_outputs(op_name, node.output) attr['tvm_custom'] = {} attr['tvm_custom']['name'] = i_name + attr['tvm_custom']['num_outputs'] = len(node_output) op = self._convert_operator(op_name, inputs, attr, opset) - node_output = self._fix_outputs(op_name, node.output) if not isinstance(op, _expr.TupleWrapper): outputs_num = 1 else: @@ -2298,7 +2153,18 @@ def from_onnx(self, graph, opset): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _function.Function(analysis.free_vars(outputs), outputs) + ## Maintain the order of inputs and parametersfrom the ONNX graph, but only include + ## those parameters that are needed to execute the relay graph + free_vars = analysis.free_vars(outputs) + nodes = {v:k for k,v in self._nodes.items()} + free_vars = [nodes[var] for var in free_vars] + for i_name in self._params: + if i_name in free_vars and i_name not in self._inputs: + self._inputs[i_name] = self._nodes[i_name] + func = _function.Function([v for k,v in self._inputs.items()], outputs) + if freeze_params: + func, params = self.freeze(func, self._params) + return IRModule.from_expr(func), params return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto): @@ -2341,21 +2207,15 @@ def _parse_attr(self, attr_proto): attrs[a.name] = tuple(getattr(a, f)) for f in ['g']: if a.HasField(f): - raise NotImplementedError( - "Filed {} is not supported in relay.".format(f)) + raise NotImplementedError("Filed {} is not supported in relay.".format(f)) for f in ['graphs']: if list(getattr(a, f)): - raise NotImplementedError( - "Filed {} is not supported in relay.".format(f)) + raise NotImplementedError("Filed {} is not supported in relay.".format(f)) if a.name not in attrs: raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) return attrs - def _convert_operator(self, - op_name, - inputs, - attrs, - opset): + def _convert_operator(self, op_name, inputs, attrs, opset): """Convert ONNX operator into a Relay operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. @@ -2382,8 +2242,7 @@ def _convert_operator(self, elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs, self._params) else: - raise NotImplementedError( - "Operator {} not implemented.".format(op_name)) + raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym def _fix_outputs(self, op_name, outputs): @@ -2397,10 +2256,8 @@ def _fix_outputs(self, op_name, outputs): outputs = outputs[:-1] return outputs -def from_onnx(model, - shape=None, - dtype="float32", - opset=None): + +def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False): """Convert a ONNX model into an equivalent Relay Function. ONNX graphs are represented as Python Protobuf objects. @@ -2445,7 +2302,6 @@ def from_onnx(model, warnings.warn(str(e)) except ImportError: pass - global g g = GraphProto(shape, dtype) graph = model.graph if opset is None: @@ -2453,6 +2309,5 @@ def from_onnx(model, opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 - mod, params = g.from_onnx(graph, opset) - g = None + mod, params = g.from_onnx(graph, opset, freeze_params) return mod, params diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index c81d4c51c502..9f04d011b6cb 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -215,6 +215,7 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("multiply", False, broadcast_shape_func) register_shape_func("divide", False, broadcast_shape_func) register_shape_func("floor_divide", False, broadcast_shape_func) +register_shape_func("power", False, broadcast_shape_func) register_shape_func("mod", False, broadcast_shape_func) register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 43fca6d5f80f..0fa742700637 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -811,6 +811,23 @@ def dense_shape_func(attrs, inputs, _): ret = [_dense_shape_func(inputs[0], inputs[1])] return ret +@script +def _batch_matmul_shape_func(data_shape, weight_shape): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0] - 1): + out[i] = data_shape[i] + out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2] + + return out + +@reg.register_shape_func("nn.batch_matmul", False) +def batch_matmul_shape_func(attrs, inputs, _): + """ + Shape function for dense op. + """ + ret = [_batch_matmul_shape_func(inputs[0], inputs[1])] + return ret + @script def _pad_shape_func(data_shape, pad_width): out = output_tensor((data_shape.shape[0],), "int64") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 69c9bd7caeff..879034684cc9 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -593,7 +593,7 @@ def dense_strategy(attrs, inputs, out_type, target): def wrap_compute_batch_matmul(topi_compute): """wrap batch_matmul topi compute""" def _compute_batch_matmul(attrs, inputs, out_type): - return [topi_compute(inputs[0], inputs[1])] + return [topi_compute(inputs[0], inputs[1], out_type.shape)] return _compute_batch_matmul @override_native_generic_func("batch_matmul_strategy") diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index d30b6a43984f..7a10bff403f6 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -21,6 +21,7 @@ import re from tvm import topi from tvm.te import SpecializedCondition +from tvm.relay.ty import is_dynamic from .generic import * from .. import op as _op @@ -305,10 +306,16 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), - wrap_topi_schedule(topi.x86.schedule_batch_matmul), - name="batch_matmul.x86", - plevel=10) + if is_dynamic(out_type): + strategy.add_implementation(wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), + name="batch_matmul.generic", + plevel=10) + else: + strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), + wrap_topi_schedule(topi.x86.schedule_batch_matmul), + name="batch_matmul.x86", + plevel=10) if "cblas" in target.libs: strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas), diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 373b1ec10111..1f6a9a293096 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -25,7 +25,7 @@ from ..util import traverse_inline, get_const_tuple, get_max_power2_factor @autotvm.register_topi_compute("batch_matmul.cuda") -def batch_matmul(cfg, x, y): +def batch_matmul(cfg, x, y, out_shape=None): """Compute conv2d with NCHW layout""" return nn.batch_matmul(x, y) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 0d9f3510d097..6a41504e0972 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -19,7 +19,7 @@ from tvm import te from ..util import get_const_tuple -def batch_matmul(x, y): +def batch_matmul(x, y, oshape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -36,14 +36,19 @@ def batch_matmul(x, y): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - x_shape = get_const_tuple(x.shape) - y_shape = get_const_tuple(y.shape) - assert x_shape[0] == y_shape[0], "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch, M, K = x.shape - N = y.shape[1] - k = te.reduce_axis((0, K), name='k') - return te.compute((batch, M, N), + if oshape is None: + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + batch, M, K = x.shape + N = y.shape[1] + k = te.reduce_axis((0, K), name='k') + oshape = (batch, M, N) + else: + _, _, K = x.shape + k = te.reduce_axis((0, K), name='k') + return te.compute(oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag='batch_matmul') diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 539a918f1f87..a2dccb6d9489 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -25,7 +25,7 @@ @autotvm.register_topi_compute("batch_matmul.x86") -def batch_matmul(cfg, x, y): +def batch_matmul(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 533619ec8a19..fdca7dfdf0a8 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -254,6 +254,7 @@ class RelayBuildModule : public runtime::ModuleNode { Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); + pass_seqs.push_back(transform::DynamicToStatic()); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index de1cc5a4ed95..4b594ffccfa5 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -58,6 +58,11 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, Array oshape; const auto* newshape = types[1].as(); + if (newshape == nullptr) { + CHECK(types[1].as()) + << "reshape: expect input type to be TensorType but get " << types[1]; + return false; + } // Doesn't support dynamic output rank for (int i = 0; i < newshape->shape[0].as()->value; i++) { @@ -209,10 +214,17 @@ bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs // types = [data_type, broadcast_shape_type, ret_type] CHECK_EQ(types.size(), 3); - const auto* target_shape = types[1].as(); - DataType out_dtype = types[0].as()->dtype; + const auto* input_type = types[0].as(); + const auto* target_type = types[1].as(); + if (target_type == nullptr) { + return false; + } + if (input_type == nullptr) { + return false; + } + auto out_dtype = input_type->dtype; // rank must be static - const IntImmNode* rank = target_shape->shape[0].as(); + const IntImmNode* rank = target_type->shape[0].as(); CHECK(rank) << "Target shape must have static rank"; // rank must be static even in dyn pass // could add support for dyn rank in futures diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 19348c018dbf..3f1d2ba9a6da 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -851,15 +851,26 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "BatchDot: batch dimension doesn't match, " - << " x shape=" << x->shape << ", y shape=" << y->shape; - CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) - << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y->shape; - - Array oshape = x->shape; - oshape.Set(2, y->shape[1]); + bool is_dyn = false; + Array oshape; + for (size_t i = 0; i < 3; ++i) { + if (x->shape[i].as() != nullptr || y->shape[i].as() != nullptr) { + is_dyn = true; + oshape.push_back(Any()); + } else { + oshape.push_back(x->shape[i]); + } + } + if (!is_dyn) { + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) + << "BatchDot: batch dimension doesn't match, " + << " x shape=" << x->shape << ", y shape=" << y->shape; + CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) + << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape << ", y shape=" << y->shape; + + oshape.Set(2, y->shape[1]); + } // assign output type reporter->Assign(types[2], TensorType(oshape, x->dtype)); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 0fb02638db07..e7f5a4b9d618 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -63,9 +63,11 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; Array wshape = weight->shape; CHECK(static_cast(weight->shape.size()) == 2); - CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) - << "DenseRel: input dimension doesn't match," - << " data shape=" << data->shape << ", weight shape=" << weight->shape; + if (!data->shape.back().as()) { + CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) + << "DenseRel: input dimension doesn't match," + << " data shape=" << data->shape << ", weight shape=" << weight->shape; + } oshape.Set((oshape.size() - 1), wshape[0]); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 293875ebf6ea..acc19768d348 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1783,9 +1783,9 @@ bool SqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, if (p.second) { result_shape.push_back(p.first); } else { - const int64_t* axis_ptr = tir::as_const_int(p.first); - CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor"; - CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; + if (const int64_t* axis_ptr = tir::as_const_int(p.first)) { + CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; + } } } } @@ -1989,9 +1989,13 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const StridedSliceAttrs* param = attrs.as(); - CHECK(param != nullptr); + if (param == nullptr) { + return false; + } const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + return false; + } auto dshape = data->shape; int64_t num_axis = dshape.size(); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 113b599579ab..07358f1955fb 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -236,13 +236,13 @@ Expr DynamicToStatic(Function f, IRModule m) { expr = mutator.Mutate(m->functions[gv]); m->Update(gv, Downcast(expr)); i += 1; - } while (pre != expr && i < 1000); + } while (!StructuralEqual()(pre, expr) && i < 1000); return expr; } namespace transform { -Pass ConvertDynamicToStatic() { +Pass DynamicToStatic() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(DynamicToStatic(f, m)); @@ -251,7 +251,7 @@ Pass ConvertDynamicToStatic() { } TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() { - return ConvertDynamicToStatic(); + return DynamicToStatic(); }); } // namespace transform diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 394c7458cc05..12a49e280e29 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -44,17 +44,19 @@ def get_input_data_shape_dict(graph_def, input_data): return input_names, shape_dict -def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None): +def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freeze_params=False): """ Generic function to execute and get tvm output with vm executor""" + if not isinstance(input_data, list): + input_data = [input_data] + input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) - _, shape_dict = get_input_data_shape_dict(graph_def, input_data) - - mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) + mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset, freeze_params=freeze_params) ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target) - indata = tvm.nd.array(input_data) - result = ex.evaluate()(indata) - return result.asnumpy() + result = ex.evaluate()(*input_data) + if isinstance(result, tvm.runtime.NDArray): + return result.asnumpy() + return [r.asnumpy() for r in result] def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None): @@ -161,7 +163,8 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_expand(): def _test_expand(name, data, shape, ref_data): @@ -185,7 +188,7 @@ def _test_expand(name, data, shape, ref_data): model = helper.make_model(graph, producer_name=name) for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 'float32') + tvm_out = get_tvm_output_with_vm(model, data, target, ctx, freeze_params=True) tvm.testing.assert_allclose(ref_data, tvm_out) @@ -738,13 +741,14 @@ def test_gather_nd(): verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32') -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_onehot(): indices_shape = [10] indices_array = np.random.randint( low=0, high=9, size=indices_shape, dtype='int32') depth = 10 - values = np.asarray([0, 1]) + values = np.asarray([0, 1]).astype("int32") out_np = np.eye(depth)[indices_array.reshape(-1)] onehot_node = helper.make_node( @@ -758,15 +762,13 @@ def test_onehot(): TensorProto.INT32, [1]), helper.make_tensor_value_info("values", TensorProto.INT32, values.shape)], - initializer=[helper.make_tensor("depth", TensorProto.INT32, [1], [depth]), - helper.make_tensor("values", TensorProto.INT32, values.shape, values)], outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)]) model = helper.make_model(graph, producer_name="onehot_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [indices_array], target, ctx, out_np.shape) + tvm_out = get_tvm_output_with_vm( + model, [indices_array, np.array([depth]).astype("int32"), values], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -816,11 +818,12 @@ def verify_batch_matmul(a_shape, b_shape): model = helper.make_model(graph, producer_name='matmul_test') for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_array, b_array], target, ctx, out_np.shape) + tvm_out = get_tvm_output_with_vm( + model, [a_array, b_array], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_batch_matmul(): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4)) verify_batch_matmul((2, 4, 3), (3, 4)) @@ -1367,21 +1370,19 @@ def verify_constantofshape(input_dim, value, dtype): outputs=[ helper.make_tensor_value_info("output", TensorProto.FLOAT, list(out.shape)) - ], - initializer=[ - helper.make_tensor("input", TensorProto.INT32, (len(input_dim), ), - input_dim) ]) model = helper.make_model(graph, producer_name='fill_test') for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [], target, ctx, out.shape) + input_np = np.array(input_dim).astype("float32") + tvm_out = get_tvm_output_with_vm(model, [input_np], target, ctx) tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_constantofshape(): verify_constantofshape((2, 3, 4, 5), 10, 'float32') verify_constantofshape((3, 3), 0, 'int32') @@ -1577,20 +1578,28 @@ def test_all_reduce_funcs(): axis=(1,), keepdims=keepdims) -def verify_split(indata, outdatas, split, axis=0): +def verify_split(indata, outdatas, split, axis=0, pass_split=True): indata = np.array(indata).astype(np.float32) outdatas = [np.array(o).astype(np.float32) for o in outdatas] if split: split_index = range(len(split)) else: split_index = range(len(outdatas)) - node = helper.make_node( - 'Split', - inputs=['input'], - outputs=['output_{}'.format(i) for i in range(len(split_index))], - axis=axis, - split=split - ) + if pass_split: + node = helper.make_node( + 'Split', + inputs=['input'], + outputs=['output_{}'.format(i) for i in range(len(split_index))], + axis=axis, + split=split + ) + else: + node = helper.make_node( + 'Split', + inputs=['input'], + outputs=['output_{}'.format(i) for i in range(len(split_index))], + axis=axis, + ) graph = helper.make_graph([node], 'split_test', inputs=[helper.make_tensor_value_info("input", @@ -1601,13 +1610,17 @@ def verify_split(indata, outdatas, split, axis=0): ]) model = helper.make_model(graph, producer_name='split_test') + import onnxruntime.backend + rep = onnxruntime.backend.prepare(model, 'CPU') + onnx_out = rep.run(indata) + for target, ctx in tvm.testing.enabled_targets(): output_shape = [o.shape for o in outdatas] output_type = ['float32', 'float32', 'float32'] tvm_out = get_tvm_output( model, indata, target, ctx, output_shape, output_type) - for o, t in zip(outdatas, tvm_out): - tvm.testing.assert_allclose(o, t) + for o, t in zip(onnx_out, tvm_out): + tvm.testing.assert_allclose(o, t) @tvm.testing.uses_gpu @@ -1615,13 +1628,15 @@ def test_split(): # 1D verify_split([1., 2., 3., 4., 5., 6.], [ [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0) + verify_split([1., 2., 3., 4., 5., 6.], [ + [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0, False) verify_split([1., 2., 3., 4., 5., 6.], [ [1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0) # 2D verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]], [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1) # Split evenly (unstack) - verify_split([1, 2, 3], [[1], [2], [3]], False) + verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) @tvm.testing.uses_gpu @@ -2008,24 +2023,19 @@ def verify_tile_v6(indata, repeats, outdata): outputs=[ helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape)) - ], - initializer=[ - helper.make_tensor("repeats", TensorProto.INT64, - list(repeats.shape), repeats) ]) model = helper.make_model(graph, producer_name='tile_test') for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [indata], + tvm_out = get_tvm_output_with_vm(model, [indata, repeats], target, ctx, - outdata.shape, opset=6) tvm.testing.assert_allclose(outdata, tvm_out) - -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) repeats = np.random.randint( @@ -2196,9 +2206,11 @@ def verify_batch_norm(in_shape): verify_batch_norm([16, 16, 10, 10]) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_batch_norm_dynamic_subgraph(): def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): + batchnorm = onnx.helper.make_node('BatchNormalization', inputs=["x", "scale", "B", "mean", "var"], outputs=['Y']) @@ -2233,7 +2245,7 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): mean = np.random.uniform(size=in_shape[1]).astype('float32') var = np.random.uniform(size=in_shape[1]).astype('float32') onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], 'float32')[0] - tvm_out = get_tvm_output(model, [x, inp, scale, b, mean, var], target, ctx, in_shape, 'float32') + tvm_out = get_tvm_output_with_vm(model, [x, inp, scale, b, mean, var], target, ctx) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) @@ -3043,7 +3055,8 @@ def test_gru(): rnn_type='GRU') -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_resize(): def verify(ishape, oshape, scales, mode, coord_trans): nodes = [ @@ -3064,7 +3077,6 @@ def verify(ishape, oshape, scales, mode, coord_trans): if oshape == []: oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] - graph = helper.make_graph(nodes, "resize_test", inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], @@ -3075,7 +3087,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): for target, ctx in tvm.testing.enabled_targets(): x = np.random.uniform(size=ishape).astype('float32') onnx_out = get_onnxruntime_output(model, x, 'float32') - tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) + tvm_out = get_tvm_output_with_vm(model, x, target, ctx, opset=11, freeze_params=True) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) @@ -3136,18 +3148,16 @@ def verify_topk(input_dims, K, axis=-1): "topk_test", inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), helper.make_tensor_value_info("K", TensorProto.INT64, [1,])], - initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])], outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)]) model = helper.make_model(graph, producer_name='topk_test') indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) - onnx_out = get_onnxruntime_output(model, [indata, k]) + onnx_out = get_onnxruntime_output(model, [indata, np.array([K])]) for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], - output_dtype=['float32', 'int64']) + tvm_out = get_tvm_output_with_vm(model, [indata, np.array(K)], target, ctx) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) for n in [12, 32]: diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index e3c8c9eb0bea..8f8eb2ecbc9e 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -27,33 +27,59 @@ import random import tvm.testing -# TODO(mbrookhart): Enable when VM supports heterogenus execution +# TODO(mbrookhart): Enable when the VM supports heterogenus execution # @tvm.testing.uses_gpu -def test_dyn_broadcast_to(): - dtype = 'uint8' - rank = 3 - shape_type = 'int64' - dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) - x_shape = (1, ) - x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) - z = relay.broadcast_to(x, dyn_shape) - zz = run_infer_type(z) - - assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype) - - func = relay.Function([x, dyn_shape], z) - - x = np.random.uniform(size=x_shape).astype(dtype) - dyn_shape = (1, ) * rank - ref_res = np.broadcast_to(x, dyn_shape) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - -# TODO(mbrookhart): Enable when VM supports heterogenus execution +def test_broadcast_to(): + def verify_more_dynamic_broadcast_to(x_shape, out_shape): + rank = len(out_shape) + dtype = 'float32' + shape_type = 'int64' + reshape_shape = relay.Var("shape", relay.ty.TensorType((len(x_shape), ), shape_type)) + broadcast_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) + x = relay.Var("x", relay.ty.TensorType((np.prod(x_shape), ), dtype)) + r = relay.reshape(x, reshape_shape) + z = relay.broadcast_to(r, broadcast_shape) + + func = relay.Function([x, reshape_shape, broadcast_shape], z) + + x = np.random.uniform(size=np.prod(x_shape)).astype(dtype) + ref_res = np.broadcast_to(np.reshape(x, x_shape), out_shape) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type)) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3)) + + def verify_broadcast_to(x_shape, out_shape): + rank = len(out_shape) + dtype = 'float32' + shape_type = 'int64' + dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) + x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) + z = relay.broadcast_to(x, dyn_shape) + zz = run_infer_type(z) + + assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype) + + func = relay.Function([x, dyn_shape], z) + + x = np.random.uniform(size=x_shape).astype(dtype) + ref_res = np.broadcast_to(x, out_shape) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x, np.array(out_shape).astype(shape_type)) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_broadcast_to((1,), (1, 1, 1)) + verify_broadcast_to((1, 1), (4, 1, 1)) + verify_broadcast_to((4, 1), (1, 4, 3)) + +# TODO(mbrookhart): Enable when the VM supports heterogenus execution # @tvm.testing.uses_gpu def test_dyn_one_hot(): def _get_oshape(indices_shape, depth, axis): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 3aaa76d771d3..5f62669b0505 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -348,6 +348,32 @@ def test_batch_matmul(): verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) +def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): + x = relay.var("x", relay.TensorType(x_shape, dtype)) + y = relay.var("y", relay.TensorType((relay.Any(), ) * len(y_shape), dtype)) + z = relay.nn.batch_matmul(x, y) + + func = relay.Function([x, y], z) + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np) + + for target, ctx in ctx_list(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + z = intrp.evaluate()(x_np, y_np) + tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) + +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu +def test_dynamic_batch_matmul(): + verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) + verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) + verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) + verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) + + @tvm.testing.uses_gpu def test_shape_of(): shape = (10, 5, 12) From c1e993c876db7de73a08cb651385490d3c0ef151 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 27 Aug 2020 13:53:05 -0700 Subject: [PATCH 02/17] Dynamic ONNX importer: Upsampling and Pad (#2) fix lint fix Call reference fix a type issue with expand fix a bad test refactor respond to review comments, fix batch matmul tests --- include/tvm/relay/transform.h | 2 +- python/tvm/relay/frontend/onnx.py | 127 +++++++++++++-------- python/tvm/topi/x86/batch_matmul.py | 4 + tests/python/frontend/onnx/test_forward.py | 56 +++++---- tests/python/relay/test_op_level10.py | 2 +- 5 files changed, 117 insertions(+), 74 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 493c3e027b12..faa2698fdcbc 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -213,7 +213,7 @@ TVM_DLL Pass FastMath(); * * Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces * them with static ops and re-performs type inference and constant folding. The pass repeats - * istself until the graph stops changing or we run too many iterations. + * itself until the graph stops changing or we run too many iterations. * * \return The pass. */ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a663da046d63..8639c93a9b91 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -28,17 +28,9 @@ from .. import op as _op from .. import vision as _vision -from ..function import Function -from ..expr import Call, Let -from ..expr import If, Tuple, TupleGetItem -from ..expr import RefCreate, RefRead, RefWrite -from ..expr_functor import ExprFunctor -from ..adt import Match, Clause -from ..op.tensor import minimum as _minimum, maximum as _maximum - from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels -from .common import infer_type, get_name, infer_value, infer_value_simulated +from .common import infer_type, get_name, infer_value_simulated __all__ = ['from_onnx'] @@ -642,26 +634,22 @@ def _impl_v2(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): - pad_width = [] - pads = infer_value_simulated(inputs[1], params).asnumpy() + pads = inputs[1] if len(inputs) == 3: - value = infer_value_simulated(inputs[2], params).asnumpy().item() + value = _op.take(inputs[2], _op.const(0)) else: value = 0 - attr["pad_value"] = value - dims = int(len(pads) / 2) - for i in range(dims): - pad_width.append((pads[i], pads[i + dims])) - attr['pad_width'] = pad_width + + pads_shape = infer_shape(pads) + dims = int(pads_shape[0] / 2) + pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims))) pad_mode = attr.get('mode', b'constant').decode('utf-8') - if pad_mode in ['constant', 'edge', 'reflect']: - attr['pad_mode'] = pad_mode - attr.pop('mode', None) - else: + + if not pad_mode in ['constant', 'edge', 'reflect']: raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') - return AttrCvt('pad')(inputs[:1], attr, params) + return _op.nn.pad(inputs[0], pad_width_expr, value, pad_mode=pad_mode) class ParametricSoftPlus(OnnxOpConverter): @@ -869,17 +857,24 @@ class Upsample(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): scales = attr.get('scales') + + input_shape = infer_shape(inputs[0]) + dims = len(input_shape) + if not scales: #Here we are going to higher OPSET version. - assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) + assert len(inputs) == 2, "Upsample op takes 2 inputs, {} given".format(len(inputs)) + if get_name(inputs[1]) in params: scales = params[inputs[1].name_hint].asnumpy() - else: + elif dims == 5: scales = infer_value_simulated(inputs[1], params).asnumpy() - inputs = inputs[:1] - assert scales[0] == 1.0 and scales[1] == 1.0 - input_shape = infer_shape(inputs[0]) - dims = len(input_shape) + else: + scales = inputs[1] + + if not isinstance(scales, _expr.Call): + assert scales[0] == 1.0 and scales[1] == 1.0 + mode = attr.get('mode') if mode == b'nearest': method = "nearest_neighbor" @@ -888,21 +883,41 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method} + + if method == 'nearest_neighbor': + align_corners = False + else: + align_corners = True + # in 3d case, we use the purely static op if dims == 5: - assert len(scales) == 5 - attr['scale_d'] = scales[-3] - attr['layout'] = 'NCDHW' - op_name = 'upsampling3d' + scale_h = scales[-2] + scale_w = scales[-1] + scale_d = scales[-3] + layout = 'NCDHW' + out = _op.nn.upsampling3d(inputs[0], + scale_d, + scale_h, + scale_w, + layout=layout, + method=method) + # in 2d case, use dynamic op else: - assert len(scales) == 4 - attr['layout'] = 'NCHW' - if method == 'nearest_neighbor': - attr['align_corners'] = False + if isinstance(scales, _expr.Call): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) else: - attr['align_corners'] = True - op_name = 'upsampling' - return AttrCvt(op_name)(inputs, attr) + assert len(scales) == 4 + scale_h = scales[-2] + scale_w = scales[-1] + layout = 'NCHW' + + out = _op.nn.upsampling(inputs[0], + scale_h, + scale_w, + layout=layout, + method=method, + align_corners=align_corners) + return out class Shape(OnnxOpConverter): @@ -1422,7 +1437,8 @@ class Expand(OnnxOpConverter): """ @classmethod def _impl_v8(cls, inputs, attr, params): - in_shape = _op.shape_of(inputs[0]) + dtype = infer_type(inputs[1]).checked_type.dtype + in_shape = _op.shape_of(inputs[0], dtype=dtype) shape = inputs[1] # Currently 'op.broadcast_to' expect the rank of the given 'shape' @@ -1441,14 +1457,11 @@ def expand_shape(in_shape, shape): in_dims = infer_shape(in_shape)[0] new_dims = infer_shape(shape)[0] if in_dims < new_dims: - in_shape = _op.concatenate([_expr.const([ - 1, - ] * (new_dims - in_dims)), in_shape], - axis=0) + in_shape = _op.concatenate([_expr.const([1, ] * (new_dims - in_dims), dtype=dtype), + in_shape], axis=0) elif new_dims > in_dims: - shape = _op.concatenate([_expr.const([ - 1, - ] * (in_dims - new_dims)), shape], axis=0) + shape = _op.concatenate([_expr.const([1, ] * (in_dims - new_dims), dtype=dtype), + shape], axis=0) new_shape = _op.maximum(in_shape, shape) return new_shape @@ -2058,6 +2071,13 @@ def from_onnx(self, graph, opset, freeze_params=False): opset : opset version + freeze_params: bool + If this parameter is true, the importer will take any provided + onnx input values (weights, shapes, etc) and embed them into the relay model + as Constants instead of variables. This allows more aggressive optimizations + at compile time and helps in making models static if certain inputs represent + attributes relay would traditionally consider compile-time constants. + Returns ------- mod : tvm.IRModule @@ -2156,12 +2176,12 @@ def from_onnx(self, graph, opset, freeze_params=False): ## Maintain the order of inputs and parametersfrom the ONNX graph, but only include ## those parameters that are needed to execute the relay graph free_vars = analysis.free_vars(outputs) - nodes = {v:k for k,v in self._nodes.items()} + nodes = {v: k for k, v in self._nodes.items()} free_vars = [nodes[var] for var in free_vars] for i_name in self._params: if i_name in free_vars and i_name not in self._inputs: self._inputs[i_name] = self._nodes[i_name] - func = _function.Function([v for k,v in self._inputs.items()], outputs) + func = _function.Function([v for k, v in self._inputs.items()], outputs) if freeze_params: func, params = self.freeze(func, self._params) return IRModule.from_expr(func), params @@ -2282,6 +2302,13 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals Override to autodetected opset. This can be helpful for some testing. + freeze_params: bool + If this parameter is true, the importer will take any provided + onnx input values (weights, shapes, etc) and embed them into the relay model + as Constants instead of variables. This allows more aggressive optimizations + at compile time and helps in making models static if certain inputs represent + attributes relay would traditionally consider compile-time constants. + Returns ------- mod : tvm.IRModule diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index a2dccb6d9489..9f6a8f289c6f 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -50,6 +50,10 @@ def batch_matmul(cfg, x, y, out_shape=None): assert XK == YK, "shapes of x and y is inconsistant" B = XB K = XK + if out_shape is not None: + assert out_shape[0] == B, "got invalid output shape" + assert out_shape[1] == M, "got invalid output shape" + assert out_shape[2] == N, "got invalid output shape" if cfg.is_fallback: _default_batch_matmul_config(cfg, M, N, K) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 12a49e280e29..d3074aa3bce3 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -48,7 +48,7 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freez """ Generic function to execute and get tvm output with vm executor""" if not isinstance(input_data, list): input_data = [input_data] - input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) + _, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset, freeze_params=freeze_params) @@ -167,15 +167,26 @@ def test_reshape(): # @tvm.testing.uses_gpu def test_expand(): - def _test_expand(name, data, shape, ref_data): + def _test_expand(name, data, shape, ref_data, dtype="int32"): shape_array = np.array(shape) - shape_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['shape'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT32, - dims = shape_array.shape, - vals = shape_array.flatten().astype('int32'))) + if dtype == "int32": + shape_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['shape'], + value=onnx.helper.make_tensor(name = 'const_tensor', + data_type = onnx.TensorProto.INT32, + dims = shape_array.shape, + vals = shape_array.flatten().astype('int32'))) + elif dtype == "int64": + shape_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['shape'], + value=onnx.helper.make_tensor(name = 'const_tensor', + data_type = onnx.TensorProto.INT64, + dims = shape_array.shape, + vals = shape_array.flatten().astype('int64'))) + else: + raise "Invalid dtype" expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) graph = helper.make_graph([shape_node, expand_node], @@ -196,13 +207,15 @@ def _test_expand(name, data, shape, ref_data): shape = (3, 4) data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = np.tile(data, 4) - _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data) + _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data, "int32") + _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data, "int64") in_shape = (3, 1) shape = (2, 1, 6) data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = data * np.ones(shape, dtype=np.float32) - _test_expand('expand_with_dim_changed_test', data, shape, ref_data) + _test_expand('expand_with_dim_changed_test', data, shape, ref_data, "int32") + _test_expand('expand_with_dim_changed_test', data, shape, ref_data, "int64") def verify_depth_to_space(inshape, outshape, mode, blockSize): @@ -822,8 +835,8 @@ def verify_batch_matmul(a_shape, b_shape): model, [a_array, b_array], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -# TODO(mbrookhart): enable once VM supports heterogenous execution -# @tvm.testing.uses_gpu +# TODO(mbrookhart): enable cuda once VM supports heterogenous execution +@tvm.testing.parametrize_targets("llvm") def test_batch_matmul(): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4)) verify_batch_matmul((2, 4, 3), (3, 4)) @@ -1024,11 +1037,9 @@ def _test_upsample_bilinear_opset9(): graph, producer_name='upsample_bilinear_opset9_test') for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output_with_vm(model, [in_array], target, ctx, opset=9, freeze_params=True) tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) - def _test_upsample3d_trilinear(): scale = 2 in_shape = (1, 1, 3, 3, 3) @@ -1062,7 +1073,8 @@ def _test_upsample3d_trilinear(): model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_upsample(): _test_upsample_nearest() _test_upsample_bilinear() @@ -1455,7 +1467,7 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))]) else: - inputs = [indata, pads, np.array([value])] + inputs = [indata, pads, np.array([value]).astype("float32")] outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) node = helper.make_node( @@ -1471,7 +1483,7 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0): helper.make_tensor_value_info("pads", TensorProto.INT64,(len(pads),)), helper.make_tensor_value_info("constant_value", - TensorProto.INT64,(1,)), + TensorProto.FLOAT,(1,)), ], initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value])], @@ -1480,12 +1492,12 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0): model = helper.make_model(graph, producer_name='pad_test') # tvm result for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, inputs, target, ctx, outdata.shape, 'float32', opset=11) + tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, freeze_params=False) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_pad(): verify_pad(np.random.randn(2, 2).astype( np.float32), [0, 1, 0, 0], 'constant', 0.0) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 5f62669b0505..d9cd4d8d2b3e 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -358,7 +358,7 @@ def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): y_np = np.random.uniform(size=y_shape).astype(dtype) z_np = tvm.topi.testing.batch_matmul(x_np, y_np) - for target, ctx in ctx_list(): + for target, ctx in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) From 80962b7e4c03ba189b6a623d2107c1a5d121e00b Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 4 Sep 2020 14:18:57 -0700 Subject: [PATCH 03/17] Change onnx importer to use dynamic upsampling3d (#3) fix pylint --- python/tvm/relay/frontend/onnx.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8639c93a9b91..297bc538023a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -867,8 +867,6 @@ def _impl_v9(cls, inputs, attr, params): if get_name(inputs[1]) in params: scales = params[inputs[1].name_hint].asnumpy() - elif dims == 5: - scales = infer_value_simulated(inputs[1], params).asnumpy() else: scales = inputs[1] @@ -890,9 +888,16 @@ def _impl_v9(cls, inputs, attr, params): align_corners = True # in 3d case, we use the purely static op if dims == 5: - scale_h = scales[-2] - scale_w = scales[-1] - scale_d = scales[-3] + if isinstance(scales, _expr.Call): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) + scale_d = _op.take(scales, _op.const(1)) + else: + assert len(scales) == 5 + scale_h = scales[-2] + scale_w = scales[-1] + scale_d = scales[-3] + layout = 'NCDHW' out = _op.nn.upsampling3d(inputs[0], scale_d, From 077f74c927480b8c24bffab4c90905fad177cf35 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 11 Sep 2020 11:39:33 -0600 Subject: [PATCH 04/17] black format --- python/tvm/relay/frontend/onnx.py | 1427 ++++---- python/tvm/relay/op/_tensor.py | 36 +- python/tvm/relay/op/nn/_nn.py | 228 +- python/tvm/relay/op/strategy/generic.py | 441 ++- python/tvm/relay/op/strategy/x86.py | 188 +- python/tvm/topi/cuda/batch_matmul.py | 30 +- python/tvm/topi/nn/batch_matmul.py | 11 +- python/tvm/topi/x86/batch_matmul.py | 17 +- tests/python/frontend/onnx/test_forward.py | 3230 +++++++++-------- .../relay/dyn/test_dynamic_op_level10.py | 29 +- tests/python/relay/test_op_level10.py | 89 +- 11 files changed, 3162 insertions(+), 2564 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 297bc538023a..9603fea670b5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -32,11 +32,12 @@ from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import infer_type, get_name, infer_value_simulated -__all__ = ['from_onnx'] +__all__ = ["from_onnx"] -class onnx_input(): +class onnx_input: """ Dual purpose list or dictionary access object.""" + def __init__(self): self.input_keys = [] self.input_dict = {} @@ -93,18 +94,19 @@ def get_numpy(tensor_proto): return to_array(tensor_proto) -def dimension_picker(prefix, suffix=''): +def dimension_picker(prefix, suffix=""): """Check that dimensions are supported.""" + def _impl(attr): - kernel = attr['kernel_shape'] + kernel = attr["kernel_shape"] if len(kernel) == 1: - return prefix + '1d' + suffix + return prefix + "1d" + suffix if len(kernel) == 2: - return prefix + '2d' + suffix + return prefix + "2d" + suffix if len(kernel) == 3: - return prefix + '3d' + suffix - msg = 'Only 1D, 2D, and 3D kernels are supported for operator {}.' - op_name = prefix + '1d/2d/3d' + return prefix + "3d" + suffix + msg = "Only 1D, 2D, and 3D kernels are supported for operator {}." + op_name = prefix + "1d/2d/3d" raise tvm.error.OpAttributeInvalid(msg.format(op_name)) return _impl @@ -117,7 +119,7 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise tvm.error.OpAttributeInvalid('Number of pads must be either 2 or 4.') + raise tvm.error.OpAttributeInvalid("Number of pads must be either 2 or 4.") return pads @@ -134,11 +136,11 @@ def get_pad_pair(input1d, kernel1d, stride1d): def onnx_default_layout(dims): if dims == 1: - return 'NCW' + return "NCW" if dims == 2: - return 'NCHW' + return "NCHW" if dims == 3: - return 'NCDHW' + return "NCDHW" msg = "Only 1D, 2D and 3D layouts are currently supported" raise tvm.error.OpAttributeInvalid(msg.format(op_name)) @@ -147,14 +149,14 @@ def onnx_default_layout(dims): def onnx_storage_order2layout(storage_order, dims=2): """converter of onnx storage order parameter to tvm storage order format""" if storage_order not in (0, 1): - raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1') + raise tvm.error.OpAttributeInvalid("Mode of storage_order must be either 0 or 1") if dims == 1: - return 'NCW' if storage_order == 0 else 'NWC' + return "NCW" if storage_order == 0 else "NWC" if dims == 2: - return 'NCHW' if storage_order == 0 else 'NHWC' + return "NCHW" if storage_order == 0 else "NHWC" if dims == 3: - return 'NCDHW' if storage_order == 0 else 'NDHWC' + return "NCDHW" if storage_order == 0 else "NDHWC" msg = "Only 1D, 2D and 3D layouts are currently supported" raise tvm.error.OpAttributeInvalid(msg.format(op_name)) @@ -162,7 +164,7 @@ def onnx_storage_order2layout(storage_order, dims=2): def dimension_constraint(): def _dim_check(attrs): - if len(attrs['kernel_shape']) in [1, 2, 3]: + if len(attrs["kernel_shape"]) in [1, 2, 3]: return True return False @@ -170,11 +172,11 @@ def _dim_check(attrs): class OnnxOpConverter(object): - """ A helper class for holding onnx op converters. - """ + """A helper class for holding onnx op converters.""" + @classmethod def get_converter(cls, opset): - """ Get converter matches given opset. + """Get converter matches given opset. Parameters ---------- @@ -186,176 +188,180 @@ def get_converter(cls, opset): converter, which should be `_impl_vx`. Number x is the biggest number smaller than or equal to opset belongs to all support versions. """ - versions = [int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d] + versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d] versions = sorted(versions + [opset]) version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1] - if hasattr(cls, '_impl_v{}'.format(version)): - return getattr(cls, '_impl_v{}'.format(version)) - raise NotImplementedError('opset version {} of {} not implemented'.format( - version, cls.__name__)) + if hasattr(cls, "_impl_v{}".format(version)): + return getattr(cls, "_impl_v{}".format(version)) + raise NotImplementedError( + "opset version {} of {} not implemented".format(version, cls.__name__) + ) class Unary(OnnxOpConverter): - """ A helper class for unary op converters. - """ - name = '' + """A helper class for unary op converters.""" + + name = "" @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format( - cls.name, len(inputs)) + cls.name, len(inputs) + ) op_name = cls.name return get_relay_op(op_name)(*inputs) class Elemwise(OnnxOpConverter): - """ A helper class for elemwise op converters. - """ - name = '' + """A helper class for elemwise op converters.""" + + name = "" @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) op_name = cls.name conv_ops = ["conv2d", "conv2d_transpose"] - if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops): + if attr.get("broadcast", 0) and any(x in str(inputs[0]) for x in conv_ops): # TODO(zhreshold): remove hard coded infershape - axis = int(attr.get('axis', 0)) + axis = int(attr.get("axis", 0)) inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2) return get_relay_op(op_name)(*inputs) class Pool(OnnxOpConverter): - """ A helper class for pool op converters. - """ - name = '' + """A helper class for pool op converters.""" + + name = "" @classmethod def _impl_v1(cls, inputs, attr, params): input_shape = infer_shape(inputs[0]) - if 'auto_pad' in attr: - attr['auto_pad'] = attr['auto_pad'].decode('utf-8') - if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): pad_tuple = [] for axis in range(len(input_shape) - 2): axis_shape = input_shape[2 + axis] - stride = attr['strides'][axis] - kernel = attr['kernel_shape'][axis] + stride = attr["strides"][axis] + kernel = attr["kernel_shape"][axis] pad = get_pad_pair(axis_shape, kernel, stride) pad_tuple.append(pad) pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr['pads'] = pad_tuple - elif attr['auto_pad'] == 'VALID': - attr['pads'] = 0 - elif attr['auto_pad'] == 'NOTSET': + attr["pads"] = pad_tuple + elif attr["auto_pad"] == "VALID": + attr["pads"] = 0 + elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.' - raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], cls.name)) + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"], cls.name)) attr.pop("auto_pad") - if 'storage_order' in attr: - attr['layout'] = onnx_storage_order2layout(attr['storage_order'], - dims=(len(input_shape) - 2)) + if "storage_order" in attr: + attr["layout"] = onnx_storage_order2layout( + attr["storage_order"], dims=(len(input_shape) - 2) + ) else: - attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2)) + attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2)) - return AttrCvt(op_name=dimension_picker(cls.name), - transforms={ - 'kernel_shape': 'pool_size', - 'pads': ('padding', 0) - }, - ignores=['dilations', 'storage_order'], - custom_check=dimension_constraint())(inputs, attr, params) + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)}, + ignores=["dilations", "storage_order"], + custom_check=dimension_constraint(), + )(inputs, attr, params) class Absolute(Unary): - """ Operator converter for Absolute. - """ - name = 'abs' + """Operator converter for Absolute.""" + + name = "abs" class Add(Elemwise): - """ Operator converter for Add. - """ - name = 'add' + """Operator converter for Add.""" + + name = "add" class AveragePool(Pool): - """ Operator converter for AveragePool. - """ - name = 'avg_pool' + """Operator converter for AveragePool.""" + + name = "avg_pool" class BatchNorm(OnnxOpConverter): - """ Operator converter for BatchNorm. - """ + """Operator converter for BatchNorm.""" + @classmethod def _impl_v1(cls, inputs, attr, params): # TODO(zhreshold): 'spatial' is not properly handled here. - out = AttrCvt(op_name='batch_norm', - ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr, - params) + out = AttrCvt( + op_name="batch_norm", ignores=["spatial", "is_test", "consumed_inputs", "momentum"] + )(inputs, attr, params) return out[0] class InstanceNorm(OnnxOpConverter): - """ Operator converter for BatchNorm. - """ + """Operator converter for BatchNorm.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - return AttrCvt(op_name='instance_norm')(inputs, attr, params) + return AttrCvt(op_name="instance_norm")(inputs, attr, params) class Conv(OnnxOpConverter): - """ Operator converter for Conv. - """ + """Operator converter for Conv.""" + @classmethod def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. input_shape = infer_shape(inputs[0]) - if 'auto_pad' in attr: - attr['auto_pad'] = attr['auto_pad'].decode('utf-8') - if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): pad_tuple = [] for axis in range(len(input_shape) - 2): axis_shape = input_shape[2 + axis] - stride = attr['strides'][axis] - kernel = attr['kernel_shape'][axis] - dilation = attr['dilations'][axis] + stride = attr["strides"][axis] + kernel = attr["kernel_shape"][axis] + dilation = attr["dilations"][axis] dilated_kernel = (kernel - 1) * dilation + 1 pad = get_pad_pair(axis_shape, dilated_kernel, stride) pad_tuple.append(pad) pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr['pads'] = pad_tuple - elif attr['auto_pad'] == 'VALID': - attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)]) - elif attr['auto_pad'] == 'NOTSET': + attr["pads"] = pad_tuple + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(len(input_shape) - 2)]) + elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' - raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) - attr.pop('auto_pad') - elif len(attr['kernel_shape']) == 2: + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + elif len(attr["kernel_shape"]) == 2: sym_pad = True - if 'pads' in attr: - padding = attr['pads'] + if "pads" in attr: + padding = attr["pads"] else: padding = [0, 0, 0, 0] for i in range(0, len(padding), 2): sym_pad = sym_pad and padding[i] == padding[i + 1] if sym_pad: - attr['pads'] = padding[0::2] + attr["pads"] = padding[0::2] - out = AttrCvt(op_name=dimension_picker('conv'), - transforms={ - 'kernel_shape': 'kernel_size', - 'dilations': ('dilation', 1), - 'pads': ('padding', 0), - 'group': ('groups', 1) - }, - custom_check=dimension_constraint())(inputs[:2], attr, params) + out = AttrCvt( + op_name=dimension_picker("conv"), + transforms={ + "kernel_shape": "kernel_size", + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), + }, + custom_check=dimension_constraint(), + )(inputs[:2], attr, params) use_bias = len(inputs) == 3 if use_bias: @@ -364,46 +370,48 @@ def _impl_v1(cls, inputs, attr, params): class ConvTranspose(OnnxOpConverter): - """ Operator converter for ConvTranspose. - """ + """Operator converter for ConvTranspose.""" + @classmethod def _impl_v1(cls, inputs, attr, params): # get number of channels channels = infer_channels(inputs[1], True) - attr['channels'] = channels - groups = attr.pop('group') - attr['groups'] = groups + attr["channels"] = channels + groups = attr.pop("group") + attr["groups"] = groups # infer pads for auto_pad - if 'auto_pad' in attr: - attr['auto_pad'] = attr['auto_pad'].decode('utf-8') - if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): input_shape = infer_shape(inputs[0]) in_h, in_w = input_shape[2], input_shape[3] - stride_h, stride_w = attr['strides'] - kernel_h, kernel_w = attr['kernel_shape'] - dilation_h, dilation_w = attr['dilations'] + stride_h, stride_w = attr["strides"] + kernel_h, kernel_w = attr["kernel_shape"] + dilation_h, dilation_w = attr["dilations"] dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) - attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) - elif attr['auto_pad'] == 'VALID': - attr['pads'] = (0, 0) - elif attr['auto_pad'] == 'NOTSET': + attr["pads"] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + elif attr["auto_pad"] == "VALID": + attr["pads"] = (0, 0) + elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' - raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) - attr.pop('auto_pad') - - out = AttrCvt(op_name=dimension_picker('conv', '_transpose'), - transforms={ - 'kernel_shape': 'kernel_size', - 'dilations': ('dilation', (0, 0)), - 'pads': ('padding', (0, 0), revert_caffe2_pad) - }, - disables=['output_shape'], - custom_check=dimension_constraint())(inputs[:2], attr, params) + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out = AttrCvt( + op_name=dimension_picker("conv", "_transpose"), + transforms={ + "kernel_shape": "kernel_size", + "dilations": ("dilation", (0, 0)), + "pads": ("padding", (0, 0), revert_caffe2_pad), + }, + disables=["output_shape"], + custom_check=dimension_constraint(), + )(inputs[:2], attr, params) use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) @@ -411,32 +419,33 @@ def _impl_v1(cls, inputs, attr, params): class Div(Elemwise): - """ Operator converter for Divide. - """ - name = 'divide' + """Operator converter for Divide.""" + + name = "divide" class Elu(OnnxOpConverter): - """ Operator converter for Elu. - """ + """Operator converter for Elu.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = float(attr.get('alpha', 1.0)) - return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \ - _op.nn.relu(inputs[0]) + alpha = float(attr.get("alpha", 1.0)) + return _expr.const(-alpha) * _op.nn.relu( + _expr.const(1.0) - _op.exp(inputs[0]) + ) + _op.nn.relu(inputs[0]) class Gemm(OnnxOpConverter): - """ Operator converter for Gemm. - """ + """Operator converter for Gemm.""" + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs)) # Y = alpha * A * B + beta * C - alpha = float(attr.get('alpha', 1.0)) - beta = float(attr.get('beta', 1.0)) - transA = int(attr.get('transA', 0)) - transB = int(attr.get('transB', 0)) + alpha = float(attr.get("alpha", 1.0)) + beta = float(attr.get("beta", 1.0)) + transA = int(attr.get("transA", 0)) + transB = int(attr.get("transB", 0)) # get number of channels channels = infer_channels(inputs[1], not transB) if transA: @@ -457,8 +466,8 @@ def _impl_v1(cls, inputs, attr, params): class MatMul(OnnxOpConverter): - """ Operator converter for MatMul. - """ + """Operator converter for MatMul.""" + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) @@ -471,8 +480,8 @@ def _impl_v1(cls, inputs, attr, params): def flatten_to_3d(x, x_shape): ndims = infer_shape(x_shape)[0] newshape = _op.concatenate( - [_expr.const([-1]), - _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0) + [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0 + ) out = _op.reshape(x, newshape) return out @@ -480,21 +489,28 @@ def flatten_to_3d(x, x_shape): a = flatten_to_3d(inputs[0], a_shape) b = flatten_to_3d(inputs[1], b_shape) # Broadcast b to match batch size of a - new_b_shape = _op.concatenate([ - _op.strided_slice(_op.shape_of(a), [0], [1]), - _op.strided_slice(_op.shape_of(b), [1], [3]) - ], 0) + new_b_shape = _op.concatenate( + [ + _op.strided_slice(_op.shape_of(a), [0], [1]), + _op.strided_slice(_op.shape_of(b), [1], [3]), + ], + 0, + ) b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) # Reshape output to original dimensions. - final_shape = _op.concatenate([ - _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), - _op.strided_slice(b_shape, [infer_shape(b_shape)[0] - 1], - [infer_shape(b_shape)[0]]) - ], 0) + final_shape = _op.concatenate( + [ + _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), + _op.strided_slice( + b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] + ), + ], + 0, + ) return _op.reshape(output, final_shape) # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) @@ -502,8 +518,8 @@ def flatten_to_3d(x, x_shape): class Mod(OnnxOpConverter): - """ Operator converter for Mod. - """ + """Operator converter for Mod.""" + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs)) @@ -511,7 +527,7 @@ def _impl_v1(cls, inputs, attr, params): # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod. # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment. # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod - if attr['fmod'] == 0: + if attr["fmod"] == 0: op_name = "floor_mod" else: op_name = "mod" @@ -520,115 +536,117 @@ def _impl_v1(cls, inputs, attr, params): class MaxPool(Pool): - """ Operator converter for MaxPool - """ - name = 'max_pool' + """Operator converter for MaxPool""" + + name = "max_pool" class LpPool(OnnxOpConverter): - """ A helper class for lppool op converters. - """ + """A helper class for lppool op converters.""" + @classmethod def _impl_v1(cls, inputs, attr, params): input_shape = infer_shape(inputs[0]) dtype = infer_type(inputs[0]).checked_type.dtype - if 'auto_pad' in attr: - attr['auto_pad'] = attr['auto_pad'].decode('utf-8') - if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): pad_tuple = [] for axis in range(len(input_shape) - 2): axis_shape = input_shape[2 + axis] - stride = attr['strides'][axis] - kernel = attr['kernel_shape'][axis] + stride = attr["strides"][axis] + kernel = attr["kernel_shape"][axis] pad = get_pad_pair(axis_shape, kernel, stride) pad_tuple.append(pad) pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr['pads'] = pad_tuple - elif attr['auto_pad'] == 'VALID': - attr['pads'] = 0 - elif attr['auto_pad'] == 'NOTSET': + attr["pads"] = pad_tuple + elif attr["auto_pad"] == "VALID": + attr["pads"] = 0 + elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.' - raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], "LpPool")) + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"], "LpPool")) attr.pop("auto_pad") - if 'storage_order' in attr: - attr['layout'] = onnx_storage_order2layout(attr['storage_order'], - dims=(len(input_shape) - 2)) + if "storage_order" in attr: + attr["layout"] = onnx_storage_order2layout( + attr["storage_order"], dims=(len(input_shape) - 2) + ) else: - attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2)) + attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2)) - p = _expr.const(attr['p'], dtype) - reci_p = _expr.const(1.0 / attr['p'], dtype) + p = _expr.const(attr["p"], dtype) + reci_p = _expr.const(1.0 / attr["p"], dtype) inputs[0] = _op.power(inputs[0], p) - out = AttrCvt(op_name=dimension_picker("avg_pool"), - transforms={ - 'kernel_shape': 'pool_size', - 'pads': ('padding', 0) - }, - extras={'count_include_pad': True}, - ignores=['p'], - custom_check=dimension_constraint())(inputs, attr, params) - kernels = attr['kernel_shape'] + out = AttrCvt( + op_name=dimension_picker("avg_pool"), + transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)}, + extras={"count_include_pad": True}, + ignores=["p"], + custom_check=dimension_constraint(), + )(inputs, attr, params) + kernels = attr["kernel_shape"] out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype)) return _op.power(out, reci_p) class Mul(Elemwise): - """ Operator converter for Multiply. - """ - name = 'multiply' + """Operator converter for Multiply.""" + + name = "multiply" class Pad(OnnxOpConverter): - """ Operator converter for Pad. - """ + """Operator converter for Pad.""" + @classmethod def _impl_v1(cls, inputs, attr, params): pad_width = [] - pads = attr.pop('paddings') + pads = attr.pop("paddings") dims = int(len(pads) / 2) for i in range(dims): pad_width.append((pads[i], pads[i + dims])) - attr['pad_width'] = pad_width - pad_mode = attr.get('mode', b'constant').decode('utf-8') - if pad_mode in ['constant', 'edge', 'reflect']: - attr['pad_mode'] = pad_mode - attr.pop('mode', None) + attr["pad_width"] = pad_width + pad_mode = attr.get("mode", b"constant").decode("utf-8") + if pad_mode in ["constant", "edge", "reflect"]: + attr["pad_mode"] = pad_mode + attr.pop("mode", None) else: - raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + - ' in attribute "mode" is invalid for operator Pad.') + raise tvm.error.OpAttributeInvalid( + "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' + ) return AttrCvt( _op.nn.pad, transforms={ - 'value': 'pad_value', + "value": "pad_value", }, )(inputs, attr, params) @classmethod def _impl_v2(cls, inputs, attr, params): pad_width = [] - pads = attr.pop('pads') + pads = attr.pop("pads") dims = int(len(pads) / 2) for i in range(dims): pad_width.append((pads[i], pads[i + dims])) - attr['pad_width'] = pad_width - pad_mode = attr.get('mode', b'constant').decode('utf-8') - if pad_mode in ['constant', 'edge', 'reflect']: - attr['pad_mode'] = pad_mode - attr.pop('mode', None) + attr["pad_width"] = pad_width + pad_mode = attr.get("mode", b"constant").decode("utf-8") + if pad_mode in ["constant", "edge", "reflect"]: + attr["pad_mode"] = pad_mode + attr.pop("mode", None) else: - raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + - ' in attribute "mode" is invalid for operator Pad.') + raise tvm.error.OpAttributeInvalid( + "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' + ) return AttrCvt( - 'pad', + "pad", transforms={ - 'value': 'pad_value', + "value": "pad_value", }, )(inputs, attr, params) @@ -643,53 +661,54 @@ def _impl_v11(cls, inputs, attr, params): pads_shape = infer_shape(pads) dims = int(pads_shape[0] / 2) pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims))) - pad_mode = attr.get('mode', b'constant').decode('utf-8') + pad_mode = attr.get("mode", b"constant").decode("utf-8") - if not pad_mode in ['constant', 'edge', 'reflect']: - raise tvm.error.OpAttributeInvalid('Value ' + pad_mode + - ' in attribute "mode" is invalid for operator Pad.') + if not pad_mode in ["constant", "edge", "reflect"]: + raise tvm.error.OpAttributeInvalid( + "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' + ) return _op.nn.pad(inputs[0], pad_width_expr, value, pad_mode=pad_mode) class ParametricSoftPlus(OnnxOpConverter): - """ Operator converter for ParametricSoftPlus. - """ + """Operator converter for ParametricSoftPlus.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = _expr.const(float(attr.get('alpha', 1.0))) - beta = _expr.const(float(attr.get('beta', 1.0))) - return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha + alpha = _expr.const(float(attr.get("alpha", 1.0))) + beta = _expr.const(float(attr.get("beta", 1.0))) + return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha class Prelu(OnnxOpConverter): - """ Operator converter for Prelu. - """ + """Operator converter for Prelu.""" + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) alpha_shape = infer_shape(inputs[1]) if len(alpha_shape) != 1: - alpha = _op.reshape(inputs[1], (-1, )) + alpha = _op.reshape(inputs[1], (-1,)) else: alpha = inputs[1] return _op.nn.prelu(inputs[0], alpha) class Reciprocal(OnnxOpConverter): - """ Operator converter for Reciprocal. - """ + """Operator converter for Reciprocal.""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _expr.const(1.0) / inputs[0] class Flatten(OnnxOpConverter): - """ Operator converter for Flatten. - """ + """Operator converter for Flatten.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - axis = attr.get('axis', 1) + axis = attr.get("axis", 1) if axis == 1: out = _op.nn.batch_flatten(inputs[0]) else: @@ -700,11 +719,11 @@ def _impl_v1(cls, inputs, attr, params): class Reshape(OnnxOpConverter): - """ Operator converter for Reshape. - """ + """Operator converter for Reshape.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - return _op.reshape(inputs[0], attr['shape']) + return _op.reshape(inputs[0], attr["shape"]) @classmethod def _impl_v5(cls, inputs, attr, params): @@ -718,90 +737,91 @@ def _impl_v5(cls, inputs, attr, params): class DepthToSpace(OnnxOpConverter): - """ Operator converter for DepthToSpace. - """ + """Operator converter for DepthToSpace.""" + @classmethod def _impl_v11(cls, inputs, attr, params): - block_size = int(attr['blocksize']) - mode = attr.get('mode', b'DCR').decode('utf-8') + block_size = int(attr["blocksize"]) + mode = attr.get("mode", b"DCR").decode("utf-8") return _op.nn.depth_to_space(inputs[0], block_size, mode=mode) class SpaceToDepth(OnnxOpConverter): - """ Operator converter for SpaceToDepth. - """ + """Operator converter for SpaceToDepth.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - block_size = int(attr['blocksize']) + block_size = int(attr["blocksize"]) return _op.nn.space_to_depth(inputs[0], block_size) class Concat(OnnxOpConverter): - """ Operator converter for Concat. - """ + """Operator converter for Concat.""" + @classmethod def _impl_v1(cls, inputs, args, params): - return AttrCvt(op_name='concatenate')((inputs, ), args) + return AttrCvt(op_name="concatenate")((inputs,), args) class Scale(OnnxOpConverter): - """ Operator converter for Scale. - """ + """Operator converter for Scale.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - scale = float(attr.get('scale', 1.0)) + scale = float(attr.get("scale", 1.0)) return inputs[0] * _expr.const(scale) class Selu(OnnxOpConverter): - """ Operator converter for Selu. - """ + """Operator converter for Selu.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = float(attr.get('alpha', 1.6732)) - gamma = float(attr.get('gamma', 1.0507)) + alpha = float(attr.get("alpha", 1.6732)) + gamma = float(attr.get("gamma", 1.0507)) return _expr.const(gamma) * ( - _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + - _op.nn.relu(inputs[0])) + _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0]) + ) class ScaledTanh(OnnxOpConverter): - """ Operator converter for ScaledTanh. - """ + """Operator converter for ScaledTanh.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = float(attr.get('alpha', 1.0)) - beta = float(attr.get('beta', 1.0)) + alpha = float(attr.get("alpha", 1.0)) + beta = float(attr.get("beta", 1.0)) return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha) class SoftPlus(OnnxOpConverter): - """ Operator converter for SoftPlus. - """ + """Operator converter for SoftPlus.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - return _op.log(_op.exp(inputs[0]) + _expr.const(1.)) + return _op.log(_op.exp(inputs[0]) + _expr.const(1.0)) class Softsign(OnnxOpConverter): - """ Operator converter for Softsign. - """ + """Operator converter for Softsign.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params)) + return inputs[0] / (_expr.const(1.0) + Absolute.get_converter(1)(inputs, attr, params)) class Sub(Elemwise): - """ Operator converter for Subtract. - """ - name = 'subtract' + """Operator converter for Subtract.""" + + name = "subtract" class Sum(OnnxOpConverter): - """ Operator converter for Sum. - """ + """Operator converter for Sum.""" + @classmethod def _impl_v1(cls, inputs, attr, params): # Onnx Sum Operator @@ -812,21 +832,21 @@ def _impl_v1(cls, inputs, attr, params): class Affine(OnnxOpConverter): - """ Operator converter for Affine transformation. - """ + """Operator converter for Affine transformation.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = _expr.const(attr.get('alpha', 1.0)) - beta = _expr.const(attr.get('beta', 0.0)) + alpha = _expr.const(attr.get("alpha", 1.0)) + beta = _expr.const(attr.get("beta", 0.0)) return (alpha * inputs[0]) + beta class ThresholdedRelu(OnnxOpConverter): - """ Operator converter for ThresholdedRelu. - """ + """Operator converter for ThresholdedRelu.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = float(attr.get('alpha', 1.0)) + alpha = float(attr.get("alpha", 1.0)) alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) mask = _op.greater(inputs[0], alpha_tensor).astype("float32") return inputs[0] * mask @@ -834,7 +854,7 @@ def _impl_v1(cls, inputs, attr, params): def _broadcast_constraint(): def _broadcast_check(attrs): - if attrs.get('axis', None): + if attrs.get("axis", None): return False return True @@ -845,24 +865,24 @@ def _fully_connected(opset): def _impl(inputs, attr, params): # get number of channels channels = infer_channels(inputs[1], params) - attr['units'] = channels - return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr) + attr["units"] = channels + return AttrCvt("dense", ignores=["axis", "axis_w"])(inputs, attr) return _impl class Upsample(OnnxOpConverter): - """ Operator converter for Upsample (nearest mode). - """ + """Operator converter for Upsample (nearest mode).""" + @classmethod def _impl_v9(cls, inputs, attr, params): - scales = attr.get('scales') + scales = attr.get("scales") input_shape = infer_shape(inputs[0]) dims = len(input_shape) if not scales: - #Here we are going to higher OPSET version. + # Here we are going to higher OPSET version. assert len(inputs) == 2, "Upsample op takes 2 inputs, {} given".format(len(inputs)) if get_name(inputs[1]) in params: @@ -873,16 +893,17 @@ def _impl_v9(cls, inputs, attr, params): if not isinstance(scales, _expr.Call): assert scales[0] == 1.0 and scales[1] == 1.0 - mode = attr.get('mode') - if mode == b'nearest': + mode = attr.get("mode") + if mode == b"nearest": method = "nearest_neighbor" - elif mode == b'linear': + elif mode == b"linear": method = "trilinear" if dims == 5 else "bilinear" else: raise tvm.error.OpAttributeInvalid( - 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) + 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode) + ) - if method == 'nearest_neighbor': + if method == "nearest_neighbor": align_corners = False else: align_corners = True @@ -898,13 +919,10 @@ def _impl_v9(cls, inputs, attr, params): scale_w = scales[-1] scale_d = scales[-3] - layout = 'NCDHW' - out = _op.nn.upsampling3d(inputs[0], - scale_d, - scale_h, - scale_w, - layout=layout, - method=method) + layout = "NCDHW" + out = _op.nn.upsampling3d( + inputs[0], scale_d, scale_h, scale_w, layout=layout, method=method + ) # in 2d case, use dynamic op else: if isinstance(scales, _expr.Call): @@ -914,73 +932,76 @@ def _impl_v9(cls, inputs, attr, params): assert len(scales) == 4 scale_h = scales[-2] scale_w = scales[-1] - layout = 'NCHW' - - out = _op.nn.upsampling(inputs[0], - scale_h, - scale_w, - layout=layout, - method=method, - align_corners=align_corners) + layout = "NCHW" + + out = _op.nn.upsampling( + inputs[0], + scale_h, + scale_w, + layout=layout, + method=method, + align_corners=align_corners, + ) return out class Shape(OnnxOpConverter): - """ Operator converter for Shape. - """ + """Operator converter for Shape.""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _op.shape_of(inputs[0], "int64") class Cast(OnnxOpConverter): - """ Operator converter for Cast. - """ + """Operator converter for Cast.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) + return AttrCvt(op_name="cast", transforms={"to": "dtype"})(inputs, attr) @classmethod def _impl_v5(cls, inputs, attr, params): try: from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE - attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']]) + + attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]]) except ImportError as e: raise ImportError("Unable to import onnx.mapping which is required {}".format(e)) - return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) + return AttrCvt(op_name="cast", transforms={"to": "dtype"})(inputs, attr) class Unsqueeze(OnnxOpConverter): - """ Operator converter for Unsqueeze. - """ + """Operator converter for Unsqueeze.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - for axes in attr['axes']: + for axes in attr["axes"]: inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1) return inputs[0] class Split(OnnxOpConverter): - """ Operator converter for Split. - """ + """Operator converter for Split.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - splits = attr.get('split', False) + splits = attr.get("split", False) if splits: - attr['indices_or_sections'] = [] + attr["indices_or_sections"] = [] index = 0 for i in splits[:-1]: index += i - attr['indices_or_sections'].append(index) + attr["indices_or_sections"].append(index) # When splits isnt specified divide evenly over axis. else: - attr['indices_or_sections'] = attr['tvm_custom']['num_outputs'] - return AttrCvt('split', ignores=['split'])(inputs, attr, params) + attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"] + return AttrCvt("split", ignores=["split"])(inputs, attr, params) class Slice(OnnxOpConverter): - """ Operator converter for Slice. - """ + """Operator converter for Slice.""" + @classmethod def _common(cls, starts, ends, axes): new_axes = [] @@ -1001,141 +1022,138 @@ def _common(cls, starts, ends, axes): @classmethod def _impl_v1(cls, inputs, attr, params): - if isinstance(attr['starts'], int): - attr['starts'] = (attr['starts'], ) - attr['ends'] = (attr['ends'], ) + if isinstance(attr["starts"], int): + attr["starts"] = (attr["starts"],) + attr["ends"] = (attr["ends"],) try: # Update the starts and ends according to axes if required. - if isinstance(attr['axes'], int): - attr['axes'] = (attr['axes'], ) - if (max(attr['axes']) + 1) != len(attr['axes']): - new_starts, new_ends, new_axes = cls._common(attr['starts'], attr['ends'], - attr['axes']) - attr['axes'] = new_axes - attr['starts'] = new_starts - attr['ends'] = new_ends + if isinstance(attr["axes"], int): + attr["axes"] = (attr["axes"],) + if (max(attr["axes"]) + 1) != len(attr["axes"]): + new_starts, new_ends, new_axes = cls._common( + attr["starts"], attr["ends"], attr["axes"] + ) + attr["axes"] = new_axes + attr["starts"] = new_starts + attr["ends"] = new_ends except KeyError: pass - begin = list(attr['starts']) - end = list(attr['ends']) + begin = list(attr["starts"]) + end = list(attr["ends"]) - return _op.strided_slice(inputs[0], - begin=begin, - end=end) + return _op.strided_slice(inputs[0], begin=begin, end=end) @classmethod def _impl_v10(cls, inputs, attr, params): - attrs = {'starts': inputs[1], 'ends': inputs[2]} + attrs = {"starts": inputs[1], "ends": inputs[2]} if len(inputs) >= 4: - attrs['axes'] = inputs[3] + attrs["axes"] = inputs[3] attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()} attrs = { k: params[v[1]].asnumpy() - if v[1] in params else infer_value_simulated(v[0], params).asnumpy() + if v[1] in params + else infer_value_simulated(v[0], params).asnumpy() for (k, v) in attrs.items() } # Update the starts and ends according to axes if required. - if 'axes' in attrs: - if max(attrs['axes'] + 1) != len(attrs['axes']): - new_starts, new_ends, _ = cls._common(attrs['starts'], attrs['ends'], - attrs['axes']) - attrs['starts'] = new_starts - attrs['ends'] = new_ends - return _op.strided_slice(inputs[0], - begin=list(attrs['starts']), - end=list(attrs['ends'])) + if "axes" in attrs: + if max(attrs["axes"] + 1) != len(attrs["axes"]): + new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"]) + attrs["starts"] = new_starts + attrs["ends"] = new_ends + return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"])) class Gather(OnnxOpConverter): - """ Operator converter for Gather. - """ + """Operator converter for Gather.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - axis = attr.get('axis', 0) - return AttrCvt('take', extras={'axis': axis})(inputs, {}) + axis = attr.get("axis", 0) + return AttrCvt("take", extras={"axis": axis})(inputs, {}) class GatherND(OnnxOpConverter): - """ Operator converter for GatherND. - """ + """Operator converter for GatherND.""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _op.gather_nd(inputs[0], inputs[1]) class Scatter(OnnxOpConverter): - """ Operator converter for Scatter. - """ + """Operator converter for Scatter.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - axis = attr.get('axis', 0) + axis = attr.get("axis", 0) return _op.scatter(inputs[0], inputs[1], inputs[2], axis) class Greater(OnnxOpConverter): - """ Operator logical greater. - """ + """Operator logical greater.""" + @classmethod def _impl_v7(cls, inputs, attr, params): return _op.greater(inputs[0], inputs[1]) class Less(OnnxOpConverter): - """ Operator logical less than. - """ + """Operator logical less than.""" + @classmethod def _impl_v7(cls, inputs, attr, params): return _op.less(inputs[0], inputs[1]) class LRN(OnnxOpConverter): - """ Operator converter for Local Response Normalization. - """ + """Operator converter for Local Response Normalization.""" + @classmethod def _impl_v1(cls, inputs, attr, params): """LRN support only NCHW format https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN """ axis = 1 - alpha = attr.get('alpha', 0.0001) - beta = attr.get('beta', 0.75) - bias = attr.get('bias', 1.0) - nsize = attr.get('size') - attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias} - return AttrCvt('lrn')(inputs, attr) + alpha = attr.get("alpha", 0.0001) + beta = attr.get("beta", 0.75) + bias = attr.get("bias", 1.0) + nsize = attr.get("size") + attr = {"size": nsize, "axis": axis, "alpha": alpha, "beta": beta, "bias": bias} + return AttrCvt("lrn")(inputs, attr) class Maximum(OnnxOpConverter): - """ Operator converter for Maximum. - """ + """Operator converter for Maximum.""" + @classmethod def _impl_v1(cls, inputs, attr, params): if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: raise ValueError("Expect minimum 2 inputs") _max = inputs[0] for i in range(1, len(inputs)): - _max = AttrCvt('maximum')([_max, inputs[i]], {}) + _max = AttrCvt("maximum")([_max, inputs[i]], {}) return _max class Minimum(OnnxOpConverter): - """ Operator converter for Minimum. - """ + """Operator converter for Minimum.""" + @classmethod def _impl_v1(cls, inputs, attr, params): if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: raise ValueError("Expect minimum 2 inputs") _min = inputs[0] for i in range(1, len(inputs)): - _min = AttrCvt('minimum')([_min, inputs[i]], {}) + _min = AttrCvt("minimum")([_min, inputs[i]], {}) return _min class Mean(OnnxOpConverter): - """ Operator converter for Mean. - """ + """Operator converter for Mean.""" + @classmethod def _impl_v1(cls, inputs, attr, params): if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: @@ -1146,112 +1164,112 @@ def _impl_v1(cls, inputs, attr, params): class HardSigmoid(OnnxOpConverter): - """ Operator converter for HardSigmoid. - """ + """Operator converter for HardSigmoid.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = attr.get('alpha', 0.2) - beta = attr.get('beta', 0.5) + alpha = attr.get("alpha", 0.2) + beta = attr.get("beta", 0.5) transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta) - attr = {'a_min': 0, 'a_max': 1} - return AttrCvt('clip')([transformX], attr) + attr = {"a_min": 0, "a_max": 1} + return AttrCvt("clip")([transformX], attr) class Reduce(OnnxOpConverter): - """ Operator converter for reduce ops. - """ - name = '' + """Operator converter for reduce ops.""" + + name = "" @classmethod def _impl_v1(cls, inputs, attr, params): - if 'axes' in attr: - axis = attr.get('axes', 0) + if "axes" in attr: + axis = attr.get("axes", 0) else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + attr = {"axis": axis, "keepdims": attr.get("keepdims", True)} return AttrCvt(cls.name)(inputs, attr) class ReduceMax(Reduce): - """ Operator converter for ReduceMax. - """ - name = 'max' + """Operator converter for ReduceMax.""" + + name = "max" class ReduceMin(Reduce): - """ Operator converter for ReduceMin. - """ - name = 'min' + """Operator converter for ReduceMin.""" + + name = "min" class ReduceSum(Reduce): - """ Operator converter for ReduceSum. - """ - name = 'sum' + """Operator converter for ReduceSum.""" + + name = "sum" class ReduceMean(Reduce): - """ Operator converter for ReduceMean. - """ - name = 'mean' + """Operator converter for ReduceMean.""" + + name = "mean" class ReduceProd(Reduce): - """ Operator converter for ReduceProd. - """ - name = 'prod' + """Operator converter for ReduceProd.""" + + name = "prod" class ReduceLogSumExp(Reduce): - """ Operator converter for ReduceLogSumExp. - """ - name = 'logsumexp' + """Operator converter for ReduceLogSumExp.""" + + name = "logsumexp" class ReduceSumSquare(OnnxOpConverter): - """ Operator converter for ReduceSumSquare. - """ + """Operator converter for ReduceSumSquare.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - if 'axes' in attr: - axis = attr.get('axes', 0) + if "axes" in attr: + axis = attr.get("axes", 0) else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + attr = {"axis": axis, "keepdims": attr.get("keepdims", True)} inputs[0] = inputs[0] * inputs[0] return AttrCvt("sum")(inputs, attr) class ReduceL1(OnnxOpConverter): - """ Operator converter for ReduceL1. - """ + """Operator converter for ReduceL1.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - if 'axes' in attr: - axis = attr.get('axes', 0) + if "axes" in attr: + axis = attr.get("axes", 0) else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + attr = {"axis": axis, "keepdims": attr.get("keepdims", True)} inputs[0] = _op.abs(inputs[0]) return AttrCvt("sum")(inputs, attr) class ReduceL2(OnnxOpConverter): - """ Operator converter for ReduceL2. - """ + """Operator converter for ReduceL2.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - if 'axes' in attr: - axis = attr.get('axes', 0) + if "axes" in attr: + axis = attr.get("axes", 0) else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + attr = {"axis": axis, "keepdims": attr.get("keepdims", True)} inputs[0] = inputs[0] * inputs[0] out = AttrCvt("sum")(inputs, attr) @@ -1259,57 +1277,57 @@ def _impl_v1(cls, inputs, attr, params): class ReduceLogSum(OnnxOpConverter): - """ Operator converter for ReduceLogSum. - """ + """Operator converter for ReduceLogSum.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - if 'axes' in attr: - axis = attr.get('axes', 0) + if "axes" in attr: + axis = attr.get("axes", 0) else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + attr = {"axis": axis, "keepdims": attr.get("keepdims", True)} out = AttrCvt("sum")(inputs, attr) return _op.log(out) class ArgMax(OnnxOpConverter): - """ Operator converter for ArgMax. - """ + """Operator converter for ArgMax.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - axis = attr.get('axis', 0) - keepdims = attr.get('keepdims', True) - attr = {'axis': axis, 'keepdims': keepdims} - return AttrCvt('argmax')(inputs, attr) + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", True) + attr = {"axis": axis, "keepdims": keepdims} + return AttrCvt("argmax")(inputs, attr) class ArgMin(OnnxOpConverter): - """ Operator converter for ArgMin. - """ + """Operator converter for ArgMin.""" + @classmethod def _impl_v1(cls, inputs, attr, params): - axis = attr.get('axis', 0) - keepdims = attr.get('keepdims', True) - attr = {'axis': axis, 'keepdims': keepdims} - return AttrCvt('argmin')(inputs, attr) + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", True) + attr = {"axis": axis, "keepdims": keepdims} + return AttrCvt("argmin")(inputs, attr) class Softmax(OnnxOpConverter): - """ Operator converter for Softmax. - """ + """Operator converter for Softmax.""" + @classmethod def _impl_v1(cls, inputs, attr, params): # set default value when axis is not set in the model - if 'axis' not in attr: - attr['axis'] = 1 - return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params) + if "axis" not in attr: + attr["axis"] = 1 + return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params) class OneHot(OnnxOpConverter): - """ Operator converter for OneHot. - """ + """Operator converter for OneHot.""" + @classmethod def _impl_v9(cls, inputs, attr, params): # Extract relay one_hot inputs. @@ -1319,66 +1337,67 @@ def _impl_v9(cls, inputs, attr, params): # Extract the datatype of the output from on_value. dtype = infer_type(on_value).checked_type.dtype # set default value when axis is not set in the model - if 'axis' not in attr: - attr['axis'] = -1 - return _op.one_hot(indices, on_value, off_value, depth, int(attr['axis']), dtype=dtype) + if "axis" not in attr: + attr["axis"] = -1 + return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype) class ConstantOfShape(OnnxOpConverter): - """ Operator converter for ConstantOfShape. - """ + """Operator converter for ConstantOfShape.""" + @classmethod def _impl_v9(cls, inputs, attr, params): - if 'value' in attr: - np_value = get_numpy(attr.pop('value'))[0] + if "value" in attr: + np_value = get_numpy(attr.pop("value"))[0] value = _expr.const(np_value) dtype = np_value.dtype.name else: value = _expr.const(0) - dtype = 'float32' + dtype = "float32" output = _op.full(value, inputs[0], dtype=dtype) return output class Sign(OnnxOpConverter): - """ Operator converter for Sign. - """ + """Operator converter for Sign.""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _op.sign(inputs[0]) class Equal(Elemwise): - """ Operator converter for Equal. - """ - name = 'equal' + """Operator converter for Equal.""" + + name = "equal" class Not(Elemwise): - """ Operator converter for Not. - """ + """Operator converter for Not.""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _op.logical_not(inputs[0]) class And(Elemwise): - """ Operator converter for And. - """ + """Operator converter for And.""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _op.logical_and(inputs[0], inputs[1]) class Tile(Elemwise): - """Operator converter for Tile - """ + """Operator converter for Tile""" + @classmethod def _impl_v1(cls, inputs, attr, params): - if 'repeats' not in attr: - raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set ' - 'for operator Tile.') - reps = attr.pop('repeats') # The number of times repeating the tensor data. + if "repeats" not in attr: + raise tvm.error.OpAttributeInvalid( + 'Attribute "repeats" should be set ' "for operator Tile." + ) + reps = attr.pop("repeats") # The number of times repeating the tensor data. return _op.tile(inputs[0], reps) @classmethod @@ -1387,16 +1406,16 @@ def _impl_v6(cls, inputs, attr, params): class Erf(OnnxOpConverter): - """Operator converter for Erf - """ + """Operator converter for Erf""" + @classmethod def _impl_v1(cls, inputs, attr, params): return _op.erf(inputs[0]) class Where(OnnxOpConverter): - """Operator converter for Where - """ + """Operator converter for Where""" + @classmethod def _impl_v9(cls, inputs, attr, params): condition_shape = infer_shape(inputs[0]) @@ -1430,16 +1449,16 @@ def _impl_v9(cls, inputs, attr, params): class Or(Elemwise): - """ Operator converter for Or. - """ + """Operator converter for Or.""" + @classmethod def _impl_v7(cls, inputs, attr, params): return _op.logical_or(inputs[0], inputs[1]) class Expand(OnnxOpConverter): - """ Operator converter for Expand. - """ + """Operator converter for Expand.""" + @classmethod def _impl_v8(cls, inputs, attr, params): dtype = infer_type(inputs[1]).checked_type.dtype @@ -1455,18 +1474,40 @@ def _impl_v8(cls, inputs, attr, params): # In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand' # so, here we solved this problem by expanding the given 'shape' itself. def expand_shape(in_shape, shape): - """ A function expands the shape when the rank is lower than that of the given + """A function expands the shape when the rank is lower than that of the given intput. Also it replaces the extent of the shape with the corresponding extent of the intput when it is 1. """ in_dims = infer_shape(in_shape)[0] new_dims = infer_shape(shape)[0] if in_dims < new_dims: - in_shape = _op.concatenate([_expr.const([1, ] * (new_dims - in_dims), dtype=dtype), - in_shape], axis=0) + in_shape = _op.concatenate( + [ + _expr.const( + [ + 1, + ] + * (new_dims - in_dims), + dtype=dtype, + ), + in_shape, + ], + axis=0, + ) elif new_dims > in_dims: - shape = _op.concatenate([_expr.const([1, ] * (in_dims - new_dims), dtype=dtype), - shape], axis=0) + shape = _op.concatenate( + [ + _expr.const( + [ + 1, + ] + * (in_dims - new_dims), + dtype=dtype, + ), + shape, + ], + axis=0, + ) new_shape = _op.maximum(in_shape, shape) return new_shape @@ -1475,16 +1516,16 @@ def expand_shape(in_shape, shape): class RNN(OnnxOpConverter): - """ Operator converter for RNNs such as LSTM and GRU. - """ + """Operator converter for RNNs such as LSTM and GRU.""" + @classmethod def _activation_helper(cls, activation, alpha, beta): convert_map = _get_convert_map(1) attrs = {} if alpha is not None: - attrs['alpha'] = alpha + attrs["alpha"] = alpha if beta is not None: - attrs['beta'] = beta + attrs["beta"] = beta return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {}) @classmethod @@ -1510,8 +1551,8 @@ def _activation_needs_beta(cls, activation): class LSTM(RNN): - """Operator converter for LSTM - """ + """Operator converter for LSTM""" + @classmethod def _impl_v7(cls, inputs, attr, params): # Unpack inputs, note that if optional and not provided then value will be None. @@ -1520,7 +1561,7 @@ def _impl_v7(cls, inputs, attr, params): R = inputs[2] B = inputs[3] # Sequence length currently unused as it can be inferred from shapes. - #sequence_lens = inputs['sequence_lens'] + # sequence_lens = inputs['sequence_lens'] h_0 = inputs[5] c_0 = inputs[6] P = inputs[7] @@ -1558,16 +1599,16 @@ def _impl_v7(cls, inputs, attr, params): C_t = c_0 h_list = [] - if 'activations' in attr: - activations = attr['activations'] + if "activations" in attr: + activations = attr["activations"] if len(activations) != 3: raise NotImplementedError("LSTM assumes 3 activation functions are provided") alpha_loc = 0 - alphas = attr.get('activation_alpha', []) + alphas = attr.get("activation_alpha", []) if isinstance(alphas, float): alphas = [alphas] beta_loc = 0 - betas = attr.get('activation_beta', []) + betas = attr.get("activation_beta", []) if isinstance(betas, float): betas = [betas] acts = [] @@ -1623,8 +1664,8 @@ def _impl_v7(cls, inputs, attr, params): class GRU(RNN): - """Operator convert for GRU - """ + """Operator convert for GRU""" + @classmethod def _impl_v7(cls, inputs, attr, params): # Unpack inputs, note that if optional and not provided then value will be None. @@ -1633,9 +1674,9 @@ def _impl_v7(cls, inputs, attr, params): R = inputs[2] B = inputs[3] # Sequence length currently unused as it can be inferred from shapes. - #sequence_lens = inputs['sequence_lens'] + # sequence_lens = inputs['sequence_lens'] h_0 = inputs[5] - linear_before_reset = attr.get('linear_before_reset', 0) + linear_before_reset = attr.get("linear_before_reset", 0) num_directions = infer_shape(W)[0] W_dtype = infer_type(W).type_annotation.dtype @@ -1662,16 +1703,16 @@ def _impl_v7(cls, inputs, attr, params): H_t = h_0 h_list = [] - if 'activations' in attr: - activations = attr['activations'] + if "activations" in attr: + activations = attr["activations"] if len(activations) != 2: raise NotImplementedError("GRU assumes 2 activation functions are provided") alpha_loc = 0 - alphas = attr.get('activation_alpha', []) + alphas = attr.get("activation_alpha", []) if isinstance(alphas, float): alphas = [alphas] beta_loc = 0 - betas = attr.get('activation_beta', []) + betas = attr.get("activation_beta", []) if isinstance(betas, float): betas = [betas] acts = [] @@ -1730,62 +1771,66 @@ def _impl_v7(cls, inputs, attr, params): class Resize(OnnxOpConverter): - """Operator converter for Resize - """ + """Operator converter for Resize""" + @classmethod def _impl_v11(cls, inputs, attr, params): - mode = attr.get('mode') - if mode == b'nearest': + mode = attr.get("mode") + if mode == b"nearest": method = "nearest_neighbor" - elif mode == b'linear': + elif mode == b"linear": method = "bilinear" else: raise tvm.error.OpAttributeInvalid( - 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)) + 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) + ) scale = inputs[2] scale_shape = infer_shape(scale) if len(inputs) == 4: - assert len(scale_shape) == 0 or scale_shape[ - 0] == 0, "One of scale or size should be passed, not both." + assert ( + len(scale_shape) == 0 or scale_shape[0] == 0 + ), "One of scale or size should be passed, not both." size = inputs[3] else: assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(_op.shape_of(inputs[0]), - infer_type(scale).type_annotation.dtype) * scale + size = ( + _op.cast(_op.shape_of(inputs[0]), infer_type(scale).type_annotation.dtype) * scale + ) - coord_trans = attr.get('coordinate_transformation_mode') - if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: + coord_trans = attr.get("coordinate_transformation_mode") + if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]: coord_trans = "half_pixel" - elif coord_trans == b'align_corners': + elif coord_trans == b"align_corners": coord_trans = "align_corners" - elif coord_trans == b'asymmetric' or method == "nearest_neighbor": + elif coord_trans == b"asymmetric" or method == "nearest_neighbor": coord_trans = "asymmetric" else: raise tvm.error.OpAttributeInvalid( - 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) + "Unsupported coordinate_transformation_mode: {}".format(coord_trans) + ) layout = "NCHW" # ONNX assumes NCHW layout out_size = _op.strided_slice(size, [2], [4]) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) class NonZero(OnnxOpConverter): - """Operator converter for NonZero - """ + """Operator converter for NonZero""" + @classmethod def _impl_v9(cls, inputs, attr, params): if len(inputs) > 1: raise ValueError("Expect 1 input only") - output = AttrCvt(op_name='argwhere')(inputs, attr, params) + output = AttrCvt(op_name="argwhere")(inputs, attr, params) # ONNX NonZero always outputs int64 output = _op.cast(output, "int64") return _op.transpose(output, axes=(1, 0)) class TopK(OnnxOpConverter): - """Operator converter for TopK - """ + """Operator converter for TopK""" + @classmethod def _impl_v1(cls, inputs, attr, params): if len(inputs) != 2: @@ -1800,8 +1845,8 @@ def _impl_v1(cls, inputs, attr, params): class MaxRoiPool(OnnxOpConverter): - """Operator converter for MaxRoiPool. - """ + """Operator converter for MaxRoiPool.""" + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MMaxRoiPool op take 2 inputs, {} given".format(len(inputs)) @@ -1815,8 +1860,8 @@ def _impl_v1(cls, inputs, attr, params): class RoiAlign(OnnxOpConverter): - """Operator converter for RoiAlign. - """ + """Operator converter for RoiAlign.""" + @classmethod def _impl_v1(cls, inputs, attr, params): if len(inputs) != 3: @@ -1825,7 +1870,7 @@ def _impl_v1(cls, inputs, attr, params): rois = inputs[1] batch_indices = inputs[2] mode = attr.get("mode", "avg") - if mode != b'avg': + if mode != b"avg": raise ValueError("RoiAlign in Relay only uses avg mode") output_height = attr.get("output_height", 1) output_width = attr.get("output_width", 1) @@ -1837,16 +1882,17 @@ def _impl_v1(cls, inputs, attr, params): batch_indices = _op.cast(batch_indices, infer_type(rois).type_annotation.dtype) rois = _op.concatenate([batch_indices, rois], 1) - return _vision.roi_align(x, rois, [output_height, output_width], spatial_scale, - sampling_ratio) + return _vision.roi_align( + x, rois, [output_height, output_width], spatial_scale, sampling_ratio + ) class Clip(OnnxOpConverter): - """Operator converter for Clip. - """ + """Operator converter for Clip.""" + @staticmethod def convert_attributes(inputs, attr, params): - convert = AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}) + convert = AttrCvt("clip", transforms={"min": "a_min", "max": "a_max"}) return convert(inputs, attr, params) @classmethod @@ -1855,16 +1901,17 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): - if 'min' in attr and 'max' in attr: + if "min" in attr and "max" in attr: return Clip.convert_attributes(inputs, attr, params) assert len(inputs) <= 3, "Clip-11 takes up to 3 inputs, input, min, max" result = inputs[0] for i, op in enumerate([_maximum, _minimum]): if i < len(inputs) - 1: - result = op(result, inputs[i+1]) + result = op(result, inputs[i + 1]) return result + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1877,159 +1924,151 @@ def _impl_v11(cls, inputs, attr, params): def _get_convert_map(opset): return { # defs/experimental - 'Identity': Renamer('copy'), - 'Affine': Affine.get_converter(opset), - 'ThresholdedRelu': ThresholdedRelu.get_converter(opset), - 'ScaledTanh': ScaledTanh.get_converter(opset), - 'ParametricSoftplus': ParametricSoftPlus.get_converter(opset), - 'ConstantOfShape': ConstantOfShape.get_converter(opset), + "Identity": Renamer("copy"), + "Affine": Affine.get_converter(opset), + "ThresholdedRelu": ThresholdedRelu.get_converter(opset), + "ScaledTanh": ScaledTanh.get_converter(opset), + "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), + "ConstantOfShape": ConstantOfShape.get_converter(opset), # 'GivenTensorFill' - 'FC': AttrCvt('dense', ignores=['axis', 'axis_w']), - 'Scale': Scale.get_converter(opset), + "FC": AttrCvt("dense", ignores=["axis", "axis_w"]), + "Scale": Scale.get_converter(opset), # 'GRUUnit' # 'ATen' # 'ImageScaler' # 'MeanVarianceNormalization' # 'Crop' # 'Embedding' - 'Upsample': Upsample.get_converter(opset), - 'SpatialBN': BatchNorm.get_converter(opset), - + "Upsample": Upsample.get_converter(opset), + "SpatialBN": BatchNorm.get_converter(opset), # defs/generator # 'Constant' # Implemented # 'RandomUniform' # 'RandomNormal' # 'RandomUniformLike' # 'RandomNormalLike' - # defs/logical - # defs/math - 'Add': Add.get_converter(opset), - 'Sub': Sub.get_converter(opset), - 'Mul': Mul.get_converter(opset), - 'Div': Div.get_converter(opset), - 'Neg': Renamer('negative'), - 'Abs': Absolute.get_converter(opset), - 'Reciprocal': Reciprocal.get_converter(opset), - 'Floor': Renamer('floor'), - 'Ceil': Renamer('ceil'), - 'Round': Renamer('round'), - 'IsInf': Renamer('isinf'), - 'IsNaN': Renamer('isnan'), - 'Sqrt': Renamer('sqrt'), - 'Relu': Renamer('relu'), - 'LeakyRelu': Renamer('leaky_relu'), - 'Selu': Selu.get_converter(opset), - 'Elu': Elu.get_converter(opset), - 'Exp': Renamer('exp'), - 'Greater': Greater.get_converter(opset), - 'Less': Less.get_converter(opset), - 'Log': Renamer('log'), - 'ACos': Renamer('acos'), - 'ACosh': Renamer('acosh'), - 'ASin': Renamer('asin'), - 'ASinh': Renamer('asinh'), - 'ATan': Renamer('atan'), - 'ATanh': Renamer('atanh'), - 'Cos': Renamer('cos'), - 'Cosh': Renamer('cosh'), - 'Sin': Renamer('sin'), - 'Sinh': Renamer('sinh'), - 'Tan': Renamer('tan'), - 'Tanh': Renamer('tanh'), - 'Pow': Renamer('power'), - 'PRelu': Prelu.get_converter(opset), - 'Sigmoid': Renamer('sigmoid'), - 'HardSigmoid': HardSigmoid.get_converter(opset), - 'Max': Maximum.get_converter(opset), - 'Min': Minimum.get_converter(opset), - 'Sum': Sum.get_converter(opset), - 'Mean': Mean.get_converter(opset), - 'Clip': Clip.get_converter(opset), + "Add": Add.get_converter(opset), + "Sub": Sub.get_converter(opset), + "Mul": Mul.get_converter(opset), + "Div": Div.get_converter(opset), + "Neg": Renamer("negative"), + "Abs": Absolute.get_converter(opset), + "Reciprocal": Reciprocal.get_converter(opset), + "Floor": Renamer("floor"), + "Ceil": Renamer("ceil"), + "Round": Renamer("round"), + "IsInf": Renamer("isinf"), + "IsNaN": Renamer("isnan"), + "Sqrt": Renamer("sqrt"), + "Relu": Renamer("relu"), + "LeakyRelu": Renamer("leaky_relu"), + "Selu": Selu.get_converter(opset), + "Elu": Elu.get_converter(opset), + "Exp": Renamer("exp"), + "Greater": Greater.get_converter(opset), + "Less": Less.get_converter(opset), + "Log": Renamer("log"), + "ACos": Renamer("acos"), + "ACosh": Renamer("acosh"), + "ASin": Renamer("asin"), + "ASinh": Renamer("asinh"), + "ATan": Renamer("atan"), + "ATanh": Renamer("atanh"), + "Cos": Renamer("cos"), + "Cosh": Renamer("cosh"), + "Sin": Renamer("sin"), + "Sinh": Renamer("sinh"), + "Tan": Renamer("tan"), + "Tanh": Renamer("tanh"), + "Pow": Renamer("power"), + "PRelu": Prelu.get_converter(opset), + "Sigmoid": Renamer("sigmoid"), + "HardSigmoid": HardSigmoid.get_converter(opset), + "Max": Maximum.get_converter(opset), + "Min": Minimum.get_converter(opset), + "Sum": Sum.get_converter(opset), + "Mean": Mean.get_converter(opset), + "Clip": Clip.get_converter(opset), # softmax default axis is different in onnx - 'Softmax': Softmax.get_converter(opset), - 'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}), - 'OneHot': OneHot.get_converter(opset), + "Softmax": Softmax.get_converter(opset), + "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}), + "OneHot": OneHot.get_converter(opset), # 'Hardmax' - 'Softsign': Softsign.get_converter(opset), - 'SoftPlus': SoftPlus.get_converter(opset), - 'Gemm': Gemm.get_converter(opset), - 'MatMul': MatMul.get_converter(opset), - 'Mod': Mod.get_converter(opset), - 'Xor': Renamer('logical_xor'), - + "Softsign": Softsign.get_converter(opset), + "SoftPlus": SoftPlus.get_converter(opset), + "Gemm": Gemm.get_converter(opset), + "MatMul": MatMul.get_converter(opset), + "Mod": Mod.get_converter(opset), + "Xor": Renamer("logical_xor"), # defs/nn - 'AveragePool': AveragePool.get_converter(opset), - 'LpPool': LpPool.get_converter(opset), - 'MaxPool': MaxPool.get_converter(opset), - 'Conv': Conv.get_converter(opset), - 'ConvTranspose': ConvTranspose.get_converter(opset), - 'GlobalAveragePool': Renamer('global_avg_pool2d'), - 'GlobalMaxPool': Renamer('global_max_pool2d'), - 'BatchNormalization': BatchNorm.get_converter(opset), - 'InstanceNormalization': InstanceNorm.get_converter(opset), + "AveragePool": AveragePool.get_converter(opset), + "LpPool": LpPool.get_converter(opset), + "MaxPool": MaxPool.get_converter(opset), + "Conv": Conv.get_converter(opset), + "ConvTranspose": ConvTranspose.get_converter(opset), + "GlobalAveragePool": Renamer("global_avg_pool2d"), + "GlobalMaxPool": Renamer("global_max_pool2d"), + "BatchNormalization": BatchNorm.get_converter(opset), + "InstanceNormalization": InstanceNorm.get_converter(opset), # 'LpNormalization' - 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), - 'Flatten': Flatten.get_converter(opset), - 'LRN': LRN.get_converter(opset), + "Dropout": AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]), + "Flatten": Flatten.get_converter(opset), + "LRN": LRN.get_converter(opset), # Recurrent Layers - 'LSTM': LSTM.get_converter(opset), - 'GRU': GRU.get_converter(opset), - + "LSTM": LSTM.get_converter(opset), + "GRU": GRU.get_converter(opset), # defs/vision - 'MaxRoiPool': MaxRoiPool.get_converter(opset), - 'RoiAlign': RoiAlign.get_converter(opset), - + "MaxRoiPool": MaxRoiPool.get_converter(opset), + "RoiAlign": RoiAlign.get_converter(opset), # defs/reduction - 'ReduceMax': ReduceMax.get_converter(opset), - 'ReduceMin': ReduceMin.get_converter(opset), - 'ReduceSum': ReduceSum.get_converter(opset), - 'ReduceMean': ReduceMean.get_converter(opset), - 'ReduceProd': ReduceProd.get_converter(opset), - 'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset), - 'ReduceLogSum': ReduceLogSum.get_converter(opset), - 'ReduceSumSquare': ReduceSumSquare.get_converter(opset), - 'ReduceL1': ReduceL1.get_converter(opset), - 'ReduceL2': ReduceL2.get_converter(opset), - - #defs/sorting - 'ArgMax': ArgMax.get_converter(opset), - 'ArgMin': ArgMin.get_converter(opset), - 'TopK': TopK.get_converter(opset), - + "ReduceMax": ReduceMax.get_converter(opset), + "ReduceMin": ReduceMin.get_converter(opset), + "ReduceSum": ReduceSum.get_converter(opset), + "ReduceMean": ReduceMean.get_converter(opset), + "ReduceProd": ReduceProd.get_converter(opset), + "ReduceLogSumExp": ReduceLogSumExp.get_converter(opset), + "ReduceLogSum": ReduceLogSum.get_converter(opset), + "ReduceSumSquare": ReduceSumSquare.get_converter(opset), + "ReduceL1": ReduceL1.get_converter(opset), + "ReduceL2": ReduceL2.get_converter(opset), + # defs/sorting + "ArgMax": ArgMax.get_converter(opset), + "ArgMin": ArgMin.get_converter(opset), + "TopK": TopK.get_converter(opset), # defs/tensor - 'Cast': Cast.get_converter(opset), - 'Reshape': Reshape.get_converter(opset), - 'Expand': Expand.get_converter(opset), - 'Concat': Concat.get_converter(opset), - 'Split': Split.get_converter(opset), - 'Slice': Slice.get_converter(opset), - 'Transpose': AttrCvt('transpose', {'perm': 'axes'}), - 'DepthToSpace': DepthToSpace.get_converter(opset), - 'SpaceToDepth': SpaceToDepth.get_converter(opset), - 'Gather': Gather.get_converter(opset), - 'GatherND': GatherND.get_converter(opset), - 'Scatter': Scatter.get_converter(opset), - 'ScatterElements': Scatter.get_converter(opset), - 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), - 'Unsqueeze': Unsqueeze.get_converter(opset), - 'Pad': Pad.get_converter(opset), - 'Shape': Shape.get_converter(opset), - 'Sign': Sign.get_converter(opset), - 'Equal': Equal.get_converter(opset), - 'Not': Not.get_converter(opset), - 'And': And.get_converter(opset), - 'Tile': Tile.get_converter(opset), - 'Erf': Erf.get_converter(opset), - 'Where': Where.get_converter(opset), - 'Or': Or.get_converter(opset), - 'Resize': Resize.get_converter(opset), - 'NonZero': NonZero.get_converter(opset), + "Cast": Cast.get_converter(opset), + "Reshape": Reshape.get_converter(opset), + "Expand": Expand.get_converter(opset), + "Concat": Concat.get_converter(opset), + "Split": Split.get_converter(opset), + "Slice": Slice.get_converter(opset), + "Transpose": AttrCvt("transpose", {"perm": "axes"}), + "DepthToSpace": DepthToSpace.get_converter(opset), + "SpaceToDepth": SpaceToDepth.get_converter(opset), + "Gather": Gather.get_converter(opset), + "GatherND": GatherND.get_converter(opset), + "Scatter": Scatter.get_converter(opset), + "ScatterElements": Scatter.get_converter(opset), + "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), + "Unsqueeze": Unsqueeze.get_converter(opset), + "Pad": Pad.get_converter(opset), + "Shape": Shape.get_converter(opset), + "Sign": Sign.get_converter(opset), + "Equal": Equal.get_converter(opset), + "Not": Not.get_converter(opset), + "And": And.get_converter(opset), + "Tile": Tile.get_converter(opset), + "Erf": Erf.get_converter(opset), + "Where": Where.get_converter(opset), + "Or": Or.get_converter(opset), + "Resize": Resize.get_converter(opset), + "NonZero": NonZero.get_converter(opset), } -class GraphProto(): +class GraphProto: """A helper class for handling Relay expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -2041,6 +2080,7 @@ class GraphProto(): dtype : str or dict of str to str The input types to the graph """ + def __init__(self, shape, dtype): self._nodes = {} self._params = {} @@ -2096,21 +2136,23 @@ def from_onnx(self, graph, opset, freeze_params=False): if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) - self._nodes[init_tensor.name] = new_var(init_tensor.name, - shape=self._params[init_tensor.name].shape, - dtype=self._params[init_tensor.name].dtype) + self._nodes[init_tensor.name] = new_var( + init_tensor.name, + shape=self._params[init_tensor.name].shape, + dtype=self._params[init_tensor.name].dtype, + ) for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' i_name = self._parse_value_proto(i) - d_type = self._parse_dtype(i, 'float32') + d_type = self._parse_dtype(i, "float32") if i_name in self._params: # i is a param instead of input self._num_param += 1 self._params[i_name] = self._params.pop(i_name) - self._nodes[i_name] = new_var(i_name, - shape=self._params[i_name].shape, - dtype=self._params[i_name].dtype) + self._nodes[i_name] = new_var( + i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype + ) else: self._num_input += 1 if i_name in self._shape: @@ -2128,13 +2170,15 @@ def from_onnx(self, graph, opset, freeze_params=False): unsupported_ops = set() for node in graph.node: op_name = node.op_type - if op_name not in convert_map and \ - op_name != 'Constant' and \ - op_name not in _identity_list: + if ( + op_name not in convert_map + and op_name != "Constant" + and op_name not in _identity_list + ): unsupported_ops.add(op_name) if unsupported_ops: - msg = 'The following operators are not supported for frontend ONNX: ' - msg += ', '.join(unsupported_ops) + msg = "The following operators are not supported for frontend ONNX: " + msg += ", ".join(unsupported_ops) raise tvm.error.OpNotImplemented(msg) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: @@ -2143,7 +2187,7 @@ def from_onnx(self, graph, opset, freeze_params=False): # Create and populate onnx input object. inputs = onnx_input() for i in node.input: - if i != '': + if i != "": inputs[i] = self._nodes[self._renames.get(i, i)] if op_name == "Constant": t_proto = self._parse_attr(node.attribute)["value"] @@ -2151,24 +2195,26 @@ def from_onnx(self, graph, opset, freeze_params=False): # We should convert scalar integers to int32, to normalize. array = self._parse_array(t_proto) self._params[node.output[0]] = array - self._nodes[node.output[0]] = new_var(node.output[0], - shape=list(t_proto.dims), - dtype=array.dtype) + self._nodes[node.output[0]] = new_var( + node.output[0], shape=list(t_proto.dims), dtype=array.dtype + ) else: i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) - attr['tvm_custom'] = {} - attr['tvm_custom']['name'] = i_name - attr['tvm_custom']['num_outputs'] = len(node_output) + attr["tvm_custom"] = {} + attr["tvm_custom"]["name"] = i_name + attr["tvm_custom"]["num_outputs"] = len(node_output) op = self._convert_operator(op_name, inputs, attr, opset) if not isinstance(op, _expr.TupleWrapper): outputs_num = 1 else: outputs_num = len(op) - assert len(node_output) == outputs_num, ( - "Number of output mismatch {} vs {} in {}.".format( - len(node_output), outputs_num, op_name)) + assert ( + len(node_output) == outputs_num + ), "Number of output mismatch {} vs {} in {}.".format( + len(node_output), outputs_num, op_name + ) if outputs_num == 1: self._nodes[node_output[0]] = op else: @@ -2204,6 +2250,7 @@ def _parse_dtype(self, value_proto, dtype): """Parse dtype.""" try: from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name except AttributeError: return dtype @@ -2216,24 +2263,24 @@ def _parse_attr(self, attr_proto): """Convert a list of AttributeProto to a dict, with names as keys.""" attrs = {} for a in attr_proto: - for f in ['f', 'i', 's']: + for f in ["f", "i", "s"]: if a.HasField(f): attrs[a.name] = getattr(a, f) - for f in ['floats', 'ints', 'strings']: + for f in ["floats", "ints", "strings"]: if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ['t']: + for f in ["t"]: if a.HasField(f): attrs[a.name] = getattr(a, f) - for f in ['tensors']: + for f in ["tensors"]: if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ['g']: + for f in ["g"]: if a.HasField(f): raise NotImplementedError("Filed {} is not supported in relay.".format(f)) - for f in ['graphs']: + for f in ["graphs"]: if list(getattr(a, f)): raise NotImplementedError("Filed {} is not supported in relay.".format(f)) if a.name not in attrs: @@ -2274,7 +2321,7 @@ def _fix_outputs(self, op_name, outputs): """A hack to handle dropout or similar operator that have more than one out in ONNX. """ - if op_name == 'Dropout': + if op_name == "Dropout": if len(outputs) == 1: return outputs # TODO(zhreshold): support dropout mask? @@ -2324,12 +2371,14 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals """ try: import onnx - if hasattr(onnx.checker, 'check_model'): + + if hasattr(onnx.checker, "check_model"): # try use onnx's own model checker before converting any model try: onnx.checker.check_model(model) except onnx.onnx_cpp2py_export.checker.ValidationError as e: import warnings + # the checker is a bit violent about errors, so simply print warnings here warnings.warn(str(e)) except ImportError: diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 9f04d011b6cb..976e7e2eb4f0 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -#pylint: disable=invalid-name, unused-argument, len-as-condition +# pylint: disable=invalid-name, unused-argument, len-as-condition """Backend compiler related feature registration""" from tvm.te.hybrid import script @@ -96,6 +96,7 @@ def zeros_compute(attrs, inputs, output_type): assert not inputs return [topi.full(output_type.shape, output_type.dtype, 0.0)] + register_broadcast_schedule("zeros") register_pattern("zeros", OpPattern.ELEMWISE) @@ -105,6 +106,7 @@ def zeros_like_compute(attrs, inputs, output_type): assert len(inputs) == 1 return [topi.full_like(inputs[0], 0.0)] + register_broadcast_schedule("zeros_like") # ones @@ -113,6 +115,7 @@ def ones_compute(attrs, inputs, output_type): assert not inputs return [topi.full(output_type.shape, output_type.dtype, 1.0)] + register_broadcast_schedule("ones") register_pattern("ones", OpPattern.ELEMWISE) @@ -122,6 +125,7 @@ def ones_like_compute(attrs, inputs, output_type): assert len(inputs) == 1 return [topi.full_like(inputs[0], 1.0)] + register_broadcast_schedule("ones_like") # clip @@ -130,6 +134,7 @@ def clip_compute(attrs, inputs, output_type): assert len(inputs) == 1 return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] + register_injective_schedule("clip") # fixed point multiply @@ -138,6 +143,7 @@ def fixed_point_multiply_compute(attrs, inputs, output_type): assert len(inputs) == 1 return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)] + register_injective_schedule("fixed_point_multiply") # full @@ -149,18 +155,21 @@ def _full_shape_func(shape): out[i] = int64(shape[i]) return out + def full_shape_func(attrs, inputs, out_ndims): """ Shape func for full. """ return [_full_shape_func(inputs[1])] + def no_data_full_shape_func(attrs, inputs, out_ndims): """ Shape func for zeros and ones. """ return [_full_shape_func(inputs[0])] + @script def _broadcast_shape_func(x, y, ndim): out = output_tensor((ndim,), "int64") @@ -173,34 +182,39 @@ def _broadcast_shape_func(x, y, ndim): else: ndim1 = x.shape[0] ndim2 = y.shape[0] - for i in const_range(1, min(ndim1, ndim2)+1): - if x[ndim1-i] == y[ndim2-i]: - out[ndim-i] = x[ndim1-i] - elif x[ndim1-i] == 1: - out[ndim-i] = y[ndim2-i] + for i in const_range(1, min(ndim1, ndim2) + 1): + if x[ndim1 - i] == y[ndim2 - i]: + out[ndim - i] = x[ndim1 - i] + elif x[ndim1 - i] == 1: + out[ndim - i] = y[ndim2 - i] else: assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % ( - x[ndim1-i], y[ndim2-i]) - out[ndim-i] = x[ndim1-i] - for i in const_range(min(ndim1, ndim2)+1, ndim+1): + x[ndim1 - i], + y[ndim2 - i], + ) + out[ndim - i] = x[ndim1 - i] + for i in const_range(min(ndim1, ndim2) + 1, ndim + 1): if ndim1 >= ndim2: - out[ndim-i] = x[ndim1-i] + out[ndim - i] = x[ndim1 - i] else: - out[ndim-i] = y[ndim2-i] + out[ndim - i] = y[ndim2 - i] return out + def broadcast_shape_func(attrs, inputs, out_ndims): """ Shape function for broadcast op. """ return [_broadcast_shape_func(*inputs, out_ndims[0])] + def elemwise_shape_func(attrs, inputs, _): """ Shape function for elemwise op. """ return [topi.math.identity(inputs[0])] + register_shape_func("cast", False, elemwise_shape_func) register_shape_func("zeros", False, full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0fa742700637..66a3ee218f88 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -50,9 +50,10 @@ # fifo_buffer -@reg.register_compute('nn.fifo_buffer') +@reg.register_compute("nn.fifo_buffer") def compute_fifo_buffer(attrs, inputs, out_type): - return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))] + return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int("axis"))] + reg.register_injective_schedule("nn.fifo_buffer") reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE) @@ -69,6 +70,7 @@ def compute_sparse_dense(attrs, inputs, out_type): """Compute definition of sparse_dense""" return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])] + reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy) reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -79,6 +81,7 @@ def compute_sparse_transpose(attrs, inputs, out_type): """Compute definition of sparse_transpose""" return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2]) + reg.register_schedule("nn.sparse_transpose", strategy.schedule_sparse_transpose) reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -92,11 +95,13 @@ def compute_sparse_transpose(attrs, inputs, out_type): reg.register_strategy("nn.conv2d", strategy.conv2d_strategy) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_alter_op_layout("nn.conv2d") def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type): """Alternate the layout of conv2d""" return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type) + @reg.register_legalize("nn.conv2d") def legalize_conv2d(attrs, inputs, types): """Legalize conv2d op. @@ -117,6 +122,7 @@ def legalize_conv2d(attrs, inputs, types): """ return topi.nn.conv2d_legalize(attrs, inputs, types) + @reg.register_convert_op_layout("nn.conv2d") def convert_conv2d(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv2d op. @@ -140,28 +146,30 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): """ # pylint: disable=import-outside-toplevel from tvm import relay + data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" desired_data_layout, desired_kernel_layout = map(str, desired_layouts) assert desired_data_layout != "default", "Data layout cannot be default" - new_attrs['data_layout'] = desired_data_layout + new_attrs["data_layout"] = desired_data_layout if desired_kernel_layout != "default": - new_attrs['kernel_layout'] = desired_kernel_layout + new_attrs["kernel_layout"] = desired_kernel_layout return relay.nn.conv2d(data, weight, **new_attrs) # Handle default kernel layouts - if desired_data_layout == 'NCHW': - new_attrs['kernel_layout'] = 'OIHW' + if desired_data_layout == "NCHW": + new_attrs["kernel_layout"] = "OIHW" return relay.nn.conv2d(data, weight, **new_attrs) - elif desired_data_layout == 'NHWC': + elif desired_data_layout == "NHWC": # Check for depthwise convolution. - if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape, - attrs['kernel_layout'], attrs['groups']): - new_attrs['kernel_layout'] = 'HWOI' + if is_depthwise_conv2d( + data.shape, attrs["data_layout"], weight.shape, attrs["kernel_layout"], attrs["groups"] + ): + new_attrs["kernel_layout"] = "HWOI" else: - new_attrs['kernel_layout'] = 'HWIO' + new_attrs["kernel_layout"] = "HWIO" return relay.nn.conv2d(data, weight, **new_attrs) raise ValueError("Layout %s is not yet supported." % desired_data_layout) @@ -192,6 +200,7 @@ def legalize_conv2d_transpose(attrs, inputs, types): """ return topi.nn.conv2d_transpose_legalize(attrs, inputs, types) + @reg.register_convert_op_layout("nn.conv2d_transpose") def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv2d_transpose op. @@ -215,31 +224,34 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): """ # pylint: disable=import-outside-toplevel from tvm import relay + data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" desired_data_layout, desired_kernel_layout = map(str, desired_layouts) assert desired_data_layout != "default", "Data layout cannot be default" - new_attrs['data_layout'] = desired_data_layout + new_attrs["data_layout"] = desired_data_layout if desired_kernel_layout != "default": - new_attrs['kernel_layout'] = desired_kernel_layout + new_attrs["kernel_layout"] = desired_kernel_layout return relay.nn.conv2d_transpose(data, weight, **new_attrs) # Handle default kernel layouts - if desired_data_layout == 'NCHW': - new_attrs['kernel_layout'] = 'OIHW' + if desired_data_layout == "NCHW": + new_attrs["kernel_layout"] = "OIHW" return relay.nn.conv2d_transpose(data, weight, **new_attrs) - elif desired_data_layout == 'NHWC': - new_attrs['kernel_layout'] = 'HWIO' + elif desired_data_layout == "NHWC": + new_attrs["kernel_layout"] = "HWIO" return relay.nn.conv2d_transpose(data, weight, **new_attrs) raise ValueError("Layout %s is not yet supported." % desired_data_layout) + # conv3d_transpose reg.register_strategy("nn.conv3d_transpose", strategy.conv3d_transpose_strategy) reg.register_pattern("nn.conv3d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_legalize("nn.conv3d_transpose") def legalize_conv3d_transpose(attrs, inputs, types): """Legalize conv3d_transpose op. @@ -265,11 +277,13 @@ def legalize_conv3d_transpose(attrs, inputs, types): reg.register_strategy("nn.conv3d", strategy.conv3d_strategy) reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_alter_op_layout("nn.conv3d") def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type): """Alternate the layout of conv3d""" return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type) + @reg.register_convert_op_layout("nn.conv3d") def convert_conv3d(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv3d op. @@ -293,45 +307,51 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts): """ # pylint: disable=import-outside-toplevel from tvm import relay + data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs" desired_data_layout, desired_kernel_layout = map(str, desired_layouts) assert desired_data_layout != "default", "Data layout cannot be default" - new_attrs['data_layout'] = desired_data_layout + new_attrs["data_layout"] = desired_data_layout if desired_kernel_layout != "default": - new_attrs['kernel_layout'] = desired_kernel_layout + new_attrs["kernel_layout"] = desired_kernel_layout return relay.nn.conv3d(data, weight, **new_attrs) # Handle default kernel layouts - if desired_data_layout == 'NCDHW': - new_attrs['kernel_layout'] = 'OIDHW' + if desired_data_layout == "NCDHW": + new_attrs["kernel_layout"] = "OIDHW" return relay.nn.conv3d(data, weight, **new_attrs) elif desired_data_layout == "NDHWC": - new_attrs['kernel_layout'] = 'DHWIO' + new_attrs["kernel_layout"] = "DHWIO" return relay.nn.conv3d(data, weight, **new_attrs) raise ValueError("Layout %s is not yet supported" % desired_data_layout) # conv3d_winograd related operators -reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform", - strategy.conv3d_winograd_without_weight_transfrom_strategy) -reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform", - OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_strategy( + "nn.contrib_conv3d_winograd_without_weight_transform", + strategy.conv3d_winograd_without_weight_transfrom_strategy, +) +reg.register_pattern( + "nn.contrib_conv3d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE +) + @reg.register_compute("nn.contrib_conv3d_winograd_weight_transform") def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): """Compute definition of contrib_conv3d_winograd_weight_transform""" - out = topi.nn.conv3d_winograd_weight_transform( - inputs[0], attrs.get_int('tile_size')) + out = topi.nn.conv3d_winograd_weight_transform(inputs[0], attrs.get_int("tile_size")) return [out] -reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform", - strategy.schedule_conv3d_winograd_weight_transform) -reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform", - OpPattern.OUT_ELEMWISE_FUSABLE) + +reg.register_schedule( + "nn.contrib_conv3d_winograd_weight_transform", + strategy.schedule_conv3d_winograd_weight_transform, +) +reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) # conv1d_transpose @@ -434,8 +454,8 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): def compute_lrn(attrs, inputs, out_dtype): """Compute definition of lrn""" assert len(inputs) == 1 - return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, - attrs.alpha, attrs.beta, attrs.bias)] + return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, attrs.alpha, attrs.beta, attrs.bias)] + reg.register_schedule("nn.lrn", strategy.schedule_lrn) reg.register_pattern("nn.lrn", OpPattern.OPAQUE) @@ -451,6 +471,7 @@ def compute_upsampling(attrs, inputs, out_dtype): align_corners = attrs.align_corners return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)] + reg.register_injective_schedule("nn.upsampling") @@ -463,8 +484,12 @@ def compute_upsampling3d(attrs, inputs, out_dtype): layout = attrs.layout method = attrs.method coordinate_transformation_mode = attrs.coordinate_transformation_mode - return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\ - coordinate_transformation_mode)] + return [ + topi.nn.upsampling3d( + inputs[0], scale_d, scale_h, scale_w, layout, method, coordinate_transformation_mode + ) + ] + reg.register_injective_schedule("nn.upsampling3d") @@ -481,6 +506,7 @@ def compute_mirror_pad(attrs, inputs, out_dtype): out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode) return [out] + reg.register_broadcast_schedule("nn.mirror_pad") @@ -491,6 +517,7 @@ def _mirror_pad_func(data_shape, pad_width): out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1]) return out + @reg.register_shape_func("nn.mirror_pad", False) def mirror_pad_func(attrs, inputs, _): pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width] @@ -498,65 +525,75 @@ def mirror_pad_func(attrs, inputs, _): # conv2d_winograd related operators -reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform", - strategy.conv2d_winograd_without_weight_transfrom_strategy) -reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", - OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_strategy( + "nn.contrib_conv2d_winograd_without_weight_transform", + strategy.conv2d_winograd_without_weight_transfrom_strategy, +) +reg.register_pattern( + "nn.contrib_conv2d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE +) # conv2d_gemm related operators -reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform", - strategy.conv2d_gemm_without_weight_transform_strategy) -reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform", - OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_strategy( + "nn.contrib_conv2d_gemm_without_weight_transform", + strategy.conv2d_gemm_without_weight_transform_strategy, +) +reg.register_pattern( + "nn.contrib_conv2d_gemm_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE +) + @reg.register_compute("nn.contrib_conv2d_gemm_weight_transform") def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype): """Compute definition of contrib_conv2d_gemm_weight_transform""" - out = topi.nn.conv2d_gemm_weight_transform( - inputs[0], attrs.tile_rows, attrs.tile_cols) + out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_rows, attrs.tile_cols) return [out] -reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform", - strategy.schedule_conv2d_gemm_weight_transform) -reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform", - OpPattern.OUT_ELEMWISE_FUSABLE) + +reg.register_schedule( + "nn.contrib_conv2d_gemm_weight_transform", strategy.schedule_conv2d_gemm_weight_transform +) +reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype): """Compute definition of contrib_conv2d_winograd_weight_transform""" - out = topi.nn.conv2d_winograd_weight_transform( - inputs[0], attrs.get_int('tile_size')) + out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int("tile_size")) return [out] -reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform", - strategy.schedule_conv2d_winograd_weight_transform) -reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", - OpPattern.OUT_ELEMWISE_FUSABLE) + +reg.register_schedule( + "nn.contrib_conv2d_winograd_weight_transform", + strategy.schedule_conv2d_winograd_weight_transform, +) +reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform") def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype): """Compute definition of contrib_conv2d_winograd_nnpack_weight_transform""" - convolution_algorithm = attrs.get_int('convolution_algorithm') + convolution_algorithm = attrs.get_int("convolution_algorithm") out = topi.nn.conv2d_winograd_nnpack_weight_transform( - inputs[0], convolution_algorithm, out_dtype) + inputs[0], convolution_algorithm, out_dtype + ) return [out] -reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform", - strategy.schedule_conv2d_winograd_nnpack_weight_transform) -reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", - OpPattern.OPAQUE) + +reg.register_schedule( + "nn.contrib_conv2d_winograd_nnpack_weight_transform", + strategy.schedule_conv2d_winograd_nnpack_weight_transform, +) +reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE) # conv2d_NCHWc reg.register_strategy("nn.contrib_conv2d_NCHWc", strategy.conv2d_NCHWc_strategy) -reg.register_pattern("nn.contrib_conv2d_NCHWc", - OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_pattern("nn.contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) # depthwise_conv2d_NCHWc -reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc", - strategy.depthwise_conv2d_NCHWc_strategy) -reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", - OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc", strategy.depthwise_conv2d_NCHWc_strategy) +reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) # deformable_conv2d @@ -576,6 +613,7 @@ def compute_bitpack(attrs, inputs, out_dtype): out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type, name) return [out] + reg.register_schedule("nn.bitpack", strategy.schedule_bitpack) reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE) @@ -584,6 +622,7 @@ def compute_bitpack(attrs, inputs, out_dtype): reg.register_strategy("nn.bitserial_conv2d", strategy.bitserial_conv2d_strategy) reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_legalize("nn.bitserial_conv2d") def legalize_bitserial_conv2d(attrs, inputs, types): """Legalize bitserial_conv2d op. @@ -616,6 +655,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype): x, y = inputs return [-topi.sum(topi.log(x) * y) / x.shape[0]] + reg.register_reduce_schedule("nn.cross_entropy") reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) @@ -625,6 +665,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype): def compute_dilate(attrs, inputs, out_dtype): return [topi.nn.dilate(inputs[0], attrs.strides)] + reg.register_broadcast_schedule("nn.dilate") reg.register_pattern("nn.dilate", OpPattern.INJECTIVE) @@ -635,6 +676,7 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): x, y = inputs return [-topi.sum(x * y) / x.shape[0]] + reg.register_reduce_schedule("nn.cross_entropy_with_logits") reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) @@ -647,6 +689,7 @@ def compute_depth_to_space(attrs, inputs, out_dtype): mode = attrs.mode return [topi.nn.depth_to_space(inputs[0], block_size, layout=layout, mode=mode)] + reg.register_injective_schedule("nn.depth_to_space") reg.register_pattern("nn.depth_to_space", OpPattern.INJECTIVE) @@ -658,6 +701,7 @@ def compute_space_to_depth(attrs, inputs, out_dtype): layout = attrs.layout return [topi.nn.space_to_depth(inputs[0], block_size, layout=layout)] + reg.register_injective_schedule("nn.space_to_depth") reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE) @@ -671,6 +715,7 @@ def compute_space_to_depth(attrs, inputs, out_dtype): # Shape functions # ##################### + @script def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn): out = output_tensor((dshape.shape[0],), "int64") @@ -699,6 +744,7 @@ def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn): out[4] = int64(oc_bn) return out + @reg.register_shape_func("nn.contrib_conv2d_NCHWc", False) def conv2d_NCHWc_shape_func(attrs, inputs, _): """ @@ -710,13 +756,20 @@ def conv2d_NCHWc_shape_func(attrs, inputs, _): out_layout = attrs.out_layout oc_bn = int(out_layout[4:-1]) - return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1], - convert(strides), convert(padding), - convert(dilation), convert(oc_bn))] + return [ + _conv2d_NCHWc_shape_func( + inputs[0], + inputs[1], + convert(strides), + convert(padding), + convert(dilation), + convert(oc_bn), + ) + ] + @script -def _pool2d_shape_func(data_shape, pool_size, strides, - padding, height_axis, width_axis): +def _pool2d_shape_func(data_shape, pool_size, strides, padding, height_axis, width_axis): out = output_tensor((data_shape.shape[0],), "int64") for i in const_range(data_shape.shape[0]): if i == height_axis: @@ -728,6 +781,7 @@ def _pool2d_shape_func(data_shape, pool_size, strides, return out + def pool2d_shape_func(attrs, inputs, _): """ Shape function for pool2d op. @@ -743,13 +797,22 @@ def pool2d_shape_func(attrs, inputs, _): elif len(padding) == 2: padding = [padding[0], padding[1], padding[0], padding[1]] - return [_pool2d_shape_func(inputs[0], convert(pool_size), - convert(strides), convert(padding), - convert(height_axis), convert(width_axis))] + return [ + _pool2d_shape_func( + inputs[0], + convert(pool_size), + convert(strides), + convert(padding), + convert(height_axis), + convert(width_axis), + ) + ] + reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func) reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func) + @script def _global_pool2d_shape_func(data_shape, height_axis, width_axis): out = output_tensor((data_shape.shape[0],), "int64") @@ -761,6 +824,7 @@ def _global_pool2d_shape_func(data_shape, height_axis, width_axis): return out + def global_pool2d_shape_func(attrs, inputs, _): """ Shape function for global pool2d op. @@ -774,9 +838,11 @@ def global_pool2d_shape_func(attrs, inputs, _): width_axis = i return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))] + reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func) reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func) + @script def _batch_flatten_shape_func(data_shape): out = output_tensor((2,), "int64") @@ -787,6 +853,7 @@ def _batch_flatten_shape_func(data_shape): return out + @reg.register_shape_func("nn.batch_flatten", False) def batch_flatten_shape_func(attrs, inputs, _): """ @@ -794,6 +861,7 @@ def batch_flatten_shape_func(attrs, inputs, _): """ return [_batch_flatten_shape_func(inputs[0])] + @script def _dense_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") @@ -803,6 +871,7 @@ def _dense_shape_func(data_shape, weight_shape): return out + @reg.register_shape_func("nn.dense", False) def dense_shape_func(attrs, inputs, _): """ @@ -811,6 +880,7 @@ def dense_shape_func(attrs, inputs, _): ret = [_dense_shape_func(inputs[0], inputs[1])] return ret + @script def _batch_matmul_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") @@ -820,6 +890,7 @@ def _batch_matmul_shape_func(data_shape, weight_shape): return out + @reg.register_shape_func("nn.batch_matmul", False) def batch_matmul_shape_func(attrs, inputs, _): """ @@ -828,6 +899,7 @@ def batch_matmul_shape_func(attrs, inputs, _): ret = [_batch_matmul_shape_func(inputs[0], inputs[1])] return ret + @script def _pad_shape_func(data_shape, pad_width): out = output_tensor((data_shape.shape[0],), "int64") @@ -836,6 +908,7 @@ def _pad_shape_func(data_shape, pad_width): return out + @reg.register_shape_func("nn.pad", False) def pad_shape_func(attrs, inputs, _): """ @@ -846,6 +919,7 @@ def pad_shape_func(attrs, inputs, _): pad_width.append(get_const_tuple(pair)) return [_pad_shape_func(inputs[0], convert(pad_width))] + @script def _dilate_shape_func(data_shape, strides): out = output_tensor((data_shape.shape[0],), "int64") @@ -854,6 +928,7 @@ def _dilate_shape_func(data_shape, strides): return out + @reg.register_shape_func("nn.dilate", False) def dilate_shape_func(attrs, inputs, _): """ @@ -861,6 +936,7 @@ def dilate_shape_func(attrs, inputs, _): """ return [_dilate_shape_func(inputs[0], convert(attrs.strides))] + reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 879034684cc9..2eac497d162a 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -24,15 +24,19 @@ from .. import op as _op from ....target import generic_func, override_native_generic_func -logger = logging.getLogger('strategy') +logger = logging.getLogger("strategy") + def wrap_topi_schedule(topi_schedule): """Wrap TOPI schedule which doesn't use attrs""" + def wrapper(attrs, outs, target): with target: return topi_schedule(outs) + return wrapper + def get_conv2d_in_channels(data_shape, data_layout): """Get conv2d input channels""" data_shape = get_const_tuple(data_shape) @@ -45,6 +49,7 @@ def get_conv2d_in_channels(data_shape, data_layout): return data_shape[1] * data_shape[4] raise ValueError("Unknown conv2d data layout {}".format(data_layout)) + def get_conv2d_out_channels(kernel_shape, kernel_layout): """Get conv2d output channels""" kernel_shape = get_const_tuple(kernel_shape) @@ -58,23 +63,27 @@ def get_conv2d_out_channels(kernel_shape, kernel_layout): return kernel_shape[0] * kernel_shape[4] raise ValueError("Unknown conv2d kernel layout {}".format(kernel_layout)) + def is_depthwise_conv2d(data_shape, data_layout, kernel_shape, kernel_layout, groups): ic = get_conv2d_in_channels(data_shape, data_layout) oc = get_conv2d_out_channels(kernel_shape, kernel_layout) return ic == oc == groups + @generic_func def schedule_injective(attrs, outs, target): """Schedule injective ops""" with target: return topi.generic.schedule_injective(outs) + @generic_func def schedule_reduce(attrs, outs, target): """Schedule reduction ops""" with target: return topi.generic.schedule_reduce(outs) + _op._schedule_injective = schedule_injective _op._schedule_reduce = schedule_reduce @@ -85,6 +94,7 @@ def schedule_concatenate(attrs, outs, target): with target: return topi.generic.schedule_injective(outs) + # pool @generic_func def schedule_pool(attrs, outs, target): @@ -92,6 +102,7 @@ def schedule_pool(attrs, outs, target): with target: return topi.generic.schedule_pool(outs, attrs.layout) + # pool_grad @generic_func def schedule_pool_grad(attrs, outs, target): @@ -99,6 +110,7 @@ def schedule_pool_grad(attrs, outs, target): with target: return topi.generic.schedule_pool_grad(outs) + # adaptive pool @generic_func def schedule_adaptive_pool(attrs, outs, target): @@ -106,14 +118,18 @@ def schedule_adaptive_pool(attrs, outs, target): with target: return topi.generic.schedule_adaptive_pool(outs) + # softmax def wrap_compute_softmax(topi_compute): """Wrap softmax topi compute""" + def _compute_softmax(attrs, inputs, out_type): axis = attrs.get_int("axis") return [topi_compute(inputs[0], axis)] + return _compute_softmax + @override_native_generic_func("softmax_strategy") def softmax_strategy(attrs, inputs, out_type, target): """softmax generic strategy""" @@ -121,9 +137,11 @@ def softmax_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), wrap_topi_schedule(topi.generic.schedule_softmax), - name="softmax.generic") + name="softmax.generic", + ) return strategy + # log_softmax @generic_func def schedule_log_softmax(attrs, outs, target): @@ -131,6 +149,7 @@ def schedule_log_softmax(attrs, outs, target): with target: return topi.generic.schedule_softmax(outs) + # lrn @generic_func def schedule_lrn(attrs, outs, target): @@ -138,6 +157,7 @@ def schedule_lrn(attrs, outs, target): with target: return topi.generic.schedule_lrn(outs) + # bitpack @generic_func def schedule_bitpack(attrs, outs, target): @@ -145,10 +165,13 @@ def schedule_bitpack(attrs, outs, target): with target: return topi.generic.schedule_bitpack(outs) + # conv2d -def wrap_compute_conv2d(topi_compute, need_data_layout=False, need_out_layout=False, - has_groups=False): +def wrap_compute_conv2d( + topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False +): """Wrap conv2d topi compute""" + def _compute_conv2d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) @@ -156,8 +179,7 @@ def _compute_conv2d(attrs, inputs, out_type): data_layout = attrs.get_str("data_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") - else out_dtype) + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilation] if has_groups: args.append(attrs.groups) @@ -167,8 +189,10 @@ def _compute_conv2d(attrs, inputs, out_type): args.append(out_layout) args.append(out_dtype) return [topi_compute(*args)] + return _compute_conv2d + @override_native_generic_func("conv2d_strategy") def conv2d_strategy(attrs, inputs, out_type, target): """conv2d generic strategy""" @@ -189,19 +213,22 @@ def conv2d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_conv2d_nchw), - name="conv2d_nchw.generic") + name="conv2d_nchw.generic", + ) elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_conv2d_nhwc), - name="conv2d_nhwc.generic") + name="conv2d_nhwc.generic", + ) elif layout == "HWCN": assert kernel_layout == "HWIO" strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), - name="conv2d_hwcn.generic") + name="conv2d_hwcn.generic", + ) else: raise RuntimeError("Unsupported conv2d layout {}".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -210,26 +237,30 @@ def conv2d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.generic") + name="depthwise_conv2d_nchw.generic", + ) elif layout == "NHWC": assert kernel_layout == "HWOI" strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + name="depthwise_conv2d_nhwc.generic", + ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) - else: # group_conv2d - if layout == 'NCHW': + else: # group_conv2d + if layout == "NCHW": assert kernel_layout == "OIHW" strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.generic") + name="group_conv2d_nchw.generic", + ) else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy + # conv2d_NCHWc @override_native_generic_func("conv2d_NCHWc_strategy") def conv2d_NCHWc_strategy(attrs, inputs, out_type, target): @@ -240,14 +271,17 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True), wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8), - name="conv2d_NCHWc_int8.generic") + name="conv2d_NCHWc_int8.generic", + ) else: strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True), wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc), - name="conv2d_NCHWc.generic") + name="conv2d_NCHWc.generic", + ) return strategy + # depthwise_conv2d_NCHWc @override_native_generic_func("depthwise_conv2d_NCHWc_strategy") def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target): @@ -257,21 +291,25 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc), - name="depthwise_conv2d_NCHWc.generic") + name="depthwise_conv2d_NCHWc.generic", + ) return strategy + # conv2d_winograd_without_weight_transform @override_native_generic_func("conv2d_winograd_without_weight_transform_strategy") def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target): """conv2d_winograd_without_weight_transfrom generic strategy""" raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform") + # conv2d_gemm_without_weight_transform @override_native_generic_func("conv2d_gemm_without_weight_transform_strategy") def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target): """conv2d_gemm_without_weight_transfrom generic strategy""" raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform") + # conv2d_winograd_weight_transform @generic_func def schedule_conv2d_winograd_weight_transform(attrs, outs, target): @@ -279,6 +317,7 @@ def schedule_conv2d_winograd_weight_transform(attrs, outs, target): with target: return topi.generic.schedule_conv2d_winograd_weight_transform(outs) + # conv2d_winograd_nnpack_weight_transform @generic_func def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): @@ -286,6 +325,7 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): with target: return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) + # conv2d_gemm_weight_transform @generic_func def schedule_conv2d_gemm_weight_transform(attrs, outs, target): @@ -293,9 +333,11 @@ def schedule_conv2d_gemm_weight_transform(attrs, outs, target): with target: return topi.generic.schedule_conv2d_gemm_weight_transform(outs) + # deformable_conv2d def wrap_compute_deformable_conv2d(topi_compute): """wrap deformable_conv2d topi compute""" + def _compute_deformable_conv2d(attrs, inputs, out_dtype): assert attrs.data_layout == "NCHW" padding = get_const_tuple(attrs.padding) @@ -305,11 +347,22 @@ def _compute_deformable_conv2d(attrs, inputs, out_dtype): groups = attrs.groups out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype - out = topi_compute(inputs[0], inputs[1], inputs[2], strides, padding, - dilation, deformable_groups, groups, out_dtype) + out = topi_compute( + inputs[0], + inputs[1], + inputs[2], + strides, + padding, + dilation, + deformable_groups, + groups, + out_dtype, + ) return [out] + return _compute_deformable_conv2d + @override_native_generic_func("deformable_conv2d_strategy") def deformable_conv2d_strategy(attrs, inputs, out_type, target): """deformable_conv2d generic strategy""" @@ -320,25 +373,28 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw), - name="deformable_conv2d.generic") + name="deformable_conv2d.generic", + ) return strategy + # conv2d_transpose def wrap_compute_conv2d_transpose(topi_compute): """wrap conv2d_transpose topi compute""" + def compute_conv2d_transpose(attrs, inputs, out_dtype): """Compute definition of conv2d_transpose""" padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") - else out_dtype) + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype output_padding = get_const_tuple(attrs.output_padding) - out = topi_compute( - inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) return [out] + return compute_conv2d_transpose + @override_native_generic_func("conv2d_transpose_strategy") def conv2d_transpose_strategy(attrs, inputs, out_type, target): """conv2d_transpose generic strategy""" @@ -353,24 +409,25 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw), wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw), - name="conv2d_transpose_nchw.generic") + name="conv2d_transpose_nchw.generic", + ) return strategy # conv3d_transpose def wrap_compute_conv3d_transpose(topi_compute): """wrap conv3d_transpose topi compute""" + def compute_conv3d_transpose(attrs, inputs, out_dtype): """Compute definition of conv3d_transpose""" padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) output_padding = get_const_tuple(attrs.output_padding) out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") - else out_dtype) - out = topi_compute( - inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) return [out] + return compute_conv3d_transpose @@ -388,12 +445,15 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw), wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw), - name="conv3d_transpose_ncdhw.generic") + name="conv3d_transpose_ncdhw.generic", + ) return strategy + # conv3d def wrap_compute_conv3d(topi_compute, need_layout=False): """wrap conv3d topi compute""" + def _compute_conv3d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) @@ -401,8 +461,7 @@ def _compute_conv3d(attrs, inputs, out_type): groups = attrs.groups layout = attrs.data_layout out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") - else out_dtype) + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype (dilation_d, dilation_h, dilation_w) = dilation if dilation_d < 1 or dilation_h < 1 or dilation_w < 1: @@ -410,14 +469,14 @@ def _compute_conv3d(attrs, inputs, out_type): if groups != 1: raise ValueError("Not support arbitrary group number for conv3d") if need_layout: - out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, - layout, out_dtype) + out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype) else: - out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, - out_dtype) + out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype) return [out] + return _compute_conv3d + @override_native_generic_func("conv3d_strategy") def conv3d_strategy(attrs, inputs, out_type, target): """conv3d generic strategy""" @@ -428,22 +487,26 @@ def conv3d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv3d(topi.nn.conv3d_ncdhw), wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw), - name="conv3d_ncdhw.generic") + name="conv3d_ncdhw.generic", + ) elif layout == "NDHWC": strategy.add_implementation( wrap_compute_conv3d(topi.nn.conv3d_ndhwc), wrap_topi_schedule(topi.generic.schedule_conv3d_ndhwc), - name="conv3d_ndhwc.generic") + name="conv3d_ndhwc.generic", + ) else: raise ValueError("Not support this layout {} yet".format(layout)) return strategy + # conv3d_winograd_without_weight_transform @override_native_generic_func("conv3d_winograd_without_weight_transform_strategy") def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target): """conv3d_winograd_without_weight_transfrom generic strategy""" raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform") + # conv3d_winograd_weight_transform @generic_func def schedule_conv3d_winograd_weight_transform(attrs, outs, target): @@ -451,21 +514,23 @@ def schedule_conv3d_winograd_weight_transform(attrs, outs, target): with target: return topi.generic.schedule_conv3d_winograd_weight_transform(outs) + # conv1d def wrap_compute_conv1d(topi_compute): """wrap conv1d topi compute""" + def _compute_conv1d(attrs, inputs, out_type): """Compute definition of conv1d""" strides = get_const_tuple(attrs.strides) padding = get_const_tuple(attrs.padding) dilation = get_const_tuple(attrs.dilation) out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") - else out_dtype) - return [topi_compute(inputs[0], inputs[1], strides, padding, dilation, - out_dtype)] + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + return [topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype)] + return _compute_conv1d + @override_native_generic_func("conv1d_strategy") def conv1d_strategy(attrs, inputs, out_type, target): """conv1d generic strategy""" @@ -479,29 +544,35 @@ def conv1d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv1d(topi.nn.conv1d_ncw), wrap_topi_schedule(topi.generic.schedule_conv1d_ncw), - name="conv1d_ncw.generic") + name="conv1d_ncw.generic", + ) elif layout == "NWC": strategy.add_implementation( wrap_compute_conv1d(topi.nn.conv1d_nwc), wrap_topi_schedule(topi.generic.schedule_conv1d_nwc), - name="conv1d_nwc.generic") + name="conv1d_nwc.generic", + ) else: raise ValueError("Unsupported conv1d layout {}".format(layout)) return strategy + # conv1d_transpose def wrap_compute_conv1d_transpose(topi_compute): """wrap conv1d_transpose topi compute""" + def _compute_conv1d_tranpsoe(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype) + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype output_padding = get_const_tuple(attrs.output_padding) out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) return [out] + return _compute_conv1d_tranpsoe + @override_native_generic_func("conv1d_transpose_strategy") def conv1d_transpose_strategy(attrs, inputs, out_type, target): """conv1d_transpose generic strategy""" @@ -513,28 +584,31 @@ def conv1d_transpose_strategy(attrs, inputs, out_type, target): assert layout == "NCW", "conv1d_transpose ncw only supported" assert dilation == (1,), "conv1d_transpose dilation is not supported" assert groups == 1, "conv1d_transpose groups == 1 only supported" - strategy.add_implementation(wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw), - wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw), - name="conv1d_transpose_ncw.generic") + strategy.add_implementation( + wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw), + wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw), + name="conv1d_transpose_ncw.generic", + ) return strategy # dilation2d def wrap_compute_dilation2d(topi_compute, need_data_layout=False): """Wrap dilation2d topi compute""" + def _compute_dilation2d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) dilations = get_const_tuple(attrs.dilations) data_layout = attrs.get_str("data_layout") out_dtype = attrs.out_dtype - out_dtype = (inputs[0].dtype if out_dtype in ("same", "") - else out_dtype) + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilations] if need_data_layout: args.append(data_layout) args.append(out_dtype) return [topi_compute(*args)] + return _compute_dilation2d @@ -557,13 +631,15 @@ def dilation2d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_dilation2d(topi.image.dilation2d_nchw), wrap_topi_schedule(topi.generic.schedule_dilation2d_nchw), - name="dilation2d_nchw.generic") + name="dilation2d_nchw.generic", + ) elif layout == "NHWC": assert kernel_layout == "HWI" strategy.add_implementation( wrap_compute_dilation2d(topi.image.dilation2d_nhwc), wrap_topi_schedule(topi.generic.schedule_dilation2d_nhwc), - name="dilation2d_nhwc.generic") + name="dilation2d_nhwc.generic", + ) else: raise RuntimeError("Unsupported dilation2d layout {}".format(layout)) return strategy @@ -572,57 +648,75 @@ def dilation2d_strategy(attrs, inputs, out_type, target): # dense def wrap_compute_dense(topi_compute): """wrap dense topi compute""" + def _compute_dense(attrs, inputs, out_type): """Compute definition of dense""" out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype return [topi_compute(inputs[0], inputs[1], None, out_dtype)] + return _compute_dense + @override_native_generic_func("dense_strategy") def dense_strategy(attrs, inputs, out_type, target): """dense generic strategy""" logger.warning("dense is not optimized for this platform.") strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_dense(topi.nn.dense), - wrap_topi_schedule(topi.generic.schedule_dense), - name="dense.generic") + strategy.add_implementation( + wrap_compute_dense(topi.nn.dense), + wrap_topi_schedule(topi.generic.schedule_dense), + name="dense.generic", + ) return strategy + # batch_matmul def wrap_compute_batch_matmul(topi_compute): """wrap batch_matmul topi compute""" + def _compute_batch_matmul(attrs, inputs, out_type): return [topi_compute(inputs[0], inputs[1], out_type.shape)] + return _compute_batch_matmul + @override_native_generic_func("batch_matmul_strategy") def batch_matmul_strategy(attrs, inputs, out_type, target): """batch_matmul generic strategy""" logger.warning("batch_matmul is not optimized for this platform.") strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_batch_matmul(topi.nn.batch_matmul), - wrap_topi_schedule(topi.generic.schedule_batch_matmul), - name="batch_matmul.generic") + strategy.add_implementation( + wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_topi_schedule(topi.generic.schedule_batch_matmul), + name="batch_matmul.generic", + ) return strategy + # sparse dense def wrap_compute_sparse_dense(topi_compute): """wrap sparse dense topi compute""" + def _compute_sparse_dense(attrs, inputs, out_type): return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])] + return _compute_sparse_dense + @override_native_generic_func("sparse_dense_strategy") def sparse_dense_strategy(attrs, inputs, out_type, target): """sparse dense generic strategy""" logger.warning("sparse dense is not optimized for this platform.") strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense), - wrap_topi_schedule(topi.generic.schedule_sparse_dense), - name="sparse_dense.generic") + strategy.add_implementation( + wrap_compute_sparse_dense(topi.nn.sparse_dense), + wrap_topi_schedule(topi.generic.schedule_sparse_dense), + name="sparse_dense.generic", + ) return strategy + # sparse_transpose @generic_func def schedule_sparse_transpose(attrs, outs, target): @@ -630,28 +724,36 @@ def schedule_sparse_transpose(attrs, outs, target): with target: return topi.generic.schedule_sparse_transpose(outs) + # argsort def wrap_compute_argsort(topi_compute): """Wrap argsort topi compute""" + def _compute_argsort(attrs, inputs, _): axis = get_const_int(attrs.axis) is_ascend = bool(get_const_int(attrs.is_ascend)) dtype = attrs.dtype return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)] + return _compute_argsort + @override_native_generic_func("argsort_strategy") def argsort_strategy(attrs, inputs, out_type, target): """argsort generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_argsort(topi.argsort), - wrap_topi_schedule(topi.generic.schedule_argsort), - name="argsort.generic") + strategy.add_implementation( + wrap_compute_argsort(topi.argsort), + wrap_topi_schedule(topi.generic.schedule_argsort), + name="argsort.generic", + ) return strategy + # topk def wrap_compute_topk(topi_compute): """Wrap topk compute""" + def _compute_topk(attrs, inputs, out_type): if attrs.k is not None: k = attrs.k @@ -664,20 +766,26 @@ def _compute_topk(attrs, inputs, out_type): out = topi_compute(inputs[0], k, axis, ret_type, is_ascend, dtype) out = out if isinstance(out, list) else [out] return out + return _compute_topk + @override_native_generic_func("topk_strategy") def topk_strategy(attrs, inputs, out_type, target): """topk generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_topk(topi.topk), - wrap_topi_schedule(topi.generic.schedule_topk), - name="topk.generic") + strategy.add_implementation( + wrap_compute_topk(topi.topk), + wrap_topi_schedule(topi.generic.schedule_topk), + name="topk.generic", + ) return strategy + # multibox_prior def wrap_compute_multibox_prior(topi_compute): """Wrap multibox_prior compute""" + def _compute_multibox_prior(attrs, inputs, _): """Compute definition of multibox_prior""" sizes = get_float_tuple(attrs.sizes) @@ -686,29 +794,36 @@ def _compute_multibox_prior(attrs, inputs, _): offsets = get_float_tuple(attrs.offsets) clip = bool(get_const_int(attrs.clip)) return [topi_compute(inputs[0], sizes, ratios, steps, offsets, clip)] + return _compute_multibox_prior + @override_native_generic_func("multibox_prior_strategy") def multibox_prior_strategy(attrs, inputs, out_type, target): """multibox_prior generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior), - wrap_topi_schedule(topi.generic.schedule_multibox_prior), - name="multibox_prior.generic") + strategy.add_implementation( + wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior), + wrap_topi_schedule(topi.generic.schedule_multibox_prior), + name="multibox_prior.generic", + ) return strategy + # multibox_transform_loc def wrap_compute_multibox_transform_loc(topi_compute): """Wrap multibox_transform_loc compute""" + def _compute_multibox_transform_loc(attrs, inputs, _): """Compute definition of multibox_detection""" clip = bool(get_const_int(attrs.clip)) threshold = get_const_float(attrs.threshold) variances = get_float_tuple(attrs.variances) - return topi_compute( - inputs[0], inputs[1], inputs[2], clip, threshold, variances) + return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances) + return _compute_multibox_transform_loc + @override_native_generic_func("multibox_transform_loc_strategy") def multibox_transform_loc_strategy(attrs, inputs, out_type, target): """schedule multibox_transform_loc""" @@ -716,31 +831,40 @@ def multibox_transform_loc_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_multibox_transform_loc(topi.vision.ssd.multibox_transform_loc), wrap_topi_schedule(topi.generic.schedule_multibox_transform_loc), - name="multibox_transform_loc.generic") + name="multibox_transform_loc.generic", + ) return strategy + # get_valid_counts def wrap_compute_get_valid_counts(topi_compute): """wrap get_valid_counts topi compute""" + def _compute_get_valid_counts(attrs, inputs, out_type): score_threshold = get_const_float(attrs.score_threshold) id_index = get_const_int(attrs.id_index) score_index = get_const_int(attrs.score_index) return topi_compute(inputs[0], score_threshold, id_index, score_index) + return _compute_get_valid_counts + @override_native_generic_func("get_valid_counts_strategy") def get_valid_counts_strategy(attrs, inputs, out_type, target): """get_valid_counts generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_get_valid_counts(topi.vision.get_valid_counts), - wrap_topi_schedule(topi.generic.schedule_get_valid_counts), - name="get_valid_counts.generic") + strategy.add_implementation( + wrap_compute_get_valid_counts(topi.vision.get_valid_counts), + wrap_topi_schedule(topi.generic.schedule_get_valid_counts), + name="get_valid_counts.generic", + ) return strategy + # non-maximum suppression def wrap_compute_nms(topi_compute): """wrap nms topi compute""" + def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[3] if attrs.max_output_size is not None: @@ -754,44 +878,84 @@ def _compute_nms(attrs, inputs, out_type): id_index = get_const_int(attrs.id_index) invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom)) if return_indices: - return topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, - force_suppress, top_k, coord_start, score_index, id_index, - return_indices, invalid_to_bottom) - return [topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, - force_suppress, top_k, coord_start, score_index, id_index, - return_indices, invalid_to_bottom)] + return topi_compute( + inputs[0], + inputs[1], + inputs[2], + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, + return_indices, + invalid_to_bottom, + ) + return [ + topi_compute( + inputs[0], + inputs[1], + inputs[2], + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, + return_indices, + invalid_to_bottom, + ) + ] + return _compute_nms + @override_native_generic_func("non_max_suppression_strategy") def nms_strategy(attrs, inputs, out_type, target): """nms generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_nms(topi.vision.non_max_suppression), - wrap_topi_schedule(topi.generic.schedule_nms), - name="nms.generic") + strategy.add_implementation( + wrap_compute_nms(topi.vision.non_max_suppression), + wrap_topi_schedule(topi.generic.schedule_nms), + name="nms.generic", + ) return strategy + # roi_align def wrap_compute_roi_align(topi_compute): """wrap roi_align topi compute""" + def _compute_roi_align(attrs, inputs, out_type): assert attrs.layout == "NCHW" pooled_size = get_const_tuple(attrs.pooled_size) - return [topi_compute(inputs[0], inputs[1], - pooled_size=pooled_size, - spatial_scale=attrs.spatial_scale, - sample_ratio=attrs.sample_ratio)] + return [ + topi_compute( + inputs[0], + inputs[1], + pooled_size=pooled_size, + spatial_scale=attrs.spatial_scale, + sample_ratio=attrs.sample_ratio, + ) + ] + return _compute_roi_align + @override_native_generic_func("roi_align_strategy") def roi_align_strategy(attrs, inputs, out_type, target): """roi_align generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw), - wrap_topi_schedule(topi.generic.schedule_roi_align), - name="roi_align.generic") + strategy.add_implementation( + wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw), + wrap_topi_schedule(topi.generic.schedule_roi_align), + name="roi_align.generic", + ) return strategy + # roi_pool @generic_func def schedule_roi_pool(attrs, outs, target): @@ -799,9 +963,11 @@ def schedule_roi_pool(attrs, outs, target): with target: return topi.generic.schedule_roi_pool(outs) + # proposal def wrap_compute_proposal(topi_compute): """wrap proposal topi compute""" + def _compute_proposal(attrs, inputs, out_type): scales = get_float_tuple(attrs.scales) ratios = get_float_tuple(attrs.ratios) @@ -811,20 +977,37 @@ def _compute_proposal(attrs, inputs, out_type): rpn_post_nms_top_n = attrs.rpn_post_nms_top_n rpn_min_size = attrs.rpn_min_size iou_loss = bool(get_const_int(attrs.iou_loss)) - return [topi_compute(inputs[0], inputs[1], inputs[2], scales, ratios, - feature_stride, threshold, rpn_pre_nms_top_n, - rpn_post_nms_top_n, rpn_min_size, iou_loss)] + return [ + topi_compute( + inputs[0], + inputs[1], + inputs[2], + scales, + ratios, + feature_stride, + threshold, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_min_size, + iou_loss, + ) + ] + return _compute_proposal + @override_native_generic_func("proposal_strategy") def proposal_strategy(attrs, inputs, out_type, target): """proposal generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_proposal(topi.vision.rcnn.proposal), - wrap_topi_schedule(topi.generic.schedule_proposal), - name="proposal.generic") + strategy.add_implementation( + wrap_compute_proposal(topi.vision.rcnn.proposal), + wrap_topi_schedule(topi.generic.schedule_proposal), + name="proposal.generic", + ) return strategy + # argwhere @generic_func def schedule_argwhere(attrs, outs, target): @@ -832,6 +1015,7 @@ def schedule_argwhere(attrs, outs, target): with target: return topi.generic.schedule_argwhere(outs) + # scatter @generic_func def schedule_scatter(attrs, outs, target): @@ -839,6 +1023,7 @@ def schedule_scatter(attrs, outs, target): with target: return topi.generic.schedule_scatter(outs) + # scatter_add @generic_func def schedule_scatter_add(attrs, outs, target): @@ -846,9 +1031,11 @@ def schedule_scatter_add(attrs, outs, target): with target: return topi.generic.schedule_scatter_add(outs) + # bitserial_conv2d def wrap_compute_bitserial_conv2d(topi_compute): """wrap bitserial_conv2d topi compute""" + def compute_bitserial_conv2d(attrs, inputs, out_dtype): """Compute definition for bitserial conv2d.""" padding = get_const_tuple(attrs.padding) @@ -858,10 +1045,23 @@ def compute_bitserial_conv2d(attrs, inputs, out_dtype): pack_dtype = attrs.pack_dtype out_dtype = attrs.out_dtype unipolar = attrs.unipolar - return [topi_compute(inputs[0], inputs[1], strides, padding, activation_bits, - weight_bits, pack_dtype, out_dtype, unipolar)] + return [ + topi_compute( + inputs[0], + inputs[1], + strides, + padding, + activation_bits, + weight_bits, + pack_dtype, + out_dtype, + unipolar, + ) + ] + return compute_bitserial_conv2d + @override_native_generic_func("bitserial_conv2d_strategy") def bitserial_conv2d_strategy(attrs, inputs, out_type, target): """bitserial_conv2d generic strategy""" @@ -872,19 +1072,23 @@ def bitserial_conv2d_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nchw), - name="bitserial_conv2d_nchw.generic") + name="bitserial_conv2d_nchw.generic", + ) elif layout == "NHWC": strategy.add_implementation( wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nhwc), - name="bitserial_conv2d_nhwc.generic") + name="bitserial_conv2d_nhwc.generic", + ) else: raise ValueError("Data layout {} not supported.".format(layout)) return strategy + # bitserial_dense def wrap_compute_bitserial_dense(topi_compute): """wrap bitserial_dense topi compute""" + def compute_bitserial_dense(attrs, inputs, out_type): """Compute definition of bitserial dense""" data_bits = attrs.data_bits @@ -893,10 +1097,15 @@ def compute_bitserial_dense(attrs, inputs, out_type): out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype unipolar = attrs.unipolar - return [topi_compute(inputs[0], inputs[1], data_bits, weight_bits, - pack_dtype, out_dtype, unipolar)] + return [ + topi_compute( + inputs[0], inputs[1], data_bits, weight_bits, pack_dtype, out_dtype, unipolar + ) + ] + return compute_bitserial_dense + @override_native_generic_func("bitserial_dense_strategy") def bitserial_dense_strategy(attrs, inputs, out_type, target): """bitserial_dense generic strategy""" @@ -905,12 +1114,15 @@ def bitserial_dense_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_bitserial_dense(topi.nn.bitserial_dense), wrap_topi_schedule(topi.generic.schedule_bitserial_dense), - name="bitserial_dense.generic") + name="bitserial_dense.generic", + ) return strategy + # correlation def wrap_compute_correlation(topi_compute): """wrap correlation topi compute""" + def _compute_correlation(attrs, inputs, out_type): kernel_size = attrs.kernel_size max_displacement = attrs.max_displacement @@ -918,10 +1130,22 @@ def _compute_correlation(attrs, inputs, out_type): stride2 = attrs.stride2 padding = get_const_tuple(attrs.padding) is_multiply = attrs.is_multiply - return [topi_compute(inputs[0], inputs[1], kernel_size, max_displacement, stride1, stride2, - padding, is_multiply)] + return [ + topi_compute( + inputs[0], + inputs[1], + kernel_size, + max_displacement, + stride1, + stride2, + padding, + is_multiply, + ) + ] + return _compute_correlation + @override_native_generic_func("correlation_strategy") def correlation_strategy(attrs, inputs, out_type, target): """correlation generic strategy""" @@ -932,5 +1156,6 @@ def correlation_strategy(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_correlation(topi.nn.correlation_nchw), wrap_topi_schedule(topi.generic.schedule_correlation_nchw), - name="correlation.generic") + name="correlation.generic", + ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 7a10bff403f6..38c2c4dc1aea 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -25,41 +25,47 @@ from .generic import * from .. import op as _op -logger = logging.getLogger('strategy') +logger = logging.getLogger("strategy") _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") _OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") + @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): """schedule injective ops for x86""" with target: return topi.x86.schedule_injective(outs) + @schedule_reduce.register("cpu") def schedule_reduce_cpu(attrs, outs, target): """schedule reduction ops for x86""" with target: return topi.x86.schedule_reduce(outs) + @schedule_concatenate.register("cpu") def schedule_concatenate_cpu(attrs, outs, target): """schedule concatenate op for x86""" with target: return topi.x86.schedule_concatenate(outs) + @schedule_pool.register("cpu") def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" with target: return topi.x86.schedule_pool(outs, attrs.layout) + @schedule_adaptive_pool.register("cpu") def schedule_adaptive_pool_cpu(attrs, outs, target): """schedule adaptive pooling ops for x86""" with target: return topi.x86.schedule_adaptive_pool(outs) + @softmax_strategy.register("cpu") def softmax_strategy_cpu(attrs, inputs, out_type, target): """softmax x86 strategy""" @@ -67,15 +73,18 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), wrap_topi_schedule(topi.x86.schedule_softmax), - name="softmax.x86") + name="softmax.x86", + ) return strategy + @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" with target: return topi.x86.schedule_softmax(outs) + @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): """conv2d x86 strategy""" @@ -95,14 +104,16 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_int8), - name="conv2d_nchw_int8.x86") + name="conv2d_nchw_int8.x86", + ) else: strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), - name="conv2d_nchw.x86") - elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc - assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + name="conv2d_nchw.x86", + ) + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" @@ -110,14 +121,16 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), - name="conv2d_nhwc.x86") + name="conv2d_nhwc.x86", + ) elif layout == "HWCN": assert kernel_layout == "HWIO" logger.warning("conv2d HWCN layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), - name="conv2d_hwcn.generic") + name="conv2d_hwcn.generic", + ) else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -128,16 +141,20 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.x86") + name="depthwise_conv2d_nchw.x86", + ) else: - logger.warning("For x86 target, depthwise_conv2d with channel " - "multiplier greater than 1 is not optimized") + logger.warning( + "For x86 target, depthwise_conv2d with channel " + "multiplier greater than 1 is not optimized" + ) strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.generic") - elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc - assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + name="depthwise_conv2d_nchw.generic", + ) + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWOI" @@ -145,21 +162,24 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + name="depthwise_conv2d_nhwc.generic", + ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) - else: # group_conv2d - if layout == 'NCHW': + else: # group_conv2d + if layout == "NCHW": assert kernel_layout == "OIHW" logger.warning("group_conv2d is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.generic") + name="group_conv2d_nchw.generic", + ) else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy + @conv2d_NCHWc_strategy.register("cpu") def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target): """conv2d_NCHWc x86 strategy""" @@ -169,14 +189,17 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True), wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8), - name="conv2d_NCHWc_int8.x86") + name="conv2d_NCHWc_int8.x86", + ) else: strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True), wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc), - name="conv2d_NCHWc.x86") + name="conv2d_NCHWc.x86", + ) return strategy + @depthwise_conv2d_NCHWc_strategy.register("cpu") def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target): """depthwise_conv2d x86 strategy""" @@ -184,9 +207,11 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True), wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc), - name="depthwise_conv2d_NCHWc.x86") + name="depthwise_conv2d_NCHWc.x86", + ) return strategy + @conv2d_transpose_strategy.register("cpu") def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target): """conv2d_transpose x86 strategy""" @@ -200,7 +225,8 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw), - name="conv2d_transpose_nchw.x86") + name="conv2d_transpose_nchw.x86", + ) return strategy @@ -217,7 +243,8 @@ def conv3d_transpose_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv3d_transpose(topi.x86.conv3d_transpose_ncdhw), wrap_topi_schedule(topi.x86.schedule_conv3d_transpose_ncdhw), - name="conv3d_transpose_ncdhw.x86") + name="conv3d_transpose_ncdhw.x86", + ) return strategy @@ -227,17 +254,22 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() layout = attrs.data_layout if layout == "NCDHW": - strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ncdhw), - wrap_topi_schedule(topi.x86.schedule_conv3d_ncdhw), - name="conv3d_ncdhw.x86") + strategy.add_implementation( + wrap_compute_conv3d(topi.x86.conv3d_ncdhw), + wrap_topi_schedule(topi.x86.schedule_conv3d_ncdhw), + name="conv3d_ncdhw.x86", + ) elif layout == "NDHWC": - strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), - wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), - name="conv3d_ndhwc.x86") + strategy.add_implementation( + wrap_compute_conv3d(topi.x86.conv3d_ndhwc), + wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), + name="conv3d_ndhwc.x86", + ) else: raise ValueError("Not support this layout {} yet".format(layout)) return strategy + @conv1d_strategy.register("cpu") def conv1d_strategy_cpu(attrs, inputs, out_type, target): """conv1d x86 strategy""" @@ -247,17 +279,22 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): raise ValueError("dilation should be a positive value") strategy = _op.OpStrategy() if layout == "NCW": - strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_ncw), - wrap_topi_schedule(topi.x86.schedule_conv1d_ncw), - name="conv1d_ncw.x86") + strategy.add_implementation( + wrap_compute_conv1d(topi.nn.conv1d_ncw), + wrap_topi_schedule(topi.x86.schedule_conv1d_ncw), + name="conv1d_ncw.x86", + ) elif layout == "NWC": - strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc), - wrap_topi_schedule(topi.x86.schedule_conv1d_nwc), - name="conv1d_nwc.x86") + strategy.add_implementation( + wrap_compute_conv1d(topi.nn.conv1d_nwc), + wrap_topi_schedule(topi.x86.schedule_conv1d_nwc), + name="conv1d_nwc.x86", + ) else: raise ValueError("Unsupported conv1d layout {}".format(layout)) return strategy + @dense_strategy.register("cpu") def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" @@ -266,10 +303,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype dtype = inputs[0].dtype u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" - strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack), - wrap_topi_schedule(topi.x86.schedule_dense_nopack), - name="dense_nopack.x86", - plevel=10) + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_nopack), + wrap_topi_schedule(topi.x86.schedule_dense_nopack), + name="dense_nopack.x86", + plevel=10, + ) if "cblas" in target.libs: with SpecializedCondition(same_type and dtype in ["float32", "float64"]): strategy.add_implementation( @@ -296,41 +335,53 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): ) with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. - strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), - wrap_topi_schedule(topi.x86.schedule_dense_pack), - name="dense_pack.x86", - plevel=5) + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_pack), + wrap_topi_schedule(topi.x86.schedule_dense_pack), + name="dense_pack.x86", + plevel=5, + ) return strategy + @batch_matmul_strategy.register("cpu") def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() if is_dynamic(out_type): - strategy.add_implementation(wrap_compute_batch_matmul(topi.nn.batch_matmul), - wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), - name="batch_matmul.generic", - plevel=10) + strategy.add_implementation( + wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), + name="batch_matmul.generic", + plevel=10, + ) else: - strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), - wrap_topi_schedule(topi.x86.schedule_batch_matmul), - name="batch_matmul.x86", - plevel=10) + strategy.add_implementation( + wrap_compute_batch_matmul(topi.x86.batch_matmul), + wrap_topi_schedule(topi.x86.schedule_batch_matmul), + name="batch_matmul.x86", + plevel=10, + ) if "cblas" in target.libs: - strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), - wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas), - name="batch_matmul_cblas.x86", - plevel=15) + strategy.add_implementation( + wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), + wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas), + name="batch_matmul_cblas.x86", + plevel=15, + ) return strategy + @sparse_dense_strategy.register("cpu") def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): """sparse dense x86 strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense), - wrap_topi_schedule(topi.x86.schedule_sparse_dense), - name="sparse_dense.x86", - plevel=10) + strategy.add_implementation( + wrap_compute_sparse_dense(topi.nn.sparse_dense), + wrap_topi_schedule(topi.x86.schedule_sparse_dense), + name="sparse_dense.x86", + plevel=10, + ) return strategy @@ -338,11 +389,14 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): def roi_align_strategy_cpu(attrs, inputs, out_type, target): """roi_align x86 strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_roi_align(topi.x86.roi_align_nchw), - wrap_topi_schedule(topi.generic.schedule_roi_align), - name="roi_align.x86") + strategy.add_implementation( + wrap_compute_roi_align(topi.x86.roi_align_nchw), + wrap_topi_schedule(topi.generic.schedule_roi_align), + name="roi_align.x86", + ) return strategy + @bitserial_conv2d_strategy.register("cpu") def bitserial_conv2d_strategy_cpu(attrs, inputs, out_type, target): """bitserial_conv2d x86 strategy""" @@ -352,16 +406,19 @@ def bitserial_conv2d_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nchw), - name="bitserial_conv2d_nchw.x86") + name="bitserial_conv2d_nchw.x86", + ) elif layout == "NHWC": strategy.add_implementation( wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nhwc), wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nhwc), - name="bitserial_conv2d_nhwc.x86") + name="bitserial_conv2d_nhwc.x86", + ) else: raise ValueError("Data layout {} not supported.".format(layout)) return strategy + @bitserial_dense_strategy.register("cpu") def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target): """bitserial_dense x86 strategy""" @@ -369,5 +426,6 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_bitserial_dense(topi.x86.bitserial_dense), wrap_topi_schedule(topi.x86.schedule_bitserial_dense), - name="bitserial_dense.x86") + name="bitserial_dense.x86", + ) return strategy diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 1f6a9a293096..bb060b3ad8a7 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -24,6 +24,7 @@ from .. import nn from ..util import traverse_inline, get_const_tuple, get_max_power2_factor + @autotvm.register_topi_compute("batch_matmul.cuda") def batch_matmul(cfg, x, y, out_shape=None): """Compute conv2d with NCHW layout""" @@ -62,14 +63,14 @@ def _schedule(cfg, op): C = s.outputs[0].output(0) b, y, x = s[C].op.axis - k, = s[CC].op.reduce_axis + (k,) = s[CC].op.reduce_axis cfg.define_split("tile_y", y, num_outputs=3) cfg.define_split("tile_x", x, num_outputs=3) cfg.define_split("tile_k", k, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64]) target = tvm.target.Target.current() - if target.kind.name in ['nvptx', 'rocm']: + if target.kind.name in ["nvptx", "rocm"]: # llvm-based backends cannot do non-explicit unrolling cfg.define_knob("unroll_explicit", [1]) else: @@ -80,10 +81,10 @@ def _schedule(cfg, op): x_bn = get_max_power2_factor(N, 64) y_nthreads = min(y_bn, 8) x_nthreads = min(x_bn, 8) - cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads]) - cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads]) - cfg['tile_k'] = SplitEntity([-1, 8]) - cfg['auto_unroll_max_step'] = OtherOptionEntity(16) + cfg["tile_x"] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads]) + cfg["tile_y"] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads]) + cfg["tile_k"] = SplitEntity([-1, 8]) + cfg["auto_unroll_max_step"] = OtherOptionEntity(16) by, ty, yi = cfg["tile_y"].apply(s, C, y) bx, tx, xi = cfg["tile_x"].apply(s, C, x) @@ -97,15 +98,15 @@ def _schedule(cfg, op): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) - s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) - s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val) + s[C].pragma(yi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[C].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val) s[CC].compute_at(s[C], tx) _, yi, xi = s[CC].op.axis ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, ki, yi, xi) - s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) - s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val) + s[CC].pragma(ki, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[CC].pragma(ki, "unroll_explicit", cfg["unroll_explicit"].val) s[AA].compute_at(s[CC], ko) s[AL].compute_at(s[CC], ki) @@ -117,8 +118,8 @@ def _schedule(cfg, op): s[AA].reorder(ty, tx, yi, ki) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) - s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) - s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val) + s[AA].pragma(yi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[AA].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val) _, x, k = s[BB].op.axis ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1]) @@ -126,8 +127,8 @@ def _schedule(cfg, op): s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].reorder(ty, tx, xi, ki) - s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) - s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val) + s[BB].pragma(xi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[BB].pragma(xi, "unroll_explicit", cfg["unroll_explicit"].val) def _callback(op): if "batch_matmul" in op.tag: @@ -136,6 +137,7 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + def batch_matmul_cublas(x, y): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 6a41504e0972..34a8c6dafc87 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -19,6 +19,7 @@ from tvm import te from ..util import get_const_tuple + def batch_matmul(x, y, oshape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -44,11 +45,11 @@ def batch_matmul(x, y, oshape=None): assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" batch, M, K = x.shape N = y.shape[1] - k = te.reduce_axis((0, K), name='k') + k = te.reduce_axis((0, K), name="k") oshape = (batch, M, N) else: _, _, K = x.shape - k = te.reduce_axis((0, K), name='k') - return te.compute(oshape, - lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), - tag='batch_matmul') + k = te.reduce_axis((0, K), name="k") + return te.compute( + oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + ) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 9f6a8f289c6f..c095dcb0b6bb 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -42,8 +42,7 @@ def batch_matmul(cfg, x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len( - y.shape) == 3, "only support 3-dim batch_matmul" + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) assert XB == YB, "batch dimension doesn't match" @@ -57,11 +56,10 @@ def batch_matmul(cfg, x, y, out_shape=None): if cfg.is_fallback: _default_batch_matmul_config(cfg, M, N, K) - k = te.reduce_axis((0, K), name='k') + k = te.reduce_axis((0, K), name="k") C = te.compute( - (B, M, N), - lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), - tag='batch_matmul') + (B, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + ) return C @@ -112,7 +110,7 @@ def _callback(op): s[O].parallel(bxyo) s[CC].compute_at(s[O], bxyo) - k, = s[CC].op.reduce_axis + (k,) = s[CC].op.reduce_axis ko, ki = cfg["tile_k"].apply(s, CC, k) Crf = s.rfactor(CC, ki) @@ -120,7 +118,7 @@ def _callback(op): _, _, y, x = s[Crf].op.axis s[Crf].fuse(y, x) s[Crf].vectorize(s[Crf].op.axis[0]) - s[O].pragma(bxyo, 'auto_unroll_max_step', 16) + s[O].pragma(bxyo, "auto_unroll_max_step", 16) traverse_inline(s, outs[0].op, _callback) return s @@ -152,8 +150,7 @@ def batch_matmul_cblas(cfg, x, y): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len( - y.shape) == 3, "only support 3-dim batch_matmul" + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) assert XB == YB, "batch dimension doesn't match" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d3074aa3bce3..f60cd612d131 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -50,27 +50,29 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freez input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) - mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset, freeze_params=freeze_params) + mod, params = relay.frontend.from_onnx( + graph_def, shape_dict, opset=opset, freeze_params=freeze_params + ) - ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target) + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data) if isinstance(result, tvm.runtime.NDArray): return result.asnumpy() return [r.asnumpy() for r in result] -def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None): +def get_tvm_output( + graph_def, input_data, target, ctx, output_shape=None, output_dtype="float32", opset=None +): """ Generic function to execute and get tvm output""" - target = 'llvm' + target = "llvm" input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) with tvm.transform.PassContext(opt_level=1): - graph, lib, params = relay.build(mod, - target, - params=params) + graph, lib, params = relay.build(mod, target, params=params) ctx = tvm.cpu(0) m = graph_runtime.create(graph, lib, ctx) @@ -80,13 +82,11 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output # Its possible for some onnx inputs to not be needed in the tvm # module, confirm its present before setting. try: - m.set_input(input_names[i], tvm.nd.array( - input_data[i].astype(input_data[i].dtype))) + m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype))) except: continue else: - m.set_input(input_names, tvm.nd.array( - input_data.astype(input_data.dtype))) + m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype))) m.set_input(**params) # execute @@ -103,9 +103,10 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output return tvm_output.asnumpy() -def get_onnxruntime_output(model, inputs, dtype='float32'): +def get_onnxruntime_output(model, inputs, dtype="float32"): import onnxruntime.backend - rep = onnxruntime.backend.prepare(model, 'CPU') + + rep = onnxruntime.backend.prepare(model, "CPU") if isinstance(inputs, list) and len(inputs) > 1: ort_out = rep.run(inputs) else: @@ -115,7 +116,7 @@ def get_onnxruntime_output(model, inputs, dtype='float32'): def verify_onnx_forward_impl(graph_file, data_shape, out_shape): - dtype = 'float32' + dtype = "float32" x = np.random.uniform(size=data_shape) model = onnx.load_model(graph_file) c2_out = get_onnxruntime_output(model, x, dtype) @@ -123,14 +124,15 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) + def make_constant_node(name, data_type, dims, vals): - return helper.make_node('Constant', - inputs=[], - outputs=[name], - value=helper.make_tensor(name=name, - data_type=data_type, - dims=dims, - vals=vals)) + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=data_type, dims=dims, vals=vals), + ) + @tvm.testing.uses_gpu def test_reshape(): @@ -138,27 +140,31 @@ def test_reshape(): ref_shape = (6, 2, 4, 3) ref_array = np.array(ref_shape) - ref_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['ref_in'], - value=onnx.helper.make_tensor(name='const_tensor', - data_type=onnx.TensorProto.INT32, - dims=ref_array.shape, - vals=ref_array.flatten().astype(int))) + ref_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["ref_in"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT32, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int), + ), + ) reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"]) - graph = helper.make_graph([ref_node, reshape_node], - "reshape_test", - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_shape))]) + graph = helper.make_graph( + [ref_node, reshape_node], + "reshape_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))], + ) - model = helper.make_model(graph, producer_name='reshape_test') + model = helper.make_model(graph, producer_name="reshape_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('int32') - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32") tvm.testing.assert_allclose(ref_shape, tvm_out.shape) @@ -166,35 +172,42 @@ def test_reshape(): # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_expand(): - def _test_expand(name, data, shape, ref_data, dtype="int32"): shape_array = np.array(shape) if dtype == "int32": - shape_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['shape'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT32, - dims = shape_array.shape, - vals = shape_array.flatten().astype('int32'))) + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT32, + dims=shape_array.shape, + vals=shape_array.flatten().astype("int32"), + ), + ) elif dtype == "int64": - shape_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['shape'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT64, - dims = shape_array.shape, - vals = shape_array.flatten().astype('int64'))) + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT64, + dims=shape_array.shape, + vals=shape_array.flatten().astype("int64"), + ), + ) else: raise "Invalid dtype" expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) - graph = helper.make_graph([shape_node, expand_node], - "expand_test", - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(data.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_data.shape))]) + graph = helper.make_graph( + [shape_node, expand_node], + "expand_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_data.shape))], + ) model = helper.make_model(graph, producer_name=name) @@ -207,34 +220,33 @@ def _test_expand(name, data, shape, ref_data, dtype="int32"): shape = (3, 4) data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = np.tile(data, 4) - _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data, "int32") - _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data, "int64") + _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data, "int32") + _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data, "int64") in_shape = (3, 1) shape = (2, 1, 6) data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = data * np.ones(shape, dtype=np.float32) - _test_expand('expand_with_dim_changed_test', data, shape, ref_data, "int32") - _test_expand('expand_with_dim_changed_test', data, shape, ref_data, "int64") + _test_expand("expand_with_dim_changed_test", data, shape, ref_data, "int32") + _test_expand("expand_with_dim_changed_test", data, shape, ref_data, "int64") def verify_depth_to_space(inshape, outshape, mode, blockSize): - node = onnx.helper.make_node('DepthToSpace', - inputs=['x'], - outputs=['y'], - blocksize=blockSize) + node = onnx.helper.make_node("DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blockSize) - graph = helper.make_graph([node], - "depth_to_space_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + graph = helper.make_graph( + [node], + "depth_to_space_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) - model = helper.make_model(graph, producer_name='depth_to_space_test') + model = helper.make_model(graph, producer_name="depth_to_space_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=inshape).astype('float32') - tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32') - onnx_out = get_onnxruntime_output(model, x, 'float32') + x = np.random.uniform(size=inshape).astype("float32") + tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32") + onnx_out = get_onnxruntime_output(model, x, "float32") tvm.testing.assert_allclose(onnx_out, tvm_out) @@ -247,22 +259,21 @@ def test_depth_to_space(): def verify_space_to_depth(inshape, outshape, blockSize): - node = onnx.helper.make_node('SpaceToDepth', - inputs=['x'], - outputs=['y'], - blocksize=blockSize) + node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"], blocksize=blockSize) - graph = helper.make_graph([node], - "space_to_depth_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + graph = helper.make_graph( + [node], + "space_to_depth_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) - model = helper.make_model(graph, producer_name='space_to_depth_test') + model = helper.make_model(graph, producer_name="space_to_depth_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=inshape).astype('float32') - tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32') - onnx_out = get_onnxruntime_output(model, x, 'float32') + x = np.random.uniform(size=inshape).astype("float32") + tvm_out = get_tvm_output(model, x, target, ctx, outshape, "float32") + onnx_out = get_onnxruntime_output(model, x, "float32") tvm.testing.assert_allclose(onnx_out, tvm_out) @@ -277,29 +288,33 @@ def test_shape(): ref_shape = (6, 2, 4, 3) ref_array = np.array(ref_shape) - ref_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['ref_in'], - value=onnx.helper.make_tensor(name='const_tensor', - data_type=onnx.TensorProto.INT32, - dims=ref_array.shape, - vals=ref_array.flatten().astype(int))) + ref_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["ref_in"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT32, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int), + ), + ) reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"]) - shape_node = helper.make_node("Shape", ['out'], ['final_out']) + shape_node = helper.make_node("Shape", ["out"], ["final_out"]) - graph = helper.make_graph([ref_node, reshape_node, shape_node], - "shape_test", - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("final_out", - TensorProto.FLOAT, list(ref_shape))]) + graph = helper.make_graph( + [ref_node, reshape_node, shape_node], + "shape_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("final_out", TensorProto.FLOAT, list(ref_shape))], + ) - model = helper.make_model(graph, producer_name='shape_test') + model = helper.make_model(graph, producer_name="shape_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('int32') - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32') + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "int32") tvm.testing.assert_allclose(ref_shape, tvm_out) @@ -313,18 +328,19 @@ def _test_power_iteration(x_shape, y_shape): np_res = np.power(x, y).astype(np.float32) - res = helper.make_node("Pow", ['x', 'y'], ['out']) + res = helper.make_node("Pow", ["x", "y"], ["out"]) - graph = helper.make_graph([res], - 'power_test', - inputs=[helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("y", - TensorProto.FLOAT, list(y_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(np_res.shape))]) + graph = helper.make_graph( + [res], + "power_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))], + ) - model = helper.make_model(graph, producer_name='power_test') + model = helper.make_model(graph, producer_name="power_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape) @@ -342,20 +358,20 @@ def test_power(): def test_squeeze(): in_shape = (1, 3, 1, 3, 1, 1) out_shape = (3, 3) - y = helper.make_node("Squeeze", ['in'], ['out'], axes=[0, 2, 4, 5]) + y = helper.make_node("Squeeze", ["in"], ["out"], axes=[0, 2, 4, 5]) - graph = helper.make_graph([y], - 'squeeze_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [y], + "squeeze_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name='squeeze_test') + model = helper.make_model(graph, producer_name="squeeze_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('float32') - tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32') + x = np.random.uniform(size=in_shape).astype("float32") + tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_shape, tvm_out.shape) @@ -369,18 +385,18 @@ def test_flatten(): flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis) - graph = helper.make_graph([flatten_node], - "flatten_test", - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_shape))]) + graph = helper.make_graph( + [flatten_node], + "flatten_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))], + ) - model = helper.make_model(graph, producer_name='flatten_test') + model = helper.make_model(graph, producer_name="flatten_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('int32') - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32") tvm.testing.assert_allclose(ref_shape, tvm_out.shape) @@ -390,20 +406,20 @@ def test_unsqueeze(): in_shape = (3, 3) axis = (0, 3, 4) out_shape = (1, 3, 3, 1, 1) - y = helper.make_node("Unsqueeze", ['in'], ['out'], axes=list(axis)) + y = helper.make_node("Unsqueeze", ["in"], ["out"], axes=list(axis)) - graph = helper.make_graph([y], - 'squeeze_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [y], + "squeeze_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name='squeeze_test') + model = helper.make_model(graph, producer_name="squeeze_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('float32') - tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32') + x = np.random.uniform(size=in_shape).astype("float32") + tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_shape, tvm_out.shape) @@ -413,32 +429,32 @@ def verify_gather(in_shape, indices, axis, dtype): indices = np.array(indices, dtype="int32") out_np = np.take(x, indices, axis=axis) - y = helper.make_node("Gather", ['in', 'indices'], ['out'], axis=axis) + y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis) - graph = helper.make_graph([y], - 'gather_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", - TensorProto.INT32, list(indices.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_np.shape))]) - model = helper.make_model(graph, producer_name='gather_test') + graph = helper.make_graph( + [y], + "gather_test", + inputs=[ + helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], + ) + model = helper.make_model(graph, producer_name="gather_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x, indices], target, ctx, out_np.shape) + tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out) @tvm.testing.uses_gpu def test_gather(): - verify_gather((4,), [1], 0, 'int32') - verify_gather((1, 4), [0], 0, 'int32') - verify_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32') - verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32') - verify_gather((3, 3, 3), [[[1, 0]]], -1, 'int32') - verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32') + verify_gather((4,), [1], 0, "int32") + verify_gather((1, 4), [0], 0, "int32") + verify_gather((4,), [[[1, 0], [0, 1]]], 0, "float32") + verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, "int32") + verify_gather((3, 3, 3), [[[1, 0]]], -1, "int32") + verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32") def verify_scatter(in_shape, indices, axis): @@ -446,24 +462,23 @@ def verify_scatter(in_shape, indices, axis): indices = np.array(indices, dtype="int32") updates = np.random.uniform(size=indices.shape).astype("float32") - y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis) - - graph = helper.make_graph([y], - 'scatter_test', - inputs=[helper.make_tensor_value_info("data", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", - TensorProto.INT32, list(indices.shape)), - helper.make_tensor_value_info("updates", - TensorProto.FLOAT, list(indices.shape))], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(in_shape))]) - model = helper.make_model(graph, producer_name='scatter_test') + y = helper.make_node("ScatterElements", ["data", "indices", "updates"], ["output"], axis=axis) + + graph = helper.make_graph( + [y], + "scatter_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="scatter_test") onnx_out = get_onnxruntime_output(model, [x, indices, updates]) for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x, indices, updates], target, ctx, onnx_out[0].shape) + tvm_out = get_tvm_output(model, [x, indices, updates], target, ctx, onnx_out[0].shape) tvm.testing.assert_allclose(onnx_out[0], tvm_out) @@ -479,75 +494,82 @@ def test_scatter(): def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): if axes: - y = helper.make_node( - "Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends) + y = helper.make_node("Slice", ["in"], ["out"], axes=axes, starts=starts, ends=ends) else: - y = helper.make_node( - "Slice", ['in'], ['out'], starts=starts, ends=ends) + y = helper.make_node("Slice", ["in"], ["out"], starts=starts, ends=ends) - graph = helper.make_graph([y], - 'slice_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(outdata.shape))]) + graph = helper.make_graph( + [y], + "slice_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='slice_test') + model = helper.make_model(graph, producer_name="slice_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, 'float32', opset=1) + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=1) tvm.testing.assert_allclose(outdata, tvm_out) + def _test_slice_iteration_v10(indata, outdata, **attrs): - starts = attrs['starts'] - ends = attrs['ends'] - axes = None if 'axes' not in attrs else attrs['axes'] + starts = attrs["starts"] + ends = attrs["ends"] + axes = None if "axes" not in attrs else attrs["axes"] starts = np.asarray(starts) ends = np.asarray(ends) inputs = [ - helper.make_tensor_value_info("data", TensorProto.FLOAT, - list(indata.shape)), - helper.make_tensor_value_info("starts", TensorProto.INT64, - list(starts.shape)), - helper.make_tensor_value_info("ends", TensorProto.INT64, - list(ends.shape)) + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("starts", TensorProto.INT64, list(starts.shape)), + helper.make_tensor_value_info("ends", TensorProto.INT64, list(ends.shape)), ] initializer = [ - helper.make_tensor("starts", TensorProto.INT64, list(starts.shape), - starts), - helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends) + helper.make_tensor("starts", TensorProto.INT64, list(starts.shape), starts), + helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends), ] nodes = [] - if 'add_noop_to_input_attrs' in attrs: + if "add_noop_to_input_attrs" in attrs: + def add_noop_to_input_attr(attr_name, attr): - output_name = attr_name+"_output" + output_name = attr_name + "_output" ref_shape = list(np.array(attr).shape) ref_shape.insert(0, 1) ref_shape = tuple(ref_shape) ref_array = np.array(ref_shape) - ref_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['ref_in_'+attr_name], - value=onnx.helper.make_tensor(name='const_tensor__1_'+attr_name, - data_type=onnx.TensorProto.INT64, - dims=ref_array.shape, - vals=ref_array.flatten().astype(int))) + ref_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["ref_in_" + attr_name], + value=onnx.helper.make_tensor( + name="const_tensor__1_" + attr_name, + data_type=onnx.TensorProto.INT64, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int), + ), + ) in_shape = np.array(attr).shape in_array = np.array(in_shape) - ref_node2 = onnx.helper.make_node('Constant', - inputs=[], - outputs=['input_shape_'+attr_name], - value=onnx.helper.make_tensor(name='const_tensor__2_'+attr_name, - data_type=onnx.TensorProto.INT64, - dims=in_array.shape, - vals=in_array.flatten().astype(int))) - - reshape1_node = helper.make_node("Reshape", [attr_name, "ref_in_"+attr_name], ["reshape_"+attr_name]) - reshape2_node = helper.make_node("Reshape", ["reshape_"+attr_name, "input_shape_"+attr_name], [output_name]) + ref_node2 = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["input_shape_" + attr_name], + value=onnx.helper.make_tensor( + name="const_tensor__2_" + attr_name, + data_type=onnx.TensorProto.INT64, + dims=in_array.shape, + vals=in_array.flatten().astype(int), + ), + ) + + reshape1_node = helper.make_node( + "Reshape", [attr_name, "ref_in_" + attr_name], ["reshape_" + attr_name] + ) + reshape2_node = helper.make_node( + "Reshape", ["reshape_" + attr_name, "input_shape_" + attr_name], [output_name] + ) return [ref_node, ref_node2, reshape1_node, reshape2_node] slice_inputs = [] @@ -562,34 +584,22 @@ def add_noop_to_input_attr(attr_name, attr): if axes: axes = np.asarray(axes) - inputs.append( - helper.make_tensor_value_info("axes", TensorProto.INT32, - list(axes.shape))) - initializer.append( - helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), - axes)) + inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT32, list(axes.shape))) + initializer.append(helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), axes)) y = helper.make_node("Slice", ["data", *slice_inputs], ["out"]) nodes.append(y) - graph = helper.make_graph(nodes, - 'slice_test', - inputs=inputs, - outputs=[ - helper.make_tensor_value_info( - "out", TensorProto.FLOAT, - list(outdata.shape)) - ], - initializer=initializer) - model = helper.make_model(graph, producer_name='slice_test') + graph = helper.make_graph( + nodes, + "slice_test", + inputs=inputs, + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + initializer=initializer, + ) + model = helper.make_model(graph, producer_name="slice_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, - indata, - target, - ctx, - outdata.shape, - 'float32', - opset=10) + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10) tvm.testing.assert_allclose(outdata, tvm_out) @@ -605,123 +615,158 @@ def test_slice(): _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4)) _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,)) _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,)) - _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["starts"]) - _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["ends"]) - _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["axes"]) - _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,), add_noop_to_input_attrs=["starts", "ends"]) - _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["ends", "axes"]) - _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["starts", "axes"]) - _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["starts", "ends", "axes"]) + _test_slice_iteration_v10( + x, + x[0:3, 0:10], + starts=(0, 0), + ends=(3, 10), + axes=(0, 1), + add_noop_to_input_attrs=["starts"], + ) + _test_slice_iteration_v10( + x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["ends"] + ) + _test_slice_iteration_v10( + x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["axes"] + ) + _test_slice_iteration_v10( + x, + x[:, 0:-1], + starts=(0,), + ends=(-1,), + axes=(1,), + add_noop_to_input_attrs=["starts", "ends"], + ) + _test_slice_iteration_v10( + x, + x[0:3, 0:10], + starts=(0, 0), + ends=(3, 10), + axes=(0, 1), + add_noop_to_input_attrs=["ends", "axes"], + ) + _test_slice_iteration_v10( + x, + x[:, :, 3:4], + starts=(0, 0, 3), + ends=(20, 10, 4), + add_noop_to_input_attrs=["starts", "axes"], + ) + _test_slice_iteration_v10( + x, + x[:, 1:1000], + starts=(1,), + ends=(1000,), + axes=(1,), + add_noop_to_input_attrs=["starts", "ends", "axes"], + ) x = np.random.randn(1, 1, 1, 128).astype(np.float32) - _test_slice_iteration_v10(x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3)) + _test_slice_iteration_v10( + x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3) + ) def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): indata = np.random.uniform(-1, 1, size=inshape).astype(dtype) outdata = outfunc(indata, **npargs) - y = helper.make_node(opname, ['in'], ['out'], **kwargs) + y = helper.make_node(opname, ["in"], ["out"], **kwargs) - graph = helper.make_graph([y], - opname+'_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(outdata.shape))]) + graph = helper.make_graph( + [y], + opname + "_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name=opname+'_test') + model = helper.make_model(graph, producer_name=opname + "_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, dtype) + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) tvm.testing.assert_allclose(outdata, tvm_out) @tvm.testing.uses_gpu def test_floor(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, - {}, 'float32', 'Floor', {}) + _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, "float32", "Floor", {}) @tvm.testing.uses_gpu def test_ceil(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {}) + _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, "float32", "Ceil", {}) @tvm.testing.uses_gpu def test_clip(): - _test_onnx_op_elementwise((2, 4, 5, 6), - np.clip, - {'a_min': -1.0, 'a_max': 1.0}, - 'float32', - 'Clip', - {'min': -1.0, 'max': 1.0}) + _test_onnx_op_elementwise( + (2, 4, 5, 6), + np.clip, + {"a_min": -1.0, "a_max": 1.0}, + "float32", + "Clip", + {"min": -1.0, "max": 1.0}, + ) @tvm.testing.uses_gpu def test_clip_min_max_as_inputs(): - input_shape=(2,4,5,6) + input_shape = (2, 4, 5, 6) nodes = [ - make_constant_node('min', onnx.TensorProto.FLOAT, (), [0.]), - make_constant_node('max', onnx.TensorProto.FLOAT, (), [6.]), + make_constant_node("min", onnx.TensorProto.FLOAT, (), [0.0]), + make_constant_node("max", onnx.TensorProto.FLOAT, (), [6.0]), ] - input_names = ['in', 'min', 'max'] - nodes.append(helper.make_node( - 'Clip', - inputs=input_names, - outputs=['out'])) - graph = helper.make_graph(nodes, - "clip_test", - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(input_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(input_shape))]) - model = helper.make_model(graph, producer_name='clip_test') - - indata = np.random.uniform(-1, 7, size=input_shape).astype('float32') - onnx_out = get_onnxruntime_output(model, indata, 'float32') + input_names = ["in", "min", "max"] + nodes.append(helper.make_node("Clip", inputs=input_names, outputs=["out"])) + graph = helper.make_graph( + nodes, + "clip_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(input_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_shape))], + ) + model = helper.make_model(graph, producer_name="clip_test") + + indata = np.random.uniform(-1, 7, size=input_shape).astype("float32") + onnx_out = get_onnxruntime_output(model, indata, "float32") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, indata, target, ctx, input_shape, 'float32') + tvm_out = get_tvm_output(model, indata, target, ctx, input_shape, "float32") tvm.testing.assert_allclose(onnx_out, tvm_out) @tvm.testing.uses_gpu def test_round(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, 'float32', 'Round', {}) + _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, "float32", "Round", {}) def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): indata = np.random.choice(a=[np.nan, np.inf, -np.inf, 0.5, 1.0, 0], size=inshape).astype(dtype) outdata = outfunc(indata, **npargs) - y = helper.make_node(opname, ['in'], ['out'], **kwargs) + y = helper.make_node(opname, ["in"], ["out"], **kwargs) - graph = helper.make_graph([y], - opname+'_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.BOOL, list(outdata.shape))]) + graph = helper.make_graph( + [y], + opname + "_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name=opname+'_test') + model = helper.make_model(graph, producer_name=opname + "_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, dtype) + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) tvm.testing.assert_allclose(outdata, tvm_out) @tvm.testing.uses_gpu def test_isinf(): - _test_finite_ops((2, 4, 5, 6), np.isinf, {}, 'float32', 'IsInf', {}) + _test_finite_ops((2, 4, 5, 6), np.isinf, {}, "float32", "IsInf", {}) @tvm.testing.uses_gpu def test_isnan(): - _test_finite_ops((2, 4, 5, 6), np.isnan, {}, 'float32', 'IsNaN', {}) + _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {}) def verify_gather_nd(in_shape, indices, dtype): @@ -729,59 +774,59 @@ def verify_gather_nd(in_shape, indices, dtype): indices = np.array(indices, dtype="int32") out_np = tvm.topi.testing.gather_nd_python(x, indices) - y = helper.make_node("GatherND", ['in', 'indices'], ['out']) + y = helper.make_node("GatherND", ["in", "indices"], ["out"]) - graph = helper.make_graph([y], - 'gather_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", - TensorProto.INT32, list(indices.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_np.shape))]) - model = helper.make_model(graph, producer_name='gather_test') + graph = helper.make_graph( + [y], + "gather_test", + inputs=[ + helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], + ) + model = helper.make_model(graph, producer_name="gather_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x, indices], target, ctx, out_np.shape) + tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out) @tvm.testing.uses_gpu def test_gather_nd(): - verify_gather_nd((2, 2), [[0,0],[1,1]], 'int32') - verify_gather_nd((3, 3, 3), [[0,1],[1,0]] , 'float32') - verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32') + verify_gather_nd((2, 2), [[0, 0], [1, 1]], "int32") + verify_gather_nd((3, 3, 3), [[0, 1], [1, 0]], "float32") + verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32") # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_onehot(): indices_shape = [10] - indices_array = np.random.randint( - low=0, high=9, size=indices_shape, dtype='int32') + indices_array = np.random.randint(low=0, high=9, size=indices_shape, dtype="int32") depth = 10 values = np.asarray([0, 1]).astype("int32") out_np = np.eye(depth)[indices_array.reshape(-1)] - onehot_node = helper.make_node( - "OneHot", ["indices", "depth", "values"], ["out"]) + onehot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["out"]) - graph = helper.make_graph([onehot_node], - "onehot_test", - inputs=[helper.make_tensor_value_info("indices", - TensorProto.INT32, indices_shape), - helper.make_tensor_value_info("depth", - TensorProto.INT32, [1]), - helper.make_tensor_value_info("values", - TensorProto.INT32, values.shape)], - outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)]) + graph = helper.make_graph( + [onehot_node], + "onehot_test", + inputs=[ + helper.make_tensor_value_info("indices", TensorProto.INT32, indices_shape), + helper.make_tensor_value_info("depth", TensorProto.INT32, [1]), + helper.make_tensor_value_info("values", TensorProto.INT32, values.shape), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)], + ) model = helper.make_model(graph, producer_name="onehot_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm( - model, [indices_array, np.array([depth]).astype("int32"), values], target, ctx) + model, [indices_array, np.array([depth]).astype("int32"), values], target, ctx + ) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -790,51 +835,53 @@ def test_matmul(): a_shape = (4, 3) b_shape = (3, 4) - a_array = np.random.uniform(size=a_shape).astype('float32') - b_array = np.random.uniform(size=b_shape).astype('float32') + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") out_np = np.matmul(a_array, b_array) mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) - graph = helper.make_graph([mul_node], - "matmul_test", - inputs=[helper.make_tensor_value_info("a", - TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", - TensorProto.FLOAT, list(b_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_np.shape))]) + graph = helper.make_graph( + [mul_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], + ) - model = helper.make_model(graph, producer_name='matmul_test') + model = helper.make_model(graph, producer_name="matmul_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_array, b_array], target, ctx, out_np.shape) + tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + def verify_batch_matmul(a_shape, b_shape): - a_array = np.random.uniform(size=a_shape).astype('float32') - b_array = np.random.uniform(size=b_shape).astype('float32') + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") out_np = np.matmul(a_array, b_array) mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) - graph = helper.make_graph([mul_node], - "matmul_test", - inputs=[helper.make_tensor_value_info("a", - TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", - TensorProto.FLOAT, list(b_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_np.shape))]) + graph = helper.make_graph( + [mul_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], + ) - model = helper.make_model(graph, producer_name='matmul_test') + model = helper.make_model(graph, producer_name="matmul_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm( - model, [a_array, b_array], target, ctx) + tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + # TODO(mbrookhart): enable cuda once VM supports heterogenous execution @tvm.testing.parametrize_targets("llvm") def test_batch_matmul(): @@ -842,6 +889,7 @@ def test_batch_matmul(): verify_batch_matmul((2, 4, 3), (3, 4)) verify_batch_matmul((2, 3, 4, 3), (3, 4)) + def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): in_array = np.random.uniform(size=shape).astype(dtype) @@ -849,46 +897,51 @@ def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): alpha = 0.0001 beta = 0.75 bias = 1.0 - node = onnx.helper.make_node( - 'LRN', inputs=['in'], outputs=['out'], size=nsize) + node = onnx.helper.make_node("LRN", inputs=["in"], outputs=["out"], size=nsize) else: - node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha, - beta=beta, bias=bias, size=nsize) + node = onnx.helper.make_node( + "LRN", inputs=["in"], outputs=["out"], alpha=alpha, beta=beta, bias=bias, size=nsize + ) - graph = helper.make_graph([node], - "lrn_test", - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))]) - model = helper.make_model(graph, producer_name='lrn_test') + graph = helper.make_graph( + [node], + "lrn_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))], + ) + model = helper.make_model(graph, producer_name="lrn_test") def _get_python_lrn(): square_sum = np.zeros(shape).astype(dtype) for n, c, h, w in np.ndindex(in_array.shape): - square_sum[n, c, h, w] = sum(in_array[n, - max(0, c - int(math.floor((nsize - 1) / 2))): - min(5, c + int(math.ceil((nsize - 1) / 2)) + 1), - h, - w] ** 2) + square_sum[n, c, h, w] = sum( + in_array[ + n, + max(0, c - int(math.floor((nsize - 1) / 2))) : min( + 5, c + int(math.ceil((nsize - 1) / 2)) + 1 + ), + h, + w, + ] + ** 2 + ) py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta) return py_out for target, ctx in tvm.testing.enabled_targets(): input_name = model.graph.input[0].name py_out = _get_python_lrn() - tvm_out = get_tvm_output( - model, in_array, target, ctx, py_out.shape, 'float32') + tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, "float32") tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu def test_lrn(): - verify_lrn((5, 5, 5, 5), 3, 'float32') - verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0) + verify_lrn((5, 5, 5, 5), 3, "float32") + verify_lrn((5, 5, 5, 5), 3, "float32", alpha=0.0002, beta=0.5, bias=2.0) def verify_instance_norm(shape, axis=1): - def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): dims_x = len(x.shape) axis = tuple(range(2, dims_x)) @@ -906,22 +959,24 @@ def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32) node = onnx.helper.make_node( - 'InstanceNormalization', - inputs=['x', 'gamma', 'beta'], - outputs=['y'], + "InstanceNormalization", + inputs=["x", "gamma", "beta"], + outputs=["y"], epsilon=epsilon, ) - graph = helper.make_graph([node], - "instance_norm_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), - helper.make_tensor_value_info( - "gamma", TensorProto.FLOAT, (shape[1],)), - helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))]) - model = helper.make_model(graph, producer_name='instance_norm_test') + graph = helper.make_graph( + [node], + "instance_norm_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))], + ) + model = helper.make_model(graph, producer_name="instance_norm_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x, gamma, beta], target, ctx, shape, 'float32') + tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, "float32") tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5) @@ -936,143 +991,154 @@ def test_instance_norm(): def _test_upsample_nearest(): scale = 2 in_shape = (1, 1, 3, 3) - out_shape = (1, 1, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in'], [ - 'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) + out_shape = (1, 1, 3 * scale, 3 * scale) + y = helper.make_node("Upsample", ["in"], ["out"], mode="nearest", scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.upsampling_python( - in_array, (scale, scale), "NCHW") + out_array = tvm.topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") - graph = helper.make_graph([y], - 'upsample_nearest_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [y], + "upsample_nearest_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name='upsample_nearest_test') + model = helper.make_model(graph, producer_name="upsample_nearest_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_array, tvm_out) def _test_upsample3d_nearest(): scale = 2 in_shape = (1, 1, 3, 3, 3) - out_shape = (1, 1, 3*scale, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in'], [ - 'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0, 2.0]) + out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale) + y = helper.make_node( + "Upsample", ["in"], ["out"], mode="nearest", scales=[1.0, 1.0, 2.0, 2.0, 2.0] + ) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.upsampling3d_python( - in_array, (scale, scale, scale), "NCDHW") + out_array = tvm.topi.testing.upsampling3d_python(in_array, (scale, scale, scale), "NCDHW") - graph = helper.make_graph([y], - 'upsample_nearest_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [y], + "upsample_nearest_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name='upsample_nearest_test') + model = helper.make_model(graph, producer_name="upsample_nearest_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_array, tvm_out) + def _test_upsample_bilinear(): scale = 2 in_shape = (1, 1, 3, 3) - out_shape = (1, 1, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in'], [ - 'out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0]) + out_shape = (1, 1, 3 * scale, 3 * scale) + y = helper.make_node("Upsample", ["in"], ["out"], mode="linear", scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.bilinear_resize_python( - in_array, (3*scale, 3*scale), "NCHW") + out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 3 * scale), "NCHW") - graph = helper.make_graph([y], - 'upsample_bilinear_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [y], + "upsample_bilinear_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name='upsample_bilinear_test') + model = helper.make_model(graph, producer_name="upsample_bilinear_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) def _test_upsample_bilinear_opset9(): scale = 2 in_shape = (1, 1, 3, 3) - out_shape = (1, 1, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear') + out_shape = (1, 1, 3 * scale, 3 * scale) + y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear") scales = [1, 1, 2, 2] in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.bilinear_resize_python( - in_array, (3*scale, 3*scale), "NCHW") - - ref_node = helper.make_node('Constant', - inputs=[], - outputs=['const'], - value=onnx.helper.make_tensor(name='const_tensor', - data_type=TensorProto.FLOAT, - dims=scales, - vals=np.random.random(scales).flatten().astype(float))) + out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 3 * scale), "NCHW") + + ref_node = helper.make_node( + "Constant", + inputs=[], + outputs=["const"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=scales, + vals=np.random.random(scales).flatten().astype(float), + ), + ) - shape_node = helper.make_node("Shape", ['const'], ['scales']) + shape_node = helper.make_node("Shape", ["const"], ["scales"]) - graph = helper.make_graph([ref_node, shape_node, y], - 'upsample_bilinear_opset9_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [ref_node, shape_node, y], + "upsample_bilinear_opset9_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model( - graph, producer_name='upsample_bilinear_opset9_test') + model = helper.make_model(graph, producer_name="upsample_bilinear_opset9_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, [in_array], target, ctx, opset=9, freeze_params=True) + tvm_out = get_tvm_output_with_vm( + model, [in_array], target, ctx, opset=9, freeze_params=True + ) tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + def _test_upsample3d_trilinear(): scale = 2 in_shape = (1, 1, 3, 3, 3) - out_shape = (1, 1, 3*scale, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear') + out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale) + y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear") scales = [1.0, 1.0, 2.0, 2.0, 2.0] in_array = np.random.uniform(size=in_shape).astype(np.float32) out_array = tvm.topi.testing.trilinear_resize3d_python( - in_array, (3*scale, 3*scale, 3*scale), "NCDHW", coordinate_transformation_mode="half_pixel") + in_array, + (3 * scale, 3 * scale, 3 * scale), + "NCDHW", + coordinate_transformation_mode="half_pixel", + ) ref_array = np.array(scales) - ref_node = helper.make_node('Constant', - inputs=[], - outputs=['scales'], - value=onnx.helper.make_tensor(name='const_tensor', - data_type=TensorProto.FLOAT, - dims=ref_array.shape, - vals=ref_array.flatten().astype(float))) - - graph = helper.make_graph([ref_node, y], - 'upsample_trilinear_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) - - model = helper.make_model( - graph, producer_name='upsample_trilinear_test') + ref_node = helper.make_node( + "Constant", + inputs=[], + outputs=["scales"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=ref_array.shape, + vals=ref_array.flatten().astype(float), + ), + ) + + graph = helper.make_graph( + [ref_node, y], + "upsample_trilinear_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="upsample_trilinear_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_upsample(): @@ -1082,28 +1148,28 @@ def test_upsample(): _test_upsample3d_nearest() _test_upsample3d_trilinear() + def _test_softmax(inshape, axis): - opname = 'Softmax' + opname = "Softmax" indata = np.random.uniform(size=inshape).astype(np.float32) outshape = inshape outdata = tvm.topi.testing.softmax_python(indata) if isinstance(axis, int): - y = helper.make_node(opname, ['in'], ['out'], axis=axis) + y = helper.make_node(opname, ["in"], ["out"], axis=axis) elif axis is None: - y = helper.make_node(opname, ['in'], ['out']) + y = helper.make_node(opname, ["in"], ["out"]) - graph = helper.make_graph([y], - opname+'_test', - inputs=[helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(outdata.shape))]) + graph = helper.make_graph( + [y], + opname + "_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name=opname+'_test') + model = helper.make_model(graph, producer_name=opname + "_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outshape, 'float32') + tvm_out = get_tvm_output(model, indata, target, ctx, outshape, "float32") tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) @@ -1114,7 +1180,7 @@ def test_softmax(): def verify_min(input_dim): - dtype = 'float32' + dtype = "float32" a_np1 = np.random.uniform(size=input_dim).astype(dtype) a_np2 = np.random.uniform(size=input_dim).astype(dtype) @@ -1124,22 +1190,21 @@ def verify_min(input_dim): min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"]) - graph = helper.make_graph([min_node], - "Min_test", - inputs=[helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", - TensorProto.FLOAT, list(input_dim))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + graph = helper.make_graph( + [min_node], + "Min_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + ) - model = helper.make_model(graph, producer_name='Min_test') + model = helper.make_model(graph, producer_name="Min_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -1150,7 +1215,7 @@ def test_forward_min(): def verify_max(input_dim): - dtype = 'float32' + dtype = "float32" a_np1 = np.random.uniform(size=input_dim).astype(dtype) a_np2 = np.random.uniform(size=input_dim).astype(dtype) @@ -1160,22 +1225,21 @@ def verify_max(input_dim): max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"]) - graph = helper.make_graph([max_node], - "Max_test", - inputs=[helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", - TensorProto.FLOAT, list(input_dim))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + graph = helper.make_graph( + [max_node], + "Max_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + ) - model = helper.make_model(graph, producer_name='Max_test') + model = helper.make_model(graph, producer_name="Max_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -1186,7 +1250,7 @@ def test_forward_max(): def verify_mean(input_dim): - dtype = 'float32' + dtype = "float32" a_np1 = np.random.uniform(size=input_dim).astype(dtype) a_np2 = np.random.uniform(size=input_dim).astype(dtype) @@ -1196,22 +1260,21 @@ def verify_mean(input_dim): mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"]) - graph = helper.make_graph([mean_node], - "Mean_test", - inputs=[helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", - TensorProto.FLOAT, list(input_dim))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + graph = helper.make_graph( + [mean_node], + "Mean_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + ) - model = helper.make_model(graph, producer_name='Mean_test') + model = helper.make_model(graph, producer_name="Mean_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -1222,23 +1285,22 @@ def test_forward_mean(): def verify_hardsigmoid(input_dim, alpha, beta): - dtype = 'float32' + dtype = "float32" a_np1 = np.random.uniform(size=input_dim).astype(dtype) b_np = np.clip(a_np1 * alpha + beta, 0, 1) - hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], [ - "out"], alpha=alpha, beta=beta) + hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta) - graph = helper.make_graph([hardsigmoid_node], - "HardSigmoid_test", - inputs=[helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + graph = helper.make_graph( + [hardsigmoid_node], + "HardSigmoid_test", + inputs=[helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(b_np.shape))], + ) - model = helper.make_model(graph, producer_name='HardSigmoid_test') + model = helper.make_model(graph, producer_name="HardSigmoid_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape) @@ -1254,101 +1316,79 @@ def test_forward_hardsigmoid(): def verify_argmin(input_dim, axis=None, keepdims=None): def _argmin_numpy(data, axis=0, keepdims=True): result = np.argmin(data, axis=axis) - if (keepdims == 1): + if keepdims == 1: result = np.expand_dims(result, axis) return result.astype(data.dtype) a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) if keepdims is None and axis is None: b_np = _argmin_numpy(a_np1) - node = onnx.helper.make_node('ArgMin', - inputs=['a_np1'], - outputs=['out']) + node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"]) elif axis is None: b_np = _argmin_numpy(a_np1, keepdims=keepdims) - node = onnx.helper.make_node('ArgMin', - inputs=['a_np1'], - outputs=['out'], - keepdims=keepdims) + node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"], keepdims=keepdims) elif keepdims is None: b_np = _argmin_numpy(a_np1, axis=axis) - node = onnx.helper.make_node('ArgMin', - inputs=['a_np1'], - outputs=['out'], - axis=axis) + node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis) else: b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims) - node = onnx.helper.make_node('ArgMin', - inputs=['a_np1'], - outputs=['out'], - axis=axis, - keepdims=keepdims) - graph = helper.make_graph([node], - "argmin_test", - inputs=[helper.make_tensor_value_info("a_np1", - TensorProto.INT32, list(a_np1.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.INT32, list(b_np.shape))]) - - model = helper.make_model(graph, producer_name='argmin_test') + node = onnx.helper.make_node( + "ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis, keepdims=keepdims + ) + graph = helper.make_graph( + [node], + "argmin_test", + inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, list(b_np.shape))], + ) + + model = helper.make_model(graph, producer_name="argmin_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_np1], target, ctx, b_np.shape, b_np.dtype) + tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) def verify_argmax(input_dim, axis=None, keepdims=None): def _argmax_numpy(data, axis=0, keepdims=True): result = np.argmax(data, axis=axis) - if (keepdims == 1): + if keepdims == 1: result = np.expand_dims(result, axis) return result.astype(data.dtype) a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) if keepdims is None and axis is None: b_np = _argmax_numpy(a_np1) - node = onnx.helper.make_node('ArgMax', - inputs=['a_np1'], - outputs=['out']) + node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"]) elif axis is None: b_np = _argmax_numpy(a_np1, keepdims=keepdims) - node = onnx.helper.make_node('ArgMax', - inputs=['a_np1'], - outputs=['out'], - keepdims=keepdims) + node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"], keepdims=keepdims) elif keepdims is None: b_np = _argmax_numpy(a_np1, axis=axis) - node = onnx.helper.make_node('ArgMax', - inputs=['a_np1'], - outputs=['out'], - axis=axis) + node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis) else: b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims) - node = onnx.helper.make_node('ArgMax', - inputs=['a_np1'], - outputs=['out'], - axis=axis, - keepdims=keepdims) + node = onnx.helper.make_node( + "ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis, keepdims=keepdims + ) - graph = helper.make_graph([node], - "argmax_test", - inputs=[helper.make_tensor_value_info("a_np1", - TensorProto.INT32, list(a_np1.shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.INT32, list(b_np.shape))]) + graph = helper.make_graph( + [node], + "argmax_test", + inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, list(b_np.shape))], + ) - model = helper.make_model(graph, producer_name='argmax_test') + model = helper.make_model(graph, producer_name="argmax_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [a_np1], target, ctx, b_np.shape, b_np.dtype) + tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu def test_forward_arg_min_max(): - '''Verify argmin and argmax''' + """Verify argmin and argmax""" verify_argmin([3, 4, 4]) verify_argmax([3, 4, 4]) verify_argmin([3, 4, 4], axis=1) @@ -1365,26 +1405,25 @@ def verify_constantofshape(input_dim, value, dtype): out = np.empty(shape=input_dim, dtype=dtype) out.fill(value) - fill_node = helper.make_node("ConstantOfShape", ["input"], ["output"], - value=helper.make_tensor( - 'value', - mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], - (1, ), (value, ))) + fill_node = helper.make_node( + "ConstantOfShape", + ["input"], + ["output"], + value=helper.make_tensor( + "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,) + ), + ) - inputs = [ - helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim) - ] + inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim)] graph = helper.make_graph( [fill_node], "fill_test", inputs, - outputs=[ - helper.make_tensor_value_info("output", TensorProto.FLOAT, - list(out.shape)) - ]) + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(out.shape))], + ) - model = helper.make_model(graph, producer_name='fill_test') + model = helper.make_model(graph, producer_name="fill_test") for target, ctx in tvm.testing.enabled_targets(): input_np = np.array(input_dim).astype("float32") @@ -1396,100 +1435,90 @@ def verify_constantofshape(input_dim, value, dtype): # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_constantofshape(): - verify_constantofshape((2, 3, 4, 5), 10, 'float32') - verify_constantofshape((3, 3), 0, 'int32') - verify_constantofshape((1, 2, 3), -1, 'float32') + verify_constantofshape((2, 3, 4, 5), 10, "float32") + verify_constantofshape((3, 3), 0, "int32") + verify_constantofshape((1, 2, 3), -1, "float32") -def verify_pad(indata, pads, mode='constant', value=0.0): +def verify_pad(indata, pads, mode="constant", value=0.0): indata = np.array(indata).astype(np.float32) # numpy expect result len_dim = len(pads) // 2 - np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] # onnx graph - if mode in ['edge', 'reflect']: + if mode in ["edge", "reflect"]: outdata = np.pad(indata, pad_width=np_pads, mode=mode) node = helper.make_node( - 'Pad', - inputs=['input'], - outputs=['output'], + "Pad", + inputs=["input"], + outputs=["output"], mode=mode, pads=pads, ) else: - outdata = np.pad(indata, pad_width=np_pads, - mode='constant', constant_values=value) + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( - 'Pad', - inputs=['input'], - outputs=['output'], - mode='constant', - pads=pads, - value=value + "Pad", inputs=["input"], outputs=["output"], mode="constant", pads=pads, value=value ) - graph = helper.make_graph([node], - 'pad_test', - inputs=[helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) - model = helper.make_model(graph, producer_name='pad_test') + graph = helper.make_graph( + [node], + "pad_test", + inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="pad_test") # tvm result for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, 'float32', opset=2) + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=2) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) -def verify_pad_v11(indata, pads, mode='constant', value=0.0): +def verify_pad_v11(indata, pads, mode="constant", value=0.0): indata = np.array(indata).astype(np.float32) # numpy expect result len_dim = len(pads) // 2 - np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] pads = np.array(pads) # onnx graph - if mode in ['edge', 'reflect']: + if mode in ["edge", "reflect"]: inputs = [indata, pads] outdata = np.pad(indata, pad_width=np_pads, mode=mode) - node = helper.make_node( - 'Pad', - inputs=['input', 'pads'], - outputs=['output'], - mode=mode + node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), + ], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], ) - graph = helper.make_graph([node], - 'pad_test', - inputs=[helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("pads", - TensorProto.INT64,(len(pads),))], - initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) else: inputs = [indata, pads, np.array([value]).astype("float32")] - outdata = np.pad(indata, pad_width=np_pads, - mode='constant', constant_values=value) + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( - 'Pad', - inputs=['input', 'pads', 'constant_value'], - outputs=['output'], - mode='constant' + "Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant" ) - graph = helper.make_graph([node], - 'pad_test', - inputs=[helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("pads", - TensorProto.INT64,(len(pads),)), - helper.make_tensor_value_info("constant_value", - TensorProto.FLOAT,(1,)), - ], - initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), - helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value])], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) - model = helper.make_model(graph, producer_name='pad_test') + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), + helper.make_tensor_value_info("constant_value", TensorProto.FLOAT, (1,)), + ], + initializer=[ + helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), + helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value]), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + model = helper.make_model(graph, producer_name="pad_test") # tvm result for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, freeze_params=False) @@ -1499,27 +1528,19 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0): # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_pad(): - verify_pad(np.random.randn(2, 2).astype( - np.float32), [0, 1, 0, 0], 'constant', 0.0) - verify_pad(np.random.randn(2, 3).astype( - np.float32), [1, 0, 0, 1], 'constant', 0.0) - verify_pad(np.random.randn(3, 2).astype( - np.float32), [0, 0, 1, 0], 'constant', 5.0) - verify_pad(np.random.randn(1, 3, 4, 5).astype( - np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') - verify_pad(np.random.randn(1, 3, 4, 5).astype( - np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') - - verify_pad_v11(np.random.randn(2, 2).astype( - np.float32), [0, 1, 0, 0], 'constant', 0.0) - verify_pad_v11(np.random.randn(2, 3).astype( - np.float32), [1, 0, 0, 1], 'constant', 0.0) - verify_pad_v11(np.random.randn(3, 2).astype( - np.float32), [0, 0, 1, 0], 'constant', 5.0) - verify_pad_v11(np.random.randn(1, 3, 4, 5).astype( - np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') - verify_pad_v11(np.random.randn(1, 3, 4, 5).astype( - np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0) + verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0) + verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], "constant", 5.0) + verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "edge") + verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "reflect") + + verify_pad_v11(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0) + verify_pad_v11(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0) + verify_pad_v11(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], "constant", 5.0) + verify_pad_v11(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "edge") + verify_pad_v11( + np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], "reflect" + ) def verify_reduce_func(func, data, axis, keepdims): @@ -1527,67 +1548,67 @@ def verify_reduce_func(func, data, axis, keepdims): outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape if axis: - node = onnx.helper.make_node(func, - inputs=['x'], - outputs=['y'], - axes=axis, - keepdims=keepdims) + node = onnx.helper.make_node( + func, inputs=["x"], outputs=["y"], axes=axis, keepdims=keepdims + ) else: - node = onnx.helper.make_node(func, - inputs=['x'], - outputs=['y'], - keepdims=keepdims) + node = onnx.helper.make_node(func, inputs=["x"], outputs=["y"], keepdims=keepdims) - graph = helper.make_graph([node], - "reduce_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + graph = helper.make_graph( + [node], + "reduce_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) - model = helper.make_model(graph, producer_name='reduce_test') + model = helper.make_model(graph, producer_name="reduce_test") - onnx_out = get_onnxruntime_output(model, data, 'float32') + onnx_out = get_onnxruntime_output(model, data, "float32") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, data, target, ctx, outshape, 'float32') + tvm_out = get_tvm_output(model, data, target, ctx, outshape, "float32") tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + @tvm.testing.uses_gpu def test_all_reduce_funcs(): - funcs = ["ReduceMax", - "ReduceMean", - "ReduceMin", - "ReduceProd", - "ReduceSum", - 'ReduceSumSquare', - "ReduceLogSum", - "ReduceLogSumExp", - "ReduceL1", - "ReduceL2"] + funcs = [ + "ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + "ReduceSumSquare", + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceL1", + "ReduceL2", + ] for func in funcs: for keepdims in [True, False]: - verify_reduce_func(func, - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=keepdims) + verify_reduce_func( + func, np.random.randn(3, 2, 2).astype(np.float32), axis=None, keepdims=keepdims + ) - verify_reduce_func(func, - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=keepdims) + verify_reduce_func( + func, np.random.randn(3, 2, 3).astype(np.float32), axis=None, keepdims=keepdims + ) - verify_reduce_func(func, - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=keepdims) + verify_reduce_func( + func, np.random.randn(3, 3, 3).astype(np.float32), axis=(1,), keepdims=keepdims + ) - verify_reduce_func(func, - np.random.randn(3, 3, 3, 1).astype(np.float32), - axis=(1, 2), keepdims=keepdims) + verify_reduce_func( + func, np.random.randn(3, 3, 3, 1).astype(np.float32), axis=(1, 2), keepdims=keepdims + ) - verify_reduce_func(func, - np.random.randn(3, 3, 3, 1).astype(np.float32), - axis=(1,), keepdims=keepdims) + verify_reduce_func( + func, np.random.randn(3, 3, 3, 1).astype(np.float32), axis=(1,), keepdims=keepdims + ) - verify_reduce_func(func, - np.random.randn(1, 3, 4, 1).astype(np.float32), - axis=(1,), keepdims=keepdims) + verify_reduce_func( + func, np.random.randn(1, 3, 4, 1).astype(np.float32), axis=(1,), keepdims=keepdims + ) def verify_split(indata, outdatas, split, axis=0, pass_split=True): @@ -1599,38 +1620,41 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True): split_index = range(len(outdatas)) if pass_split: node = helper.make_node( - 'Split', - inputs=['input'], - outputs=['output_{}'.format(i) for i in range(len(split_index))], + "Split", + inputs=["input"], + outputs=["output_{}".format(i) for i in range(len(split_index))], axis=axis, - split=split + split=split, ) else: node = helper.make_node( - 'Split', - inputs=['input'], - outputs=['output_{}'.format(i) for i in range(len(split_index))], + "Split", + inputs=["input"], + outputs=["output_{}".format(i) for i in range(len(split_index))], axis=axis, ) - graph = helper.make_graph([node], - 'split_test', - inputs=[helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("output_{}".format(i), - TensorProto.FLOAT, list(outdatas[i].shape)) - for i in range(len(split_index)) - ]) - model = helper.make_model(graph, producer_name='split_test') + graph = helper.make_graph( + [node], + "split_test", + inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))], + outputs=[ + helper.make_tensor_value_info( + "output_{}".format(i), TensorProto.FLOAT, list(outdatas[i].shape) + ) + for i in range(len(split_index)) + ], + ) + model = helper.make_model(graph, producer_name="split_test") import onnxruntime.backend - rep = onnxruntime.backend.prepare(model, 'CPU') + + rep = onnxruntime.backend.prepare(model, "CPU") onnx_out = rep.run(indata) for target, ctx in tvm.testing.enabled_targets(): output_shape = [o.shape for o in outdatas] - output_type = ['float32', 'float32', 'float32'] - tvm_out = get_tvm_output( - model, indata, target, ctx, output_shape, output_type) + output_type = ["float32", "float32", "float32"] + tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type) for o, t in zip(onnx_out, tvm_out): tvm.testing.assert_allclose(o, t) @@ -1638,15 +1662,18 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True): @tvm.testing.uses_gpu def test_split(): # 1D - verify_split([1., 2., 3., 4., 5., 6.], [ - [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0) - verify_split([1., 2., 3., 4., 5., 6.], [ - [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0, False) - verify_split([1., 2., 3., 4., 5., 6.], [ - [1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0) + verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0) + verify_split( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0, False + ) + verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0) # 2D - verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]], - [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1) + verify_split( + [[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]], + [[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]], + [2, 2], + 1, + ) # Split evenly (unstack) verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) @@ -1657,20 +1684,21 @@ def test_binary_ops(): dtype = "float32" out_shape = in_shape - def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=None): + def verify_binary_ops(op, x, y, out_np, x_name="in1", y_name="in2", broadcast=None): if broadcast is None: - z = helper.make_node(op, [x_name, y_name], ['out']) + z = helper.make_node(op, [x_name, y_name], ["out"]) else: - z = helper.make_node(op, [x_name, y_name], ['out'], broadcast=1) - graph = helper.make_graph([z], - '_test', - inputs=[helper.make_tensor_value_info(x_name, - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info(y_name, - TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) - model = helper.make_model(graph, producer_name='_test') + z = helper.make_node(op, [x_name, y_name], ["out"], broadcast=1) + graph = helper.make_graph( + [z], + "_test", + inputs=[ + helper.make_tensor_value_info(x_name, TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info(y_name, TensorProto.FLOAT, list(in_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -1679,12 +1707,12 @@ def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=No y = np.random.uniform(size=in_shape).astype(dtype) z = np.random.uniform(size=(3,)).astype(dtype) verify_binary_ops("Add", x, y, x + y, broadcast=None) - verify_binary_ops("Add", x, z, x + z, broadcast=True) + verify_binary_ops("Add", x, z, x + z, broadcast=True) verify_binary_ops("Sub", x, y, x - y, broadcast=None) verify_binary_ops("Sub", x, z, x - z, broadcast=True) verify_binary_ops("Mul", x, y, x * y, broadcast=None) - verify_binary_ops("Mul", x, z, x * z, broadcast=True) - verify_binary_ops("Mul", x, x, x * x, x_name='in1', y_name='in1', broadcast=None) + verify_binary_ops("Mul", x, z, x * z, broadcast=True) + verify_binary_ops("Mul", x, x, x * x, x_name="in1", y_name="in1", broadcast=None) verify_binary_ops("Div", x, y, x / y, broadcast=None) verify_binary_ops("Div", x, z, x / z, broadcast=True) verify_binary_ops("Sum", x, y, x + y, broadcast=None) @@ -1700,14 +1728,16 @@ def test_single_ops(): out_shape = in_shape def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): - z = helper.make_node(op, ['in1'], ['out']) - graph = helper.make_graph([z], - '_test', - inputs=[helper.make_tensor_value_info("in1", - TensorProto.FLOAT, list(in_shape)), ], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) - model = helper.make_model(graph, producer_name='_test') + z = helper.make_node(op, ["in1"], ["out"]) + graph = helper.make_graph( + [z], + "_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.FLOAT, list(in_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) @@ -1715,7 +1745,7 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): x = np.random.uniform(size=in_shape).astype(dtype) verify_single_ops("Neg", x, -x) verify_single_ops("Abs", x, np.abs(x)) - verify_single_ops("Reciprocal", x, 1/x) + verify_single_ops("Reciprocal", x, 1 / x) verify_single_ops("Sqrt", x, np.sqrt(x)) verify_single_ops("Relu", x, np.maximum(x, 0)) verify_single_ops("Exp", x, np.exp(x)) @@ -1742,65 +1772,67 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): def test_leaky_relu(): def leaky_relu_x(x, alpha): return np.where(x >= 0, x, x * alpha) - _test_onnx_op_elementwise((2, 4, 5, 6), - leaky_relu_x, - {'alpha': 0.25}, - 'float32', - 'LeakyRelu', - {'alpha': 0.25}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), leaky_relu_x, {"alpha": 0.25}, "float32", "LeakyRelu", {"alpha": 0.25} + ) @tvm.testing.uses_gpu def test_elu(): def elu_x(x, alpha): return np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) - _test_onnx_op_elementwise((2, 4, 5, 6), - elu_x, - {'alpha': 0.25}, - 'float32', - 'Elu', - {'alpha': 0.25}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), elu_x, {"alpha": 0.25}, "float32", "Elu", {"alpha": 0.25} + ) @tvm.testing.uses_gpu def test_selu(): def selu_x(x, alpha, gamma): return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) - _test_onnx_op_elementwise((2, 4, 5, 6), - selu_x, - {'alpha': 0.25, 'gamma': 0.3}, - 'float32', - 'Selu', - {'alpha': 0.25, 'gamma': 0.3}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), + selu_x, + {"alpha": 0.25, "gamma": 0.3}, + "float32", + "Selu", + {"alpha": 0.25, "gamma": 0.3}, + ) @tvm.testing.uses_gpu def test_prelu(): def verify_prelu(x_shape, a_shape): - node = helper.make_node('PRelu', - inputs=['X', 'slope'], - outputs=['Y']) - - graph = helper.make_graph([node], - "prelu_test", - inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("slope", TensorProto.FLOAT, list(a_shape))], - outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(x_shape))]) + node = helper.make_node("PRelu", inputs=["X", "slope"], outputs=["Y"]) + + graph = helper.make_graph( + [node], + "prelu_test", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("slope", TensorProto.FLOAT, list(a_shape)), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(x_shape))], + ) - model = helper.make_model(graph, producer_name='prelu_test') + model = helper.make_model(graph, producer_name="prelu_test") indata = np.random.uniform(-10, 10, x_shape).astype(np.float32) slopedata = np.random.uniform(-10, 10, a_shape).astype(np.float32) onnx_out = get_onnxruntime_output(model, [indata, slopedata]) - for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), - output_dtype='float32') + for target, ctx in [("llvm", tvm.cpu())]: + tvm_out = get_tvm_output( + model, [indata, slopedata], target, ctx, list(x_shape), output_dtype="float32" + ) tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) - verify_prelu([3,4,5,6], [1, 4, 1, 1]) - verify_prelu([1,8,5,6], [1, 8, 1, 1]) - verify_prelu([2,12,16,16], [1, 12, 1, 1]) + verify_prelu([3, 4, 5, 6], [1, 4, 1, 1]) + verify_prelu([1, 8, 5, 6], [1, 8, 1, 1]) + verify_prelu([2, 12, 16, 16], [1, 12, 1, 1]) @tvm.testing.uses_gpu @@ -1809,69 +1841,72 @@ def ThresholdedRelu_x(x, alpha): out_np = np.clip(x, alpha, np.inf) out_np[out_np == alpha] = 0 return out_np - _test_onnx_op_elementwise((2, 4, 5, 6), - ThresholdedRelu_x, - {'alpha': 0.25}, - 'float32', - 'ThresholdedRelu', - {'alpha': 0.25}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), + ThresholdedRelu_x, + {"alpha": 0.25}, + "float32", + "ThresholdedRelu", + {"alpha": 0.25}, + ) @tvm.testing.uses_gpu def test_ScaledTanh(): def ScaledTanh_x(x, alpha, beta): return alpha * np.tanh(beta * x) - _test_onnx_op_elementwise((2, 4, 5, 6), - ScaledTanh_x, - {'alpha': 0.25, 'beta': 0.3}, - 'float32', - 'ScaledTanh', - {'alpha': 0.25, 'beta': 0.3}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), + ScaledTanh_x, + {"alpha": 0.25, "beta": 0.3}, + "float32", + "ScaledTanh", + {"alpha": 0.25, "beta": 0.3}, + ) @tvm.testing.uses_gpu def test_ParametricSoftplus(): def ParametricSoftplus_x(x, alpha, beta): return alpha * np.log(np.exp(beta * x) + 1) - _test_onnx_op_elementwise((2, 4, 5, 6), - ParametricSoftplus_x, - {'alpha': 0.25, 'beta': 0.3}, - 'float32', - 'ParametricSoftplus', - {'alpha': 0.25, 'beta': 0.3}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), + ParametricSoftplus_x, + {"alpha": 0.25, "beta": 0.3}, + "float32", + "ParametricSoftplus", + {"alpha": 0.25, "beta": 0.3}, + ) @tvm.testing.uses_gpu def test_Scale(): def Scale_x(x, scale): return scale * x - _test_onnx_op_elementwise((2, 4, 5, 6), - Scale_x, - {'scale': 0.25}, - 'float32', - 'Scale', - {'scale': 0.25}) + + _test_onnx_op_elementwise( + (2, 4, 5, 6), Scale_x, {"scale": 0.25}, "float32", "Scale", {"scale": 0.25} + ) @tvm.testing.uses_gpu def test_LogSoftmax(): - _test_onnx_op_elementwise((1, 4), - tvm.topi.testing.log_softmax_python, - {}, - 'float32', - 'LogSoftmax', - {'axis': 1}) + _test_onnx_op_elementwise( + (1, 4), tvm.topi.testing.log_softmax_python, {}, "float32", "LogSoftmax", {"axis": 1} + ) def check_torch_conversion(model, input_size): dummy_input = torch.randn(*input_size) - file_name = '{}.onnx'.format(model.__name__) + file_name = "{}.onnx".format(model.__name__) # Set verbose=True for more output - torch.onnx.export(model(), dummy_input, file_name, - export_params=True, verbose=False) + torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) onnx_model = onnx.load(file_name) for target, ctx in tvm.testing.enabled_targets(): - input_data = np.random.uniform(size=input_size).astype('int32') + input_data = np.random.uniform(size=input_size).astype("int32") c2_out = get_onnxruntime_output(onnx_model, input_data) tvm_out = get_tvm_output(onnx_model, input_data, target, ctx) tvm.testing.assert_allclose(c2_out, tvm_out) @@ -1882,6 +1917,7 @@ def test_resnet(): check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224)) # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) + # def test_alexnet(): # Torch's ONNX export does not support the adaptive pooling used by AlexNet? # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) @@ -1905,6 +1941,7 @@ def test_densenet(): def test_inception(): check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224)) + # TODO(@jroesch): Update Torch + ONNX to support this import. # def test_googlenet(): # check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) @@ -1918,27 +1955,28 @@ def test_inception(): def test_sign(): def Sign_x(x): return np.sign(x) - _test_onnx_op_elementwise((3, 4, 5, 6), - Sign_x, - {}, - 'float32', - 'Sign', - {}) + + _test_onnx_op_elementwise((3, 4, 5, 6), Sign_x, {}, "float32", "Sign", {}) def verify_not(indata, dtype): x = indata.astype(dtype) outdata = np.logical_not(x) - node = helper.make_node('Not', inputs=['in'], outputs=['out'],) + node = helper.make_node( + "Not", + inputs=["in"], + outputs=["out"], + ) - graph = helper.make_graph([node], - 'not_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.BOOL, list(x.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) + graph = helper.make_graph( + [node], + "not_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='not_test') + model = helper.make_model(graph, producer_name="not_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape) @@ -1960,15 +1998,23 @@ def verify_and(indata, dtype): y = indata[1].astype(dtype) outdata = np.logical_and(x, y) - node = helper.make_node('And', inputs=['in1', 'in2'], outputs=['out'], ) + node = helper.make_node( + "And", + inputs=["in1", "in2"], + outputs=["out"], + ) - graph = helper.make_graph([node], - 'and_test', - inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), - helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) + graph = helper.make_graph( + [node], + "and_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), + helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='and_test') + model = helper.make_model(graph, producer_name="and_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) @@ -1978,93 +2024,85 @@ def verify_and(indata, dtype): @tvm.testing.uses_gpu def test_and(): # 2d - x = (np.random.randn(3, 4) > 0) - y = (np.random.randn(3, 4) > 0) + x = np.random.randn(3, 4) > 0 + y = np.random.randn(3, 4) > 0 verify_and(indata=[x, y], dtype=bool) # 3d - x = (np.random.randn(3, 4, 5) > 0) - y = (np.random.randn(3, 4, 5) > 0) + x = np.random.randn(3, 4, 5) > 0 + y = np.random.randn(3, 4, 5) > 0 verify_and(indata=[x, y], dtype=bool) # 4d - x = (np.random.randn(3, 4, 5, 6) > 0) - y = (np.random.randn(3, 4, 5, 6) > 0) + x = np.random.randn(3, 4, 5, 6) > 0 + y = np.random.randn(3, 4, 5, 6) > 0 verify_and(indata=[x, y], dtype=bool) # 3d vs 1d - x = (np.random.randn(3, 4, 5) > 0) - y = (np.random.randn(5) > 0) + x = np.random.randn(3, 4, 5) > 0 + y = np.random.randn(5) > 0 verify_and(indata=[x, y], dtype=bool) # 3d vs 2d - x = (np.random.randn(3, 4, 5) > 0) - y = (np.random.randn(4, 5) > 0) + x = np.random.randn(3, 4, 5) > 0 + y = np.random.randn(4, 5) > 0 verify_and(indata=[x, y], dtype=bool) def verify_tile_v1(indata, outdata, **kwargs): - node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs) - graph = helper.make_graph([node], - 'tile_test', - inputs=[helper.make_tensor_value_info( - "in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))]) + node = helper.make_node("Tile", inputs=["in"], outputs=["out"], **kwargs) + graph = helper.make_graph( + [node], + "tile_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='tile_test') + model = helper.make_model(graph, producer_name="tile_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [indata], target, ctx, outdata.shape, opset=1) + tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, opset=1) tvm.testing.assert_allclose(outdata, tvm_out) def verify_tile_v6(indata, repeats, outdata): - node = helper.make_node('Tile', - inputs=['input', 'repeats'], - outputs=['out']) + node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"]) graph = helper.make_graph( [node], - 'tile_test', + "tile_test", inputs=[ - helper.make_tensor_value_info("input", TensorProto.FLOAT, - list(indata.shape)), - helper.make_tensor_value_info("repeats", TensorProto.INT64, - list(repeats.shape)) + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("repeats", TensorProto.INT64, list(repeats.shape)), ], - outputs=[ - helper.make_tensor_value_info("out", TensorProto.FLOAT, - list(outdata.shape)) - ]) + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='tile_test') + model = helper.make_model(graph, producer_name="tile_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, [indata, repeats], - target, - ctx, - opset=6) + tvm_out = get_tvm_output_with_vm(model, [indata, repeats], target, ctx, opset=6) tvm.testing.assert_allclose(outdata, tvm_out) + # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) - repeats = np.random.randint( - low=1, high=10, size=(np.ndim(x),)).astype(np.int64) + repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64) z = np.tile(x, repeats) verify_tile_v1(x, z, repeats=repeats) verify_tile_v6(x, repeats, z) def verify_erf(indata, outdata): - node = helper.make_node('Erf', inputs=['in'], outputs=['out']) - graph = helper.make_graph([node], - 'erf_test', - inputs=[helper.make_tensor_value_info( - 'in', TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info('out', TensorProto.FLOAT, list(outdata.shape))]) - model = helper.make_model(graph, producer_name='erf_test') + node = helper.make_node("Erf", inputs=["in"], outputs=["out"]) + graph = helper.make_graph( + [node], + "erf_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="erf_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape) @@ -2079,14 +2117,18 @@ def test_erf(): def verify_where(condition, x, y, dtype, outdata): - node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out']) - graph = helper.make_graph([node], - 'where_test', - inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)), - helper.make_tensor_value_info('x', dtype, list(x.shape)), - helper.make_tensor_value_info('y', dtype, list(y.shape))], - outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))]) - model = helper.make_model(graph, producer_name='where_test') + node = helper.make_node("Where", inputs=["condition", "x", "y"], outputs=["out"]) + graph = helper.make_graph( + [node], + "where_test", + inputs=[ + helper.make_tensor_value_info("condition", TensorProto.BOOL, list(condition.shape)), + helper.make_tensor_value_info("x", dtype, list(x.shape)), + helper.make_tensor_value_info("y", dtype, list(y.shape)), + ], + outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="where_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape) @@ -2133,15 +2175,23 @@ def verify_or(indata, dtype): y = indata[1].astype(dtype) outdata = np.logical_or(x, y) - node = helper.make_node('Or', inputs=['in1', 'in2'], outputs=['out'], ) + node = helper.make_node( + "Or", + inputs=["in1", "in2"], + outputs=["out"], + ) - graph = helper.make_graph([node], - 'or_test', - inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), - helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) + graph = helper.make_graph( + [node], + "or_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), + helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='or_test') + model = helper.make_model(graph, producer_name="or_test") for target, ctx in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) @@ -2151,64 +2201,63 @@ def verify_or(indata, dtype): @tvm.testing.uses_gpu def test_or(): # 2d - x = (np.random.randn(3, 4) > 0) - y = (np.random.randn(3, 4) > 0) + x = np.random.randn(3, 4) > 0 + y = np.random.randn(3, 4) > 0 verify_or(indata=[x, y], dtype=bool) # 3d - x = (np.random.randn(3, 4, 5) > 0) - y = (np.random.randn(3, 4, 5) > 0) + x = np.random.randn(3, 4, 5) > 0 + y = np.random.randn(3, 4, 5) > 0 verify_or(indata=[x, y], dtype=bool) # 4d - x = (np.random.randn(3, 4, 5, 6) > 0) - y = (np.random.randn(3, 4, 5, 6) > 0) + x = np.random.randn(3, 4, 5, 6) > 0 + y = np.random.randn(3, 4, 5, 6) > 0 verify_or(indata=[x, y], dtype=bool) # 3d vs 1d - x = (np.random.randn(3, 4, 5) > 0) - y = (np.random.randn(5) > 0) + x = np.random.randn(3, 4, 5) > 0 + y = np.random.randn(5) > 0 verify_or(indata=[x, y], dtype=bool) # 3d vs 2d - x = (np.random.randn(3, 4, 5) > 0) - y = (np.random.randn(4, 5) > 0) + x = np.random.randn(3, 4, 5) > 0 + y = np.random.randn(4, 5) > 0 verify_or(indata=[x, y], dtype=bool) @tvm.testing.uses_gpu def test_batch_norm(): def verify_batch_norm(in_shape): - batchnorm = onnx.helper.make_node('BatchNormalization', - inputs=["x", "scale", "B", "mean", "var"], - outputs=['Y']) - - graph = helper.make_graph([batchnorm], - "batchnorm_test", - inputs=[helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("scale", - TensorProto.FLOAT, [in_shape[1]]), - helper.make_tensor_value_info("B", - TensorProto.FLOAT, [in_shape[1]]), - helper.make_tensor_value_info("mean", - TensorProto.FLOAT, [in_shape[1]]), - helper.make_tensor_value_info("var", - TensorProto.FLOAT, [in_shape[1]]), - ], - outputs=[helper.make_tensor_value_info("Y", - TensorProto.FLOAT, list(in_shape))]) - - model = helper.make_model(graph, producer_name='batchnorm_test') + batchnorm = onnx.helper.make_node( + "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], outputs=["Y"] + ) + + graph = helper.make_graph( + [batchnorm], + "batchnorm_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("mean", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("var", TensorProto.FLOAT, [in_shape[1]]), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(in_shape))], + ) + + model = helper.make_model(graph, producer_name="batchnorm_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('float32') - scale = np.random.uniform(size=in_shape[1]).astype('float32') - b = np.random.uniform(size=in_shape[1]).astype('float32') - mean = np.random.uniform(size=in_shape[1]).astype('float32') - var = np.random.uniform(size=in_shape[1]).astype('float32') - onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], 'float32')[0] - tvm_out = get_tvm_output(model, [x, scale, b, mean, var], target, ctx, in_shape, 'float32') + x = np.random.uniform(size=in_shape).astype("float32") + scale = np.random.uniform(size=in_shape[1]).astype("float32") + b = np.random.uniform(size=in_shape[1]).astype("float32") + mean = np.random.uniform(size=in_shape[1]).astype("float32") + var = np.random.uniform(size=in_shape[1]).astype("float32") + onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], "float32")[0] + tvm_out = get_tvm_output( + model, [x, scale, b, mean, var], target, ctx, in_shape, "float32" + ) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) verify_batch_norm([1, 3, 224, 224]) @@ -2223,91 +2272,106 @@ def verify_batch_norm(in_shape): def test_batch_norm_dynamic_subgraph(): def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): - batchnorm = onnx.helper.make_node('BatchNormalization', - inputs=["x", "scale", "B", "mean", "var"], - outputs=['Y']) + batchnorm = onnx.helper.make_node( + "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], outputs=["Y"] + ) - shape_node = helper.make_node("Shape", ['Y'], ['shape']) + shape_node = helper.make_node("Shape", ["Y"], ["shape"]) reshape_node = helper.make_node("Reshape", ["in", "shape"], ["out"]) - graph = helper.make_graph([batchnorm, shape_node, reshape_node], - "batchnorm_test", - inputs=[helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(o_shape)), - helper.make_tensor_value_info("scale", - TensorProto.FLOAT, [in_shape[1]]), - helper.make_tensor_value_info("B", - TensorProto.FLOAT, [in_shape[1]]), - helper.make_tensor_value_info("mean", - TensorProto.FLOAT, [in_shape[1]]), - helper.make_tensor_value_info("var", - TensorProto.FLOAT, [in_shape[1]]), - ], - outputs=[helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(in_shape))]) - - model = helper.make_model(graph, producer_name='batchnorm_test') + graph = helper.make_graph( + [batchnorm, shape_node, reshape_node], + "batchnorm_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("in", TensorProto.FLOAT, list(o_shape)), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("mean", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("var", TensorProto.FLOAT, [in_shape[1]]), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(in_shape))], + ) + + model = helper.make_model(graph, producer_name="batchnorm_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype('float32') - inp = np.random.uniform(size=o_shape).astype('float32') - scale = np.random.uniform(size=in_shape[1]).astype('float32') - b = np.random.uniform(size=in_shape[1]).astype('float32') - mean = np.random.uniform(size=in_shape[1]).astype('float32') - var = np.random.uniform(size=in_shape[1]).astype('float32') - onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], 'float32')[0] + x = np.random.uniform(size=in_shape).astype("float32") + inp = np.random.uniform(size=o_shape).astype("float32") + scale = np.random.uniform(size=in_shape[1]).astype("float32") + b = np.random.uniform(size=in_shape[1]).astype("float32") + mean = np.random.uniform(size=in_shape[1]).astype("float32") + var = np.random.uniform(size=in_shape[1]).astype("float32") + onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], "float32")[0] tvm_out = get_tvm_output_with_vm(model, [x, inp, scale, b, mean, var], target, ctx) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) -def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False): +def verify_conv( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + unset_pad=False, +): if unset_pad: - node = helper.make_node('Conv', - inputs=['x', 'W'], - outputs=['y'], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - # groups=1 - ) + node = helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + ) elif padding is None: - node = helper.make_node('Conv', - inputs=['x', 'W'], - outputs=['y'], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - # groups=1 - auto_pad=auto_pad) + node = helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + auto_pad=auto_pad, + ) else: - node = helper.make_node('Conv', - inputs=['x', 'W'], - outputs=['y'], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - # groups=1 - pads=padding) - - graph = helper.make_graph([node], - 'conv_test', - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))]) - - model = helper.make_model(graph, producer_name='conv_test') + node = helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, + ) + + graph = helper.make_graph( + [node], + "conv_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + ) + + model = helper.make_model(graph, producer_name="conv_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=x_shape).astype('float32') - W = np.random.uniform(size=w_shape).astype('float32') + x = np.random.uniform(size=x_shape).astype("float32") + W = np.random.uniform(size=w_shape).astype("float32") tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) - onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0] + onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0] tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) @@ -2315,90 +2379,112 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat def test_conv(): def repeat(N, D): return tuple([N for _ in range(D)]) + for D in [1, 2, 3]: # Convolution with padding - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), - 2 * repeat(1, D), - repeat(3, D), - repeat(1, D), - repeat(1, D)) + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) # Convolution without padding - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(3, D), - 2 * repeat(0, D), - repeat(3, D), - repeat(1, D), - repeat(1, D)) + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) # Convolution with autopadding - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), - None, - repeat(3, D), - repeat(1, D), - repeat(1, D), - auto_pad="SAME_UPPER") + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) # Convolution with valid autopadding - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(3, D), - None, - repeat(3, D), - repeat(1, D), - repeat(1, D), - auto_pad="VALID") + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) # Convolution with unset padding - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(3, D), - 2 * repeat(0, D), - repeat(3, D), - repeat(1, D), - repeat(1, D), - True) + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + True, + ) # Convolution with non uniform stride - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(3, D), - None, - repeat(3, D), - repeat(2, D), - repeat(1, D), - auto_pad="SAME_UPPER") + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) # Convolution with dilation - verify_conv((1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), - 2 * repeat(2, D), - repeat(3, D), - repeat(1, D), - repeat(2, D)) + verify_conv( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D), + ) + def verify_convtranspose(x_shape, w_shape, y_shape, p): - node = onnx.helper.make_node("ConvTranspose", - inputs=["x", "W"], - outputs=['y'], - strides=[3, 2], - group=1, - kernel_shape=[3, 3], - pads=p) - - graph = helper.make_graph([node], - 'verify_convtranspose_test', - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))]) - - model = helper.make_model(graph, producer_name='convtranspose_trest') + node = onnx.helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + strides=[3, 2], + group=1, + kernel_shape=[3, 3], + pads=p, + ) + + graph = helper.make_graph( + [node], + "verify_convtranspose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + ) + + model = helper.make_model(graph, producer_name="convtranspose_trest") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=x_shape).astype('float32') - W = np.random.uniform(size=w_shape).astype('float32') + x = np.random.uniform(size=x_shape).astype("float32") + W = np.random.uniform(size=w_shape).astype("float32") tvm_out = get_tvm_output(model, [x, W], target, ctx, y_shape) - onnx_out = get_onnxruntime_output(model, [x, W], 'float32')[0] + onnx_out = get_onnxruntime_output(model, [x, W], "float32")[0] tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) @@ -2415,11 +2501,13 @@ def test_convtranspose(): @tvm.testing.uses_gpu def test_unsqueeze_constant(): from torch.nn import Linear, Sequential, Module + class Flatten(Module): def forward(self, input): return input.view(input.size(0), -1) import tempfile + with tempfile.NamedTemporaryFile() as fp: file_name = fp.name input_size = (1, 16, 32, 32) @@ -2428,156 +2516,185 @@ def forward(self, input): torch.onnx.export(layer, dummy_input, file_name, export_params=True) onnx_model = onnx.load(file_name) - relay.frontend.from_onnx(onnx_model, {'0': input_size}) + relay.frontend.from_onnx(onnx_model, {"0": input_size}) def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): - x_np = np.random.uniform(size=x_shape).astype('float32') + x_np = np.random.uniform(size=x_shape).astype("float32") - if mode == 'max': + if mode == "max": node_type = "MaxPool" - elif mode == 'average': + elif mode == "average": node_type = "AveragePool" else: raise ValueError("Pool method {} is not supported.".format(mode)) pool_node = helper.make_node( - node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides) + node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides + ) if pads is None: - pad_attr = helper.make_attribute('auto_pad', auto_pad) + pad_attr = helper.make_attribute("auto_pad", auto_pad) else: - pad_attr = helper.make_attribute('pads', pads) + pad_attr = helper.make_attribute("pads", pads) pool_node.attribute.append(pad_attr) - if mode == 'max': - storage_attr = helper.make_attribute('storage_order', 0) + if mode == "max": + storage_attr = helper.make_attribute("storage_order", 0) pool_node.attribute.append(storage_attr) - graph = helper.make_graph([pool_node], - "pooling_test", - inputs=[helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(x_shape))], - outputs=[helper.make_tensor_value_info("y", - TensorProto.FLOAT, list(out_shape))]) + graph = helper.make_graph( + [pool_node], + "pooling_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) - model = helper.make_model(graph, producer_name='pooling_test') + model = helper.make_model(graph, producer_name="pooling_test") for target, ctx in tvm.testing.enabled_targets(): - onnx_out = get_onnxruntime_output(model, x_np, 'float32') - tvm_out = get_tvm_output( - model, [x_np], target, ctx, out_shape) + onnx_out = get_onnxruntime_output(model, x_np, "float32") + tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu def test_pooling(): - for mode in ['max', 'average']: + for mode in ["max", "average"]: # Pool1D - verify_pooling(x_shape=[1, 1, 32], - kernel_shape=[3], - strides=[1], - pads=[1, 1], - out_shape=[1, 1, 32], - mode=mode) + verify_pooling( + x_shape=[1, 1, 32], + kernel_shape=[3], + strides=[1], + pads=[1, 1], + out_shape=[1, 1, 32], + mode=mode, + ) # Pool2D - verify_pooling(x_shape=[1, 1, 32, 32], - kernel_shape=[3, 3], - strides=[1, 1], - pads=[1, 1, 1, 1], - out_shape=[1, 1, 32, 32], - mode=mode) + verify_pooling( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + strides=[1, 1], + pads=[1, 1, 1, 1], + out_shape=[1, 1, 32, 32], + mode=mode, + ) # Pool1D with stride - verify_pooling(x_shape=[1, 1, 32], - kernel_shape=[3], - strides=[2], - pads=[1, 1], - out_shape=[1, 1, 16], - mode=mode) + verify_pooling( + x_shape=[1, 1, 32], + kernel_shape=[3], + strides=[2], + pads=[1, 1], + out_shape=[1, 1, 16], + mode=mode, + ) # Pool2D with stride - verify_pooling(x_shape=[1, 1, 32, 32], - kernel_shape=[3, 3], - strides=[2, 2], - pads=[1, 1, 1, 1], - out_shape=[1, 1, 16, 16], - mode=mode) + verify_pooling( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + strides=[2, 2], + pads=[1, 1, 1, 1], + out_shape=[1, 1, 16, 16], + mode=mode, + ) # Pool1D with stride and autopadding - verify_pooling(x_shape=[1, 1, 32], - kernel_shape=[3], - strides=[2], - pads=None, - out_shape=[1, 1, 16], - mode=mode, - auto_pad='SAME_UPPER') + verify_pooling( + x_shape=[1, 1, 32], + kernel_shape=[3], + strides=[2], + pads=None, + out_shape=[1, 1, 16], + mode=mode, + auto_pad="SAME_UPPER", + ) # Pool2D with stride and autopadding - verify_pooling(x_shape=[1, 1, 32, 32], - kernel_shape=[3, 3], - strides=[2, 2], - pads=None, - out_shape=[1, 1, 16, 16], - mode=mode, - auto_pad='SAME_UPPER') + verify_pooling( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + strides=[2, 2], + pads=None, + out_shape=[1, 1, 16, 16], + mode=mode, + auto_pad="SAME_UPPER", + ) # Pool3D with stride - verify_pooling(x_shape=[1, 1, 32, 32, 32], - kernel_shape=[3, 3, 3], - strides=[2, 2, 2], - pads=[1, 1, 1, 1, 1, 1], - out_shape=[1, 1, 16, 16, 16], - mode=mode) + verify_pooling( + x_shape=[1, 1, 32, 32, 32], + kernel_shape=[3, 3, 3], + strides=[2, 2, 2], + pads=[1, 1, 1, 1, 1, 1], + out_shape=[1, 1, 16, 16, 16], + mode=mode, + ) # Pool3D with stride and autopadding - verify_pooling(x_shape=[1, 1, 32, 32, 32], - kernel_shape=[3, 3, 3], - strides=[2, 2, 2], - pads=None, - out_shape=[1, 1, 16, 16, 16], - mode=mode, - auto_pad='SAME_UPPER') + verify_pooling( + x_shape=[1, 1, 32, 32, 32], + kernel_shape=[3, 3, 3], + strides=[2, 2, 2], + pads=None, + out_shape=[1, 1, 16, 16, 16], + mode=mode, + auto_pad="SAME_UPPER", + ) -def verify_mod(x_shape, y_shape, fmod, out_shape, dtype='float32'): +def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"): x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype) y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype) - y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error + y_np = np.where(y_np == 0, 1, y_np) # remove 0's to avoid division by zero error - mod_node = helper.make_node("Mod", - inputs=["x", "y"], - outputs=["z"], - fmod=fmod) + mod_node = helper.make_node("Mod", inputs=["x", "y"], outputs=["z"], fmod=fmod) onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32 - graph = helper.make_graph([mod_node], - "mod_test", - inputs=[helper.make_tensor_value_info("x", - onnx_dtype, list(x_shape)), - helper.make_tensor_value_info("y", - onnx_dtype, list(y_shape))], - outputs=[helper.make_tensor_value_info("z", - onnx_dtype, list(out_shape))]) - model = helper.make_model(graph, producer_name='mod_test') + graph = helper.make_graph( + [mod_node], + "mod_test", + inputs=[ + helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="mod_test") onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0] for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x_np, y_np], target, ctx, out_shape) + tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu def test_mod(): # Mod - verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32") - verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=0, out_shape=(1, 32, 32, 32), dtype="int32") + verify_mod( + x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32" + ) + verify_mod( + x_shape=[1, 32, 32, 32], + y_shape=[1, 32, 32, 32], + fmod=0, + out_shape=(1, 32, 32, 32), + dtype="int32", + ) # fmod - verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32") + verify_mod( + x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32" + ) verify_mod(x_shape=[1, 1, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 1, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) - verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32), dtype="int32") + verify_mod( + x_shape=[1, 32, 32, 32], + y_shape=[1, 32, 32, 32], + fmod=1, + out_shape=(1, 32, 32, 32), + dtype="int32", + ) verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) @@ -2588,24 +2705,22 @@ def verify_xor(x_shape, y_shape): np_out = np.logical_xor(x_np, y_np) out_shape = np_out.shape - xor_node = helper.make_node("Xor", - inputs=["x", "y"], - outputs=["z"]) + xor_node = helper.make_node("Xor", inputs=["x", "y"], outputs=["z"]) onnx_dtype = TensorProto.BOOL - graph = helper.make_graph([xor_node], - "xor_test", - inputs=[helper.make_tensor_value_info("x", - onnx_dtype, list(x_shape)), - helper.make_tensor_value_info("y", - onnx_dtype, list(y_shape))], - outputs=[helper.make_tensor_value_info("z", - onnx_dtype, list(out_shape))]) - model = helper.make_model(graph, producer_name='xor_test') + graph = helper.make_graph( + [xor_node], + "xor_test", + inputs=[ + helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="xor_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x_np, y_np], target, ctx, out_shape) + tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape) tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5) @@ -2619,199 +2734,248 @@ def test_xor(): def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape): - x_np = np.random.uniform(size=x_shape).astype('float32') - rois_np = np.random.uniform(size=rois_shape).astype('float32') + x_np = np.random.uniform(size=x_shape).astype("float32") + rois_np = np.random.uniform(size=rois_shape).astype("float32") if spatial_scale is None: - pool_node = helper.make_node("MaxRoiPool", - inputs=["x", "rois"], - outputs=["y"], - pooled_shape=pooled_shape) + pool_node = helper.make_node( + "MaxRoiPool", inputs=["x", "rois"], outputs=["y"], pooled_shape=pooled_shape + ) else: - pool_node = helper.make_node("MaxRoiPool", - inputs=["x", "rois"], - outputs=["y"], - pooled_shape=pooled_shape, - spatial_scale=spatial_scale) - - graph = helper.make_graph([pool_node], - "pool_test", - inputs=[helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("rois", - TensorProto.FLOAT, list(rois_shape))], - outputs=[helper.make_tensor_value_info("y", - TensorProto.FLOAT, list(out_shape))]) - - model = helper.make_model(graph, producer_name='pool_test') - - onnx_out = get_onnxruntime_output(model, [x_np, rois_np], 'float32')[0] + pool_node = helper.make_node( + "MaxRoiPool", + inputs=["x", "rois"], + outputs=["y"], + pooled_shape=pooled_shape, + spatial_scale=spatial_scale, + ) + + graph = helper.make_graph( + [pool_node], + "pool_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("rois", TensorProto.FLOAT, list(rois_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="pool_test") + + onnx_out = get_onnxruntime_output(model, [x_np, rois_np], "float32")[0] for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x_np, rois_np], target, ctx, out_shape) + tvm_out = get_tvm_output(model, [x_np, rois_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu def test_max_roi_pool(): - verify_max_roi_pool(x_shape=[1, 3, 6, 6], - rois_shape=[3, 5], - pooled_shape=[1, 1], - spatial_scale=None, - out_shape=[3, 3, 1, 1]) + verify_max_roi_pool( + x_shape=[1, 3, 6, 6], + rois_shape=[3, 5], + pooled_shape=[1, 1], + spatial_scale=None, + out_shape=[3, 3, 1, 1], + ) - verify_max_roi_pool(x_shape=[1, 3, 10, 10], - rois_shape=[4, 5], - pooled_shape=[2, 2], - spatial_scale=2.0, - out_shape=[4, 3, 2, 2]) + verify_max_roi_pool( + x_shape=[1, 3, 10, 10], + rois_shape=[4, 5], + pooled_shape=[2, 2], + spatial_scale=2.0, + out_shape=[4, 3, 2, 2], + ) def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"): - x_np = np.random.uniform(size=x_shape).astype('float32') + x_np = np.random.uniform(size=x_shape).astype("float32") if pads is None: - pool_node = helper.make_node("LpPool", - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - p = p, - auto_pad=auto_pad, - strides=strides) + pool_node = helper.make_node( + "LpPool", + inputs=["x"], + outputs=["y"], + kernel_shape=kernel_shape, + p=p, + auto_pad=auto_pad, + strides=strides, + ) else: - pool_node = helper.make_node("LpPool", - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - p = p, - pads=pads, - strides=strides) - - graph = helper.make_graph([pool_node], - "lppool_test", - inputs=[helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(x_shape))], - outputs=[helper.make_tensor_value_info("y", - TensorProto.FLOAT, list(out_shape))]) - - model = helper.make_model(graph, producer_name='lppool_test') + pool_node = helper.make_node( + "LpPool", + inputs=["x"], + outputs=["y"], + kernel_shape=kernel_shape, + p=p, + pads=pads, + strides=strides, + ) + + graph = helper.make_graph( + [pool_node], + "lppool_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="lppool_test") for target, ctx in tvm.testing.enabled_targets(): - onnx_out = get_onnxruntime_output(model, x_np, 'float32') - tvm_out = get_tvm_output( - model, [x_np], target, ctx, out_shape) + onnx_out = get_onnxruntime_output(model, x_np, "float32") + tvm_out = get_tvm_output(model, [x_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu def test_lppool(): # Pool1D - verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1], - out_shape=[1, 1, 32]) + verify_lppool( + x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1], out_shape=[1, 1, 32] + ) # Pool2D - verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[1, 1], - pads=[1, 1, 1, 1], out_shape=[1, 1, 32, 32]) + verify_lppool( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + p=2, + strides=[1, 1], + pads=[1, 1, 1, 1], + out_shape=[1, 1, 32, 32], + ) # Pool1D with stride - verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=[1, 1], - out_shape=[1, 1, 16]) + verify_lppool( + x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=[1, 1], out_shape=[1, 1, 16] + ) # Pool2D with stride - verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2], - pads=[1, 1, 1, 1], out_shape=[1, 1, 16, 16]) + verify_lppool( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + p=2, + strides=[2, 2], + pads=[1, 1, 1, 1], + out_shape=[1, 1, 16, 16], + ) # Pool1D with stride and autopadding - verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=None, - out_shape=[1, 1, 16], auto_pad='SAME_UPPER') + verify_lppool( + x_shape=[1, 1, 32], + kernel_shape=[3], + p=2, + strides=[2], + pads=None, + out_shape=[1, 1, 16], + auto_pad="SAME_UPPER", + ) # Pool2D with stride and autopadding - verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2], - pads=None, out_shape=[1, 1, 16, 16], auto_pad='SAME_UPPER') + verify_lppool( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + p=2, + strides=[2, 2], + pads=None, + out_shape=[1, 1, 16, 16], + auto_pad="SAME_UPPER", + ) # Pool3D with stride - verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2], - pads=[1, 1, 1, 1, 1, 1], out_shape=[1, 1, 16, 16, 16]) + verify_lppool( + x_shape=[1, 1, 32, 32, 32], + kernel_shape=[3, 3, 3], + p=2, + strides=[2, 2, 2], + pads=[1, 1, 1, 1, 1, 1], + out_shape=[1, 1, 16, 16, 16], + ) # Pool3D with stride and autopadding - verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2], - pads=None, out_shape=[1, 1, 16, 16, 16], auto_pad='SAME_UPPER') - - -def verify_rnn(seq_length, - batch_size, - input_size, - hidden_size, - rnn_type='LSTM', - use_bias=False, - activations=None, - alphas=None, - betas=None, - use_initial_state=False, - use_peep=False, - linear_before_reset=False): - if rnn_type == 'LSTM': + verify_lppool( + x_shape=[1, 1, 32, 32, 32], + kernel_shape=[3, 3, 3], + p=2, + strides=[2, 2, 2], + pads=None, + out_shape=[1, 1, 16, 16, 16], + auto_pad="SAME_UPPER", + ) + + +def verify_rnn( + seq_length, + batch_size, + input_size, + hidden_size, + rnn_type="LSTM", + use_bias=False, + activations=None, + alphas=None, + betas=None, + use_initial_state=False, + use_peep=False, + linear_before_reset=False, +): + if rnn_type == "LSTM": multiplier = 4 - elif rnn_type == 'GRU': + elif rnn_type == "GRU": multiplier = 3 else: raise NotImplementedError("%s RNNs not yet supported." % rnn_type) - x_np = np.random.uniform(size=(seq_length, batch_size, - input_size)).astype('float32') - w_np = np.random.uniform(size=(1, multiplier * hidden_size, - input_size)).astype('float32') - r_np = np.random.uniform(size=(1, multiplier * hidden_size, - hidden_size)).astype('float32') + x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") + w_np = np.random.uniform(size=(1, multiplier * hidden_size, input_size)).astype("float32") + r_np = np.random.uniform(size=(1, multiplier * hidden_size, hidden_size)).astype("float32") input_names = ["X", "W", "R"] input_tensors = [ helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)), helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)), - helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape)) + helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape)), ] input_values = [x_np, w_np, r_np] if use_bias: - b_np = np.random.uniform(size=(1, multiplier * 2 * - hidden_size)).astype('float32') + b_np = np.random.uniform(size=(1, multiplier * 2 * hidden_size)).astype("float32") input_names.append("B") input_tensors.append( - helper.make_tensor_value_info("B", TensorProto.FLOAT, - [1, multiplier * 2 * hidden_size])) + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, multiplier * 2 * hidden_size]) + ) input_values.append(b_np) if use_initial_state: assert use_bias == True, "Initial states must have bias specified." - sequence_np = np.repeat(seq_length, batch_size).astype('int32') + sequence_np = np.repeat(seq_length, batch_size).astype("int32") input_names.append("sequence_lens") input_tensors.append( - helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, - [batch_size])) + helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size]) + ) input_values.append(sequence_np) - initial_h_np = np.random.uniform(size=(1, batch_size, - hidden_size)).astype('float32') + initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32") input_names.append("initial_h") input_tensors.append( - helper.make_tensor_value_info("initial_h", TensorProto.FLOAT, - [1, batch_size, hidden_size])) + helper.make_tensor_value_info( + "initial_h", TensorProto.FLOAT, [1, batch_size, hidden_size] + ) + ) input_values.append(initial_h_np) - if rnn_type == 'LSTM': - initial_c_np = np.random.uniform( - size=(1, batch_size, hidden_size)).astype('float32') + if rnn_type == "LSTM": + initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32") input_names.append("initial_c") input_tensors.append( - helper.make_tensor_value_info("initial_c", TensorProto.FLOAT, - [1, batch_size, hidden_size])) + helper.make_tensor_value_info( + "initial_c", TensorProto.FLOAT, [1, batch_size, hidden_size] + ) + ) input_values.append(initial_c_np) - if use_peep and rnn_type == 'LSTM': + if use_peep and rnn_type == "LSTM": assert use_initial_state == True, "Peepholes require initial state to be specified." - p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype('float32') + p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype("float32") input_names.append("P") input_tensors.append( - helper.make_tensor_value_info("P", TensorProto.FLOAT, - [1, 3 * hidden_size])) + helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size]) + ) input_values.append(p_np) Y_shape = [seq_length, 1, batch_size, hidden_size] @@ -2819,49 +2983,48 @@ def verify_rnn(seq_length, outputs = ["Y", "Y_h"] graph_outputs = [ helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(Y_shape)), - helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, list(Y_h_shape)) + helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, list(Y_h_shape)), ] output_shapes = [Y_shape, Y_h_shape] - if rnn_type == 'LSTM': + if rnn_type == "LSTM": Y_c_shape = [1, batch_size, hidden_size] outputs.append("Y_c") graph_outputs.append( - helper.make_tensor_value_info("Y_c", TensorProto.FLOAT, - list(Y_c_shape))) + helper.make_tensor_value_info("Y_c", TensorProto.FLOAT, list(Y_c_shape)) + ) output_shapes.append(Y_c_shape) rnn_node = helper.make_node( - rnn_type, inputs=input_names, outputs=outputs, hidden_size=hidden_size) + rnn_type, inputs=input_names, outputs=outputs, hidden_size=hidden_size + ) if activations is not None: - activations_attr = helper.make_attribute('activations', activations) + activations_attr = helper.make_attribute("activations", activations) rnn_node.attribute.append(activations_attr) if alphas is not None: - alphas_attr = helper.make_attribute('activation_alpha', alphas) + alphas_attr = helper.make_attribute("activation_alpha", alphas) rnn_node.attribute.append(alphas_attr) if betas is not None: - betas_attr = helper.make_attribute('activation_beta', betas) + betas_attr = helper.make_attribute("activation_beta", betas) rnn_node.attribute.append(betas_attr) - if linear_before_reset and rnn_type == 'GRU': - lbr_attr = helper.make_attribute('linear_before_reset', 1) + if linear_before_reset and rnn_type == "GRU": + lbr_attr = helper.make_attribute("linear_before_reset", 1) rnn_node.attribute.append(lbr_attr) - graph = helper.make_graph([rnn_node], - "rnn_test", - inputs=input_tensors, - outputs=graph_outputs) + graph = helper.make_graph([rnn_node], "rnn_test", inputs=input_tensors, outputs=graph_outputs) - model = helper.make_model(graph, producer_name='rnn_test') + model = helper.make_model(graph, producer_name="rnn_test") for target, ctx in tvm.testing.enabled_targets(): - onnx_out = get_onnxruntime_output(model, input_values, 'float32') + onnx_out = get_onnxruntime_output(model, input_values, "float32") tvm_out = get_tvm_output( model, input_values, target, ctx, output_shapes, - output_dtype=['float32'] * len(output_shapes)) + output_dtype=["float32"] * len(output_shapes), + ) for o_out, t_out in zip(onnx_out, tvm_out): tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3) @@ -2870,52 +3033,28 @@ def verify_rnn(seq_length, def test_lstm(): # No bias. verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - rnn_type='LSTM') + seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, rnn_type="LSTM" + ) # large batch. verify_rnn( - seq_length=4, - batch_size=8, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type='LSTM') + seq_length=4, batch_size=8, input_size=16, hidden_size=32, use_bias=True, rnn_type="LSTM" + ) # Non power of two. verify_rnn( - seq_length=3, - batch_size=3, - input_size=16, - hidden_size=40, - use_bias=True, - rnn_type='LSTM') + seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True, rnn_type="LSTM" + ) # Long sequence. verify_rnn( - seq_length=8, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type='LSTM') + seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True, rnn_type="LSTM" + ) # Large hidden. verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=128, - use_bias=True, - rnn_type='LSTM') + seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True, rnn_type="LSTM" + ) # Large input. verify_rnn( - seq_length=2, - batch_size=1, - input_size=64, - hidden_size=32, - use_bias=True, - rnn_type='LSTM') + seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True, rnn_type="LSTM" + ) # Different activation testing. # Default value hardsigmoid. @@ -2925,8 +3064,9 @@ def test_lstm(): input_size=16, hidden_size=32, use_bias=False, - activations=['HardSigmoid', 'Tanh', 'Tanh'], - rnn_type='LSTM') + activations=["HardSigmoid", "Tanh", "Tanh"], + rnn_type="LSTM", + ) # Multiple parameterized activations. verify_rnn( seq_length=2, @@ -2934,10 +3074,11 @@ def test_lstm(): input_size=16, hidden_size=32, use_bias=False, - activations=['HardSigmoid', 'LeakyRelu', 'Tanh'], + activations=["HardSigmoid", "LeakyRelu", "Tanh"], alphas=[2.0, 0.5], - betas=[.3], - rnn_type='LSTM') + betas=[0.3], + rnn_type="LSTM", + ) # All parameterized with new Affine activation. verify_rnn( seq_length=2, @@ -2945,10 +3086,11 @@ def test_lstm(): input_size=16, hidden_size=32, use_bias=False, - activations=['HardSigmoid', 'LeakyRelu', 'Affine'], + activations=["HardSigmoid", "LeakyRelu", "Affine"], alphas=[2.0, 0.5, 0.8], - betas=[.3, 0.1], - rnn_type='LSTM') + betas=[0.3, 0.1], + rnn_type="LSTM", + ) # Testing with initial state and peepholes verify_rnn( @@ -2958,7 +3100,8 @@ def test_lstm(): hidden_size=32, use_bias=True, use_initial_state=True, - rnn_type='LSTM') + rnn_type="LSTM", + ) verify_rnn( seq_length=2, @@ -2968,19 +3111,16 @@ def test_lstm(): use_bias=True, use_initial_state=True, use_peep=True, - rnn_type='LSTM') + rnn_type="LSTM", + ) @tvm.testing.uses_gpu def test_gru(): # No bias. verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - rnn_type='GRU') + seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, rnn_type="GRU" + ) # large batch. verify_rnn( seq_length=4, @@ -2988,40 +3128,25 @@ def test_gru(): input_size=16, hidden_size=32, use_bias=True, - rnn_type='GRU', - linear_before_reset=True) + rnn_type="GRU", + linear_before_reset=True, + ) # Non power of two. verify_rnn( - seq_length=3, - batch_size=3, - input_size=16, - hidden_size=40, - use_bias=True, - rnn_type='GRU') + seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True, rnn_type="GRU" + ) # Long sequence. verify_rnn( - seq_length=8, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type='GRU') + seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True, rnn_type="GRU" + ) # Large hidden. verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=128, - use_bias=True, - rnn_type='GRU') + seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True, rnn_type="GRU" + ) # Large input. verify_rnn( - seq_length=2, - batch_size=1, - input_size=64, - hidden_size=32, - use_bias=True, - rnn_type='GRU') + seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True, rnn_type="GRU" + ) # Different activation testing. # Default value hardsigmoid. @@ -3031,8 +3156,9 @@ def test_gru(): input_size=16, hidden_size=32, use_bias=False, - activations=['HardSigmoid', 'Softsign'], - rnn_type='GRU') + activations=["HardSigmoid", "Softsign"], + rnn_type="GRU", + ) # Multiple parameterized activations. verify_rnn( seq_length=2, @@ -3040,10 +3166,11 @@ def test_gru(): input_size=16, hidden_size=32, use_bias=False, - activations=['HardSigmoid', 'LeakyRelu'], + activations=["HardSigmoid", "LeakyRelu"], alphas=[2.0, 0.5], - betas=[.3], - rnn_type='GRU') + betas=[0.3], + rnn_type="GRU", + ) # All parameterized with new Affine activation. verify_rnn( seq_length=2, @@ -3051,10 +3178,11 @@ def test_gru(): input_size=16, hidden_size=32, use_bias=False, - activations=['HardSigmoid', 'Affine'], + activations=["HardSigmoid", "Affine"], alphas=[2.0, 0.8], - betas=[.3, 0.1], - rnn_type='GRU') + betas=[0.3, 0.1], + rnn_type="GRU", + ) # Testing with initial state verify_rnn( @@ -3064,7 +3192,8 @@ def test_gru(): hidden_size=32, use_bias=True, use_initial_state=True, - rnn_type='GRU') + rnn_type="GRU", + ) # TODO(mbrookhart): enable once VM supports heterogenous execution @@ -3072,33 +3201,39 @@ def test_gru(): def test_resize(): def verify(ishape, oshape, scales, mode, coord_trans): nodes = [ - make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []), - make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales) + make_constant_node("roi", onnx.TensorProto.FLOAT, (0,), []), + make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), ] - input_names = ['X', 'roi', 'scales'] + input_names = ["X", "roi", "scales"] if oshape != []: - nodes.append(make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape)) - input_names.append('sizes') - nodes.append(helper.make_node( - 'Resize', - inputs=input_names, - outputs=['Y'], - mode=mode, - coordinate_transformation_mode=coord_trans - )) + nodes.append( + make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape) + ) + input_names.append("sizes") + nodes.append( + helper.make_node( + "Resize", + inputs=input_names, + outputs=["Y"], + mode=mode, + coordinate_transformation_mode=coord_trans, + ) + ) if oshape == []: oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] - graph = helper.make_graph(nodes, - "resize_test", - inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], - outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)]) + graph = helper.make_graph( + nodes, + "resize_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)], + ) - model = helper.make_model(graph, producer_name='resize_test') + model = helper.make_model(graph, producer_name="resize_test") for target, ctx in tvm.testing.enabled_targets(): - x = np.random.uniform(size=ishape).astype('float32') - onnx_out = get_onnxruntime_output(model, x, 'float32') + x = np.random.uniform(size=ishape).astype("float32") + onnx_out = get_onnxruntime_output(model, x, "float32") tvm_out = get_tvm_output_with_vm(model, x, target, ctx, opset=11, freeze_params=True) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) @@ -3118,22 +3253,25 @@ def verify(ishape, oshape, scales, mode, coord_trans): @tvm.testing.uses_gpu def test_nonzero(): - def verify_nonzero(indata, outdata, dtype): - node = helper.make_node('NonZero', - inputs=['X'], - outputs=['Y'],) + node = helper.make_node( + "NonZero", + inputs=["X"], + outputs=["Y"], + ) - graph = helper.make_graph([node], - "nonzero_test", - inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))], - outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))]) + graph = helper.make_graph( + [node], + "nonzero_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name='nonzero_test') + model = helper.make_model(graph, producer_name="nonzero_test") onnx_out = get_onnxruntime_output(model, indata, dtype) - for target, ctx in [('llvm', tvm.cpu())]: + for target, ctx in [("llvm", tvm.cpu())]: tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) @@ -3145,30 +3283,42 @@ def verify_nonzero(indata, outdata, dtype): result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]] verify_nonzero(input_data, result, dtype=np.int64) + @tvm.testing.uses_gpu def test_topk(): def verify_topk(input_dims, K, axis=-1): output_dims = list(input_dims) output_dims[axis] = K - node = helper.make_node('TopK', - inputs=['X', 'K'], - outputs=['Values', 'Indicies'], - axis=axis) + node = helper.make_node( + "TopK", inputs=["X", "K"], outputs=["Values", "Indicies"], axis=axis + ) - graph = helper.make_graph([node], - "topk_test", - inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), - helper.make_tensor_value_info("K", TensorProto.INT64, [1,])], - outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), - helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)]) + graph = helper.make_graph( + [node], + "topk_test", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), + helper.make_tensor_value_info( + "K", + TensorProto.INT64, + [ + 1, + ], + ), + ], + outputs=[ + helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), + helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims), + ], + ) - model = helper.make_model(graph, producer_name='topk_test') + model = helper.make_model(graph, producer_name="topk_test") indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) onnx_out = get_onnxruntime_output(model, [indata, np.array([K])]) - for target, ctx in [('llvm', tvm.cpu())]: + for target, ctx in [("llvm", tvm.cpu())]: tvm_out = get_tvm_output_with_vm(model, [indata, np.array(K)], target, ctx) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) @@ -3184,68 +3334,70 @@ def verify_topk(input_dims, K, axis=-1): @tvm.testing.uses_gpu def test_roi_align(): - def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0): + def verify_roi_align( + input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0 + ): output_dims = [num_roi, input_dims[1], output_height, output_width] - node = helper.make_node('RoiAlign', - inputs=['X', 'rois', 'batch_indicies'], - outputs=['Y'], - mode="avg", - output_height=output_height, - output_width=output_width, - sampling_ratio=sampling_ratio, - spatial_scale=spatial_scale, - ) - - graph = helper.make_graph([node], - "roialign_test", - inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), - helper.make_tensor_value_info( - "rois", TensorProto.FLOAT, [num_roi, 4]), - helper.make_tensor_value_info( - "batch_indicies", TensorProto.INT64, [num_roi, ]), - ], - outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)]) - - model = helper.make_model(graph, producer_name='roialign_test') + node = helper.make_node( + "RoiAlign", + inputs=["X", "rois", "batch_indicies"], + outputs=["Y"], + mode="avg", + output_height=output_height, + output_width=output_width, + sampling_ratio=sampling_ratio, + spatial_scale=spatial_scale, + ) + + graph = helper.make_graph( + [node], + "roialign_test", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), + helper.make_tensor_value_info("rois", TensorProto.FLOAT, [num_roi, 4]), + helper.make_tensor_value_info( + "batch_indicies", + TensorProto.INT64, + [ + num_roi, + ], + ), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)], + ) + + model = helper.make_model(graph, producer_name="roialign_test") np_data = np.random.uniform(size=input_dims).astype("float32") - np_rois = np.random.uniform(size=[num_roi, 4]).astype( - 'float32') * input_dims[2] - np_batch_indicies = np.random.randint( - low=0, high=input_dims[0], size=num_roi) - - onnx_out = get_onnxruntime_output( - model, [np_data, np_rois, np_batch_indicies]) - for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims, - output_dtype='float32') - tvm.testing.assert_allclose( - onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) - - verify_roi_align((1, 4, 16, 16), 32, 7, 7, - sampling_ratio=0, spatial_scale=1.0) - verify_roi_align((4, 4, 16, 32), 32, 7, 7, - sampling_ratio=0, spatial_scale=1.0) - verify_roi_align((1, 8, 16, 16), 32, 7, 7, - sampling_ratio=0, spatial_scale=1.0) - verify_roi_align((1, 4, 8, 8), 32, 7, 7, - sampling_ratio=0, spatial_scale=1.0) - verify_roi_align((1, 4, 16, 16), 16, 5, 7, - sampling_ratio=0, spatial_scale=1.0) - verify_roi_align((1, 4, 16, 12), 8, 7, 3, - sampling_ratio=0, spatial_scale=1.0) - verify_roi_align((1, 4, 16, 16), 32, 7, 7, - sampling_ratio=0, spatial_scale=0.5) - verify_roi_align((3, 4, 12, 16), 32, 7, 7, - sampling_ratio=0, spatial_scale=1.5) - verify_roi_align((5, 4, 16, 14), 32, 7, 7, - sampling_ratio=1, spatial_scale=1.0) - verify_roi_align((1, 4, 16, 16), 32, 7, 7, - sampling_ratio=2, spatial_scale=1.0) - - -if __name__ == '__main__': + np_rois = np.random.uniform(size=[num_roi, 4]).astype("float32") * input_dims[2] + np_batch_indicies = np.random.randint(low=0, high=input_dims[0], size=num_roi) + + onnx_out = get_onnxruntime_output(model, [np_data, np_rois, np_batch_indicies]) + for target, ctx in [("llvm", tvm.cpu())]: + tvm_out = get_tvm_output( + model, + [np_data, np_rois, np_batch_indicies], + target, + ctx, + output_dims, + output_dtype="float32", + ) + tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + + verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 8, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 8, 8), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 16, 5, 7, sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 12), 8, 7, 3, sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=0.5) + verify_roi_align((3, 4, 12, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.5) + verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) + + +if __name__ == "__main__": test_flatten() test_reshape() test_shape() diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index 8f8eb2ecbc9e..18e1dd5bb72e 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -32,14 +32,14 @@ def test_broadcast_to(): def verify_more_dynamic_broadcast_to(x_shape, out_shape): rank = len(out_shape) - dtype = 'float32' - shape_type = 'int64' - reshape_shape = relay.Var("shape", relay.ty.TensorType((len(x_shape), ), shape_type)) - broadcast_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) - x = relay.Var("x", relay.ty.TensorType((np.prod(x_shape), ), dtype)) + dtype = "float32" + shape_type = "int64" + reshape_shape = relay.Var("shape", relay.ty.TensorType((len(x_shape),), shape_type)) + broadcast_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) + x = relay.Var("x", relay.ty.TensorType((np.prod(x_shape),), dtype)) r = relay.reshape(x, reshape_shape) z = relay.broadcast_to(r, broadcast_shape) - + func = relay.Function([x, reshape_shape, broadcast_shape], z) x = np.random.uniform(size=np.prod(x_shape)).astype(dtype) @@ -48,21 +48,23 @@ def verify_more_dynamic_broadcast_to(x_shape, out_shape): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type)) + op_res = intrp.evaluate(func)( + x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type) + ) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3)) def verify_broadcast_to(x_shape, out_shape): rank = len(out_shape) - dtype = 'float32' - shape_type = 'int64' - dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) + dtype = "float32" + shape_type = "int64" + dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) z = relay.broadcast_to(x, dyn_shape) zz = run_infer_type(z) - assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype) + assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype) func = relay.Function([x, dyn_shape], z) @@ -79,6 +81,7 @@ def verify_broadcast_to(x_shape, out_shape): verify_broadcast_to((1, 1), (4, 1, 1)) verify_broadcast_to((4, 1), (1, 4, 3)) + # TODO(mbrookhart): Enable when the VM supports heterogenus execution # @tvm.testing.uses_gpu def test_dyn_one_hot(): @@ -112,8 +115,8 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) - _verify((3, ), 3, 1, 0, -1, "int32") - _verify((3, ), 3, 1.0, 0.0, -1, "float32") + _verify((3,), 3, 1, 0, -1, "int32") + _verify((3,), 3, 1.0, 0.0, -1, "float32") _verify((2, 2), 5, 2, -2, 0, "int32") _verify((2, 2), 5, 0.5, -0.5, 1, "float32") _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index d9cd4d8d2b3e..8ad40a617b34 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -47,17 +47,20 @@ def test_checkpoint(): f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs) tvm.testing.assert_allclose(f_res.asnumpy(), f_checkpoint_res.asnumpy(), 0, 0) + def test_checkpoint_alpha_equal(): xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] - f = relay.Function(xs, relay.annotation.checkpoint( - relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])) - )) + f = relay.Function( + xs, + relay.annotation.checkpoint( + relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])) + ), + ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): - passes = [transform.PartialEvaluate(), - transform.DeadCodeElimination(inline_once=True)] + passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] @@ -103,17 +106,20 @@ def test_checkpoint_alpha_equal(): tvm.ir.assert_structural_equal(df, df_parsed) + def test_checkpoint_alpha_equal_tuple(): xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] - f = relay.Function(xs, relay.annotation.checkpoint( - relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])]) - )) + f = relay.Function( + xs, + relay.annotation.checkpoint( + relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])]) + ), + ) df = transform.gradient(run_infer_type(f)) # run PE and DCE with tvm.transform.PassContext(opt_level=3): - passes = [transform.PartialEvaluate(), - transform.DeadCodeElimination(inline_once=True)] + passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] @@ -150,12 +156,13 @@ def test_checkpoint_alpha_equal_tuple(): tvm.ir.assert_structural_equal(df, df_parsed) + @tvm.testing.uses_gpu def test_collapse_sum_like(): shape = (3, 4, 5, 6) shape_like = (4, 5, 6) dtype = "float32" - x = relay.Var("x", relay.ty.TensorType(shape , dtype)) + x = relay.Var("x", relay.ty.TensorType(shape, dtype)) y = relay.Var("y", relay.ty.TensorType(shape_like, dtype)) z = relay.collapse_sum_like(x, y) zz = run_infer_type(z) @@ -177,7 +184,7 @@ def test_collapse_sum_to(): shape = (3, 4, 5, 6) shape_to = (4, 5, 6) dtype = "float32" - x = relay.Var("x", relay.ty.TensorType(shape , dtype)) + x = relay.Var("x", relay.ty.TensorType(shape, dtype)) z = relay.collapse_sum_to(x, shape_to) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(shape_to, dtype) @@ -197,7 +204,7 @@ def test_broadcast_to(): shape = (4, 1, 6) shape_like = (3, 4, 5, 6) dtype = "float32" - x = relay.Var("x", relay.ty.TensorType(shape , dtype)) + x = relay.Var("x", relay.ty.TensorType(shape, dtype)) z = relay.broadcast_to(x, shape=shape_like) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(shape_like, dtype) @@ -211,12 +218,13 @@ def test_broadcast_to(): op_res = intrp.evaluate(func)(x) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + @tvm.testing.uses_gpu def test_broadcast_to_like(): shape = (4, 1, 6) shape_like = (3, 4, 5, 6) dtype = "float32" - x = relay.Var("x", relay.ty.TensorType(shape , dtype)) + x = relay.Var("x", relay.ty.TensorType(shape, dtype)) y = relay.Var("y", relay.ty.TensorType(shape_like, dtype)) z = relay.broadcast_to_like(x, y) @@ -263,8 +271,9 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): assert "axes" in z.astext() assert zz.checked_type == relay.ty.TensorType(output, dtype) - if all(isinstance(v, int) == 0 for v in data) or \ - all(isinstance(v, int) == 0 for v in slice_like): + if all(isinstance(v, int) == 0 for v in data) or all( + isinstance(v, int) == 0 for v in slice_like + ): return func = relay.Function([x, y], z) @@ -278,20 +287,21 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): op_res = intrp.evaluate(func)(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + @tvm.testing.uses_gpu def test_slice_like(): d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3)) verify_slice_like(data=(1, 2, 3), slice_like=(d1, d2, d3), axes=None, output=(d1, d2, d3)) - verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1,2), output=(d2, d2, d3)) + verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1, 2), output=(d2, d2, d3)) verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3)) verify_slice_like(data=(3, 4, 5), slice_like=(1, 2), axes=None, output=(1, 2, 5)) verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(1, 2), output=(3, 2, 3)) verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(-1, -3), output=(1, 4, 3)) - verify_slice_like(data=(1, 3, 224, 224), - slice_like=(1, 3, 112, 112), - axes=(2, 3), - output=(1, 3, 112, 112)) + verify_slice_like( + data=(1, 3, 224, 224), slice_like=(1, 3, 112, 112), axes=(2, 3), output=(1, 3, 112, 112) + ) + @tvm.testing.uses_gpu def test_reverse_reshape(): @@ -310,12 +320,14 @@ def verify_reverse_reshape(shape, newshape, oshape): intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) verify_reverse_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) verify_reverse_reshape((2, 3, 4), (0, -1), (3, 8)) verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4)) verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) + def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): x = relay.var("x", relay.TensorType(x_shape, dtype)) y = relay.var("y", relay.TensorType(y_shape, dtype)) @@ -334,6 +346,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): z = intrp.evaluate(func)(x_np, y_np) tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) + @tvm.testing.uses_gpu def test_batch_matmul(): b, m, n, k = te.size_var("b"), te.size_var("m"), te.size_var("n"), te.size_var("k") @@ -348,9 +361,10 @@ def test_batch_matmul(): verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) + def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): x = relay.var("x", relay.TensorType(x_shape, dtype)) - y = relay.var("y", relay.TensorType((relay.Any(), ) * len(y_shape), dtype)) + y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype)) z = relay.nn.batch_matmul(x, y) func = relay.Function([x, y], z) @@ -365,6 +379,7 @@ def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): z = intrp.evaluate()(x_np, y_np) tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) + # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_dynamic_batch_matmul(): @@ -380,15 +395,15 @@ def test_shape_of(): x = relay.var("x", shape=shape) func = relay.Function([x], relay.op.shape_of(x)) func = run_infer_type(func) - x_data = np.random.rand(*shape).astype('float32') + x_data = np.random.rand(*shape).astype("float32") for target, ctx in tvm.testing.enabled_targets(): # Because using graph executor, this op will be optimized after # constant folding pass, here we only test with interpreter for kind in ["debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), - np.array(shape).astype('int32')) + tvm.testing.assert_allclose(op_res.asnumpy(), np.array(shape).astype("int32")) + @tvm.testing.uses_gpu def test_ndarray_size(): @@ -403,8 +418,8 @@ def verify_ndarray_size(shape): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), - ref_res) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) + verify_ndarray_size((2, 3, 5)) verify_ndarray_size((2, 3, 5, 7)) @@ -467,9 +482,11 @@ def _verify(data_shape, mask_value, axis, dtype, itype): intrp = relay.create_executor(kind, ctx=ctx, target=target) out_relay = intrp.evaluate(func)(data_np, valid_length_np) tvm.testing.assert_allclose(out_relay.asnumpy(), gt_out_np) - _verify((5, 10), 0.0, 1, 'float32', 'int32') - _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64') - _verify((5, 8, 3), 0.1, 1, 'float64', 'float32') + + _verify((5, 10), 0.0, 1, "float32", "int32") + _verify((2, 3, 5, 3), 0.0, 0, "float32", "int64") + _verify((5, 8, 3), 0.1, 1, "float64", "float32") + @tvm.testing.uses_gpu def test_one_hot(): @@ -493,7 +510,9 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): off_value_const = relay.const(off_value) out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype) checked = run_infer_type(out) - assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype) + assert checked.checked_type == relay.ty.TensorType( + _get_oshape(indices_shape, depth, axis), dtype + ) func = relay.Function([indices], out) indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) @@ -511,6 +530,7 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + @tvm.testing.uses_gpu def test_matrix_set_diag(): def _verify(input_shape, dtype): @@ -535,9 +555,10 @@ def _verify(input_shape, dtype): out_relay = intrp.evaluate(func)(input_np, diagonal_np) tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) - _verify((2, 2), 'float32') - _verify((4, 3, 3), 'int32') - _verify((2, 3, 4), 'float32') + _verify((2, 2), "float32") + _verify((4, 3, 3), "int32") + _verify((2, 3, 4), "float32") + if __name__ == "__main__": test_adaptive_pool() From 2892e6a5ec43f05563dafc6f7c988933836795d0 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 11 Sep 2020 13:35:25 -0600 Subject: [PATCH 05/17] fix batch matmul test --- tests/python/frontend/onnx/test_forward.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dc9eb67baa20..6457319a9200 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -858,7 +858,7 @@ def test_matmul(): tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -def verify_batch_matmul(a_shape, b_shape): +def verify_batch_matmul(a_shape, b_shape, target, ctx): a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") out_np = np.matmul(a_array, b_array) @@ -877,17 +877,16 @@ def verify_batch_matmul(a_shape, b_shape): model = helper.make_model(graph, producer_name="matmul_test") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) # TODO(mbrookhart): enable cuda once VM supports heterogenous execution @tvm.testing.parametrize_targets("llvm") -def test_batch_matmul(): - verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4)) - verify_batch_matmul((2, 4, 3), (3, 4)) - verify_batch_matmul((2, 3, 4, 3), (3, 4)) +def test_batch_matmul(target, ctx): + verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) + verify_batch_matmul((2, 4, 3), (3, 4), target, ctx) + verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx) def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): From 1fc37215270ff450a56f468d198fc2c918df5184 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 11 Sep 2020 14:19:12 -0600 Subject: [PATCH 06/17] add dynamic strided slice to the onnx importer --- python/tvm/relay/frontend/onnx.py | 47 +++++++++++++--------- tests/python/frontend/onnx/test_forward.py | 5 ++- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index aa1ed7cb8e96..464e3470c9ce 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -30,7 +30,7 @@ from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels -from .common import infer_type, get_name, infer_value_simulated +from .common import infer_type, get_name __all__ = ["from_onnx"] @@ -945,7 +945,6 @@ def _impl_v9(cls, inputs, attr, params): return out - class Shape(OnnxOpConverter): """Operator converter for Shape.""" @@ -1047,24 +1046,35 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v10(cls, inputs, attr, params): - attrs = {"starts": inputs[1], "ends": inputs[2]} - if len(inputs) >= 4: - attrs["axes"] = inputs[3] - attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()} - attrs = { - k: params[v[1]].asnumpy() - if v[1] in params - else infer_value_simulated(v[0], params).asnumpy() - for (k, v) in attrs.items() - } + starts = inputs[1] + ends = inputs[2] + axes = inputs[3] + steps = inputs[4] + + data_rank = len(infer_shape(inputs[0])) # Update the starts and ends according to axes if required. - if "axes" in attrs: - if max(attrs["axes"] + 1) != len(attrs["axes"]): - new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"]) - attrs["starts"] = new_starts - attrs["ends"] = new_ends - return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"])) + if axes is not None: + data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) + starts = _op.scatter( + _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype), + axes, + starts, + axis=0, + ) + ends = _op.scatter(data_shape, axes, ends, axis=0) + if steps is not None: + steps = _op.scatter( + _op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype), + axes, + steps, + axis=0, + ) + + if steps is None: + steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype) + + return _op.strided_slice(inputs[0], starts, ends, steps) class Gather(OnnxOpConverter): @@ -1406,7 +1416,6 @@ def _impl_v6(cls, inputs, attr, params): return _op.tile(inputs[0], inputs[1]) - class Erf(OnnxOpConverter): """Operator converter for Erf""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6457319a9200..3c26a722c17e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -599,12 +599,13 @@ def add_noop_to_input_attr(attr_name, attr): model = helper.make_model(graph, producer_name="slice_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10) + tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, freeze_params=True) tvm.testing.assert_allclose(outdata, tvm_out) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_slice(): x = np.random.randn(20, 10, 5).astype(np.float32) _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) From 55b40123b3cbb99f5b4e4de215bcd7c8c53625d4 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Sat, 12 Sep 2020 20:28:47 -0600 Subject: [PATCH 07/17] fix clip importer --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 464e3470c9ce..f3ac9c015100 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1917,7 +1917,7 @@ def _impl_v11(cls, inputs, attr, params): assert len(inputs) <= 3, "Clip-11 takes up to 3 inputs, input, min, max" result = inputs[0] - for i, op in enumerate([_maximum, _minimum]): + for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]): if i < len(inputs) - 1: result = op(result, inputs[i + 1]) return result From 8864abcb420cb0a5cdca37574a6886357645d30b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Sat, 12 Sep 2020 21:34:15 -0600 Subject: [PATCH 08/17] fix qnn tutorial --- src/relay/backend/build_module.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 912add4603b2..b95e0962bd27 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -254,7 +254,6 @@ class RelayBuildModule : public runtime::ModuleNode { Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - pass_seqs.push_back(transform::DynamicToStatic()); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); @@ -264,6 +263,9 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::Legalize()); } + // Convert Dynamic ops to static versions + pass_seqs.push_back(transform::DynamicToStatic()); + pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { Expr expr = args[0]; From 53f7a7b640efc48db49e54aaed5f61cbca04575d Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 14 Sep 2020 14:10:29 -0600 Subject: [PATCH 09/17] fix bad merge, respond to review comments --- python/tvm/relay/frontend/onnx.py | 2 +- src/relay/transforms/dynamic_to_static.cc | 5 ++++- tests/python/frontend/onnx/test_forward.py | 11 +++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a3ded3733da0..222eb36da1b5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2247,7 +2247,7 @@ def from_onnx(self, graph, opset, freeze_params=False): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - ## Maintain the order of inputs and parametersfrom the ONNX graph, but only include + ## Maintain the order of inputs and parameters from the ONNX graph, but only include ## those parameters that are needed to execute the relay graph free_vars = analysis.free_vars(outputs) nodes = {v: k for k, v in self._nodes.items()} diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 07358f1955fb..edcb83972cc7 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -227,6 +227,9 @@ Expr DynamicToStatic(Function f, IRModule m) { vars.Set(kv.second, kv.first); } const auto gv = vars[f]; + // Put a limit on the while loop + // Primarily used to prevent accidental infinite lops in development + const int loop_limit = 1000; int i = 0; do { pre = expr; @@ -236,7 +239,7 @@ Expr DynamicToStatic(Function f, IRModule m) { expr = mutator.Mutate(m->functions[gv]); m->Update(gv, Downcast(expr)); i += 1; - } while (!StructuralEqual()(pre, expr) && i < 1000); + } while (!StructuralEqual()(pre, expr) && i < loop_limit); return expr; } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8d3e9c9edd34..58e8f499b8d1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -121,6 +121,7 @@ def verify_with_ort_with_inputs( targets=None, use_vm=False, opset=None, + freeze_params=False, dtype="float32", rtol=1e-5, atol=1e-5, @@ -141,7 +142,9 @@ def flatten(out): ctx = tvm.context(target, 0) if use_vm: - tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=opset) + tvm_out = get_tvm_output_with_vm( + model, inputs, target, ctx, opset=opset, freeze_params=freeze_params + ) else: tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) @@ -155,6 +158,7 @@ def verify_with_ort( targets=None, use_vm=False, opset=None, + freeze_params=False, dtype="float32", rtol=1e-5, atol=1e-5, @@ -167,6 +171,7 @@ def verify_with_ort( targets=targets, use_vm=use_vm, opset=opset, + freeze_params=freeze_params, dtype=dtype, rtol=rtol, atol=atol, @@ -2352,7 +2357,6 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] verify_with_ort(model, inshapes, in_shape, use_vm=True) - verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) @@ -3246,8 +3250,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort(model, [ishape], oshape, use_vm=True, opset=11) - + verify_with_ort(model, [ishape], oshape, use_vm=True, opset=11, freeze_params=True) # upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") From ab33d69835737ce257e1bf980ed7ff59e3d74b54 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 16 Sep 2020 11:05:06 -0600 Subject: [PATCH 10/17] add a simple dynamic model test --- tests/python/frontend/onnx/test_forward.py | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 58e8f499b8d1..4ce08e3547f9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -973,6 +973,57 @@ def test_batch_matmul(target, ctx): verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx) +def verify_simple_dynamic_model(a_shape, b_shape, target, ctx): + def verify_model(ex, a_shape, b_shape): + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + # matmul + out_np = np.matmul(a_array, b_array) + # relu + out_np[out_np < 0] = 0 + + tvm_out = ex.evaluate()(a_array, b_array).asnumpy() + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + relu_node = helper.make_node("Relu", ["out"], ["relu"]) + + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + # matmul + out_np = np.matmul(a_array, b_array) + + graph = helper.make_graph( + [mul_node, relu_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("relu", TensorProto.FLOAT, list(out_np.shape))], + ) + + model = helper.make_model(graph, producer_name="matmul_test") + + a_anys = [relay.Any()] * len(a_shape) + b_anys = [relay.Any()] * len(b_shape) + + mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys}) + + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) + verify_model(ex, a_shape, b_shape) + verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) + verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape]) + + +# TODO(mbrookhart): enable cuda once VM supports heterogenous execution +@tvm.testing.parametrize_targets("llvm") +def test_batch_matmul_dynamic_model(target, ctx): + verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) + verify_simple_dynamic_model((2, 4, 3), (3, 4), target, ctx) + verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, ctx) + + def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): in_array = np.random.uniform(size=shape).astype(dtype) From db74cc386294665bf3023bb702e9c22827d6222a Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 18 Sep 2020 14:38:06 -0600 Subject: [PATCH 11/17] Add dynamic-shaped autopadding to convolution and pooling ops --- python/tvm/autotvm/record.py | 2 +- python/tvm/relay/frontend/onnx.py | 161 ++++++++++++------- python/tvm/relay/op/nn/_nn.py | 36 +++++ src/relay/op/nn/convolution.h | 77 +++++++--- tests/python/frontend/onnx/test_forward.py | 170 ++++++++++++++++++++- 5 files changed, 366 insertions(+), 80 deletions(-) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index af3540e3ea49..6650f500a996 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -152,7 +152,7 @@ def decode(row, protocol="json"): tgt, task_name, task_args, task_kwargs = row["input"] tgt = str(tgt) if "-target" in tgt: - logger.warning('"-target" is deprecated, use "-mtriple" instead.') + # logger.warning('"-target" is deprecated, use "-mtriple" instead.') tgt = tgt.replace("-target", "-mtriple") tgt = Target(str(tgt)) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 222eb36da1b5..0077ae12ac1d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -18,6 +18,7 @@ # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relay.""" import numpy as np +import logging import tvm from tvm.ir import IRModule @@ -32,6 +33,9 @@ from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import infer_type, get_name + +logger = logging.getLogger("onnx_frontend") + __all__ = ["from_onnx"] @@ -236,21 +240,29 @@ class Pool(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - input_shape = infer_shape(inputs[0]) + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - pad_tuple = [] - for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] - kernel = attr["kernel_shape"][axis] - pad = get_pad_pair(axis_shape, kernel, stride) - pad_tuple.append(pad) - pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr["pads"] = pad_tuple + if cls.name == "avg_pool": + pad_tuple = [] + for axis in range(len(input_shape) - 2): + axis_shape = input_shape[2 + axis] + stride = attr["strides"][axis] + kernel = attr["kernel_shape"][axis] + pad = get_pad_pair(axis_shape, kernel, stride) + pad_tuple.append(pad) + pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) + attr["pads"] = pad_tuple + else: + logger.warning( + "Performing dynamic autopadding on Pool. Pool kernels don't currently support dynamic shapes." + ) + data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": - attr["pads"] = 0 + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -270,7 +282,7 @@ def _impl_v1(cls, inputs, attr, params): transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)}, ignores=["dilations", "storage_order"], custom_check=dimension_constraint(), - )(inputs, attr, params) + )([data], attr, params) class Absolute(Unary): @@ -311,29 +323,69 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name="instance_norm")(inputs, attr, params) +def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", deconv=False): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim]) + # get input shape + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + return _op.nn.pad(data, pad, _op.const(0.0), pad_type) + + class Conv(OnnxOpConverter): """Operator converter for Conv.""" @classmethod def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. - input_shape = infer_shape(inputs[0]) + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - pad_tuple = [] - for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] - kernel = attr["kernel_shape"][axis] - dilation = attr["dilations"][axis] - dilated_kernel = (kernel - 1) * dilation + 1 - pad = get_pad_pair(axis_shape, dilated_kernel, stride) - pad_tuple.append(pad) - pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr["pads"] = pad_tuple + logger.warning( + "Performing dynamic autopadding on Conv. Conv kernels don't currently support dynamic shapes." + ) + data = autopad(data, attr["strides"], attr["kernel_shape"], attr["dilations"], ndim) elif attr["auto_pad"] == "VALID": - attr["pads"] = tuple([0 for i in range(len(input_shape) - 2)]) + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -361,7 +413,7 @@ def _impl_v1(cls, inputs, attr, params): "group": ("groups", 1), }, custom_check=dimension_constraint(), - )(inputs[:2], attr, params) + )([data, inputs[1]], attr, params) use_bias = len(inputs) == 3 if use_bias: @@ -380,21 +432,25 @@ def _impl_v1(cls, inputs, attr, params): groups = attr.pop("group") attr["groups"] = groups # infer pads for auto_pad + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - input_shape = infer_shape(inputs[0]) - in_h, in_w = input_shape[2], input_shape[3] - stride_h, stride_w = attr["strides"] - kernel_h, kernel_w = attr["kernel_shape"] - dilation_h, dilation_w = attr["dilations"] - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) - attr["pads"] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + logger.warning( + "Performing dynamic autopadding on ConvTranspose. ConvTranspose kernels don't currently support dynamic shapes." + ) + data = autopad( + data, + attr["strides"], + attr["kernel_shape"], + attr["dilations"], + ndim, + deconv=True, + ) elif attr["auto_pad"] == "VALID": - attr["pads"] = (0, 0) + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -406,12 +462,13 @@ def _impl_v1(cls, inputs, attr, params): op_name=dimension_picker("conv", "_transpose"), transforms={ "kernel_shape": "kernel_size", - "dilations": ("dilation", (0, 0)), - "pads": ("padding", (0, 0), revert_caffe2_pad), + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), }, disables=["output_shape"], custom_check=dimension_constraint(), - )(inputs[:2], attr, params) + )([data, inputs[1]], attr, params) use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) @@ -546,23 +603,19 @@ class LpPool(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - input_shape = infer_shape(inputs[0]) dtype = infer_type(inputs[0]).checked_type.dtype - + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - pad_tuple = [] - for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] - kernel = attr["kernel_shape"][axis] - pad = get_pad_pair(axis_shape, kernel, stride) - pad_tuple.append(pad) - pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr["pads"] = pad_tuple + logger.warning( + "Performing dynamic autopadding on LpPool. LpPool kernels don't currently support dynamic shapes." + ) + data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": - attr["pads"] = 0 + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -579,7 +632,7 @@ def _impl_v1(cls, inputs, attr, params): p = _expr.const(attr["p"], dtype) reci_p = _expr.const(1.0 / attr["p"], dtype) - inputs[0] = _op.power(inputs[0], p) + data = _op.power(data, p) out = AttrCvt( op_name=dimension_picker("avg_pool"), @@ -587,7 +640,7 @@ def _impl_v1(cls, inputs, attr, params): extras={"count_include_pad": True}, ignores=["p"], custom_check=dimension_constraint(), - )(inputs, attr, params) + )([data], attr, params) kernels = attr["kernel_shape"] out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype)) return _op.power(out, reci_p) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6917030f7268..086603550532 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -721,6 +721,42 @@ def compute_space_to_depth(attrs, inputs, out_dtype): ##################### +@script +def _conv_shape_func(dshape, kshape, strides, padding, dilation): + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[1] = kshape[0] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i + 2] - 1) * dilation[i] + 1 + out[i + 2] = (dshape[i + 2] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + +def conv_shape_func(attrs, inputs, _): + """ + Shape function for contrib_conv2d_NCHWc op. + """ + strides = get_const_tuple(attrs.strides) + padding = get_const_tuple(attrs.padding) + dilation = get_const_tuple(attrs.dilation) + + return [ + _conv_shape_func( + inputs[0], + inputs[1], + convert(strides), + convert(padding), + convert(dilation), + ) + ] + + +reg.register_shape_func("nn.conv1d", False, conv_shape_func) +reg.register_shape_func("nn.conv2d", False, conv_shape_func) +reg.register_shape_func("nn.conv3d", False, conv_shape_func) + + @script def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn): out = output_tensor((dshape.shape[0],), "int64") diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index f53f4e0454a4..bd98547ab72b 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -100,7 +100,9 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv1D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); + if (!dshape_ncw[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); + } channels = wshape[0]; dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0]; } @@ -211,7 +213,9 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); + if (!dshape_nchw[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); + } channels = wshape[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -322,7 +326,9 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv3D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); + if (!dshape_ncdhw[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); + } channels = wshape[0]; dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -800,7 +806,9 @@ bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& a << "Conv1D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + if (!dshape_ncw[1].as() && !wshape[0].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + } channels = wshape[1]; dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; } @@ -808,8 +816,12 @@ bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& a IndexExpr pad_w; GetPaddingWidth(param->padding, &pad_w); Array oshape({dshape_ncw[0], channels, 0}); - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + - param->output_padding[0])); + if (!dshape_ncw[1].as()) { + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_ncw[2]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -890,7 +902,9 @@ bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& a << "Conv3D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + if (!dshape_ncdhw[1].as() && !wshape[0].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + } channels = wshape[1]; dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -901,12 +915,25 @@ bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& a Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + - param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + - param->output_padding[1])); - oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + - param->output_padding[2])); + + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + if (!dshape_ncdhw[3].as()) { + oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_x - pad_h + + param->output_padding[1])); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_y - pad_w + + param->output_padding[2])); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -985,7 +1012,9 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a << "Conv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + if (!dshape_nchw[1].as() && !wshape[0].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + } channels = wshape[1]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -994,10 +1023,18 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a Array oshape({dshape_nchw[0], channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + - param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + - param->output_padding[1])); + if (!dshape_nchw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_nchw[2]); + } + if (!dshape_nchw[3].as()) { + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + + param->output_padding[1])); + } else { + oshape.Set(3, dshape_nchw[3]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -1053,7 +1090,9 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& << "DeformableConv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); + if (!data->shape[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); + } channels = wshape[0]; ksize_y = wshape[2]; ksize_x = wshape[3]; diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 4ce08e3547f9..c0eda419175f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -42,7 +42,9 @@ def get_input_data_shape_dict(graph_def, input_data): return input_names, shape_dict -def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freeze_params=False): +def get_tvm_output_with_vm( + graph_def, input_data, target, ctx, opset=None, freeze_params=False, convert_to_static=False +): """ Generic function to execute and get tvm output with vm executor""" if not isinstance(input_data, list): input_data = [input_data] @@ -51,6 +53,10 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freez mod, params = relay.frontend.from_onnx( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) + if convert_to_static: + from tvm.relay import transform + + mod = transform.DynamicToStatic()(mod) ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data) @@ -122,6 +128,7 @@ def verify_with_ort_with_inputs( use_vm=False, opset=None, freeze_params=False, + convert_to_static=False, dtype="float32", rtol=1e-5, atol=1e-5, @@ -140,10 +147,15 @@ def flatten(out): for target in targets: ctx = tvm.context(target, 0) - if use_vm: tvm_out = get_tvm_output_with_vm( - model, inputs, target, ctx, opset=opset, freeze_params=freeze_params + model, + inputs, + target, + ctx, + opset=opset, + freeze_params=freeze_params, + convert_to_static=convert_to_static, ) else: tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) @@ -159,6 +171,7 @@ def verify_with_ort( use_vm=False, opset=None, freeze_params=False, + convert_to_static=False, dtype="float32", rtol=1e-5, atol=1e-5, @@ -172,6 +185,7 @@ def verify_with_ort( use_vm=use_vm, opset=opset, freeze_params=freeze_params, + convert_to_static=convert_to_static, dtype=dtype, rtol=rtol, atol=atol, @@ -2470,7 +2484,7 @@ def verify_conv( model = helper.make_model(graph, producer_name="conv_test") - verify_with_ort(model, [x_shape, w_shape], y_shape) + verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -2555,6 +2569,68 @@ def repeat(N, D): ) +def verify_convtranspose_with_padding( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + unset_pad=False, +): + if unset_pad: + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=1, + ) + elif padding is None: + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=1, + auto_pad=auto_pad, + ) + else: + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=1, + pads=padding, + ) + + graph = helper.make_graph( + [node], + "convtranspose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + ) + + model = helper.make_model(graph, producer_name="conv_test") + + verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, convert_to_static=True) + + def verify_convtranspose(x_shape, w_shape, y_shape, p): node = onnx.helper.make_node( "ConvTranspose", @@ -2589,6 +2665,87 @@ def test_convtranspose(): # [1, 2, 1, 2] list for pads verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) + def repeat(N, D): + return tuple([N for _ in range(D)]) + + # TODO(mbrookhart): onnxruntime doesn't support conv3d_transpose, find something else to test against + for D in [1, 2]: + # Convolution with padding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(7, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(7, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with unset padding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(7, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + True, + ) + # Convolution with non uniform stride + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(9, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + # TODO(mbrookhart): Relay doesn't currently support convtranspose with dilation + # verify_convtranspose_with_padding( + # (1, 1) + repeat(5, D), + # (1, 1) + repeat(3, D), + # (1, 1) + repeat(5, D), + # 2 * repeat(2, D), + # repeat(3, D), + # repeat(1, D), + # repeat(2, D), + # ) + @tvm.testing.uses_gpu def test_unsqueeze_constant(): @@ -2612,6 +2769,7 @@ def forward(self, input): def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): + print(x_shape, kernel_shape, strides, mode, pads, auto_pad) x_np = np.random.uniform(size=x_shape).astype("float32") if mode == "max": @@ -2643,7 +2801,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p ) model = helper.make_model(graph, producer_name="pooling_test") - verify_with_ort(model, [x_shape], out_shape) + verify_with_ort(model, [x_shape], out_shape, use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -2893,7 +3051,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" ) model = helper.make_model(graph, producer_name="lppool_test") - verify_with_ort(model, [x_shape], out_shape) + verify_with_ort(model, [x_shape], out_shape, use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu From dfeed7fd4ef629f0313e6d6532ef1f5ecaef6cc5 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 18 Sep 2020 16:15:05 -0600 Subject: [PATCH 12/17] fix dynamic issues in a few ops --- src/relay/op/nn/nn.cc | 24 ++++++++++++++++++------ src/relay/op/tensor/transform.cc | 4 +++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 3f1d2ba9a6da..8ab1ab1be66d 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1031,9 +1031,15 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(1, indexdiv(oshape[1], (block_size * block_size))); - oshape.Set(2, oshape[2] * block_size); - oshape.Set(3, oshape[3] * block_size); + if (!oshape[1].as()) { + oshape.Set(1, indexdiv(oshape[1], (block_size * block_size))); + } + if (!oshape[2].as()) { + oshape.Set(2, oshape[2] * block_size); + } + if (!oshape[3].as()) { + oshape.Set(3, oshape[3] * block_size); + } // Assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); @@ -1088,9 +1094,15 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(1, oshape[1] * (block_size * block_size)); - oshape.Set(2, indexdiv(oshape[2], block_size)); - oshape.Set(3, indexdiv(oshape[3], block_size)); + if (!oshape[1].as()) { + oshape.Set(1, oshape[1] * (block_size * block_size)); + } + if (!oshape[2].as()) { + oshape.Set(2, indexdiv(oshape[2], block_size)); + } + if (!oshape[3].as()) { + oshape.Set(3, indexdiv(oshape[3], block_size)); + } // Assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3812d4ea2b8a..c9807c39764e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1676,7 +1676,9 @@ bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* condition = types[0].as(); const auto* x = types[1].as(); const auto* y = types[2].as(); - CHECK(condition != nullptr && x != nullptr && y != nullptr); + if (!(condition != nullptr && x != nullptr && y != nullptr)) { + return false; + } const auto& cond_shape = condition->shape; const auto& x_shape = x->shape; From dd6a1d7a7757f29d92e53a4be8301b85c5169aa2 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 18 Sep 2020 16:30:34 -0600 Subject: [PATCH 13/17] fix pylint --- python/tvm/relay/frontend/onnx.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0077ae12ac1d..467dd4fce2af 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -17,8 +17,8 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relay.""" -import numpy as np import logging +import numpy as np import tvm from tvm.ir import IRModule @@ -257,9 +257,11 @@ def _impl_v1(cls, inputs, attr, params): pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) attr["pads"] = pad_tuple else: - logger.warning( - "Performing dynamic autopadding on Pool. Pool kernels don't currently support dynamic shapes." + warning = ( + "Performing dynamic autopadding on Pool. " + + "Pool kernels don't currently support dynamic shapes." ) + logger.warning(warning) data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) @@ -380,9 +382,11 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - logger.warning( - "Performing dynamic autopadding on Conv. Conv kernels don't currently support dynamic shapes." + warning = ( + "Performing dynamic autopadding on Conv. " + + "Conv kernels don't currently support dynamic shapes." ) + logger.warning(warning) data = autopad(data, attr["strides"], attr["kernel_shape"], attr["dilations"], ndim) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) @@ -438,9 +442,11 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - logger.warning( - "Performing dynamic autopadding on ConvTranspose. ConvTranspose kernels don't currently support dynamic shapes." + warning = ( + "Performing dynamic autopadding on ConvTranspose. " + + "ConvTranspose kernels don't currently support dynamic shapes." ) + logger.warning(warning) data = autopad( data, attr["strides"], @@ -610,9 +616,11 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - logger.warning( - "Performing dynamic autopadding on LpPool. LpPool kernels don't currently support dynamic shapes." + warning = ( + "Performing dynamic autopadding on LpPool. " + + "LpPool kernels don't currently support dynamic shapes." ) + logger.warning(warning) data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) From 2faf2aad9a57629917cf0068ee0173171c3babe4 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 21 Sep 2020 10:07:46 -0600 Subject: [PATCH 14/17] disable tests onnxrt doesn't support --- tests/python/frontend/onnx/test_forward.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c0eda419175f..bda0f472148b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2668,8 +2668,9 @@ def test_convtranspose(): def repeat(N, D): return tuple([N for _ in range(D)]) - # TODO(mbrookhart): onnxruntime doesn't support conv3d_transpose, find something else to test against - for D in [1, 2]: + # TODO(mbrookhart): onnxruntime in CI only supports 2D, + # find something else to test 1D and 3D against + for D in [2]: # Convolution with padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), From 78d3ff5a0cf7e663b25900b656a3acb55c0e1385 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 22 Sep 2020 16:40:15 -0600 Subject: [PATCH 15/17] fix pytorch test --- src/relay/op/nn/convolution.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index bd98547ab72b..2311585deb60 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -816,7 +816,7 @@ bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& a IndexExpr pad_w; GetPaddingWidth(param->padding, &pad_w); Array oshape({dshape_ncw[0], channels, 0}); - if (!dshape_ncw[1].as()) { + if (!dshape_ncw[2].as()) { oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + param->output_padding[0])); } else { @@ -923,13 +923,13 @@ bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& a oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[3].as()) { - oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_x - pad_h + + oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + param->output_padding[1])); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_y - pad_w + + oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + param->output_padding[2])); } else { oshape.Set(4, dshape_ncdhw[4]); From 54dd8d53068061277d995a5c97ba0e29c75de7ba Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 25 Sep 2020 16:07:21 -0600 Subject: [PATCH 16/17] respond to review comments --- python/tvm/autotvm/record.py | 2 +- python/tvm/relay/frontend/onnx.py | 31 ++++++++----------------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 5a4e26b6d2ee..a1b89404b5a1 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -152,7 +152,7 @@ def decode(row, protocol="json"): tgt, task_name, task_args, task_kwargs = row["input"] tgt = str(tgt) if "-target" in tgt: - # logger.warning('"-target" is deprecated, use "-mtriple" instead.') + logger.warning('"-target" is deprecated, use "-mtriple" instead.') tgt = tgt.replace("-target", "-mtriple") tgt = Target(str(tgt)) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 40ad44961405..a9c7666d4ff4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relay.""" -import logging import numpy as np import tvm from tvm.ir import IRModule @@ -34,8 +33,6 @@ from .common import infer_type, get_name -logger = logging.getLogger("onnx_frontend") - __all__ = ["from_onnx"] @@ -257,11 +254,8 @@ def _impl_v1(cls, inputs, attr, params): pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) attr["pads"] = pad_tuple else: - warning = ( - "Performing dynamic autopadding on Pool. " - + "Pool kernels don't currently support dynamic shapes." - ) - logger.warning(warning) + # Warning: Pool does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) @@ -382,11 +376,8 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - warning = ( - "Performing dynamic autopadding on Conv. " - + "Conv kernels don't currently support dynamic shapes." - ) - logger.warning(warning) + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import data = autopad(data, attr["strides"], attr["kernel_shape"], attr["dilations"], ndim) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) @@ -442,11 +433,8 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - warning = ( - "Performing dynamic autopadding on ConvTranspose. " - + "ConvTranspose kernels don't currently support dynamic shapes." - ) - logger.warning(warning) + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import data = autopad( data, attr["strides"], @@ -616,11 +604,8 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - warning = ( - "Performing dynamic autopadding on LpPool. " - + "LpPool kernels don't currently support dynamic shapes." - ) - logger.warning(warning) + # Warning: LpPool does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) From 24a9e22cafba23763079c500078a5ff6916f3a42 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 2 Oct 2020 12:27:27 -0600 Subject: [PATCH 17/17] add documentation about partially supporting dynamic shapes --- python/tvm/relay/frontend/onnx.py | 7 +++++++ tutorials/frontend/from_onnx.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a9c7666d4ff4..59fdb32d1a16 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2409,6 +2409,13 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals For convenience, we rename the `real` input names to "input_0", "input_1"... And renaming parameters to "param_0", "param_1"... + By default, ONNX defines models in terms of dynamic shapes. The ONNX importer + retains that dynamism upon import, and the compiler attempts to convert the + model into a static shapes at compile time. If this fails, there may still + be dynamic operations in the model. Not all TVM kernels currently support + dynamic shapes, please file an issue on discuss.tvm.ai + if you hit an error with dynamic kernels. + Parameters ---------- model : protobuf object diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index e68a398e44b0..22c839cede12 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -103,3 +103,12 @@ canvas[:, 672:, :] = np.asarray(result) plt.imshow(canvas.astype(np.uint8)) plt.show() + +###################################################################### +# Notes +# --------------------------------------------- +# By default, ONNX defines models in terms of dynamic shapes. The ONNX importer +# retains that dynamism upon import, and the compiler attemps to convert the model +# into a static shapes at compile time. If this fails, there may still be dynamic +# operations in the model. Not all TVM kernels currently support dynamic shapes, +# please file an issue on discuss.tvm.ai if you hit an error with dynamic kernels.