Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] add squeeze #494

Merged
merged 2 commits into from
Sep 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from . import tag
from .util import ravel_index, unravel_index, get_const_int
from .util import ravel_index, unravel_index, get_const_int, get_const_tuple

@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
Expand Down Expand Up @@ -77,6 +77,57 @@ def reshape(a, newshape):
lambda *indices: a(*unravel_index(ravel_index(indices, newshape), a_shape)))


@tvm.tag_scope(tag=tag.INJECTIVE)
def squeeze(a, axis=None):
"""Remove single-dimensional entries from the shape of an array.

Parameters
----------
a : tvm.Tensor

axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.

Returns
-------
squeezed : tvm.Tensor
"""
a_ndim = len(a.shape)
a_shape = get_const_tuple(a.shape)
if axis is None:
axis = []
for i, ele in enumerate(a_shape):
if ele == 1:
axis.append(i)
else:
if isinstance(axis, int):
axis = axis + a_ndim if axis < 0 else axis
assert a_shape[axis] == 1
axis = [axis]
else:
axis = [ele + a_ndim if ele < 0 else ele for ele in axis]
for ele in axis:
assert a_shape[ele] == 1
out_shape = []
search_axis = set(axis)
for i, a_dim in enumerate(a_shape):
if i not in search_axis:
out_shape.append(a_dim)
def _compute(*indices):
real_indices = []
flag = 0
for i in range(a_ndim):
if i not in search_axis:
real_indices.append(indices[i - flag])
else:
real_indices.append(0)
flag += 1
return a(*real_indices)

return tvm.compute(out_shape, _compute)


@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis.
Expand Down
29 changes: 29 additions & 0 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ def check_device(device):
check_device("metal")


def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

check_device("cuda")
check_device("opencl")
check_device("metal")


def verify_concatenate(shapes, axis):
tensor_l = []
for i, shape in enumerate(shapes):
Expand Down Expand Up @@ -133,6 +155,12 @@ def test_reshape():
verify_reshape((16, ), (2, 2, 2, 2))


def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
verify_squeeze((1, 1, 1, 4), (1, 2))


def test_concatenate():
verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
Expand All @@ -152,6 +180,7 @@ def test_split():
test_tranpose()
test_expand_dims()
test_reshape()
test_squeeze()
test_concatenate()
test_split()