diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 2871b7f73163..d07f2af3e08b 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -321,6 +321,10 @@ def __call__(self, inputs, attrs, *args): else: assert callable(self._op_name), "op_name can either be string or callable" op_name = self._op_name(attrs) + + # ignore 'tvm_custom' always + self._ignores.append('tvm_custom') + # convert attributes new_attrs = {} for k in attrs.keys(): @@ -329,7 +333,8 @@ def __call__(self, inputs, attrs, *args): elif k in self._disables: logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) elif k in self._ignores: - logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name) + if k != 'tvm_custom': + logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name) elif k in self._transforms: new_name, defaults, transform = self._parse_default(self._transforms[k]) if defaults is None: @@ -416,4 +421,6 @@ def __init__(self, new_name): self._new_name = new_name def __call__(self, inputs, attrs, *args): + if 'tvm_custom' in attrs: + attrs.pop('tvm_custom') return get_relay_op(self._new_name)(*inputs, **attrs) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a6851b833931..f0f0356e1cb0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -106,7 +106,7 @@ def _impl_v1(cls, inputs, attr, params): 'pads': ('padding', (0, 0), revert_caffe2_pad) }, # very weird attributes here in onnx, force check - ignores=['dilations'], + ignores=['dilations', 'auto_pad'], # TODO(zhreshold): make sure ceil_mode in onnx, and layout? extras={'ceil_mode': False}, custom_check=dimension_constraint())(inputs, attr, params) @@ -160,6 +160,7 @@ def _impl_v1(cls, inputs, attr, params): 'dilations': ('dilation', (0, 0)), 'pads': ('padding', (0, 0), revert_caffe2_pad), 'group': ('groups', 1)}, + ignores=['auto_pad'], custom_check=dimension_constraint())(inputs[:2], attr, params) use_bias = len(inputs) == 3 if use_bias: @@ -332,7 +333,21 @@ def _impl_v1(cls, inputs, attr, params): shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: - out = _op.reshape_like(inputs[0], inputs[1]) + # Try to infer shape by precompute prune if possible. + # TODO: good to check inputs to be in params. + # to be enhanced when relay support list_input_names API of NNVM + logging.warning("Infering Reshape argument by precompute") + func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + params_new = m.get_output(0) + inputs.pop(1) + out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) return out @@ -477,10 +492,7 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # Result of this operator is prominently used by reshape operator. - # Just pass the input as it is so that reshape_like can be used there. - logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)") - return inputs[0] + return _op.shape_of(inputs[0]) class Cast(OnnxOpConverter): """ Operator converter for Cast. @@ -494,7 +506,7 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params): try: from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE - attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']] + attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']]) except ImportError as e: raise ImportError( "Unable to import onnx.mapping which is required {}".format(e)) @@ -674,6 +686,11 @@ class ReduceMean(Reduce): """ name = 'mean' +class ReduceProd(Reduce): + """ Operator converter for ArgMax. + """ + name = 'prod' + class ArgMax(OnnxOpConverter): """ Operator converter for ArgMax. """ @@ -826,6 +843,7 @@ def _get_convert_map(opset): 'ReduceMin': ReduceMin.get_converter(opset), 'ReduceSum': ReduceSum.get_converter(opset), 'ReduceMean': ReduceMean.get_converter(opset), + 'ReduceProd': ReduceProd.get_converter(opset), # 'ReduceProd' # 'ReduceLogSumExp' 'ArgMax': ArgMax.get_converter(opset), @@ -842,8 +860,7 @@ def _get_convert_map(opset): 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), - # TODO(zhreshold) Shape op is implemented as bypass op in relay - # 'Shape': Shape.get_converter(opset), + 'Shape': Shape.get_converter(opset), } @@ -883,6 +900,7 @@ def from_onnx(self, graph, opset): ---------- graph : onnx protobuf object The loaded onnx graph + opset : opset version Returns @@ -911,12 +929,12 @@ def from_onnx(self, graph, opset): dtype=self._params[i_name].dtype) else: self._num_input += 1 - shape = self._shape[i_name] if i_name in self._shape else () + tshape = self._shape[i_name] if i_name in self._shape else () if isinstance(self._dtype, dict): dtype = self._dtype[i_name] if i_name in self._dtype else d_type else: dtype = d_type - self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype) + self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type @@ -936,6 +954,10 @@ def from_onnx(self, graph, opset): self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype) inputs.append(self._nodes[i_name]) + i_name = self._parse_value_proto(node) + attr['tvm_custom'] = {} + attr['tvm_custom']['name'] = i_name + op = self._convert_operator(op_name, inputs, attr, opset) node_output = self._fix_outputs(op_name, node.output) if not isinstance(op, _expr.TupleWrapper): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index de95ff00aef9..1796a548d8ac 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -113,35 +113,36 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) -def test_reshape_like(): +def test_shape(): in_shape = (4, 3, 3, 4) - ref_shape = (3, 4, 4, 3) + ref_shape = (6, 2, 4, 3) - ref_array = np.random.uniform(size=ref_shape).astype('float32') + ref_array = np.array(ref_shape) ref_node = onnx.helper.make_node('Constant', inputs=[], outputs=['ref_in'], value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.FLOAT, + data_type = onnx.TensorProto.INT32, dims = ref_array.shape, - vals = ref_array.flatten().astype(float))) - copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"]) - reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"]) + vals = ref_array.flatten().astype(int))) + reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"]) + + shape_node = helper.make_node("Shape", ['out'], ['final_out']) - graph = helper.make_graph([ref_node, copy_node, reshape_node], - "reshape_like_test", + graph = helper.make_graph([ref_node, reshape_node, shape_node], + "shape_test", inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", + outputs = [helper.make_tensor_value_info("final_out", TensorProto.FLOAT, list(ref_shape))]) - model = helper.make_model(graph, producer_name='reshape_like_test') + model = helper.make_model(graph, producer_name='shape_test') for target, ctx in ctx_list(): - x = np.random.uniform(size=in_shape).astype('float32') - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + x = np.random.uniform(size=in_shape).astype('int32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32') - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + tvm.testing.assert_allclose(ref_shape, tvm_out) def _test_power_iteration(x_shape, y_shape): if isinstance(y_shape, int): @@ -995,7 +996,7 @@ def test_LogSoftmax(): if __name__ == '__main__': test_reshape() - test_reshape_like() + test_shape() test_power() test_squeeze() test_unsqueeze()