diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b7fe2cf62b5a..822a4315856d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -885,6 +885,18 @@ def _impl_v1(cls, inputs, attr, params): return _op.logical_and(inputs[0], inputs[1]) +class Tile(Elemwise): + """Operator converter for Tile + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'repeats' not in attr: + raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set ' + 'for operator Tile.') + reps = attr.pop('repeats') # The number of times repeating the tensor data. + return _op.tile(inputs[0], reps) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1002,7 +1014,8 @@ def _get_convert_map(opset): 'Sign': Sign.get_converter(opset), 'Equal': Equal.get_converter(opset), 'Not': Not.get_converter(opset), - 'And': And.get_converter(opset) + 'And': And.get_converter(opset), + 'Tile': Tile.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7e0e11f4686e..cdcc596e5cef 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1205,6 +1205,27 @@ def test_and(): verify_and(indata=[x, y], dtype=bool) +def verify_tile(indata, outdata, **kwargs): + node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs) + graph = helper.make_graph([node], + 'tile_test', + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name='tile_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape) + tvm.testing.assert_allclose(outdata, tvm_out) + + +def test_tile(): + x = np.random.rand(2, 3, 4, 5).astype(np.float32) + repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64) + z = np.tile(x, repeats) + verify_tile(x, z, repeats=repeats) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1250,3 +1271,4 @@ def test_and(): test_sign() test_not() test_and() + test_tile()