Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unstack, moveaxis, swapaxes #1137

Merged
merged 2 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@
finfo,
flip,
iinfo,
moveaxis,
permute_dims,
result_type,
roll,
squeeze,
stack,
swapaxes,
unstack,
)
from dpctl.tensor._print import (
get_print_options,
Expand Down Expand Up @@ -143,6 +146,9 @@
"complex128",
"iinfo",
"finfo",
"unstack",
"moveaxis",
"swapaxes",
"can_cast",
"result_type",
"meshgrid",
Expand Down
115 changes: 114 additions & 1 deletion dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2022 Intel Corporation
# Copyright 2020-2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -741,6 +741,119 @@ def finfo(dtype):
return finfo_object(dtype)


def unstack(X, axis=0):
"""unstack(x, axis=0)

Splits an array in a sequence of arrays along the given axis.

Args:
x (usm_ndarray): input array

axis (int, optional): axis along which `x` is unstacked.
If `x` has rank (i.e, number of dimensions) `N`,
a valid `axis` must reside in the half-open interval `[-N, N)`.
Default: `0`.

Returns:
Tuple[usm_ndarray,...]: A tuple of arrays.

Raises:
AxisError: if the `axis` value is invalid.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")

axis = normalize_axis_index(axis, X.ndim)
Y = dpt.moveaxis(X, axis, 0)

return tuple(Y[i] for i in range(Y.shape[0]))


def moveaxis(X, src, dst):
"""moveaxis(x, src, dst)

Moves axes of an array to new positions.

Args:
x (usm_ndarray): input array

src (int or a sequence of int):
Original positions of the axes to move.
These must be unique. If `x` has rank (i.e., number of
dimensions) `N`, a valid `axis` must be in the
half-open interval `[-N, N)`.

dst (int or a sequence of int):
Destination positions for each of the original axes.
These must also be unique. If `x` has rank
(i.e., number of dimensions) `N`, a valid `axis` must be
in the half-open interval `[-N, N)`.

Returns:
usm_narray: Array with moved axes.
The returned array must has the same data type as `x`,
is created on the same device as `x` and has the same
USM allocation type as `x`.

Raises:
AxisError: if `axis` value is invalid.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")

if not isinstance(src, (tuple, list)):
src = (src,)

if not isinstance(dst, (tuple, list)):
dst = (dst,)

src = normalize_axis_tuple(src, X.ndim, "src")
dst = normalize_axis_tuple(dst, X.ndim, "dst")
ind = list(range(0, X.ndim))
for i in range(len(src)):
ind.remove(src[i]) # using the value here which is the same as index
ind.insert(dst[i], src[i])

return dpt.permute_dims(X, tuple(ind))


def swapaxes(X, axis1, axis2):
"""swapaxes(x, axis1, axis2)

Interchanges two axes of an array.

Args:
x (usm_ndarray): input array

axis1 (int): First axis.
If `x` has rank (i.e., number of dimensions) `N`,
a valid `axis` must be in the half-open interval `[-N, N)`.

axis2 (int): Second axis.
If `x` has rank (i.e., number of dimensions) `N`,
a valid `axis` must be in the half-open interval `[-N, N)`.

Returns:
usm_narray: Array with swapped axes.
The returned array must has the same data type as `x`,
is created on the same device as `x` and has the same USM
allocation type as `x`.

Raises:
AxisError: if `axis` value is invalid.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")

axis1 = normalize_axis_index(axis1, X.ndim, "axis1")
axis2 = normalize_axis_index(axis2, X.ndim, "axis2")

ind = list(range(0, X.ndim))
ind[axis1] = axis2
ind[axis2] = axis1
return dpt.permute_dims(X, tuple(ind))
Comment on lines +851 to +854
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating list to be able to mutate it only to then copy it into tuple, use generator to create the right tuple directly. Here is an example:

In [1]: def permuted_range(n, a1, a2):
   ...:     for i in range(n):
   ...:         if i == a1:
   ...:             yield a2
   ...:         elif i == a2:
   ...:             yield a1
   ...:         else:
   ...:             yield i
   ...:

In [2]: tuple(permuted_range(5, 1, 3))
Out[2]: (0, 3, 2, 1, 4)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the way it is currently implement is faster:


In [13]: def way1(n, a1, a2):
    ...:     ind = list(range(n))
    ...:     ind[a1] = a2
    ...:     ind[a2] = a1
    ...:     return tuple(ind)
    ...:

In [14]: def permuted_range(n, a1, a2):
    ...:     for i in range(n):
    ...:         if i == a1:
    ...:             yield a2
    ...:         elif i == a2:
    ...:             yield a1
    ...:         else:
    ...:             yield i
    ...:

In [15]: def way2(n, a1, a2):
    ...:     return tuple(permuted_range(n, a1, a2))
    ...:

In [16]: way1(5, 1, 3) == way2(5, 1, 3)
Out[16]: True

In [17]: %timeit way1(5, 1, 3)
434 ns ± 23 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [18]: %timeit way2(5, 1, 3)
865 ns ± 98 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)



def _supported_dtype(dtypes):
for dtype in dtypes:
if dtype.char not in "?bBhHiIlLqQefdFD":
Expand Down
76 changes: 76 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,3 +1046,79 @@ def test_result_type():
X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"]

assert dpt.result_type(*X) == np.result_type(*X_np)


def test_swapaxes_1d():
x = np.array([[1, 2, 3]])
exp = np.swapaxes(x, 0, 1)

y = dpt.asarray([[1, 2, 3]])
res = dpt.swapaxes(y, 0, 1)

assert_array_equal(exp, dpt.asnumpy(res))


def test_swapaxes_2d():
x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
exp = np.swapaxes(x, 0, 2)

y = dpt.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
res = dpt.swapaxes(y, 0, 2)

assert_array_equal(exp, dpt.asnumpy(res))


def test_moveaxis_1axis():
x = np.arange(60).reshape((3, 4, 5))
exp = np.moveaxis(x, 0, -1)

y = dpt.reshape(dpt.arange(60), (3, 4, 5))
res = dpt.moveaxis(y, 0, -1)

assert_array_equal(exp, dpt.asnumpy(res))


def test_moveaxis_2axes():
x = np.arange(60).reshape((3, 4, 5))
exp = np.moveaxis(x, [0, 1], [-1, -2])

y = dpt.reshape(dpt.arange(60), (3, 4, 5))
res = dpt.moveaxis(y, [0, 1], [-1, -2])

assert_array_equal(exp, dpt.asnumpy(res))


def test_moveaxis_3axes():
x = np.arange(60).reshape((3, 4, 5))
exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])

y = dpt.reshape(dpt.arange(60), (3, 4, 5))
res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3])

assert_array_equal(exp, dpt.asnumpy(res))


def test_unstack_axis0():
y = dpt.reshape(dpt.arange(6), (2, 3))
res = dpt.unstack(y)

assert_array_equal(dpt.asnumpy(y[0, ...]), dpt.asnumpy(res[0]))
assert_array_equal(dpt.asnumpy(y[1, ...]), dpt.asnumpy(res[1]))


def test_unstack_axis1():
y = dpt.reshape(dpt.arange(6), (2, 3))
res = dpt.unstack(y, 1)

assert_array_equal(dpt.asnumpy(y[:, 0, ...]), dpt.asnumpy(res[0]))
assert_array_equal(dpt.asnumpy(y[:, 1, ...]), dpt.asnumpy(res[1]))
assert_array_equal(dpt.asnumpy(y[:, 2, ...]), dpt.asnumpy(res[2]))


def test_unstack_axis2():
y = dpt.reshape(dpt.arange(60), (4, 5, 3))
res = dpt.unstack(y, 2)

assert_array_equal(dpt.asnumpy(y[:, :, 0, ...]), dpt.asnumpy(res[0]))
assert_array_equal(dpt.asnumpy(y[:, :, 1, ...]), dpt.asnumpy(res[1]))
assert_array_equal(dpt.asnumpy(y[:, :, 2, ...]), dpt.asnumpy(res[2]))