Skip to content

Commit

Permalink
fix squeeze to output (1,) if all axes are squeezed. E.g squeeze((1,1…
Browse files Browse the repository at this point in the history
…,1...), None) case (#498)
  • Loading branch information
sxjscience authored and tqchen committed Sep 28, 2017
1 parent e44b38e commit 5d9647e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 2 additions & 0 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def squeeze(a, axis=None):
for i, a_dim in enumerate(a_shape):
if i not in search_axis:
out_shape.append(a_dim)
if not out_shape:
out_shape.append(1)
def _compute(*indices):
real_indices = []
flag = 0
Expand Down
7 changes: 6 additions & 1 deletion topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def check_device(device):
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
if out_npy.shape == ():
out_nd_shape = (1,)
else:
out_nd_shape = out_npy.shape
out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

Expand Down Expand Up @@ -159,6 +163,7 @@ def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
verify_squeeze((1, 1, 1, 4), (1, 2))
verify_squeeze((1, 1, 1, 1), None)


def test_concatenate():
Expand Down

0 comments on commit 5d9647e

Please sign in to comment.