Skip to content

Commit

Permalink
Add dpnp.linalg.tensorinv() implementation (#1752)
Browse files Browse the repository at this point in the history
* Add a new dpnp.linalg.tensorinv impl

* Add tests for tensorinv

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Mar 22, 2024
1 parent e44469c commit e5d3127
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 0 deletions.
59 changes: 59 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"solve",
"svd",
"slogdet",
"tensorinv",
]


Expand Down Expand Up @@ -930,3 +931,61 @@ def slogdet(a):
check_stacked_square(a)

return dpnp_slogdet(a)


def tensorinv(a, ind=2):
"""
Compute the `inverse` of a tensor.
For full documentation refer to :obj:`numpy.linalg.tensorinv`.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Tensor to `invert`. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int
Number of first indices that are involved in the inverse sum.
Must be a positive integer.
Default: 2.
Returns
-------
out : dpnp.ndarray
The inverse of a tensor whose shape is equivalent to
``a.shape[ind:] + a.shape[:ind]``.
See Also
--------
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
:obj:`dpnp.linalg.tensorsolve` : Solve the tensor equation ``a x = b`` for x.
Examples
--------
>>> import dpnp as np
>>> a = np.eye(4*6)
>>> a.shape = (4, 6, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=2)
>>> ainv.shape
(8, 3, 4, 6)
>>> a = np.eye(4*6)
>>> a.shape = (24, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=1)
>>> ainv.shape
(8, 3, 24)
"""

dpnp.check_supported_arrays_type(a)

if ind <= 0:
raise ValueError("Invalid ind argument")

old_shape = a.shape
inv_shape = old_shape[ind:] + old_shape[:ind]
prod = numpy.prod(old_shape[ind:])
a = a.reshape(prod, -1)
a_inv = inv(a)

return a_inv.reshape(*inv_shape)
39 changes: 39 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,3 +1459,42 @@ def test_pinv_errors(self):
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)


class TestTensorinv:
@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
"shape, ind",
[
((4, 6, 8, 3), 2),
((24, 8, 3), 1),
],
ids=[
"(4, 6, 8, 3)",
"(24, 8, 3)",
],
)
def test_tensorinv(self, dtype, shape, ind):
a = numpy.eye(24, dtype=dtype).reshape(shape)
a_dp = inp.array(a)

ainv = numpy.linalg.tensorinv(a, ind=ind)
ainv_dp = inp.linalg.tensorinv(a_dp, ind=ind)

assert ainv.shape == ainv_dp.shape
assert_dtype_allclose(ainv_dp, ainv)

def test_test_tensorinv_errors(self):
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)

# unsupported type `a`
a_np = inp.asnumpy(a_dp)
assert_raises(TypeError, inp.linalg.pinv, a_np)

# unsupported type `ind`
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, 2.0)
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, [2.0])
assert_raises(ValueError, inp.linalg.tensorinv, a_dp, -1)

# non-square
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)
18 changes: 18 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,3 +1873,21 @@ def test_pinv(shape, hermitian, rcond_as_array, device):
B_queue = B_result.sycl_queue

assert_sycl_queue_equal(B_queue, a_dp.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_tensorinv(device):
a_np = numpy.eye(12).reshape(12, 4, 3)
a_dp = dpnp.array(a_np, device=device)

result = dpnp.linalg.tensorinv(a_dp, ind=1)
expected = numpy.linalg.tensorinv(a_np, ind=1)
assert_dtype_allclose(result, expected)

result_queue = result.sycl_queue

assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)
8 changes: 8 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,3 +1027,11 @@ def test_qr(shape, mode, usm_type):

assert a.usm_type == dp_q.usm_type
assert a.usm_type == dp_r.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_tensorinv(usm_type):
a = dp.eye(12, usm_type=usm_type).reshape(12, 4, 3)
ainv = dp.linalg.tensorinv(a, ind=1)

assert a.usm_type == ainv.usm_type
44 changes: 44 additions & 0 deletions tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,47 @@ def test_pinv_size_0(self):
self.check_x((0, 0), rcond=1e-15)
self.check_x((0, 2, 3), rcond=1e-15)
self.check_x((2, 0, 3), rcond=1e-15)


class TestTensorInv(unittest.TestCase):
@testing.for_dtypes("ifdFD")
@_condition.retry(10)
def check_x(self, a_shape, ind, dtype):
a_cpu = numpy.random.randint(0, 10, size=a_shape).astype(dtype)
a_gpu = cupy.asarray(a_cpu)
a_gpu_copy = a_gpu.copy()
result_cpu = numpy.linalg.tensorinv(a_cpu, ind=ind)
result_gpu = cupy.linalg.tensorinv(a_gpu, ind=ind)
assert_dtype_allclose(result_gpu, result_cpu)
testing.assert_array_equal(a_gpu_copy, a_gpu)

def check_shape(self, a_shape, ind):
a = cupy.random.rand(*a_shape)
with self.assertRaises(
(numpy.linalg.LinAlgError, cupy.linalg.LinAlgError)
):
cupy.linalg.tensorinv(a, ind=ind)

def check_ind(self, a_shape, ind):
a = cupy.random.rand(*a_shape)
with self.assertRaises(ValueError):
cupy.linalg.tensorinv(a, ind=ind)

def test_tensorinv(self):
self.check_x((12, 3, 4), ind=1)
self.check_x((3, 8, 24), ind=2)
self.check_x((18, 3, 3, 2), ind=1)
self.check_x((1, 4, 2, 2), ind=2)
self.check_x((2, 3, 5, 30), ind=3)
self.check_x((24, 2, 2, 3, 2), ind=1)
self.check_x((3, 4, 2, 3, 2), ind=2)
self.check_x((1, 2, 3, 2, 3), ind=3)
self.check_x((3, 2, 1, 2, 12), ind=4)

def test_invalid_shape(self):
self.check_shape((2, 3, 4), ind=1)
self.check_shape((1, 2, 3, 4), ind=3)

def test_invalid_index(self):
self.check_ind((12, 3, 4), ind=-1)
self.check_ind((18, 3, 3, 2), ind=0)

0 comments on commit e5d3127

Please sign in to comment.