Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] operator support: Tile (apache#3941)
Browse files Browse the repository at this point in the history
* [Relay][Frontend][ONNX] operator support: Tile

* Trigger notification
  • Loading branch information
cchung100m authored and wweic committed Oct 1, 2019
1 parent 9e91345 commit bbe1e40
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,18 @@ def _impl_v1(cls, inputs, attr, params):
return _op.logical_and(inputs[0], inputs[1])


class Tile(Elemwise):
"""Operator converter for Tile
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'repeats' not in attr:
raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
'for operator Tile.')
reps = attr.pop('repeats') # The number of times repeating the tensor data.
return _op.tile(inputs[0], reps)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1002,7 +1014,8 @@ def _get_convert_map(opset):
'Sign': Sign.get_converter(opset),
'Equal': Equal.get_converter(opset),
'Not': Not.get_converter(opset),
'And': And.get_converter(opset)
'And': And.get_converter(opset),
'Tile': Tile.get_converter(opset)
}


Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,27 @@ def test_and():
verify_and(indata=[x, y], dtype=bool)


def verify_tile(indata, outdata, **kwargs):
node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs)
graph = helper.make_graph([node],
'tile_test',
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])

model = helper.make_model(graph, producer_name='tile_test')

for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)


def test_tile():
x = np.random.rand(2, 3, 4, 5).astype(np.float32)
repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
z = np.tile(x, repeats)
verify_tile(x, z, repeats=repeats)


if __name__ == '__main__':
test_flatten()
test_reshape()
Expand Down Expand Up @@ -1250,3 +1271,4 @@ def test_and():
test_sign()
test_not()
test_and()
test_tile()

0 comments on commit bbe1e40

Please sign in to comment.