diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1b09cf307554..f52c318c8e97 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2601,7 +2601,7 @@ def convert_batch_to_space_nd(self, op): cropped = reshaped_permuted for axis in range(1, M + 1): crop = crops[axis - 1] - if (crop != [0, 0]).all(): + if (crop != [0, 0]).any(): indices = _op.arange( _expr.const(crop[0]), _expr.const(reshaped_permuted_shape[axis] - crop[1]), diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 27980047e909..caa41806c8aa 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -706,6 +706,8 @@ def test_forward_batch_to_space_nd(): _test_batch_to_space_nd(input_shape=[4, 2, 2, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]]) + _test_batch_to_space_nd(input_shape=[4, 3, 3, 1], block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + ###################################################################### # SpaceToBatchND