From 6737739c3753f3ebce96b3426de7cd0e546582fa Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Mon, 14 Jan 2019 16:15:42 +0530 Subject: [PATCH] [Tensorflow] Support for Crop (#2285) fixes fixes --- nnvm/python/nnvm/frontend/tensorflow.py | 16 ++++++++++++++++ .../python/frontend/tensorflow/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index a869abac9c4f..c0848bb1092c 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -388,6 +388,21 @@ def _impl(inputs, attr, params): return _impl +def _slice(): + def _impl(inputs, attr, params): + begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist() + size = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist() + data_shape = attr['_input_shapes'][inputs[0]] + data_dim = len(data_shape) + end = size + for i in range(data_dim): + if size[i] == -1: + end[i] = data_shape[i] - begin[i] + else: + end[i] += begin[i] + return _sym.strided_slice(inputs[0], begin=begin, end=size) + return _impl + def _reshape(): def _impl(inputs, attr, params): try: @@ -883,6 +898,7 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): 'Sum' : _sum(), 'Square' : _square(), 'Pack' : _pack(), + 'Slice' : _slice(), 'LeakyRelu' : AttrCvt('leaky_relu'), 'Relu' : AttrCvt('relu'), 'Reshape' : _reshape(), diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 5b8f11695790..0ea92248f0f5 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -655,6 +655,23 @@ def test_forward_resize_bilinear(): _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) +####################################################################### +# Crop to bounding box +# -------------------- + +def _test_crop(in_shape, off_h, off_w, tar_h, tar_w): + """ Crop to bounding box """ + data = np.random.uniform(size=in_shape).astype('float32') + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w) + compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0') + +def test_forward_crop(): + """ Crop to bounding box """ + _test_crop((1, 224, 224, 3), 20, 20, 120, 120) + + ####################################################################### # LSTM # ---- @@ -1139,6 +1156,7 @@ def test_forward_rel_ops(): test_forward_squeeze() test_forward_pack() test_forward_resize_bilinear() + test_forward_crop() test_forward_pad() test_forward_gather() test_forward_stridedslice()