From 156aa5900ab04d9176cd333bb7d1ce10dce19faa Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 30 Oct 2019 11:24:47 -0700 Subject: [PATCH] [Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (#4197) * Added slice v10 * Added constantofshape operation and small refactor. * Finished one_hot implementation. * Reshape working across all bert layers. * Fixed constantofshape and removed code duplication. * onnx model fully ingested. * Working on improving onnx tests. * Changed onnx testing to use onnxruntime instead of caffe2, also formatted. * Add arbitrary output nodes to onnx frontend. * Added v6 tiling for bert squad 8 support. * Small syntax fixes * Reduced code duplication in split opset versions. * Added batch matmul test * Added unstack split testing. * Adde onehot test, needs a little cleanup probably. * Replaced deprecated constant fill with constantofshape and updated tests accordingly. * Added tests for new opset version of slice and tile. * lint clean up * Lint fixes * Changed onnx dependency * Went back to caffe2 runtime for CI integration. * Rebase and small typo/syntax changes. * Added hard casting of onehot attributes to int. --- python/tvm/relay/frontend/common.py | 46 ++ python/tvm/relay/frontend/onnx.py | 236 +++--- python/tvm/relay/frontend/tensorflow.py | 14 +- tests/python/frontend/onnx/test_forward.py | 829 ++++++++++++++------- 4 files changed, 744 insertions(+), 381 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index d4b9162d6f3d..25ba0ef31d72 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -19,11 +19,13 @@ import logging import tvm +import numpy as np from topi.util import get_const_tuple from .. import expr as _expr from .. import module as _module from .. import transform as _transform from .. import op as _op +from .. import analysis class RequiredAttr(object): @@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False): return channels +def infer_value(input_val, params): + """A hack for getting the value of an expression by evaluating a + portion of the relay graph. This is often needed for functions that + whose output shape depends on the value of a tensor. + """ + from tvm.contrib import graph_runtime + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in analysis.free_vars( + input_val)), "All inputs to infer must be available in params." + func = _expr.Function(analysis.free_vars(input_val), input_val) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + return m.get_output(0) + + +def infer_value_simulated(input_val, params): + """Extention to infer_value that can be used when some input + values are missing. This function creates dummy inputs with the same + shape and random values then calls infer_value. This is helpful when + implementing certain onnx operators where we need to evaluate the graph + to determine a static shape. + """ + fake_params = [] + # Add a fake copy of all missing params. + for free_param in analysis.free_vars(input_val): + if free_param.name_hint not in params: + fp_dtype = free_param.type_annotation.dtype + fp_shape = [s.value for s in free_param.type_annotation.shape] + fake_params.append(free_param) + params[free_param.name_hint] = tvm.nd.array( + np.random.rand(*fp_shape).astype(fp_dtype) + ) + # Now infer the value. + output_value = infer_value(input_val, params) + # Clean fake params out of param dictionary. + for fake_p in fake_params: + params.pop(fake_p.name_hint, None) + return output_value + + def new_var(name_hint, type_annotation=None, shape=None, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1d74a01b1860..41fafbc55405 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -18,20 +18,30 @@ """ONNX: Open Neural Network Exchange frontend for Relay.""" from __future__ import absolute_import as _abs -import logging import numpy as np import tvm from ... import nd as _nd from .. import analysis -from .. import transform as _transform from .. import expr as _expr from .. import module as _module from .. import op as _op from .common import AttrCvt, Renamer -from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name +from .common import get_relay_op, new_var, infer_shape, infer_channels +from .common import infer_type, infer_value, infer_value_simulated, get_name __all__ = ['from_onnx'] + +def get_numpy(tensor_proto): + """Grab data in TensorProto and convert to numpy array.""" + try: + from onnx.numpy_helper import to_array + except ImportError as e: + raise ImportError( + "Unable to import onnx which is required {}".format(e)) + return to_array(tensor_proto) + + def dimension_picker(prefix, surfix=''): def _impl(attr): kernel = attr['kernel_shape'] @@ -43,6 +53,7 @@ def _impl(attr): return _impl + def revert_caffe2_pad(pads): """Caffe2 requires two times the normal padding.""" if len(pads) == 4: @@ -279,6 +290,21 @@ class MatMul(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) + # Need to check input shape as batch matmul must be supported. + a_shape = infer_shape(inputs[0]) + # When performing a batch matmul, we need to properly handle N-dim shapes. + if len(a_shape) > 2: + b_shape = infer_shape(inputs[1]) + # Convert a and b into 3 dimensional tensors. + a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]]) + b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]]) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) + # Reshape output to original dimensions. + return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) @@ -426,35 +452,18 @@ class Reshape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if 'shape' in attr: - return _op.reshape(inputs[0], attr['shape']) + return _op.reshape(inputs[0], attr['shape']) + @classmethod + def _impl_v5(cls, inputs, attr, params): if get_name(inputs[1]) in params: shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: data, shape = inputs - logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") - shape_params = analysis.free_vars(shape) - func = _expr.Function(shape_params, shape) - mod = _module.Module.from_expr(func) - seq = _transform.Sequential([_transform.InferType(), - _transform.FoldConstant(), - _transform.FuseOps(0), - _transform.InferType()]) - with tvm.relay.PassContext(opt_level=2): - mod = seq(mod) - with tvm.relay.build_config(opt_level=0): - ex = tvm.relay.create_executor("debug", mod=mod) - inputs = [] - for sp in shape_params: - if not sp.name_hint in params: - sh = [int(i) for i in sp.type_annotation.shape] - inputs.append( - tvm.nd.array(np.random.rand(*sh).astype('float32'))) - static_shape = ex.evaluate()(*inputs, **params) - out = _op.reshape(data, newshape=tuple(static_shape.asnumpy())) - + static_shape = infer_value_simulated(shape, params) + out = _op.reshape(data, newshape=tuple( + static_shape.asnumpy().astype('int32'))) return out class Concat(OnnxOpConverter): @@ -640,11 +649,17 @@ class Split(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - attr['indices_or_sections'] = [] - index = 0 - for i in attr['split'][:-1]: - index += i - attr['indices_or_sections'].append(index) + splits = attr.get('split', False) + if splits: + attr['indices_or_sections'] = [] + index = 0 + for i in splits[:-1]: + index += i + attr['indices_or_sections'].append(index) + # When splits isnt specified divide evenly over axis. + else: + in_shape = infer_shape(inputs[0]) + attr['indices_or_sections'] = in_shape[attr['axis']] return AttrCvt( 'split', ignores=['split'])(inputs, attr, params) @@ -653,6 +668,25 @@ def _impl_v1(cls, inputs, attr, params): class Slice(OnnxOpConverter): """ Operator converter for Slice. """ + + @classmethod + def _common(cls, starts, ends, axes): + new_axes = [] + new_starts = [] + new_ends = [] + pop_index = 0 + for i in range(max(axes) + 1): + if i in axes: + new_axes.append(i) + new_starts.append(starts[pop_index]) + new_ends.append(ends[pop_index]) + pop_index += 1 + else: + new_axes.append(i) + new_starts.append(0) + new_ends.append(np.iinfo(np.int32).max) + return new_starts, new_ends, new_axes + @classmethod def _impl_v1(cls, inputs, attr, params): if isinstance(attr['starts'], int): @@ -663,22 +697,9 @@ def _impl_v1(cls, inputs, attr, params): # Update the starts and ends according to axes if required. if isinstance(attr['axes'], int): attr['axes'] = (attr['axes'],) - if (max(attr['axes']) + 1) != len(attr['axes']): - new_axes = [] - new_starts = [] - new_ends = [] - pop_index = 0 - for i in range(max(attr['axes']) + 1): - if i in attr['axes']: - new_axes.append(i) - new_starts.append(attr['starts'][pop_index]) - new_ends.append(attr['ends'][pop_index]) - pop_index += 1 - else: - new_axes.append(i) - new_starts.append(0) - new_ends.append(np.iinfo(np.int32).max) + new_starts, new_ends, new_axes = cls._common( + attr['starts'], attr['ends'], attr['axes']) attr['axes'] = new_axes attr['starts'] = new_starts attr['ends'] = new_ends @@ -690,6 +711,23 @@ def _impl_v1(cls, inputs, attr, params): 'ends': 'end'}, ignores=['axes'])(inputs, attr) + @classmethod + def _impl_v10(cls, inputs, attr, params): + starts = params[get_name(inputs[1])].asnumpy() + ends = params[get_name(inputs[2])].asnumpy() + + # Update the starts and ends according to axes if required. + if len(inputs) >= 4: + axes = params[get_name(inputs[3])].asnumpy() + + if max(axes + 1) != len(axes): + new_starts, new_ends, _ = cls._common( + starts, ends, axes) + starts = new_starts + ends = new_ends + return _op.strided_slice(inputs[0], begin=starts, end=ends) + + class Gather(OnnxOpConverter): """ Operator converter for Gather. """ @@ -698,7 +736,6 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 0) return AttrCvt('take', extras={'axis':axis})(inputs, {}) - #return _op.take(inputs[0], inputs[1], axis) class Greater(OnnxOpConverter): @@ -848,33 +885,49 @@ def _impl_v1(cls, inputs, attr, params): attr['axis'] = 1 return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params) -class ConstantFill(OnnxOpConverter): - """ Operator converter for ConstantFill. + +class OneHot(OnnxOpConverter): + """ Operator converter for OneHot. """ @classmethod - def _impl_v1(cls, inputs, attr, params): - num_inputs = len(inputs) - if 'shape' in attr: - if num_inputs > 1: - raise ImportError( - "Can't set shape and input tensor at a time") - shape = attr.pop('shape') + def _impl_v9(cls, inputs, attr, params): + # Extract relay one_hot inputs. + indices, depth, values = inputs + # Split onnx on off values into two separate expressions. + off_value, on_value = _op.take( + values, _op.const(0)), _op.take(values, _op.const(1)) + # Extract the datatype of the output from on_value. + dtype = infer_type(on_value).checked_type.dtype + # Convert depth into an integer. + depth = int(infer_value(depth, params).asnumpy()[0]) + # set default value when axis is not set in the model + if 'axis' not in attr: + attr['axis'] = -1 + return _op.one_hot(indices, + on_value, + off_value, + depth, + int(attr['axis']), + dtype=dtype) + + +class ConstantOfShape(OnnxOpConverter): + """ Operator converter for ConstantOfShape. + """ + @classmethod + def _impl_v9(cls, inputs, attr, params): + if 'value' in attr: + np_value = get_numpy(attr.pop('value'))[0] + value = _expr.const(np_value) + dtype = np_value.dtype.name else: - if num_inputs == 1: - raise ImportError( - "Either shape attribute or input should be set") - if 'input_as_shape' in attr and attr['input_as_shape']: - shape = params[get_name(inputs[0])].asnumpy() - else: - if 'extra_shape' in attr: - raise tvm.error.OpAttributeInvalid('Attribute "extra_shape" not ' - 'supported with "fill_like" for ' - 'operator ConstantFill.') - return _op.full_like(inputs[0], inputs[1]) + value = _expr.const(0) + dtype = 'float32' + static_shape = infer_value_simulated(inputs[0], params) + output = _op.full( + value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype) + return output - if 'extra_shape' in attr: - shape = shape + attr.pop('extra_shape') - return _op.full(inputs[0], shape) class Sign(OnnxOpConverter): """ Operator converter for Sign. @@ -916,6 +969,12 @@ def _impl_v1(cls, inputs, attr, params): reps = attr.pop('repeats') # The number of times repeating the tensor data. return _op.tile(inputs[0], reps) + @classmethod + def _impl_v6(cls, inputs, attr, params): + reps = tuple(infer_value_simulated( + inputs[1], params).asnumpy().astype('int32')) + return _op.tile(inputs[0], reps) + class Erf(OnnxOpConverter): """Operator converter for Erf """ @@ -948,7 +1007,7 @@ def _get_convert_map(opset): 'ThresholdedRelu': ThresholdedRelu.get_converter(opset), 'ScaledTanh': ScaledTanh.get_converter(opset), 'ParametricSoftplus': ParametricSoftPlus.get_converter(opset), - 'ConstantFill': ConstantFill.get_converter(opset), + 'ConstantOfShape': ConstantOfShape.get_converter(opset), # 'GivenTensorFill' 'FC': AttrCvt('dense', ignores=['axis', 'axis_w']), 'Scale': Scale.get_converter(opset), @@ -958,7 +1017,7 @@ def _get_convert_map(opset): # 'MeanVarianceNormalization' # 'Crop' # 'Embedding' - 'Upsample' : Upsample.get_converter(opset), + 'Upsample': Upsample.get_converter(opset), 'SpatialBN': BatchNorm.get_converter(opset), # defs/generator @@ -1002,6 +1061,7 @@ def _get_convert_map(opset): # softmax default axis is different in onnx 'Softmax': Softmax.get_converter(opset), 'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}), + 'OneHot': OneHot.get_converter(opset), # 'Hardmax' 'Softsign': Softsign.get_converter(opset), 'SoftPlus': SoftPlus.get_converter(opset), @@ -1164,14 +1224,6 @@ def from_onnx(self, graph, opset): shape=list(t_proto.dims), dtype=array.dtype) else: - if op_name == "ConstantFill": - fill_value = attr.get('value', 0.0) - dtype = attr.get('dtype', b'int32').decode("utf-8") - i_name = node.output[0] - self._params[i_name] = fill_value - self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype) - inputs.append(self._nodes[i_name]) - i_name = self._parse_value_proto(node) attr['tvm_custom'] = {} attr['tvm_custom']['name'] = i_name @@ -1214,13 +1266,7 @@ def _parse_dtype(self, value_proto, dtype): return dtype def _parse_array(self, tensor_proto): - """Grab data in TensorProto and convert to numpy array.""" - try: - from onnx.numpy_helper import to_array - except ImportError as e: - raise ImportError( - "Unable to import onnx which is required {}".format(e)) - np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) + np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims)) return _nd.array(np_array) def _parse_attr(self, attr_proto): @@ -1301,7 +1347,8 @@ def _fix_outputs(self, op_name, outputs): def from_onnx(model, shape=None, - dtype="float32"): + dtype="float32", + opset=None): """Convert a ONNX model into an equivalent Relay Function. ONNX graphs are represented as Python Protobuf objects. @@ -1322,6 +1369,10 @@ def from_onnx(model, dtype : str or dict of str to str The input types to the graph + opset : int, optional + Override to autodetected opset. + This can be helpful for some testing. + Returns ------- mod : tvm.relay.Module @@ -1344,9 +1395,10 @@ def from_onnx(model, pass g = GraphProto(shape, dtype) graph = model.graph - try: - opset = model.opset_import[0].version if model.opset_import else 1 - except AttributeError: - opset = 1 + if opset is None: + try: + opset = model.opset_import[0].version if model.opset_import else 1 + except AttributeError: + opset = 1 mod, params = g.from_onnx(graph, opset) return mod, params diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bfa3431ba29e..2ef8d15fe291 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -39,22 +39,10 @@ from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_channels as _infer_channels +from .common import infer_value as _infer_value __all__ = ['from_tensorflow'] -def _infer_value(input_val, params): - from tvm.contrib import graph_runtime - # Check that all free variables have associated parameters. - assert all(var.name_hint in params.keys() for var in analysis.free_vars( - input_val)), "All inputs to infer must be available in params." - func = _expr.Function(analysis.free_vars(input_val), input_val) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - return m.get_output(0) def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3d1262f436bb..2d2265b57b95 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import attr import numpy as np import math import torch @@ -26,11 +25,11 @@ from tvm.contrib import graph_runtime from nnvm.testing.config import ctx_list import onnx -from onnx import helper, TensorProto -import unittest +from onnx import helper, TensorProto, mapping import scipy -def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'): + +def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None): """ Generic function to execute and get tvm output""" target = 'llvm' if isinstance(input_data, list): @@ -46,21 +45,22 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output shape_dict = {input_names: input_data.shape} dtype_dict = {input_names: input_data.dtype} - mod, params = relay.frontend.from_onnx(graph_def, shape_dict) + mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) with relay.build_config(opt_level=1): graph, lib, params = relay.build(mod, target, params=params) ctx = tvm.cpu(0) - from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs if isinstance(input_data, list): for i, e in enumerate(input_names): - m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + m.set_input(input_names[i], tvm.nd.array( + input_data[i].astype(input_data[i].dtype))) else: - m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype))) + m.set_input(input_names, tvm.nd.array( + input_data.astype(input_data.dtype))) m.set_input(**params) # execute @@ -76,6 +76,7 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output tvm_output = m.get_output(0) return tvm_output.asnumpy() + def get_caffe2_output(model, x, dtype='float32'): import caffe2.python.onnx.backend prepared_backend = caffe2.python.onnx.backend.prepare(model) @@ -93,15 +94,20 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) + def verify_super_resolution_example(): - verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672)) + verify_onnx_forward_impl( + super_resolution, (1, 1, 224, 224), (1, 1, 672, 672)) + def verify_squeezenet1_1(): verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000)) + def verify_lenet(): verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10)) + def verify_resnet18(): verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000)) @@ -112,20 +118,20 @@ def test_reshape(): ref_array = np.array(ref_shape) ref_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['ref_in'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT32, - dims = ref_array.shape, - vals = ref_array.flatten().astype(int))) + inputs=[], + outputs=['ref_in'], + value=onnx.helper.make_tensor(name='const_tensor', + data_type=onnx.TensorProto.INT32, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int))) reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"]) graph = helper.make_graph([ref_node, reshape_node], "reshape_test", - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) model = helper.make_model(graph, producer_name='reshape_test') @@ -135,28 +141,29 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + def test_shape(): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) ref_array = np.array(ref_shape) ref_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['ref_in'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT32, - dims = ref_array.shape, - vals = ref_array.flatten().astype(int))) + inputs=[], + outputs=['ref_in'], + value=onnx.helper.make_tensor(name='const_tensor', + data_type=onnx.TensorProto.INT32, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int))) reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"]) shape_node = helper.make_node("Shape", ['out'], ['final_out']) graph = helper.make_graph([ref_node, reshape_node, shape_node], "shape_test", - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("final_out", - TensorProto.FLOAT, list(ref_shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("final_out", + TensorProto.FLOAT, list(ref_shape))]) model = helper.make_model(graph, producer_name='shape_test') @@ -166,6 +173,7 @@ def test_shape(): tvm.testing.assert_allclose(ref_shape, tvm_out) + def _test_power_iteration(x_shape, y_shape): if isinstance(y_shape, int): y_shape = [y_shape] @@ -179,12 +187,12 @@ def _test_power_iteration(x_shape, y_shape): graph = helper.make_graph([res], 'power_test', - inputs = [helper.make_tensor_value_info("x", - TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("y", - TensorProto.FLOAT, list(y_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(np_res.shape))]) + inputs=[helper.make_tensor_value_info("x", + TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("y", + TensorProto.FLOAT, list(y_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(np_res.shape))]) model = helper.make_model(graph, producer_name='power_test') @@ -192,11 +200,13 @@ def _test_power_iteration(x_shape, y_shape): tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape) tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5) + def test_power(): _test_power_iteration((1, 3), (1)) _test_power_iteration((2, 3), (2, 3)) _test_power_iteration((2, 3), (1, 3)) + def test_squeeze(): in_shape = (1, 3, 1, 3, 1, 1) out_shape = (3, 3) @@ -204,10 +214,10 @@ def test_squeeze(): graph = helper.make_graph([y], 'squeeze_test', - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='squeeze_test') @@ -217,20 +227,21 @@ def test_squeeze(): tvm.testing.assert_allclose(out_shape, tvm_out.shape) + def test_flatten(): in_shape = (1, 3, 4, 4) axis = 1 ref_shape = (1, 48) - flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis = axis) + flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis) graph = helper.make_graph([flatten_node], "flatten_test", - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) model = helper.make_model(graph, producer_name='flatten_test') @@ -240,6 +251,7 @@ def test_flatten(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + def test_unsqueeze(): in_shape = (3, 3) axis = (0, 3, 4) @@ -248,10 +260,10 @@ def test_unsqueeze(): graph = helper.make_graph([y], 'squeeze_test', - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='squeeze_test') @@ -261,6 +273,7 @@ def test_unsqueeze(): tvm.testing.assert_allclose(out_shape, tvm_out.shape) + def verify_gather(in_shape, indices, axis, dtype): x = np.random.uniform(size=in_shape).astype(dtype) indices = np.array(indices, dtype="int32") @@ -270,52 +283,123 @@ def verify_gather(in_shape, indices, axis, dtype): graph = helper.make_graph([y], 'gather_test', - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", - TensorProto.INT32, list(indices.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_np.shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", + TensorProto.INT32, list(indices.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) model = helper.make_model(graph, producer_name='gather_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape) + tvm_out = get_tvm_output( + model, [x, indices], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out) + def test_gather(): verify_gather((4,), [1], 0, 'int32') - verify_gather((1,4), [0], 0, 'int32') - verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32') - verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32') - verify_gather((3,3,3), [[[1,0]]], -1, 'int32') - verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32') + verify_gather((1, 4), [0], 0, 'int32') + verify_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32') + verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32') + verify_gather((3, 3, 3), [[[1, 0]]], -1, 'int32') + verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32') + -def _test_slice_iteration(indata, outdata, starts, ends, axes=None): +def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): if axes: - y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends) + y = helper.make_node( + "Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends) else: - y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends) + y = helper.make_node( + "Slice", ['in'], ['out'], starts=starts, ends=ends) graph = helper.make_graph([y], 'slice_test', - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(outdata.shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='slice_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32') + tvm_out = get_tvm_output( + model, indata, target, ctx, outdata.shape, 'float32', opset=1) tvm.testing.assert_allclose(outdata, tvm_out) + +def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None): + if isinstance(starts, int): + starts = (starts, ) + if isinstance(ends, int): + ends = (ends, ) + if isinstance(axes, int): + axes = (axes, ) + starts = np.asarray(starts) + ends = np.asarray(ends) + inputs = [ + helper.make_tensor_value_info("data", TensorProto.FLOAT, + list(indata.shape)), + helper.make_tensor_value_info("starts", TensorProto.INT32, + list(starts.shape)), + helper.make_tensor_value_info("ends", TensorProto.INT32, + list(ends.shape)) + ] + initializer = [ + helper.make_tensor("starts", TensorProto.INT32, list(starts.shape), + starts), + helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends) + ] + + if axes: + axes = np.asarray(axes) + y = helper.make_node("Slice", ["data", "starts", "ends", "axes"], + ["out"]) + inputs.append( + helper.make_tensor_value_info("axes", TensorProto.INT32, + list(axes.shape))) + initializer.append( + helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), + axes)) + else: + y = helper.make_node("Slice", ["data", "starts", "ends"], ["out"]) + + graph = helper.make_graph([y], + 'slice_test', + inputs=inputs, + outputs=[ + helper.make_tensor_value_info( + "out", TensorProto.FLOAT, + list(outdata.shape)) + ], + initializer=initializer) + model = helper.make_model(graph, producer_name='slice_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, + indata, + target, + ctx, + outdata.shape, + 'float32', + opset=10) + + tvm.testing.assert_allclose(outdata, tvm_out) + + def test_slice(): x = np.random.randn(20, 10, 5).astype(np.float32) - _test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) - _test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) - _test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1)) - _test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1)) + _test_slice_iteration_v1(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) + _test_slice_iteration_v1(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) + _test_slice_iteration_v1(x, x[:, 1:1000], (1), (1000), (1)) + _test_slice_iteration_v1(x, x[:, 0:-1], (0), (-1), (1)) + _test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) + _test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) + _test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1)) + _test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1)) + def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): indata = np.random.uniform(-1, 1, size=inshape).astype(dtype) @@ -325,24 +409,29 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): graph = helper.make_graph([y], opname+'_test', - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(outdata.shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name=opname+'_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) + tvm_out = get_tvm_output( + model, indata, target, ctx, outdata.shape, dtype) tvm.testing.assert_allclose(outdata, tvm_out) + def test_floor(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {}) + _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, + {}, 'float32', 'Floor', {}) + def test_ceil(): _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {}) + def test_clip(): _test_onnx_op_elementwise((2, 4, 5, 6), np.clip, @@ -351,6 +440,38 @@ def test_clip(): 'Clip', {'min': -1.0, 'max': 1.0}) + +def test_onehot(): + indices_shape = [10] + indices_array = np.random.randint( + low=0, high=9, size=indices_shape, dtype='int32') + depth = 10 + values = np.asarray([0, 1]) + out_np = np.eye(depth)[indices_array.reshape(-1)] + + onehot_node = helper.make_node( + "OneHot", ["indices", "depth", "values"], ["out"]) + + graph = helper.make_graph([onehot_node], + "onehot_test", + inputs=[helper.make_tensor_value_info("indices", + TensorProto.INT32, indices_shape), + helper.make_tensor_value_info("depth", + TensorProto.INT32, [1]), + helper.make_tensor_value_info("values", + TensorProto.INT32, values.shape)], + initializer=[helper.make_tensor("depth", TensorProto.INT32, [1], [depth]), + helper.make_tensor("values", TensorProto.INT32, values.shape, values)], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)]) + + model = helper.make_model(graph, producer_name="onehot_test") + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [indices_array], target, ctx, out_np.shape) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + + def test_matmul(): a_shape = (4, 3) b_shape = (3, 4) @@ -363,52 +484,84 @@ def test_matmul(): graph = helper.make_graph([mul_node], "matmul_test", - inputs = [helper.make_tensor_value_info("a", - TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", - TensorProto.FLOAT, list(b_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_np.shape))]) + inputs=[helper.make_tensor_value_info("a", + TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", + TensorProto.FLOAT, list(b_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) + + model = helper.make_model(graph, producer_name='matmul_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [a_array, b_array], target, ctx, out_np.shape) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_batch_matmul(): + a_shape = (2, 3, 4, 3) + b_shape = (2, 3, 3, 4) + + a_array = np.random.uniform(size=a_shape).astype('float32') + b_array = np.random.uniform(size=b_shape).astype('float32') + out_np = np.matmul(a_array, b_array) + + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + + graph = helper.make_graph([mul_node], + "matmul_test", + inputs=[helper.make_tensor_value_info("a", + TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", + TensorProto.FLOAT, list(b_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) model = helper.make_model(graph, producer_name='matmul_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape) + tvm_out = get_tvm_output( + model, [a_array, b_array], target, ctx, out_np.shape) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): in_array = np.random.uniform(size=shape).astype(dtype) - if alpha == None and beta == None and bias==None: + if alpha == None and beta == None and bias == None: alpha = 0.0001 beta = 0.75 bias = 1.0 - node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize) + node = onnx.helper.make_node( + 'LRN', inputs=['in'], outputs=['out'], size=nsize) else: node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha, beta=beta, bias=bias, size=nsize) graph = helper.make_graph([node], "lrn_test", - inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], - outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))]) + inputs=[helper.make_tensor_value_info( + "in", TensorProto.FLOAT, list(shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))]) model = helper.make_model(graph, producer_name='lrn_test') def _get_python_lrn(): square_sum = np.zeros(shape).astype(dtype) for n, c, h, w in np.ndindex(in_array.shape): square_sum[n, c, h, w] = sum(in_array[n, - max(0, c - int(math.floor((nsize - 1) / 2))): \ - min(5, c + int(math.ceil((nsize - 1) / 2)) + 1), - h, - w] ** 2) + max(0, c - int(math.floor((nsize - 1) / 2))): + min(5, c + int(math.ceil((nsize - 1) / 2)) + 1), + h, + w] ** 2) py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta) return py_out for target, ctx in ctx_list(): input_name = model.graph.input[0].name py_out = _get_python_lrn() - tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 'float32') + tvm_out = get_tvm_output( + model, in_array, target, ctx, py_out.shape, 'float32') tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5) @@ -436,20 +589,22 @@ def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32) node = onnx.helper.make_node( - 'InstanceNormalization', - inputs=['x', 'gamma', 'beta'], - outputs=['y'], - epsilon=epsilon, - ) + 'InstanceNormalization', + inputs=['x', 'gamma', 'beta'], + outputs=['y'], + epsilon=epsilon, + ) graph = helper.make_graph([node], "instance_norm_test", inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), - helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)), + helper.make_tensor_value_info( + "gamma", TensorProto.FLOAT, (shape[1],)), helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))], outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))]) model = helper.make_model(graph, producer_name='instance_norm_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32') + tvm_out = get_tvm_output( + model, [x, gamma, beta], target, ctx, shape, 'float32') tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5) @@ -464,103 +619,122 @@ def _test_upsample_nearest(): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) + y = helper.make_node("Upsample", ['in'], [ + 'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") + out_array = topi.testing.upsampling_python( + in_array, (scale, scale), "NCHW") graph = helper.make_graph([y], 'upsample_nearest_test', - inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + inputs=[helper.make_tensor_value_info( + "in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='upsample_nearest_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output( + model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out) + def _test_upsample_bilinear(): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0]) + y = helper.make_node("Upsample", ['in'], [ + 'out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") + out_array = topi.testing.bilinear_resize_python( + in_array, (3*scale, 3*scale), "NCHW") graph = helper.make_graph([y], 'upsample_bilinear_test', - inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + inputs=[helper.make_tensor_value_info( + "in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='upsample_bilinear_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output( + model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + def _test_upsample_bilinear_opset9(): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3*scale, 3*scale) - y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear') - scales=[1.0, 1.0, 2.0, 2.0] + y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear') + scales = [1.0, 1.0, 2.0, 2.0] in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") + out_array = topi.testing.bilinear_resize_python( + in_array, (3*scale, 3*scale), "NCHW") ref_array = np.array(scales) ref_node = helper.make_node('Constant', - inputs=[], - outputs=['scales'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = TensorProto.FLOAT, - dims = ref_array.shape, - vals = ref_array.flatten().astype(float))) + inputs=[], + outputs=['scales'], + value=onnx.helper.make_tensor(name='const_tensor', + data_type=TensorProto.FLOAT, + dims=ref_array.shape, + vals=ref_array.flatten().astype(float))) graph = helper.make_graph([ref_node, y], 'upsample_bilinear_opset9_test', - inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + inputs=[helper.make_tensor_value_info( + "in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) - model = helper.make_model(graph, producer_name='upsample_bilinear_opset9_test') + model = helper.make_model( + graph, producer_name='upsample_bilinear_opset9_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32') + tvm_out = get_tvm_output( + model, in_array, target, ctx, out_shape, 'float32') tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + def test_upsample(): _test_upsample_nearest() _test_upsample_bilinear() _test_upsample_bilinear_opset9() + def _test_softmax(inshape, axis): opname = 'Softmax' indata = np.random.uniform(size=inshape).astype(np.float32) outshape = inshape outdata = topi.testing.softmax_python(indata) if isinstance(axis, int): - y = helper.make_node(opname, ['in'], ['out'], axis = axis) + y = helper.make_node(opname, ['in'], ['out'], axis=axis) elif axis is None: y = helper.make_node(opname, ['in'], ['out']) graph = helper.make_graph([y], opname+'_test', - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(indata.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(outdata.shape))]) + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name=opname+'_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32') + tvm_out = get_tvm_output( + model, indata, target, ctx, outshape, 'float32') tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + def test_softmax(): _test_softmax((1, 10), None) _test_softmax((1, 10), 1) + def verify_min(input_dim): dtype = 'float32' @@ -574,25 +748,28 @@ def verify_min(input_dim): graph = helper.make_graph([min_node], "Min_test", - inputs = [helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", - TensorProto.FLOAT, list(input_dim))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + inputs=[helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", + TensorProto.FLOAT, list(input_dim))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) model = helper.make_model(graph, producer_name='Min_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm_out = get_tvm_output( + model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + def test_forward_min(): verify_min((1, 3, 20, 20)) verify_min((20, 20)) + def verify_max(input_dim): dtype = 'float32' @@ -606,25 +783,28 @@ def verify_max(input_dim): graph = helper.make_graph([max_node], "Max_test", - inputs = [helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", - TensorProto.FLOAT, list(input_dim))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + inputs=[helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", + TensorProto.FLOAT, list(input_dim))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) model = helper.make_model(graph, producer_name='Max_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm_out = get_tvm_output( + model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + def test_forward_max(): verify_max((1, 3, 20, 20)) verify_max((20, 20)) + def verify_mean(input_dim): dtype = 'float32' @@ -638,25 +818,28 @@ def verify_mean(input_dim): graph = helper.make_graph([mean_node], "Mean_test", - inputs = [helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", - TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", - TensorProto.FLOAT, list(input_dim))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + inputs=[helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", + TensorProto.FLOAT, list(input_dim))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) model = helper.make_model(graph, producer_name='Mean_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm_out = get_tvm_output( + model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + def test_forward_mean(): verify_mean((1, 3, 20, 20)) verify_mean((20, 20)) + def verify_hardsigmoid(input_dim, alpha, beta): dtype = 'float32' @@ -664,14 +847,15 @@ def verify_hardsigmoid(input_dim, alpha, beta): b_np = np.clip(a_np1 * alpha + beta, 0, 1) - hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta) + hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], [ + "out"], alpha=alpha, beta=beta) graph = helper.make_graph([hardsigmoid_node], "HardSigmoid_test", - inputs = [helper.make_tensor_value_info("a_np1", - TensorProto.FLOAT, list(input_dim))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(b_np.shape))]) + inputs=[helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) model = helper.make_model(graph, producer_name='HardSigmoid_test') @@ -679,10 +863,12 @@ def verify_hardsigmoid(input_dim, alpha, beta): tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + def test_forward_hardsigmoid(): verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6) verify_hardsigmoid((20, 20), 0.3, 0.4) + def verify_argmin(input_dim, axis=None, keepdims=None): def _argmin_numpy(data, axis=0, keepdims=True): result = np.argmin(data, axis=axis) @@ -717,17 +903,19 @@ def _argmin_numpy(data, axis=0, keepdims=True): keepdims=keepdims) graph = helper.make_graph([node], "argmin_test", - inputs = [helper.make_tensor_value_info("a_np1", - TensorProto.INT32, list(a_np1.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.INT32, list(b_np.shape))]) + inputs=[helper.make_tensor_value_info("a_np1", + TensorProto.INT32, list(a_np1.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.INT32, list(b_np.shape))]) model = helper.make_model(graph, producer_name='argmin_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) + tvm_out = get_tvm_output( + model, [a_np1], target, ctx, b_np.shape, b_np.dtype) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + def verify_argmax(input_dim, axis=None, keepdims=None): def _argmax_numpy(data, axis=0, keepdims=True): result = np.argmax(data, axis=axis) @@ -763,66 +951,72 @@ def _argmax_numpy(data, axis=0, keepdims=True): graph = helper.make_graph([node], "argmax_test", - inputs = [helper.make_tensor_value_info("a_np1", - TensorProto.INT32, list(a_np1.shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.INT32, list(b_np.shape))]) + inputs=[helper.make_tensor_value_info("a_np1", + TensorProto.INT32, list(a_np1.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.INT32, list(b_np.shape))]) model = helper.make_model(graph, producer_name='argmax_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) + tvm_out = get_tvm_output( + model, [a_np1], target, ctx, b_np.shape, b_np.dtype) tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + def test_forward_arg_min_max(): '''Verify argmin and argmax''' - verify_argmin([3,4,4]) - verify_argmax([3,4,4]) - verify_argmin([3,4,4], axis=1) - verify_argmax([3,4,4], axis=0) - verify_argmin([3,4,4], keepdims=0) - verify_argmax([3,4,4], keepdims=1) - for axis in [None, 0,1,2]: - for keepdims in [None, True,False]: - verify_argmin([3,4,4], axis, keepdims) - verify_argmax([3,4,4], axis, keepdims) - -def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs): - input_a = np.random.uniform(size=input_dim).astype(dtype) - out = np.empty(shape=out_dim, dtype=dtype) + verify_argmin([3, 4, 4]) + verify_argmax([3, 4, 4]) + verify_argmin([3, 4, 4], axis=1) + verify_argmax([3, 4, 4], axis=0) + verify_argmin([3, 4, 4], keepdims=0) + verify_argmax([3, 4, 4], keepdims=1) + for axis in [None, 0, 1, 2]: + for keepdims in [None, True, False]: + verify_argmin([3, 4, 4], axis, keepdims) + verify_argmax([3, 4, 4], axis, keepdims) + + +def verify_constantofshape(input_dim, value, dtype): + out = np.empty(shape=input_dim, dtype=dtype) out.fill(value) - if is_shape == True: - fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs) - else: - fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs) - - if is_shape == True: - inputs = [] - else: - inputs = [helper.make_tensor_value_info("input_a", - TensorProto.FLOAT, list(input_dim))] - - graph = helper.make_graph([fill_node], - "fill_test", - inputs, - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out.shape))]) + fill_node = helper.make_node("ConstantOfShape", ["input"], ["output"], + value=helper.make_tensor( + 'value', + mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], + (1, ), (value, ))) + + inputs = [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim) + ] + + graph = helper.make_graph( + [fill_node], + "fill_test", + inputs, + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, + list(out.shape)) + ], + initializer=[ + helper.make_tensor("input", TensorProto.INT32, (len(input_dim), ), + input_dim) + ]) model = helper.make_model(graph, producer_name='fill_test') for target, ctx in ctx_list(): - if is_shape == True: - tvm_out = get_tvm_output(model, [], target, ctx, out.shape) - else: - tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape) + tvm_out = get_tvm_output(model, [], target, ctx, out.shape) tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5) -def test_constantfill(): - verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32') - verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32') - verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6)) + +def test_constantofshape(): + verify_constantofshape((2, 3, 4, 5), 10, 'float32') + verify_constantofshape((3, 3), 0, 'int32') + verify_constantofshape((1, 2, 3), -1, 'float32') def verify_pad(indata, pads, mode='constant', value=0.0): @@ -841,7 +1035,8 @@ def verify_pad(indata, pads, mode='constant', value=0.0): pads=pads, ) else: - outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) + outdata = np.pad(indata, pad_width=np_pads, + mode='constant', constant_values=value) node = helper.make_node( 'Pad', inputs=['input'], @@ -852,22 +1047,30 @@ def verify_pad(indata, pads, mode='constant', value=0.0): ) graph = helper.make_graph([node], 'pad_test', - inputs = [helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs = [helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='pad_test') # tvm result for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32') + tvm_out = get_tvm_output( + model, indata, target, ctx, outdata.shape, 'float32') tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + def test_pad(): - verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0) - verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0) - verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0) - verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') - verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + verify_pad(np.random.randn(2, 2).astype( + np.float32), [0, 1, 0, 0], 'constant', 0.0) + verify_pad(np.random.randn(2, 3).astype( + np.float32), [1, 0, 0, 1], 'constant', 0.0) + verify_pad(np.random.randn(3, 2).astype( + np.float32), [0, 0, 1, 0], 'constant', 5.0) + verify_pad(np.random.randn(1, 3, 4, 5).astype( + np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') + verify_pad(np.random.randn(1, 3, 4, 5).astype( + np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + def verify_reduce_x(name, indata, axis, keepdims): indata = np.array(indata).astype(np.float32) @@ -893,16 +1096,18 @@ def verify_reduce_x(name, indata, axis, keepdims): axes=axis, keepdims=keepdims) graph = helper.make_graph([node], '{}_test'.format(name), - inputs = [helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs = [helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='{}_test'.format(name)) # tvm result for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32') + tvm_out = get_tvm_output( + model, indata, target, ctx, outdata.shape, 'float32') tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + def test_reduce_max(): verify_reduce_x("ReduceMax", np.random.randn(3, 2, 2).astype(np.float32), @@ -914,6 +1119,7 @@ def test_reduce_max(): np.random.randn(3, 3, 3).astype(np.float32), axis=(1,), keepdims=1) + def test_reduce_min(): verify_reduce_x("ReduceMin", np.random.randn(3, 2, 2).astype(np.float32), @@ -925,6 +1131,7 @@ def test_reduce_min(): np.random.randn(3, 3, 3).astype(np.float32), axis=(1,), keepdims=1) + def test_reduce_sum(): verify_reduce_x("ReduceSum", np.random.randn(3, 2, 2).astype(np.float32), @@ -936,6 +1143,7 @@ def test_reduce_sum(): np.random.randn(3, 3, 3).astype(np.float32), axis=(1,), keepdims=1) + def test_reduce_mean(): verify_reduce_x("ReduceMean", np.random.randn(3, 2, 2).astype(np.float32), @@ -947,40 +1155,52 @@ def test_reduce_mean(): np.random.randn(3, 3, 3).astype(np.float32), axis=(1,), keepdims=1) + def verify_split(indata, outdatas, split, axis=0): indata = np.array(indata).astype(np.float32) outdatas = [np.array(o).astype(np.float32) for o in outdatas] + if split: + split_index = range(len(split)) + else: + split_index = range(len(outdatas)) node = helper.make_node( 'Split', inputs=['input'], - outputs=['output_{}'.format(i) for i in range(len(split))], + outputs=['output_{}'.format(i) for i in range(len(split_index))], axis=axis, split=split ) graph = helper.make_graph([node], 'split_test', - inputs = [helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs = [helper.make_tensor_value_info("output_{}".format(i), - TensorProto.FLOAT, list(outdatas[i].shape)) - for i in range(len(split)) - ]) + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("output_{}".format(i), + TensorProto.FLOAT, list(outdatas[i].shape)) + for i in range(len(split_index)) + ]) model = helper.make_model(graph, producer_name='split_test') for target, ctx in ctx_list(): output_shape = [o.shape for o in outdatas] output_type = ['float32', 'float32', 'float32'] - tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type) + tvm_out = get_tvm_output( + model, indata, target, ctx, output_shape, output_type) for o, t in zip(outdatas, tvm_out): tvm.testing.assert_allclose(o, t) + def test_split(): # 1D - verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0) - verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0) + verify_split([1., 2., 3., 4., 5., 6.], [ + [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0) + verify_split([1., 2., 3., 4., 5., 6.], [ + [1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0) # 2D verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]], [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1) + # Split evenly (unstack) + verify_split([1, 2, 3], [[1], [2], [3]], False) + def test_binary_ops(): in_shape = (1, 2, 3, 3) @@ -993,13 +1213,13 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): else: z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1) graph = helper.make_graph([z], - '_test', - inputs = [helper.make_tensor_value_info("in1", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("in2", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) + '_test', + inputs=[helper.make_tensor_value_info("in1", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("in2", + TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='_test') for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, [x, y], target, ctx) @@ -1008,11 +1228,11 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): x = np.random.uniform(size=in_shape).astype(dtype) y = np.random.uniform(size=in_shape).astype(dtype) z = np.random.uniform(size=(3,)).astype(dtype) - verify_binary_ops("Add",x, y, x + y, broadcast=None) + verify_binary_ops("Add", x, y, x + y, broadcast=None) verify_binary_ops("Add", x, z, x + z, broadcast=True) verify_binary_ops("Sub", x, y, x - y, broadcast=None) verify_binary_ops("Sub", x, z, x - z, broadcast=True) - verify_binary_ops("Mul",x, y, x * y, broadcast=None) + verify_binary_ops("Mul", x, y, x * y, broadcast=None) verify_binary_ops("Mul", x, z, x * z, broadcast=True) verify_binary_ops("Div", x, y, x / y, broadcast=None) verify_binary_ops("Div", x, z, x / z, broadcast=True) @@ -1021,6 +1241,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): verify_binary_ops("Less", x, y, x < y, broadcast=True) verify_binary_ops("Equal", x, y, x == y, broadcast=True) + def test_single_ops(): in_shape = (1, 2, 3, 3) dtype = "float32" @@ -1029,29 +1250,30 @@ def test_single_ops(): def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): z = helper.make_node(op, ['in1'], ['out']) graph = helper.make_graph([z], - '_test', - inputs = [helper.make_tensor_value_info("in1", - TensorProto.FLOAT, list(in_shape)),], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(out_shape))]) + '_test', + inputs=[helper.make_tensor_value_info("in1", + TensorProto.FLOAT, list(in_shape)), ], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) model = helper.make_model(graph, producer_name='_test') for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, [x], target, ctx) tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) x = np.random.uniform(size=in_shape).astype(dtype) - verify_single_ops("Neg",x, -x) - verify_single_ops("Abs",x, np.abs(x)) - verify_single_ops("Reciprocal",x, 1/x) - verify_single_ops("Sqrt",x, np.sqrt(x)) - verify_single_ops("Relu",x, np.maximum(x, 0)) - verify_single_ops("Exp",x, np.exp(x)) - verify_single_ops("Log",x, np.log(x)) - verify_single_ops("Log",x, np.log(x)) - verify_single_ops("Tanh",x, np.tanh(x)) - verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x))) - verify_single_ops("Softsign",x, x / (1 + np.abs(x))) - verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x))) + verify_single_ops("Neg", x, -x) + verify_single_ops("Abs", x, np.abs(x)) + verify_single_ops("Reciprocal", x, 1/x) + verify_single_ops("Sqrt", x, np.sqrt(x)) + verify_single_ops("Relu", x, np.maximum(x, 0)) + verify_single_ops("Exp", x, np.exp(x)) + verify_single_ops("Log", x, np.log(x)) + verify_single_ops("Log", x, np.log(x)) + verify_single_ops("Tanh", x, np.tanh(x)) + verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x))) + verify_single_ops("Softsign", x, x / (1 + np.abs(x))) + verify_single_ops("SoftPlus", x, np.log(1 + np.exp(x))) + def test_leaky_relu(): def leaky_relu_x(x, alpha): @@ -1063,6 +1285,7 @@ def leaky_relu_x(x, alpha): 'LeakyRelu', {'alpha': 0.25}) + def test_elu(): def elu_x(x, alpha): return np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) @@ -1073,6 +1296,7 @@ def elu_x(x, alpha): 'Elu', {'alpha': 0.25}) + def test_selu(): def selu_x(x, alpha, gamma): return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) @@ -1083,6 +1307,7 @@ def selu_x(x, alpha, gamma): 'Selu', {'alpha': 0.25, 'gamma': 0.3}) + def test_ThresholdedRelu(): def ThresholdedRelu_x(x, alpha): out_np = np.clip(x, alpha, np.inf) @@ -1095,6 +1320,7 @@ def ThresholdedRelu_x(x, alpha): 'ThresholdedRelu', {'alpha': 0.25}) + def test_ScaledTanh(): def ScaledTanh_x(x, alpha, beta): return alpha * np.tanh(beta * x) @@ -1105,6 +1331,7 @@ def ScaledTanh_x(x, alpha, beta): 'ScaledTanh', {'alpha': 0.25, 'beta': 0.3}) + def test_ParametricSoftplus(): def ParametricSoftplus_x(x, alpha, beta): return alpha * np.log(np.exp(beta * x) + 1) @@ -1115,6 +1342,7 @@ def ParametricSoftplus_x(x, alpha, beta): 'ParametricSoftplus', {'alpha': 0.25, 'beta': 0.3}) + def test_Scale(): def Scale_x(x, scale): return scale * x @@ -1125,6 +1353,7 @@ def Scale_x(x, scale): 'Scale', {'scale': 0.25}) + def test_LogSoftmax(): _test_onnx_op_elementwise((1, 4), topi.testing.log_softmax_python, @@ -1138,7 +1367,8 @@ def check_torch_conversion(model, input_size): dummy_input = torch.randn(*input_size) file_name = '{}.onnx'.format(model.__name__) # Set verbose=True for more output - torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) + torch.onnx.export(model(), dummy_input, file_name, + export_params=True, verbose=False) onnx_model = onnx.load(file_name) for target, ctx in ctx_list(): input_data = np.random.uniform(size=input_size).astype('int32') @@ -1146,13 +1376,14 @@ def check_torch_conversion(model, input_size): tvm_out = get_tvm_output(onnx_model, input_data, target, ctx) tvm.testing.assert_allclose(c2_out, tvm_out) + def test_resnet(): - check_torch_conversion(torchvision.models.resnet18, (1,3,224,224)) + check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224)) # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) # def test_alexnet(): - # Torch's ONNX export does not support the adaptive pooling used by AlexNet? - # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) +# Torch's ONNX export does not support the adaptive pooling used by AlexNet? +# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) # Torch's ONNX export does not support the adaptive pooling used by vgg16? # def test_vgg16(): @@ -1163,11 +1394,13 @@ def test_resnet(): # # Torch's ONNX export does not support the max pooling used by Squezenet # check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) + def test_densenet(): - check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) + check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224)) + def test_inception(): - check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) + check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224)) # TODO(@jroesch): Update Torch + ONNX to support this import. # def test_googlenet(): @@ -1177,6 +1410,7 @@ def test_inception(): # def test_shufflenetv2(): # check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) + def test_sign(): def Sign_x(x): return np.sign(x) @@ -1196,7 +1430,8 @@ def verify_not(indata, dtype): graph = helper.make_graph([node], 'not_test', - inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))], + inputs=[helper.make_tensor_value_info( + "in", TensorProto.BOOL, list(x.shape))], outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) model = helper.make_model(graph, producer_name='not_test') @@ -1262,31 +1497,70 @@ def test_and(): verify_and(indata=[x, y], dtype=bool) -def verify_tile(indata, outdata, **kwargs): +def verify_tile_v1(indata, outdata, **kwargs): node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs) graph = helper.make_graph([node], 'tile_test', - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + inputs=[helper.make_tensor_value_info( + "in", TensorProto.FLOAT, list(indata.shape))], outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='tile_test') for target, ctx in ctx_list(): - tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape) + tvm_out = get_tvm_output( + model, [indata], target, ctx, outdata.shape, opset=1) + tvm.testing.assert_allclose(outdata, tvm_out) + + +def verify_tile_v6(indata, repeats, outdata): + node = helper.make_node('Tile', + inputs=['input', 'repeats'], + outputs=['out']) + graph = helper.make_graph( + [node], + 'tile_test', + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, + list(indata.shape)), + helper.make_tensor_value_info("repeats", TensorProto.INT64, + list(repeats.shape)) + ], + outputs=[ + helper.make_tensor_value_info("out", TensorProto.FLOAT, + list(outdata.shape)) + ], + initializer=[ + helper.make_tensor("repeats", TensorProto.INT64, + list(repeats.shape), repeats) + ]) + + model = helper.make_model(graph, producer_name='tile_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [indata], + target, + ctx, + outdata.shape, + opset=6) tvm.testing.assert_allclose(outdata, tvm_out) def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) - repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64) + repeats = np.random.randint( + low=1, high=10, size=(np.ndim(x),)).astype(np.int64) z = np.tile(x, repeats) - verify_tile(x, z, repeats=repeats) + verify_tile_v1(x, z, repeats=repeats) + verify_tile_v6(x, repeats, z) + def verify_erf(indata, outdata): node = helper.make_node('Erf', inputs=['in'], outputs=['out']) graph = helper.make_graph([node], 'erf_test', - inputs=[helper.make_tensor_value_info('in', TensorProto.FLOAT, list(indata.shape))], + inputs=[helper.make_tensor_value_info( + 'in', TensorProto.FLOAT, list(indata.shape))], outputs=[helper.make_tensor_value_info('out', TensorProto.FLOAT, list(outdata.shape))]) model = helper.make_model(graph, producer_name='erf_test') @@ -1294,6 +1568,7 @@ def verify_erf(indata, outdata): tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape) tvm.testing.assert_allclose(outdata, tvm_out) + def test_erf(): x = np.random.rand(2, 3, 4, 6).astype(np.float32) z = scipy.special.erf(x) @@ -1337,7 +1612,9 @@ def test_where(): test_floor() test_ceil() test_clip() + test_onehot() test_matmul() + test_batch_matmul() test_gather() test_lrn() test_instance_norm() @@ -1348,7 +1625,7 @@ def test_where(): test_forward_hardsigmoid() test_forward_arg_min_max() test_softmax() - test_constantfill() + test_constantofshape() test_reduce_max() test_reduce_min() test_reduce_sum()