diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py index f95e134cf0dd..49f53df1174f 100644 --- a/nnvm/python/nnvm/frontend/__init__.py +++ b/nnvm/python/nnvm/frontend/__init__.py @@ -7,10 +7,3 @@ from .darknet import from_darknet from .tensorflow import from_tensorflow from .caffe2 import from_caffe2 -from .common import raise_not_supported, get_nnvm_op, required_attr, \ - warn_not_used, parse_tshape, parse_bool_str -from tvm.error_handling import raise_attribute_required, \ - raise_attribute_invalid, \ - raise_operator_unimplemented, \ - raise_attribute_unimplemented, \ - warn_not_used diff --git a/nnvm/python/nnvm/frontend/caffe2.py b/nnvm/python/nnvm/frontend/caffe2.py index 32d08678a0f8..63b7913dd755 100755 --- a/nnvm/python/nnvm/frontend/caffe2.py +++ b/nnvm/python/nnvm/frontend/caffe2.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs import tvm from nnvm import symbol as _sym -from nnvm.frontend.common import get_nnvm_op, Renamer, AttrConverter as AttrCvt +from .common import get_nnvm_op from .onnx_caffe2_utils import dimension_picker, dimension_constraint, infer_channels, revert_caffe2_pad from . import onnx @@ -73,7 +73,8 @@ def get_converter(cls): if hasattr(cls, '_impl'): return getattr(cls, '_impl') - raise_operator_unimplemented(cls.__name__) + raise tvm.error.OpNotImplemented( + 'Operator {} is not implemented in frontend Caffe2.'.format(cls.__name__)) _caffe2_internal_args = { @@ -175,7 +176,7 @@ def _get_axis_from_order_str(order): return 1 if order == 'NHWC': return 3 - raise_attribute_invalid(order, 'storage order', 'concat') + raise tvm.error.OpAttributeInvalid('Value {} in attribute {} of operator {} is not valid.'.format(order, 'order', 'Concat')) return AttrCvt( op_name='concatenate', @@ -425,7 +426,8 @@ def _convert_operator(self, # Add a sanitizing step to convert all byte strings in args to strings sym = convert_map[op_type](inputs, args, self._params) else: - raise_operator_unimplemented(op_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Caffe2.'.format(op_type)) return sym diff --git a/nnvm/python/nnvm/frontend/common.py b/nnvm/python/nnvm/frontend/common.py index 58ce6703b28d..5a8defdb3d6e 100644 --- a/nnvm/python/nnvm/frontend/common.py +++ b/nnvm/python/nnvm/frontend/common.py @@ -7,13 +7,15 @@ def get_nnvm_op(op_name): op = getattr(_sym, op_name) if not op: - raise_operator_unimplemented(op_name) + raise OpNotImplemented( + 'Operator {} is not supported.'.format(op)) return op def required_attr(attr, key, op_name): assert isinstance(attr, dict) if key not in attr: - raise_attribute_required(key, op_name) + raise OpAttributeRequired( + 'Required attribute {} not found in operator {}'.format(key, op_name)) return attr[key] def parse_tshape(tshape): diff --git a/nnvm/python/nnvm/frontend/coreml.py b/nnvm/python/nnvm/frontend/coreml.py index e7c5a0d7eda8..1483e95cf6f0 100644 --- a/nnvm/python/nnvm/frontend/coreml.py +++ b/nnvm/python/nnvm/frontend/coreml.py @@ -2,11 +2,10 @@ """CoreML frontend.""" from __future__ import absolute_import as _abs import numpy as np - import tvm +from .common import SymbolTable from .. import symbol as _sym from .._base import string_types -from .common import SymbolTable __all__ = ['from_coreml'] @@ -83,7 +82,8 @@ def BatchnormLayerParams(op, insym, symtab): """Get layer of batchnorm parameter""" # this changes the symbol if op.instanceNormalization: - raise_operator_unimplemented('instance normalization') + msg = 'Operator "instance normalization" is not supported in frontend CoreML.' + raise tvm.error.OpNotImplemented(msg) else: params = {'gamma':symtab.new_const(list(op.gamma.floatValue)), 'beta':symtab.new_const(list(op.beta.floatValue)), @@ -136,7 +136,8 @@ def ActivationParams(op, insym, symtab): betasym = symtab.new_const(beta) return _sym.broadcast_mul(_sym.log(_sym.broadcast_add( _sym.exp(insym), betasym)), alphasym) - raise_operator_unimplemented(whichActivation) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(whichActivation)) def ScaleLayerParams(op, insym, symtab): """Scale layer params.""" @@ -158,7 +159,8 @@ def PoolingLayerParams(op, insym, symtab): return _sym.global_max_pool2d(insym) if op.type == 1: return _sym.global_avg_pool2d(insym) - raise_operator_unimplemented('pooling (not max or average)') + raise tvm.error.OpNotImplemented( + 'Operator pooling (not max or average) is not supported in frontend CoreML.') else: params = {'pool_size':list(op.kernelSize), @@ -178,8 +180,8 @@ def PoolingLayerParams(op, insym, symtab): params['padding'] = padding params['ceil_mode'] = True else: - raise_attribute_invalid(op.WhichOneof('PoolingPaddingType'), - 'PoolingPaddingType', 'pooling') + msg = 'Value {} in attribute PoolingPaddingType of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(op.WhichOneof('PoolingPaddingType'))) # consume padding layer if symtab.in_padding: @@ -191,7 +193,8 @@ def PoolingLayerParams(op, insym, symtab): return _sym.max_pool2d(insym, **params) if op.type == 1: return _sym.avg_pool2d(insym, **params) - raise_operator_unimplemented('pooling (not max or average)') + msg = 'Operator pooling (not max or average) is not supported in frontend CoreML.' + raise tvm.error.OpNotImplemented(msg) def SoftmaxLayerParams(op, insym, symtab): return _sym.softmax(_sym.flatten(insym)) @@ -230,7 +233,8 @@ def ConcatLayerParams(op, insyms, symtab): if not isinstance(insyms, list): insyms = [insyms] if op.sequenceConcat: - raise_operator_unimplemented('sequence concat') + raise tvm.error.OpNotImplemented( + 'Operator Sequence Concat is not supported in frontend CoreML.') ret = _sym.concatenate(*insyms, axis=1) return ret @@ -244,14 +248,16 @@ def PaddingLayerParams(op, insym, symtab): if op.WhichOneof('PaddingType') == 'constant': constant = op.constant if constant.value != 0: - raise_attribute_invalid(constant.value, 'padding value', 'padding') + msg = 'Value {} in attribute "padding value" of operator Padding is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(constant.value)) padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts] padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts] for i, j in zip(padding, padding2): assert i == j symtab.set_padding(padding) else: - raise_operator_unimplemented('non-constant padding') + raise tvm.error.OpNotImplemented( + 'Operator "non-constant padding" is not supported in frontend CoreML.') return insym def PermuteLayerParams(op, insym, symtab): @@ -260,8 +266,8 @@ def PermuteLayerParams(op, insym, symtab): def UpsampleLayerParams(op, insym, symtab): if op.scalingFactor[0] != op.scalingFactor[1]: - raise_attribute_invalid(op.scalingFactor, 'scaling factors', - 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Height and width scaling factors of Upsample operator must be equal.') interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' return _sym.upsampling(insym, scale=op.scalingFactor[0], method=interpolationMode) @@ -342,7 +348,8 @@ def coreml_op_to_nnvm(op, inname, outname, symtab): """ classname = type(op).__name__ if classname not in _convert_map: - raise_operator_unimplemented(classname) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(classname)) if isinstance(inname, string_types): insym = symtab.get_var(inname) else: diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index bbb0926f29c8..bf5a832258fa 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -6,6 +6,7 @@ import numpy as np import tvm from .. import symbol as _sym +from .common import get_nnvm_op, required_attr, parse_tshape, parse_bool_str class LAYERTYPE(object): """Darknet LAYERTYPE Class constant.""" @@ -61,7 +62,8 @@ def _darknet_maxpooling(inputs, attrs): """Process the max pool 2d operation.""" kernel = parse_tshape(required_attr(attrs, 'kernel', 'maxpool')) if len(kernel) != 1: - raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels for Max Pooling are not supported in frontend Darknet.') op_name, new_attrs = 'max_pool2d', {} strides = int(attrs.get('stride', (1, 1))) @@ -79,7 +81,8 @@ def _darknet_avgpooling(inputs, attrs): """Process the average pool 2d operation.""" kernel = parse_tshape(required_attr(attrs, 'kernel', 'avgpool')) if len(kernel) != 1: - raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels for Average Pooling are not supported in frontend Darknet.') op_name, new_attrs = 'avg_pool2d', {} strides = int(attrs.get('stride', (1, 1))) @@ -103,10 +106,12 @@ def _darknet_conv2d(inputs, attrs): """Process the convolution 2d operation.""" kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d')) if len(kernel) != 1: - raise_attribute_unimplemented('non 2d kernel', 'conv2d') + raise tvm.error.OpAttributeUnimplemented('Non-2D kernels for Conv2D are unsupported ' + 'in frontend Darknet.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_invalid(layout, 'layout', 'conv2d') + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "layout" of operator Conv2D is not valid.'.format(layout)) strides = int(attrs.get('stride', (1, 1))) pads = int(attrs.get('pad', (0, 0))) @@ -142,13 +147,16 @@ def _darknet_conv2d(inputs, attrs): def _darknet_conv2d_transpose(inputs, attrs): """Process the convolution 2d transpose operation.""" if 'target_shape' in attrs: - raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "target_shape" is not supported in operator Conv2D-transpose.') kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d_transpose')) if len(kernel) != 2: - raise_attribute_unimplemented('non-2d kernel', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported in operator Conv2D-transpose.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_invalid(layout, 'layout', 'conv2d_transpose') + msg = 'Value {} in attribute "layout" of operator Conv2D-transpose is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(layout)) op_name, new_attrs = 'conv2d_transpose', {} new_attrs['channels'] = required_attr(attrs, 'num_filter', 'conv2d_transpose') new_attrs['kernel_size'] = kernel @@ -222,7 +230,8 @@ def _darknet_dropout(inputs, attrs): def _darknet_reshape(inputs, attrs): """Process the reshape operation.""" if parse_bool_str(attrs, 'reverse'): - raise_attribute_unimplemented('reverse', 'reshape') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "reverse" is not supported in operator Reshape.') op_name, new_attrs = 'reshape', {} new_attrs['shape'] = required_attr(attrs, 'shape', 'reshape') return get_nnvm_op(op_name)(*inputs, **new_attrs), None @@ -324,7 +333,8 @@ def _darknet_activations(inputs, attrs): elif ACTIVATION.ELU == act: act_type = 'elu' else: - raise_operator_unimplemented('act: ' + act) + raise tvm.error.OpNotImplemented( + 'Operator act: {} is not supported in framework Darknet.'.format(act)) if act_type in ['relu', 'tanh']: op_name, new_attrs = act_type, {} @@ -339,7 +349,8 @@ def _darknet_activations(inputs, attrs): op_name, new_attrs = act_type, {} sym = get_nnvm_op(op_name)(*inputs, **new_attrs) else: - raise_operator_unimplemented('act_type: ' + act_type) + raise tvm.error.OpNotImplemented( + 'Operator act: {} is not supported in framework Darknet.'.format(act)) return sym, None def _darknet_op_not_support(inputs, attrs): @@ -402,7 +413,8 @@ def _darknet_convert_symbol(op_name, inputs, attrs): if op_name in _DARKNET_CONVERT_MAP: sym, out_name = _DARKNET_CONVERT_MAP[op_name](inputs, attrs) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Darknet.'.format(op_name)) if out_name is None: out_name = sym.list_output_names()[0].replace('_output', '') return out_name, sym @@ -448,9 +460,10 @@ def _get_convolution_weights(self, layer, opname): if layer.nweights == 0: return - if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: - raise_attribute_invalid(layer.n * layer.c * layer.size * layer.size, - 'layer weights size', 'conv2d') + if layer.n * layer.c * layer.size * layer.size != layer.nweights: + msg = 'nweights ({}) != n * c * h * w ({}) in operator {}' + msg = msg.format(layer.nweights, layer.n * layer.c * layer.size ** 2, opname) + raise tvm.error.OpAttributeInvalid(msg) shape = (layer.n, layer.c, layer.size, layer.size) weights = self._read_memory_buffer(shape, layer.weights) @@ -630,7 +643,8 @@ def _get_darknet_attrs(self, layer, layer_num): pass else: - raise_operator_unimplemented(layer.type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Darknet.'.format(layer.type)) return attr @@ -763,7 +777,8 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): elif LAYERTYPE.LSTM == layer.type: if layer.steps > 1: - raise_attribute_invalid(layer.steps, 'number of steps', 'RNN') + raise tvm.error.OpAttributeInvalid( + 'Number of steps {} of RNN is not valid.'.format(layer.steps)) op_name_add = 'elemwise_add' op_name_mul = 'elemwise_mul' @@ -829,7 +844,8 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): elif LAYERTYPE.GRU == layer.type: if layer.steps > 1: - raise_attribute_invalid(layer.steps, 'number of steps', 'RNN') + raise tvm.error.OpAttributeInvalid( + 'Number of steps {} is not valid in RNN.'.format(layer.steps)) op_name_add = 'elemwise_add' op_name_mul = 'elemwise_mul' diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index d15d2b3f01ab..63b4122a4060 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -74,7 +74,8 @@ def _convert_activation(insym, keras_layer, _): if act_type == 'hard_sigmoid': transformX = (0.2 * insym) + 0.5 return _sym.clip(transformX, a_min=0, a_max=1) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_advanced_activation(insym, keras_layer, symtab): @@ -100,7 +101,8 @@ def _convert_advanced_activation(insym, keras_layer, symtab): theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0 theta_tensor = _sym.full_like(insym[0], fill_value=float(theta)) return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32")) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_merge(insym, keras_layer, _): @@ -113,12 +115,9 @@ def _convert_merge(insym, keras_layer, _): ret = _sym.elemwise_sub(ret, insym[i]) elif merge_type == 'Multiply': ret = _sym.elemwise_mul(ret, insym[i]) - elif merge_type == 'Average': - raise_operator_unimplemented('average merge') - elif merge_type == 'Maximum': - raise_operator_unimplemented('maximum merge') else: - raise_operator_unimplemented(merge_type) + raise tvm.error.OpNotImplemented( + 'Operator {} Merge is not supported in frontend Keras.'.format(merge_type)) return ret @@ -135,7 +134,8 @@ def _convert_dense(insym, keras_layer, symtab): if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise_attribute_invalid(input_shape, 'input shape', 'dense') + msg = 'Value {} in attribute "input_shape" of operator Dense is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape)) insym = _sym.squeeze(insym, axis=0) out = _sym.dense(data=insym, **params) # defuse activation @@ -199,7 +199,8 @@ def _convert_convolution(insym, keras_layer, symtab): else: insym = _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Value {} in attribute "padding" of operator Convolution is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) if is_deconv: out = _sym.conv2d_transpose(data=insym, **params) else: @@ -240,7 +241,8 @@ def _convert_separable_convolution(insym, keras_layer, symtab): insym = _sym.pad(data=insym, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Value {} in attribute "padding" of operator Separable Convolution is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) depthconv = _sym.conv2d(data=insym, **params0) # pointwise conv weight1 = weightList[1].transpose([3, 2, 0, 1]) @@ -294,13 +296,15 @@ def _convert_pooling(insym, keras_layer, symtab): pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params['padding'] = [pad_t, pad_l, pad_b, pad_r] else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) if pool_type == 'MaxPooling2D': return _sym.max_pool2d(insym, **params) if pool_type == 'AveragePooling2D': # TODO: in keras, padded zeros are not calculated return _sym.avg_pool2d(insym, **params) - raise_operator_unimplemented('pooling with {}'.format(keras_layer)) + msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) def _convert_upsample(insym, keras_layer, _): @@ -312,28 +316,30 @@ def _convert_upsample(insym, keras_layer, _): elif upsample_type == "UpSampling2D": h, w = keras_layer.size if h != w: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Upsample height ({}) must equal width ({})'.format(h, w)) params = {'scale': h} elif upsample_type == "UpSampling3D": h, w, d = keras_layer.size if h != w or w != d: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Upsample height ({}), width ({}), and depth ({}) must be equal.'.format(h, w, d)) params = {'scale': h} else: - raise_operator_unimplemented(upsample_type) + msg = 'Operator {} is not supported in frontend Keras.' + raise tvm.error.OpNotImplemented(msg.format(upsample_type)) return _sym.upsampling(insym, **params) def _convert_cropping(insym, keras_layer, _): _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ - if crop_type == "Cropping1D": - raise_operator_unimplemented(crop_type) - elif crop_type == "Cropping2D": + if crop_type == "Cropping2D": (_, in_h, in_w, _) = keras_layer.input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: - raise_operator_unimplemented(crop_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(crop_type)) int32_max = np.iinfo(np.int32).max return _sym.strided_slice(insym, begin=[0, 0, crop_t, crop_l], end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) @@ -377,11 +383,13 @@ def _convert_padding(insym, keras_layer, _): top, bottom = padding[0] left, right = padding[1] else: - raise_attribute_invalid(str(padding), 'padding', padding_type) + msg = 'Value {} in attribute "padding" of operator {} is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding), padding_type)) else: - raise_attribute_invalid(str(padding), 'padding', padding_type) + msg = 'Value {} in attribute "padding" of operator {} is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding), padding_type)) else: - raise_operator_unimplemented(padding_type) + raise tvm.error.OpNotImplemented('Operator {} is not supported in frontend Keras.') return _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) @@ -588,8 +596,10 @@ def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument def _check_unsupported_layers(model): for layer in model.layers: - if type(layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(layer).__name__) + op_name = type(layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(op_name)) def _as_list(arr): """Force being a list, ignore if already is.""" @@ -614,9 +624,11 @@ def keras_op_to_nnvm(insym, keras_layer, outname, symtab): symtab : nnvm.frontend.common.SymbolTable The global symbol table to be updated """ - if type(keras_layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(keras_layer).__name__) - outs = _convert_map[type(keras_layer).__name__](insym, keras_layer, symtab) + op_name = type(keras_layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(op_name)) + outs = _convert_map[op_name](insym, keras_layer, symtab) outs = _as_list(outs) for t_idx, out in enumerate(outs): diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 372f10bd98b9..da5e154bce12 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -4,6 +4,7 @@ import json import tvm from .. import symbol as _sym +from .common import get_nnvm_op, required_attr, parse_tshape, parse_bool_str __all__ = ['from_mxnet'] @@ -15,11 +16,13 @@ def impl(inputs, attrs): def _pooling(inputs, attrs): kernel = parse_tshape(required_attr(attrs, 'kernel', 'pooling')) if len(kernel) != 2: - raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for Pool2D.') global_pool = 'global' if parse_bool_str(attrs, 'global_pool') else '' pool_type = required_attr(attrs, 'pool_type', 'pooling') if pool_type not in ['avg', 'max']: - raise_attribute_unimplemented('non-avg/max', 'pool2d') + raise tvm.error.OpNotImplemented( + 'Only max and average pooling are supported in frontend MXNet.') op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {} # new_attrs['layout'] = 'NCHW' if not global_pool: @@ -32,11 +35,15 @@ def _pooling(inputs, attrs): return get_nnvm_op(op_name)(*inputs, **new_attrs) def _batch_norm(inputs, attrs): - raise_attribute_unimplemented('output_mean_var', 'batch_norm') + if parse_bool_str(attrs, 'output_mean_var'): + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "output_mean_var" is not supported in operator batch_norm.') # if parse_bool_str(attrs, 'fix_gamma'): # _warn_not_used('fix_gamma', 'batch_norm') if parse_bool_str(attrs, 'use_global_stats'): - warn_not_used('use_global_stats', 'batch_norm') + from warnings import warn + warn( + 'Attribute "use_global_stats" is ignored in operator batch_norm.') # if parse_bool_str(attrs, 'momentum'): # _warn_not_used('momentum', 'batch_norm') op_name, new_attrs = 'batch_norm', {} @@ -54,10 +61,12 @@ def _concat(inputs, attrs): def _conv2d(inputs, attrs): kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d')) if len(kernel) != 2: - raise_attribute_unimplemented('non 2d kernel', 'conv2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for operator Conv2D.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_unimplemented('layout: ' + layout, 'conv2d') + raise tvm.error.OpAttributeUnimplemented( + 'Layout {} is not supported in operator Conv2D.'.format(layout)) if 'kernel_layout' in attrs: kernel_layout = attrs['kernel_layout'] else: @@ -76,13 +85,16 @@ def _conv2d(inputs, attrs): def _conv2d_transpose(inputs, attrs): if 'target_shape' in attrs: - raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "target_shape" is not supported in operator Conv2D-transpose.') kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d_transpose')) if len(kernel) != 2: - raise_attribute_invalid(len(kernel), 'kernel dim', 'conv2d_transpose') + raise tvm.error.OpAttributeInvalid( + 'Non-2D kernels are not supported in Conv2D-transpose.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_unimplemented('layout: ' + layout, 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Layout {} is not supported in operator Conv2D-transpose.') if 'kernel_layout' in attrs: kernel_layout = attrs['kernel_layout'] else: @@ -138,7 +150,8 @@ def _leaky_relu(inputs, attrs): op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)} sym = get_nnvm_op(op_name)(*inputs, **new_attrs) else: - raise_attribute_unimplemented([act_type]) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(act_type)) return sym def _activations(inputs, attrs): @@ -149,12 +162,14 @@ def _activations(inputs, attrs): elif act_type == 'softrelu': sym = _sym.log((1 + _sym.exp(*inputs))) else: - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(act_type)) return sym def _reshape(inputs, attrs): if parse_bool_str(attrs, 'reverse'): - raise_attribute_unimplemented('reverse', 'reshape') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "reverse" is not supported in operator Reshape.') op_name, new_attrs = 'reshape', {} new_attrs['shape'] = required_attr(attrs, 'shape', 'reshape') return get_nnvm_op(op_name)(*inputs, **new_attrs) @@ -218,7 +233,7 @@ def _contrib_multibox_detection(inputs, attrs): new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold), 'force_suppress': force_suppress, 'top_k': int(nms_topk)} data, valid_count = get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], - inputs[2], **new_attrs0) + inputs[2], **new_attrs0) return get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1) def _elemwise_sum(inputs, _): @@ -231,10 +246,12 @@ def _crop_like(inputs, attrs): tuple([float(x.strip()) for x in attrs.get('offsets').strip('()').split(',')]) \ if attrs.get('offsets') is not None else (0, 0) if offsets != (0, 0): - raise_attribute_invalid(offsets, 'offsets', 'crop_like') + raise tvm.error.OpAttributeInvalid( + 'crop_like offsets must equal (0,0).') center_crop = parse_bool_str(attrs, 'center_crop', default="False") if center_crop: - raise_attribute_unimplemented('center crop', 'crop_like') + raise tvm.error.OpAttributeUnimplemented( + 'Center crop is not supported in operator crop_like.') if len(inputs) < 2: raise RuntimeError("Only support crop_like pattern.") new_attrs["axis"] = [2, 3] @@ -381,7 +398,8 @@ def _convert_symbol(op_name, inputs, attrs, elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(op_name)) return sym def _as_list(arr): diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index 1262bebbb85f..18eb213bab7b 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -397,7 +397,8 @@ def _impl_v7(cls, inputs, attr, params): elif mode == b'linear': method = "BILINEAR" else: - raise_attribute_invalid(mode, 'mode', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) return _sym.upsampling(inputs[0], scale=int(scales[-1]), method=method, layout='NCHW') @@ -922,7 +923,8 @@ def _convert_operator(self, elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs, self._params) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend ONNX.') return sym def _fix_outputs(self, op_name, outputs): diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 140fa900eefa..f2ff60294489 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -11,7 +11,7 @@ from .. import symbol as _sym from .. import graph as _graph from .. compiler import graph_util, build_module -from .common import AttrConverter as AttrConvert +from .common import get_nnvm_op, AttrConverter as AttrConvert __all__ = ['from_tensorflow'] @@ -68,7 +68,8 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_attribute_unimplemented('non-2d kernel', prefix) + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for operator {}.'.format(prefix)) return _impl def _dimension_constraint(): @@ -129,7 +130,8 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_invalid(attr['data_format'], 'data_format', 'pooling') + msg = 'Value {} in attribute "data_format" of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -158,7 +160,8 @@ def _impl(inputs, attr, params): attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: - raise_attribute_unimplemented(attr['padding'], 'padding', 'pooling') + msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' + raise tvm.error.OpAttributeUnimplemented(msg.format(attr['padding'])) if name == "avg_pool": attr['count_include_pad'] = False @@ -232,7 +235,8 @@ def _impl(inputs, attr, params): attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_invalid(attr['data_format'], 'data_format', 'conv') + msg = 'Value {} in attribute "data_format" of operator Conv is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if opname == 'depthwise': @@ -276,7 +280,8 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] else: - raise_attribute_invalid(attr['padding'], 'padding', 'conv') + msg = 'Value {} in attribute "padding" of operator Conv is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: if opname == 'conv': @@ -432,7 +437,8 @@ def _impl(inputs, attr, params): op_name="reshape", extras={'shape':tuple(params_new[0].asnumpy().flatten())}, ignores=['Tshape'])(inputs, attr) - raise_attribute_unimplemented('dynamic shape', 'reshape') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "dynamic shape" of operator Reshape is not supported.') return _impl def _bias_add(): @@ -736,7 +742,8 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise_attribute_required(padlist_key, 'pad') + raise tvm.error.OpAttributeRequired( + 'Required attribute "{}" not found in operator Pad.'.format(padlist_key)) paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 @@ -1188,7 +1195,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) if missing_operators: - raise_operator_unimplemented(*missing_operators) + msg = 'The following operators are not supported in frontend TensorFlow: {}' + ops = str(list(missing_operators)).strip('[,]') + raise tvm.error.OpNotImplemented(msg.format(ops)) for node in graph.node: if node.op == 'Placeholder': @@ -1528,7 +1537,8 @@ def _convert_operator(self, op_name, inputs, attrs, self._params, graph, convert_map_rnn) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend TensorFlow.'.format(op_name)) return sym def _fix_extranodes(self, op_name, attr, inputs): diff --git a/python/tvm/error_handling/__init__.py b/python/tvm/error_handling/__init__.py deleted file mode 100644 index 8616d1ba973a..000000000000 --- a/python/tvm/error_handling/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -import warnings -import traceback -import sys - -def _excepthook(type, value, tb): - print(''.join(traceback.format_exception(type, value, tb))) - -sys.excepthook = _excepthook - -class OperatorError(Exception): - pass - -def _raise_error_helper(exception, msg, *args): - raise exception(msg.format(*args)) - -def raise_attribute_required(key, op_name): - class OperatorAttributeRequired(OperatorError): - pass - msg = 'Required attribute {} not found in operator {}.' - _raise_error_helper(OperatorAttributeRequired, msg, key, op_name) - -def raise_attribute_invalid(val, attr, op_name): - class OperatorAttributeValueNotValid(OperatorError): - pass - msg = 'Value {} in attr {} is not valid in operator {}.' - _raise_error_helper(OperatorAttributeValueNotValid, msg, val, attr, - op_name) - -def raise_operator_unimplemented(*missing_ops): - class OperatorNotImplemented(OperatorError): - pass - missing_ops = str(missing_ops).strip('(,)') - msg = 'The following operators are not supported: {}.' - _raise_error_helper(OperatorNotImplemented, msg, missing_ops) - -def raise_attribute_unimplemented(key, op_name): - class OperatorAttributeNotImplemented(OperatorError): - pass - msg = 'Attribute {} is not supported in operator {}.' - _raise_error_helper(OperatorAttributeNotImplemented, msg, key, op_name) - -def warn_not_used(attr, op_name): - msg = '{} is ignored in {}.'.format(attr, op_name) - warnings.warn(msg) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 6ba2f0bde12d..dee3999ad3f1 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -14,8 +14,3 @@ from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow -from tvm.error_handling import raise_attribute_required, \ - raise_attribute_invalid, \ - raise_operator_unimplemented, \ - raise_attribute_unimplemented, \ - warn_not_used diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 5ae7a294d306..769740df0be3 100755 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -1,6 +1,7 @@ # pylint: disable=import-self, invalid-name, line-too-long, unused-argument """Caffe2 frontend""" from __future__ import absolute_import as _abs +import tvm from .. import ir_pass from .. import expr as _expr from .. import op as _op @@ -15,7 +16,8 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_operator_unimplemented('non 2d kernel') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for operator {}2d'.format(prefix)) return _impl @@ -27,7 +29,8 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise_attribute_invalid(str(len(pads)), 'len(pads)', 'padding') + raise tvm.error.OpAttributeInvalid( + 'Number of pads must equal 2 or 4.') return pads @@ -103,7 +106,8 @@ def get_converter(cls): if hasattr(cls, '_impl'): return getattr(cls, '_impl') - raise_operator_unimplemented(cls.__name__) + raise tvm.error.OpNotInplemented( + 'Operator {} is not supported in frontend Caffe2.'.format(cls.__name__)) _caffe2_internal_args = [ @@ -223,7 +227,8 @@ def _get_axis_from_order_str(order): return 1 if order == 'NHWC': return 3 - raise_attribute_unimplemented(order, 'Concat') + raise tvm.error.OpAttributeUnimplemented( + 'Order {} is not supported in operator Concat.'.format(order)) return AttrCvt( op_name='concatenate', @@ -515,7 +520,8 @@ def _convert_operator(self, # Add a sanitizing step to convert all byte strings in args to strings func = convert_map[op_type](inputs, args, self._params) else: - raise_operator_unimplemented(op_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Caffe2.'.format(op_type)) return func diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 369a5d4bb3a4..963b21f38297 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name, import-self, unused-argument, unused-variable, inconsistent-return-statements """CoreML frontend.""" from __future__ import absolute_import as _abs +import tvm import numpy as np from .. import ir_pass from .. import expr as _expr @@ -81,7 +82,8 @@ def _BatchnormLayerParams(op, inexpr, etab): """Get layer of batchnorm parameter""" # this changes the symbol if op.instanceNormalization: - raise_operator_unimplemented('instance normalization') + raise tvm.error.OpNotImplemented( + 'Operator "instance normalization" is not supported in frontend CoreML.') else: params = {'gamma':etab.new_const(list(op.gamma.floatValue)), 'beta':etab.new_const(list(op.beta.floatValue)), @@ -142,7 +144,8 @@ def _ActivationParams(op, inexpr, etab): alpha_expr = etab.new_const(alpha) beta_expr = etab.new_const(beta) return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr) - raise_operator_unimplemented(whichActivation) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(whichActivation)) def _ScaleLayerParams(op, inexpr, etab): @@ -164,7 +167,8 @@ def _PoolingLayerParams(op, inexpr, etab): return _op.nn.global_max_pool2d(inexpr) if op.type == 1: return _op.nn.global_avg_pool2d(inexpr) - raise_operator_unimplemented('pooling (not max or average)') + raise tvm.error.OpNotImplemented( + 'Only Max and Average Pooling are supported in frontend CoreML.') else: params = {'pool_size':list(op.kernelSize), @@ -184,8 +188,9 @@ def _PoolingLayerParams(op, inexpr, etab): params['padding'] = padding params['ceil_mode'] = True else: - raise_attribute_unimplemented(op.WhichOneof('PoolingPaddingType'), - 'PoolingPaddingType', 'pooling') + msg = 'PoolingPaddingType {} is not supported in operator Pooling.' + op_name = op.WhichOneof('PoolingPaddingType') + raise tvm.error.OpAttributeUnimplemented(msg.format(op_name)) # consume padding layer if etab.in_padding: @@ -197,7 +202,8 @@ def _PoolingLayerParams(op, inexpr, etab): return _op.nn.max_pool2d(inexpr, **params) if op.type == 1: return _op.nn.avg_pool2d(inexpr, **params) - raise_operator_unimplemented('pooling (not max or average)') + raise tvm.error.OpNotImplemented( + 'Only Max and Average Pooling are supported in CoreML.') def _SoftmaxLayerParams(op, inexpr, etab): @@ -240,7 +246,8 @@ def _ConcatLayerParams(op, inexpr, etab): if not isinstance(inexpr, list): inexpr = [inexpr] if op.sequenceConcat: - raise_operator_unimplemented('Sequence Concat') + raise tvm.error.OpNotImplemented( + 'Operator Sequence Concat is not supported in frontend CoreML.') ret = _op.concatenate(inexpr, axis=1) return ret @@ -256,14 +263,16 @@ def _PaddingLayerParams(op, inexpr, etab): if op.WhichOneof('PaddingType') == 'constant': constant = op.constant if constant.value != 0: - raise_attribute_unimplemented(constant.value, 'padding value', 'padding') + raise tvm.error.OpAttributeUnimplemented( + '{} is not supported in operator Padding.'.format(constant.value)) padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts] padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts] for i, j in zip(padding, padding2): assert i == j etab.set_padding(padding) else: - raise_operator_unimplemented('non-constant padding') + raise tvm.error.OpNotImplemented( + 'Non-constant padding is not supported in frontend CoreML.') return inexpr @@ -274,7 +283,8 @@ def _PermuteLayerParams(op, inexpr, etab): def _UpsampleLayerParams(op, inexpr, etab): if op.scalingFactor[0] != op.scalingFactor[1]: - raise_attribute_unimplemented('unequal height/width scaling factors', 'upsample') + raise tvm.error.OpAttributeUnimplemented( + 'Upsample height and width must be equal.') interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) @@ -364,7 +374,8 @@ def coreml_op_to_relay(op, inname, outname, etab): """ classname = type(op).__name__ if classname not in _convert_map: - raise_operator_unimplemented(classname) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(classname)) if isinstance(inname, _base.string_types): insym = etab.get_expr(inname) else: diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 2e266852f9dc..bd7cb4f3b110 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -2,6 +2,7 @@ """Keras frontend.""" from __future__ import absolute_import as _abs import sys +import tvm import numpy as np from .. import ir_pass from .. import expr as _expr @@ -91,7 +92,8 @@ def _convert_activation(inexpr, keras_layer, _): x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32') return _op.clip(x, a_min=0., a_max=1.) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_advanced_activation(inexpr, keras_layer, etab): @@ -118,7 +120,8 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): return _op.multiply(inexpr, _op.greater(inexpr, \ _expr.const(theta, dtype='float32')).astype('float32')) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_merge(inexpr, keras_layer, _): @@ -136,7 +139,8 @@ def _convert_merge(inexpr, keras_layer, _): ret = _op.add(ret, inexpr[i]) ret = ret / _expr.const(len(inexpr), dtype='float32') else: - raise_operator_unimplemented(merge_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(merge_type)) return ret @@ -150,7 +154,8 @@ def _convert_dense(inexpr, keras_layer, etab): if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise_attribute_invaid(input_shape, 'input shape', 'dense') + raise tvm.error.OpAttributeInvalid( + 'Input shape {} is not valid for operator Dense.'.format(input_shape)) inexpr = _op.squeeze(inexpr, axis=0) out = _op.nn.dense(data=inexpr, **params) if keras_layer.use_bias: @@ -214,7 +219,9 @@ def _convert_convolution(inexpr, keras_layer, etab): inexpr = _op.nn.pad(data=inexpr, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Padding with {} is not supported for operator Convolution ' \ + 'in frontend Keras.' + raise tvm.error.OpAttributeUnimplemented(msg.format(keras_layer.padding)) if is_deconv: out = _op.nn.conv2d_transpose(data=inexpr, **params) else: @@ -260,7 +267,10 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): inexpr = _op.nn.pad(data=inexpr, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Padding with {} is not supported for operator Separable ' \ + 'Convolution in frontend Keras.' + raise tvm.error.OpAttributeUnimplemented(msg.format(keras_layer.padding)) + depthconv = _op.nn.conv2d(data=inexpr, **params0) # pointwise conv weight1 = weightList[1].transpose([3, 2, 0, 1]) @@ -313,13 +323,15 @@ def _convert_pooling(inexpr, keras_layer, etab): pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params['padding'] = [pad_t, pad_l, pad_b, pad_r] else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + raise tvm.error.OpAttributeUnimplemented( + 'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding)) if pool_type == 'MaxPooling2D': return _op.nn.max_pool2d(inexpr, **params) if pool_type == 'AveragePooling2D': params['count_include_pad'] = False return _op.nn.avg_pool2d(inexpr, **params) - raise_operator_unimplemented('pooling type {}'.format(keras_layer)) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(keras_layer)) def _convert_upsample(inexpr, keras_layer, _): @@ -331,7 +343,8 @@ def _convert_upsample(inexpr, keras_layer, _): elif upsample_type == 'UpSampling2D': h, w = keras_layer.size if h != w: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Height must equal width for operator Upsample.') params = {'scale': h} if hasattr(keras_layer, 'interpolation'): @@ -344,23 +357,24 @@ def _convert_upsample(inexpr, keras_layer, _): elif upsample_type == 'UpSampling3D': h, w, d = keras_layer.size if h != w or w != d: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Height, width, and depth must all be equal for operator Upsample.') params = {'scale': h} else: - raise_operator_unimplemented(upsample_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(upsample_type)) return _op.nn.upsampling(inexpr, **params) def _convert_cropping(inexpr, keras_layer, _): _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ - if crop_type == 'Cropping1D': - raise_operator_unimplemented(crop_type) - elif crop_type == 'Cropping2D': + if crop_type == 'Cropping2D': (_, in_h, in_w, _) = keras_layer.input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: - raise_operator_unimplemented(crop_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(crop_type)) int32_max = np.iinfo(np.int32).max return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \ end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) @@ -405,14 +419,18 @@ def _convert_padding(inexpr, keras_layer, _): top, bottom = padding[0] left, right = padding[1] else: - raise_attribute_invalid(str(padding), 'padding', 'padding') + msg = 'Value {} in attribute "padding" of operator Padding ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) else: - raise_attribute_invalid(str(padding), 'padding', 'padding') - elif padding_type == 'ZeroPadding1D': - raise_operator_unimplemented(padding_type) + msg = 'Value {} in attribute "padding" of operator Padding is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) else: - raise_operator_unimplemented(padding_type) - return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) + msg = 'Operator {} is not supported in frontend Keras.' + raise tvm.error.OpNotImplemented(msg.format(padding_type)) + return _op.nn.pad(data=inexpr, + pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) def _convert_concat(inexpr, keras_layer, _): @@ -599,8 +617,10 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument def _check_unsupported_layers(model): for layer in model.layers: - if type(layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(layer).__name__) + op_name = type(layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(op_name)) def keras_op_to_relay(inexpr, keras_layer, outname, etab): @@ -620,9 +640,11 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab): etab : relay.frontend.common.ExprTable The global expression table to be updated. """ - if type(keras_layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(keras_layer).__name__) - outs = _convert_map[type(keras_layer).__name__](inexpr, keras_layer, etab) + op_name = type(keras_layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(op_name)) + outs = _convert_map[op_name](inexpr, keras_layer, etab) outs = _as_list(outs) for t_idx, out in enumerate(outs): name = outname + ":" + str(t_idx) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b28558bb25f9..39daaf91063a 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -3,10 +3,12 @@ from __future__ import absolute_import as _abs import json +import tvm from .. import ir_pass from .. import expr as _expr from .. import op as _op from ... import nd as _nd + from .common import StrAttrsDict from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast @@ -41,7 +43,8 @@ def _get_channel_axis(layout, op_name): return 1 if layout == "NHWC": return 3 - raise_attribute_invalid(layout, 'layout', op_name) + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name)) def _mx_activations(inputs, attrs): @@ -61,7 +64,8 @@ def _stable_softrelu(x): return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), _op.nn.relu(x)) return _stable_softrelu(inputs[0]) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend MXNet.'.format(act_type)) def _mx_compare(new_op, wrapper): @@ -74,7 +78,8 @@ def impl(inputs, attrs): def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise_attribute_invalid(kernel_size, 'kernel size', 'conv2d') + raise tvm.error.OpAttributeInvalid( + 'Non-2D kernels are not supported for operator Conv2D.') data_layout = attrs.get_str("layout", "NCHW") channel_axis = _get_channel_axis(data_layout, "conv2d") @@ -102,10 +107,12 @@ def _mx_conv2d(inputs, attrs): def _mx_conv2d_transpose(inputs, attrs): if "target_shape" in attrs.attrs: - raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "target_shape" is not supported for operator Conv2D-transpose.') kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise_attribute_invalid(len(kernel_size), 'kernel dimensionality', 'conv2d') + raise tvm.error.OpAttributeInvalid( + 'Non-2D kernels are not supported for operator Conv2D-transpose.') data_layout = attrs.get_str("layout", "NCHW") channel_axis = _get_channel_axis(data_layout, "conv2d_transpose") @@ -140,7 +147,8 @@ def _mx_pooling(inputs, attrs): def _pool2d(new_op, is_avg): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise_attribute_invalid(len(kernel_size), 'kernel dimensionality', 'pool2d') + raise tvm.error.OpAttributeInvalid( + 'Only 2D kernels are supported for operator Pool2D.') new_attrs = {} new_attrs["pool_size"] = kernel_size new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1)) @@ -158,7 +166,8 @@ def _pool2d(new_op, is_avg): if global_pool: return _op.nn.global_avg_pool2d(inputs[0]) return _pool2d(_op.nn.avg_pool2d, True) - raise_operator_unimplemented(pool_type) + raise tvm.error.OpNotImplemented( + 'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize())) def _mx_dropout(inputs, attrs): @@ -172,7 +181,8 @@ def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument def _mx_batch_norm(inputs, attrs): if attrs.get_bool("output_mean_var", False): - raise_attribute_unimplemented('output_mean_var', 'batch_norm') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "output_mean_var" is not supported for operator Batch Norm.') if attrs.get_bool("use_global_stats", False): _warn_not_used("use_global_stats", "batch_norm") new_attrs = {} @@ -189,13 +199,17 @@ def _mx_slice(inputs, attrs): end = attrs.get_int_tuple('end', None) stride = attrs.get_int_tuple('step', None) if begin is None: - raise_attribute_required('begin', 'slice') + raise tvm.error.OpAttributeRequired( + 'Attribute "begin" not found in operator Slice.') if end is None: - raise_attribute_required('end', 'slice') + raise tvm.error.OpAttributeRequired( + 'Attribute "end" not found in operator Slice.') if None in begin: - raise_attribute_unimplemented('None in begin', 'slice') + raise tvm.error.OpAttributeInvalid( + 'Value None in attribute "begin" of operator Slice is not valid.') if None in end: - raise_attribute_unimplemented('None in end', 'slice') + raise tvm.error.OpAttributeInvalid( + 'Value None in attribute "end" of operator Slice is not valid.') new_attrs = {'begin': begin, 'end': end} if stride is not None: new_attrs['strides'] = stride @@ -299,7 +313,8 @@ def _mx_leaky_relu(inputs, attrs): upper_bound = attrs.get_float("upper_bound") alpha = (lower_bound + upper_bound) / 2.0 return _op.nn.leaky_relu(inputs[0], alpha=alpha) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend MXNet.'.format(act_type)) def _mx_make_power(power): @@ -393,7 +408,9 @@ def _mx_batch_dot(inputs, attrs): transpose_a = attrs.get_bool("transpose_a", False) transpose_b = attrs.get_bool("transpose_b", False) if transpose_a is True: - raise_attribute_invalid(transpose_a, 'transpose_a', 'batch_dot') + msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) if transpose_b is False: b = _op.transpose(b, axes=[0, 2, 1]) return _op.batch_matmul(a, b) @@ -402,7 +419,8 @@ def _mx_batch_dot(inputs, attrs): def _mx_arange(inputs, attrs): assert len(inputs) == 0 if attrs.get_int("repeat", 1) != 1: - raise_attribute_unimplemented('repeat', 'arange') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "repeat" is not supported in operator arange.') new_attrs = {} new_attrs["start"] = attrs.get_float("start", 0) new_attrs["stop"] = attrs.get_float("stop") @@ -486,15 +504,20 @@ def _mx_box_nms(inputs, attrs): in_format = attrs.get_str('in_format', 'corner') out_format = attrs.get_str('out_format', 'corner') if coord_start != 2: - raise_attribute_invalid(coord_start, 'coord_start', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "coord_start" must equal 2 for operator box_nms.') if score_index != 1: - raise_attribute_invalid(score_index, 'score_index', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "score_index" must equal 1 for operator box_nms.') if id_index != -1 and int(id_index) != 0: - raise_attribute_invalid(id_index, 'id_index', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.') if in_format != 'corner': - raise_attribute_invalid(in_format, 'in_format', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "in_format" must equal "corner" for operator box_nms.') if out_format != 'corner': - raise_attribute_invalid(out_format, 'out_format', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "out_format" must equal "corner" for operator box_nms.') ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) nms_out = _op.vision.non_max_suppression(ret[1], @@ -512,7 +535,8 @@ def _mx_l2_normalize(inputs, attrs): new_attrs = {} mode = attrs.get_str('mode', 'instance') if mode != 'channel': - raise_attribute_invalid(mode, 'mode', 'l2_normalize') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "mode" must equal "channel" for operator l2_normalize.') new_attrs['eps'] = attrs.get_float('eps', 1e-10) new_attrs['axis'] = [1] return _op.nn.l2_normalize(inputs[0], **new_attrs) @@ -772,10 +796,11 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): elif isinstance(res, _expr.Expr): res = [res] else: - raise_attribute_invalid(type(res), 'type(res)', op_name) + raise RuntimeError("unexpected type %s" % type(res)) node_map[nid] = res else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(op_name)) outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1bffdfd4bcd9..a6851b833931 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs import logging +import tvm import numpy as np from ... import nd as _nd from .. import ir_pass @@ -18,7 +19,9 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_attribute_invalid(len(kernel), 'kernel dimensionality', prefix) + msg = 'Only 2D kernels are supported for operator {}.' + op_name = prefix + '2d' + raise tvm.error.OpAttributeInvalid(msg.format(op_name)) return _impl @@ -29,7 +32,8 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise_attribute_invalid(len(pads), 'len(pads)', 'padding') + raise tvm.error.OpAttributeInvalid( + 'Number of pads must be either 2 or 4.') return pads def dimension_constraint(): @@ -461,7 +465,8 @@ def _impl_v9(cls, inputs, attr, params): elif mode == b'linear': method = "BILINEAR" else: - raise_attribute_invalid(mode, 'mode', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'} return AttrCvt('upsampling')(inputs, attr) @@ -717,7 +722,10 @@ def _impl_v1(cls, inputs, attr, params): if 'input_as_shape' in attr and attr['input_as_shape']: shape = params[get_name(inputs[0])].asnumpy() else: - raise_attribute_required('extra_shape', 'ConstantFill') + if 'extra_shape' in attr: + raise tvm.error.OpAttributeInvalid('Attribute "extra_shape" not ' + 'supported with "fill_like" for ' + 'operator ConstantFill.') return _op.full_like(inputs[0], inputs[1]) if 'extra_shape' in attr: diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f795aa70a596..afeaee7e8f95 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,8 @@ def _get_relay_op(op_name): op = getattr(_op.image, op_name) if not op: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TensorFlow.'.format(op_name)) return op class AttrCvt(object): @@ -99,7 +100,8 @@ def __call__(self, inputs, attrs, *args): new_attrs = {} for k in attrs.keys(): if k in self._excludes: - raise_operator_unimplemented(k, op_name) + raise tvm.error.OpAttributeUnimplemented( + 'Attribute {} in operator {} is not supported.'.format(k, op_name)) elif k in self._disables: logging.warning("Attribute %s is disabled in relay.%s", k, op_name) elif k in self._ignores: @@ -148,7 +150,8 @@ def _required_attr(self, attr, key): """Wrapper for getting required attributes.""" assert isinstance(attr, dict) if key not in attr: - raise_attribute_required(key, self._op_name) + raise tvm.error.OpAttributeRequired( + 'Attribute {} not found in operator {}'.format(key, self._op_name)) return attr[key] def _get_pad_pair(input1d, kernel1d, stride1d): @@ -178,7 +181,8 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_attribute_invalid(len(kernel), 'kernel dimensionality', prefix) + raise tvm.error.OpAttributeInvalid( + 'Only 2D kernels are supported for operator {}'.format(prefix + '2d')) return _impl def _dimension_constraint(): @@ -238,7 +242,9 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_invalid(attr['data_format'], 'data_format', 'pooling') + msg = 'Value {} of attribute "data_format" of operator Pooling ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attrs['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -267,7 +273,9 @@ def _impl(inputs, attr, params): attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: - raise_attribute_invalid(attr['padding'], 'padding', 'padding') + msg = 'Value {} in attribute "padding" of operator Pooling is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if name == "avg_pool": attr['count_include_pad'] = False @@ -341,7 +349,9 @@ def _impl(inputs, attr, params): attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_unimplemented(attr['data_format'], 'data_format', 'conv') + msg = 'Value {} in attribute "data_format" of operator Conv is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if opname == 'depthwise': @@ -386,7 +396,9 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] else: - raise_attribute_invalid(attr['padding'], 'padding', 'conv') + msg = 'Value {} in attribute "padding" of operator Conv is not ' \ + 'valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: if opname == 'conv': @@ -791,7 +803,8 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise_attribute_required(padlist_key, 'pad') + raise tvm.error.OpAttributeRequired( + 'Attribute {} not found in operator Pad.'.format(padlist_key)) paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 37f4e1367e53..0e31500fe67d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs import math import numpy as np +import tvm from .. import ir_pass from .. import expr as _expr from .. import op as _op @@ -59,7 +60,10 @@ def check_unsupported_ops(self): unsupported_ops_set.add(op_code_str) if unsupported_ops_set: - raise_operator_unimplemented(*upsupported_ops_set) + msg = 'The following operators are not supported in frontend ' \ + 'TFLite: {}' + ops = str(list(unsupported_ops_set)).strip('[,]') + raise tvm.error.OpNotImplemented(msg.format(ops)) def convert_op_to_relay(self): """Convert TFLite ops to relay ops""" @@ -204,7 +208,8 @@ def convert_reshape(self, op): # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: - raise_attribute_invalid(input_shape_length, 'input shape length', 'reshape') + msg = 'Input shape length {} for operator Reshape is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) out = _op.reshape(in_expr, newshape=tuple(target_shape)) @@ -221,7 +226,8 @@ def convert_reshape(self, op): elif len(target_shape) == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: - raise_attribute_invalid(len(target_shape), 'shape length', 'reshape') + raise tvm.error.OpAttributeInvalid( + 'Length of target shape must be between 1 and 5 for operator Reshape.') return out @@ -327,7 +333,8 @@ def convert_squeeze(self, op): # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: - raise_attribute_invalid(input_shape_length, 'input shape length', 'squeeze') + msg = 'Input shape length {} for operator Squeeze is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) @@ -344,7 +351,8 @@ def convert_squeeze(self, op): elif output_shape_length == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: - raise_attribute_invalid(output_shape_length, 'output_shape_length', 'squeeze') + msg = 'Output shape length {} for operator Squeeze is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length)) return out @@ -364,7 +372,8 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): if fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] - raise_operator_unimplemented(fused_activation_fn_str) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" @@ -403,7 +412,8 @@ def convert_conv(self, op, conv_type): assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ "no matter original value be set by 0.25, 0.5 or any else" else: - raise_operator_unimplemented(conv_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TFLite.'.format(conv_type)) stride_h = conv_options.StrideH() stride_w = conv_options.StrideW() @@ -460,7 +470,8 @@ def convert_conv(self, op, conv_type): (pad_top, pad_bottom), (pad_left, pad_right))) else: - raise_attribute_invalid(padding, 'padding format', 'conv') + raise tvm.error.OpAttributeUnimplemented( + 'Padding format {} is not supported for operator Conv.'.format(padding)) out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params) @@ -523,14 +534,16 @@ def convert_pool2d(self, op, pool_type): pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w) params['padding'] = [pad_top, pad_left, pad_bottom, pad_right] else: - raise_attribute_invalid(padding, 'padding', 'pool2d') + raise tvm.error.OpAttributeUnimplemented( + 'Padding format {} for operator Pool2D is not supported.'.format(padding)) if pool_type == "average": out = _op.nn.avg_pool2d(in_expr, **params) elif pool_type == "max": out = _op.nn.max_pool2d(in_expr, **params) else: - raise_operator_unimplemented(pool_type + ' pool') + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) # If we have fused activations if fused_activation_fn != ActivationFunctionType.NONE: