diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 81680cbe0287..c3cb52f9abac 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -51,6 +51,7 @@ dpnp_det, dpnp_eigh, dpnp_inv, + dpnp_matrix_power, dpnp_matrix_rank, dpnp_multi_dot, dpnp_pinv, @@ -370,33 +371,65 @@ def inv(a): return dpnp_inv(a) -def matrix_power(input, count): +def matrix_power(a, n): """ - Raise a square matrix to the (integer) power `count`. + Raise a square matrix to the (integer) power `n`. + + For full documentation refer to :obj:`numpy.linalg.matrix_power`. Parameters ---------- - input : sequence of array_like + a : (..., M, M) {dpnp.ndarray, usm_ndarray} + Matrix to be "powered". + n : int + The exponent can be any integer or long integer, positive, negative, or zero. Returns ------- - output : array - Returns the dot product of the supplied arrays. + a**n : (..., M, M) dpnp.ndarray + The return value is the same shape and type as `M`; + if the exponent is positive or zero then the type of the + elements is the same as those of `M`. If the exponent is + negative the elements are floating-point. - See Also - -------- - :obj:`numpy.linalg.matrix_power` + >>> import dpnp as np + >>> i = np.array([[0, 1], [-1, 0]]) # matrix equiv. of the imaginary unit + >>> np.linalg.matrix_power(i, 3) # should = -i + array([[ 0, -1], + [ 1, 0]]) + >>> np.linalg.matrix_power(i, 0) + array([[1, 0], + [0, 1]]) + >>> np.linalg.matrix_power(i, -3) # should = 1/(-i) = i, but w/ f.p. elements + array([[ 0., 1.], + [-1., 0.]]) + + Somewhat more sophisticated example + + >>> q = np.zeros((4, 4)) + >>> q[0:2, 0:2] = -i + >>> q[2:4, 2:4] = i + >>> q # one of the three quaternion units not equal to 1 + array([[ 0., -1., 0., 0.], + [ 1., 0., 0., 0.], + [ 0., 0., 0., 1.], + [ 0., 0., -1., 0.]]) + >>> np.linalg.matrix_power(q, 2) # = -np.eye(4) + array([[-1., 0., 0., 0.], + [ 0., -1., 0., 0.], + [ 0., 0., -1., 0.], + [ 0., 0., 0., -1.]]) """ - if not use_origin_backend() and count > 0: - result = input - for _ in range(count - 1): - result = dpnp.matmul(result, input) + dpnp.check_supported_arrays_type(a) + check_stacked_2d(a) + check_stacked_square(a) - return result + if not isinstance(n, int): + raise TypeError("exponent must be an integer") - return call_origin(numpy.linalg.matrix_power, input, count) + return dpnp_matrix_power(a, n) def matrix_rank(A, tol=None, hermitian=False): diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 0305182389b7..544fa936da0b 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -40,6 +40,7 @@ "dpnp_det", "dpnp_eigh", "dpnp_inv", + "dpnp_matrix_power", "dpnp_matrix_rank", "dpnp_multi_dot", "dpnp_pinv", @@ -526,9 +527,50 @@ def _stacked_identity( """ shape = batch_shape + (n, n) - idx = dpnp.arange(n, usm_type=usm_type, sycl_queue=sycl_queue) - x = dpnp.zeros(shape, dtype=dtype, usm_type=usm_type, sycl_queue=sycl_queue) - x[..., idx, idx] = 1 + x = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, sycl_queue=sycl_queue) + x[...] = dpnp.eye( + n, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue + ) + return x + + +def _stacked_identity_like(x): + """ + Create stacked identity matrices based on the shape and properties of `x`. + + Parameters + ---------- + x : dpnp.ndarray + Input array based on whose properties (shape, data type, USM type and SYCL queue) + the identity matrices will be created. + + Returns + ------- + out : dpnp.ndarray + Array of stacked `n x n` identity matrices, + where `n` is the size of the last dimension of `x`. + The returned array has the same shape, data type, USM type + and uses the same SYCL queue as `x`, if applicable. + + Example + ------- + >>> import dpnp + >>> x = dpnp.zeros((2, 3, 3), dtype=dpnp.int64) + >>> _stacked_identity_like(x) + array([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]], dtype=int32) + + """ + + x = dpnp.empty_like(x) + x[...] = dpnp.eye( + x.shape[-2], dtype=x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue + ) return x @@ -1082,6 +1124,46 @@ def dpnp_inv(a): return b_f +def dpnp_matrix_power(a, n): + """ + dpnp_matrix_power(a, n) + + Raise a square matrix to the (integer) power `n`. + + """ + + if n == 0: + return _stacked_identity_like(a) + elif n < 0: + a = dpnp.linalg.inv(a) + n *= -1 + + if n == 1: + return a + elif n == 2: + return dpnp.matmul(a, a) + elif n == 3: + return dpnp.matmul(dpnp.matmul(a, a), a) + + # Use binary decomposition to reduce the number of matrix + # multiplications for n > 3. + # `result` will hold the final matrix power, + # while `acc` serves as an accumulator for the intermediate matrix powers. + result = None + acc = a.copy() + while n > 0: + n, bit = divmod(n, 2) + if bit: + if result is None: + result = acc.copy() + else: + dpnp.matmul(result, acc, out=result) + if n > 0: + dpnp.matmul(acc, acc, out=acc) + + return result + + def dpnp_matrix_rank(A, tol=None, hermitian=False): """ dpnp_matrix_rank(A, tol=None, hermitian=False) diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 6a900fda666f..7ee04717abed 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -332,10 +332,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1 tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_too_many_dims3 -tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge -tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large -tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two - tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index cae7e25765e5..c0401fc1a6d7 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -434,10 +434,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWith tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1 -tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge -tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large -tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two - tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal diff --git a/tests/test_linalg.py b/tests/test_linalg.py index d42e17441001..f591de24d501 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -566,6 +566,56 @@ def test_inv_errors(self): assert_raises(inp.linalg.LinAlgError, inp.linalg.inv, a_dp) +class TestMatrixPower: + @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize( + "data, power", + [ + ( + numpy.block( + [ + [numpy.eye(2), numpy.zeros((2, 2))], + [numpy.zeros((2, 2)), numpy.eye(2) * 2], + ] + ), + 3, + ), # Block-diagonal matrix + (numpy.eye(3, k=1) + numpy.eye(3), 3), # Non-diagonal matrix + ( + numpy.eye(3, k=1) + numpy.eye(3), + -3, + ), # Inverse of non-diagonal matrix + ], + ) + def test_matrix_power(self, data, power, dtype): + a = data.astype(dtype) + a_dp = inp.array(a) + + result = inp.linalg.matrix_power(a_dp, power) + expected = numpy.linalg.matrix_power(a, power) + + assert_dtype_allclose(result, expected) + + def test_matrix_power_errors(self): + a_dp = inp.eye(4, dtype="float32") + + # unsupported type `a` + a_np = inp.asnumpy(a_dp) + assert_raises(TypeError, inp.linalg.matrix_power, a_np, 2) + + # unsupported type `power` + assert_raises(TypeError, inp.linalg.matrix_power, a_dp, 1.5) + assert_raises(TypeError, inp.linalg.matrix_power, a_dp, [2]) + + # not invertible + # TODO: remove it when mkl>=2024.0 is released (MKLD-16626) + if not is_cpu_device(): + noninv = inp.array([[1, 0], [0, 0]]) + assert_raises( + inp.linalg.LinAlgError, inp.linalg.matrix_power, noninv, -1 + ) + + class TestMatrixRank: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index f82e82f98d41..9402c393a1b4 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1299,6 +1299,30 @@ def test_inv(shape, is_empty, device): assert_sycl_queue_equal(result_queue, expected_queue) +@pytest.mark.parametrize( + "n", + [-1, 0, 1, 2, 3], + ids=["-1", "0", "1", "2", "3"], +) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_matrix_power(n, device): + data = numpy.array([[1, 2], [3, 5]], dtype=dpnp.default_float_type(device)) + dp_data = dpnp.array(data, device=device) + + result = dpnp.linalg.matrix_power(dp_data, n) + expected = numpy.linalg.matrix_power(data, n) + assert_dtype_allclose(result, expected) + + expected_queue = dp_data.get_array().sycl_queue + result_queue = result.get_array().sycl_queue + + assert_sycl_queue_equal(result_queue, expected_queue) + + @pytest.mark.parametrize( "data, tol", [ diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 14898c9aee0d..7c8ce6362d84 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -924,6 +924,19 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param): assert x.usm_type == s.usm_type +@pytest.mark.parametrize( + "n", + [-1, 0, 1, 2, 3], + ids=["-1", "0", "1", "2", "3"], +) +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_matrix_power(n, usm_type): + a = dp.array([[1, 2], [3, 5]], usm_type=usm_type) + + dp_res = dp.linalg.matrix_power(a, n) + assert a.usm_type == dp_res.usm_type + + @pytest.mark.parametrize( "data, tol", [ diff --git a/tests/third_party/cupy/linalg_tests/test_product.py b/tests/third_party/cupy/linalg_tests/test_product.py index 17d7861cb64e..b23efa51838d 100644 --- a/tests/third_party/cupy/linalg_tests/test_product.py +++ b/tests/third_party/cupy/linalg_tests/test_product.py @@ -430,7 +430,6 @@ def test_tensordot_zero_length(self, xp, dtype): class TestMatrixPower(unittest.TestCase): - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_matrix_power_0(self, xp, dtype): @@ -455,7 +454,6 @@ def test_matrix_power_3(self, xp, dtype): a = testing.shaped_arange((3, 3), xp, dtype) return xp.linalg.matrix_power(a, 3) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_float_dtypes(no_float16=True) @testing.numpy_cupy_allclose(rtol=1e-5) def test_matrix_power_inv1(self, xp, dtype): @@ -463,7 +461,6 @@ def test_matrix_power_inv1(self, xp, dtype): a = a * a % 30 return xp.linalg.matrix_power(a, -1) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_float_dtypes(no_float16=True) @testing.numpy_cupy_allclose(rtol=1e-5) def test_matrix_power_inv2(self, xp, dtype): @@ -471,7 +468,6 @@ def test_matrix_power_inv2(self, xp, dtype): a = a * a % 30 return xp.linalg.matrix_power(a, -2) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_float_dtypes(no_float16=True) @testing.numpy_cupy_allclose(rtol=1e-4) def test_matrix_power_inv3(self, xp, dtype): @@ -496,3 +492,20 @@ def test_matrix_power_large(self, xp, dtype): def test_matrix_power_invlarge(self, xp, dtype): a = xp.eye(23, k=17, dtype=dtype) + xp.eye(23, k=-6, dtype=dtype) return xp.linalg.matrix_power(a, -987654321987654321) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 3, 3), + (3, 0, 0), + ], +) +@pytest.mark.parametrize("n", [0, 5, -7]) +class TestMatrixPowerBatched: + @testing.for_float_dtypes(no_float16=True) + @testing.numpy_cupy_allclose(rtol=5e-5) + def test_matrix_power_batched(self, xp, dtype, shape, n): + a = testing.shaped_arange(shape, xp, dtype) + a += xp.identity(shape[-1], dtype) + return xp.linalg.matrix_power(a, n)