Skip to content

Commit

Permalink
Add support for Tflite operator SPLIT (#3520)
Browse files Browse the repository at this point in the history
* [RFC] Initial support for Tflite operator SPLIT

This patch adds initial support for the tflite operator split. However
I am not yet sure how to handle the axis parameter for the split
operator and support it in the test infrastructure. Putting this up for
an initial review and comment.

The split operator in tflite according to
https://www.tensorflow.org/lite/guide/ops_compatibility

appears to take num_or_size_split as a 0D tensor.

I also note that tflite.split is one of the few operators that returns
multiple outputs and thus the helper routines in the tests needed some
massaging to make this work.

@apivarov , could you please review this ?

Thanks,
Ramana

* Fix the axis parameter

Add more tests

* Address review comments

* Try out frozen_gene's suggestion

* Handle split of 1 element

* int32 is only supported in tflite 1.14, let's check that version here.

* Keep this at python3.5

* Add packaging as a python package to be installed
  • Loading branch information
Ramana Radhakrishnan authored and kevinthesun committed Jul 22, 2019
1 parent 443b5b4 commit 19eb829
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 4 deletions.
4 changes: 2 additions & 2 deletions docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ set -u
set -o pipefail

# install libraries for python package on ubuntu
pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs
pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow
pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs packaging
pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow packaging
38 changes: 38 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
'SPLIT': self.convert_split
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -705,6 +706,43 @@ def convert_conv(self, op, conv_type):

return out

def convert_split(self, op):
"""split implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SplitOptions import SplitOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)

assert len(input_tensors) == 2, "input tensors length should be == 2"

axis_tensor = input_tensors[0]
split_axis = self.get_tensor_value(axis_tensor)
input_tensor = input_tensors[1]
input_tensor_idx = input_tensor.tensor_idx

assert op.BuiltinOptionsType() == BuiltinOptions.SplitOptions
op_options = op.BuiltinOptions()
split_options = SplitOptions()
split_options.Init(op_options.Bytes, op_options.Pos)
num_splits = split_options.NumSplits()

in_expr = self.get_expr(input_tensor_idx)
out = _op.split(in_expr, num_splits, axis=int(split_axis))
# Relay does not like a TupleWrapper of 1 element, further this
# only shows up with tf1.13 if we use a split with num_splits==1.
# In tf 1.14 this doesn't appear as it is automatically a reshape
# operation.
if isinstance(out, _expr.TupleWrapper):
if out.size == 1:
out = out[0]

return out

def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
Expand Down
45 changes: 43 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tensorflow.contrib import lite as interpreter_wrapper

import tvm.relay.testing.tf as tf_testing
from packaging import version as package_version

#######################################################################
# Generic run functions for TVM & TFLite
Expand Down Expand Up @@ -120,10 +121,11 @@ def run_tflite_graph(tflite_model_buf, input_data):


def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False):
output_tensors, init_global_variables=False, out_names=None):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
out_names = convert_to_list(out_names)
in_node = [0] * len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
Expand All @@ -143,7 +145,8 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
print("Skip because %s is not enabled" % device)
continue

tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device)
tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
num_output=len(out_names), out_names=out_names)
for i in range(len(tflite_output)):
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)

Expand All @@ -161,6 +164,42 @@ def with_fused_activation_function(input_tensor, fn_name):
return math_ops.tanh(input_tensor)
raise AssertionError("Unknown fused_activation_function {}".format(fn_name))

def _test_split(in_shape, axis, num_Splits, dtype):
'''internal split tester taking as parameters in_shape, number of tensors to split into
and dtype (data type)'''
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=in_shape, dtype=dtype)
out = array_ops.split(in_data, num_Splits, axis=axis)
out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)]
compare_tflite_with_tvm([np_data], ['Placeholder:0'], [in_data], out,
out_names=out_names)

def test_forward_split():
'''test split layer'''
# rank 1
_test_split((3,), 0, 1, 'float32')
_test_split((3,), 0, 3, 'float32')
_test_split((6,), 0, 3, 'float32')
# rank 2
_test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 6, 'float32')
# rank 3
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_split((6, 2, 4), 0, 2, 'int32')

_test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 1, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')

#######################################################################
# Pooling
Expand Down Expand Up @@ -782,6 +821,8 @@ def test_forward_ssd_mobilenet_v1():
# Main
# ----
if __name__ == '__main__':
# Split
test_forward_split()
# Transforms
test_forward_concatenation()
test_forward_pad()
Expand Down

0 comments on commit 19eb829

Please sign in to comment.