Skip to content

Commit

Permalink
[BUGFIX]bugfix in tensorflow space_to_batch_nd (#5175)
Browse files Browse the repository at this point in the history
* [BUGFIX]bugfix in tensorflow space_to_batch_nd

* Test case added
  • Loading branch information
siju-samuel authored Apr 1, 2020
1 parent b2a32dd commit afb8bf0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def _impl(inputs, attr, params, mod):
paddings = _infer_value(inputs[2], params).asnumpy()
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, exis=0)
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()
N = len(input_shape)
M = len(block_shape)
Expand Down
16 changes: 16 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,17 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):

compare_tf_with_tvm(data, in_data.name, out.name)

def _test_space_to_batch_nd_infer_paddings(input_shape, block_shape, dtype='int32'):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
padding_np = np.array([0, 1]).astype(np.int32).reshape((1, 2))
with tf.Graph().as_default():
in_data = tf.placeholder(shape=input_shape, dtype=dtype)
const1 = tf.constant(padding_np, dtype=tf.int32)
# make paddings an input to tf.transpose, but not an input to the graph,
# so it can be extracted with infer_value_simulated
paddings = tf.reverse(const1, axis=[-1])
out = tf.space_to_batch_nd(in_data, block_shape, paddings)
compare_tf_with_tvm(data, in_data.name, out.name)

def test_forward_space_to_batch_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d
Expand Down Expand Up @@ -637,6 +648,11 @@ def test_forward_space_to_batch_nd():
dtype='float64'
)

_test_space_to_batch_nd_infer_paddings(
input_shape=[2, 3, 2],
block_shape=[2]
)

#######################################################################
# BatchToSpaceND
# --------------
Expand Down

0 comments on commit afb8bf0

Please sign in to comment.