Skip to content

Commit

Permalink
Review comments fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Mar 3, 2020
1 parent 7c72525 commit 1b8f641
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 42 deletions.
79 changes: 75 additions & 4 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend."""
import math
import itertools
import numpy as np
import tvm
from tvm.ir import IRModule
Expand Down Expand Up @@ -943,6 +944,8 @@ def convert_gather(self, op):
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

data = self.get_expr(input_tensors[0].tensor_idx)

indices = input_tensors[1]
Expand All @@ -958,18 +961,87 @@ def convert_gather(self, op):
gather_options.Init(op_options.Bytes, op_options.Pos)
axis = gather_options.Axis()

out = _op.take(data, indices, axis=axis)
# Check the indices are with in bounds.
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)

axis_n = axis
if axis_n < 0:
axis_n += axis_n + data_dim
assert axis_n >= 0, "Axis out of bounds"
assert axis_n < data_dim, "Axis out of bounds"

indices_val = self.get_tensor_value(input_tensors[1])
indices_shape = list(indices_val.shape)
indices_len = len(indices_shape)

out_shape = []
for i in range(data_dim):
if axis_n == i:
for j in range(indices_len):
out_shape.append(indices_shape[j])
else:
out_shape.append(data_shape[i])

loopover = [range(s) for s in out_shape]
for idx in list(itertools.product(*loopover)):
indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)]

real_indices = [idx[j] for j in range(axis_n)]
real_indices.append(indices_val[tuple(indices_position)])
real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))])
for r, d in zip(real_indices, data_shape):
if r >= d:
raise ValueError("TFLite out of bound indices are not supported.")

# Use mode 'fast' since indices are already checked within bounds.
out = _op.take(data, indices, axis=axis, mode="fast")
return out

def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator"""
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
But in future, they may open up the mask implementation, so kept the implementation
same as tensorflow.
This op extracts a slice of size (end - begin) / stride from the given input tensor.
Starting at the location specified by begin the slice continues by adding stride to the
index until all dimensions are not less than end. Note that a stride can be negative,
which causes a reverse slice.
For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n.
In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
the ith bit will correspond to the ith val.
If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range
in that dimension is used instead.
If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be
inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.
If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a
new length 1 dimension is added at this point in the output tensor.
If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks
the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i]
are ignored in this case.
begin and end are zero-indexed. strides entries must be non-zero.
TVM Relay implementation of doesn't support mask, so the mask values are processed in
this function and begin/end/strides are updated accordingly. If any mask is present, and
since tvm doesn't support mask computation directly, the output need a final reshape.
"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.StridedSliceOptions import StridedSliceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 4, "input tensors length should be 4"

data_expr = self.get_expr(input_tensors[0].tensor_idx)

begin = list(self.get_tensor_value(input_tensors[1]))
Expand All @@ -988,8 +1060,7 @@ def convert_strided_slice(self, op):

data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))

stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
Expand Down
81 changes: 43 additions & 38 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,69 +273,74 @@ def test_forward_slice():
# Gather
# ------

def _test_gather(dshape, indices, axis, dtype):
def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
""" One iteration of Gather """
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
indices = np.asarray(indices).astype('int32')

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.gather(in_data, indices, axis=axis)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

#Test quantized input
data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
data = np.random.uniform(1, 10, size=dshape)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
out = array_ops.gather(in_data, indices, axis=axis)
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
if axis:
out = array_ops.gather(in_data, indices, axis=axis)
else:
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
input_range = {'in_data': (-100, 100)} if quantized else None
try:
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
quantized=quantized, input_range=input_range)
except ValueError as e:
if not oob:
raise e
except Exception as e:
raise e

def test_forward_gather():
""" GATHER """
_test_gather((4,), [1], 0, 'float32')
_test_gather((1, 4), [0], 0, 'int32')
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
for quantized in [False, True]:
_test_gather((4,), [1], 0, 'float32', quantized)
_test_gather((4,), [1], None, 'int32', quantized)
_test_gather((1, 4), [0], 0, 'int32', quantized)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)

#######################################################################
# StridedSlice
# ------------

def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0):
shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
""" One iteration of a Stridedslice """
data = np.random.uniform(size=ip_shape).astype(dtype)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(in_data, begin, end, stride,
begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask)
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])

#Test with quantized inputs
data = np.random.uniform(size=ip_shape).astype(np.uint8)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(in_data, begin, end, stride,
begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask,
end_mask=end_mask,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask)
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
input_range = {'in_data': (-100, 100)} if quantized else None
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
input_range=input_range)

def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)
for quantized in [False, True]:
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
_test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)

#######################################################################
# transpose
Expand Down

0 comments on commit 1b8f641

Please sign in to comment.