Skip to content

Commit

Permalink
[Relay][Frontend][TFLite] transpose implementation for tflite.py (apa…
Browse files Browse the repository at this point in the history
…che#3705)

* transpose implementation for tflite.py

* add TRANSPOSE to convert_map

* Fix Unexpected keyword argument 'axis' in function call

* add test for transpose oprator

* Add the parameter 'axes' handling

* add test for transpose oprator

* solve conflict within CONTRIBUTORS.md

* Improve the if condition for empty tuple

* Add one unit test to cover empty tuple

* solve conflict within CONTRIBUTORS.md
  • Loading branch information
cchung100m authored and wweic committed Sep 16, 2019
1 parent db00a7f commit 900c0f0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@ We do encourage everyone to work anything they are interested in.
- [Cody Hao Yu](https://github.com/comaniac)
- [Chris Nuernberger](https://github.com/cnuernber)
- [Shoubhik Bhattacharya](https://github.com/shoubhik)
- [Neo Chien](https://github.com/cchung100m)
28 changes: 27 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
'SPLIT': self.convert_split
'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -743,6 +744,31 @@ def convert_split(self, op):

return out

def convert_transpose(self, op):
"""transpose implementation."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx

in_expr = self.get_expr(input_tensor_idx)

# axis
in_axis = tuple(self.get_tensor_value(input_tensors[1]))

if not in_axis:
out = _op.transpose(in_expr)
else:
out = _op.transpose(in_expr, in_axis)

return out

def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
Expand Down
33 changes: 33 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,35 @@ def test_forward_split():
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')

#######################################################################
# transpose
# ---------


def _test_forward_transpose(ishape, axes=()):
data = np.random.uniform(size=ishape).astype(np.float32)

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

if not axes:
out = array_ops.transpose(in_data)
else:
out = array_ops.transpose(in_data, axes)

compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_transpose():
_test_forward_transpose((2, 2))
_test_forward_transpose((2, 3, 4))
_test_forward_transpose((7, 8, 8, 10))
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4), (0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), ())


#######################################################################
# Pooling
# -------
Expand Down Expand Up @@ -823,6 +852,10 @@ def test_forward_ssd_mobilenet_v1():
if __name__ == '__main__':
# Split
test_forward_split()

# Transpose
test_forward_transpose()

# Transforms
test_forward_concatenation()
test_forward_pad()
Expand Down

0 comments on commit 900c0f0

Please sign in to comment.