From 54f9f5bb2a53815301c5c79f1390c583eafd130d Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 29 Jan 2024 22:11:37 -0600 Subject: [PATCH 1/5] update matmul for cupy tests --- dpnp/backend/extensions/blas/gemm.cpp | 6 +- dpnp/backend/extensions/blas/gemm.hpp | 4 +- dpnp/backend/extensions/blas/gemm_batch.cpp | 6 +- dpnp/dpnp_iface_linearalgebra.py | 97 ++++++++++++++++++- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 49 +++++++--- tests/test_mathematical.py | 96 +++++++++++++++++- .../cupy/math_tests/test_matmul.py | 97 +++++++++++++++++++ 7 files changed, 326 insertions(+), 29 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index a26420f49b3..e243b1c833d 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas; namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue, +typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &, oneapi::mkl::transpose, oneapi::mkl::transpose, const std::int64_t, @@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] [dpctl_td_ns::num_types]; template -static sycl::event gemm_impl(sycl::queue exec_q, +static sycl::event gemm_impl(sycl::queue &exec_q, oneapi::mkl::transpose transA, oneapi::mkl::transpose transB, const std::int64_t m, @@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q, } std::pair - gemm(sycl::queue exec_q, + gemm(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, diff --git a/dpnp/backend/extensions/blas/gemm.hpp b/dpnp/backend/extensions/blas/gemm.hpp index 3f1ec6e745a..cd93494ce03 100644 --- a/dpnp/backend/extensions/blas/gemm.hpp +++ b/dpnp/backend/extensions/blas/gemm.hpp @@ -39,14 +39,14 @@ namespace ext namespace blas { extern std::pair - gemm(sycl::queue exec_q, + gemm(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, const std::vector &depends); extern std::pair - gemm_batch(sycl::queue exec_q, + gemm_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 9359901edd8..41211fb710a 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -47,7 +47,7 @@ namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( - sycl::queue, + sycl::queue &, const std::int64_t, const std::int64_t, const std::int64_t, @@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types]; template -static sycl::event gemm_batch_impl(sycl::queue exec_q, +static sycl::event gemm_batch_impl(sycl::queue &exec_q, const std::int64_t m, const std::int64_t n, const std::int64_t k, @@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q, } std::pair - gemm_batch(sycl::queue exec_q, + gemm_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index 7baca14c93b..de0d6ebabc4 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -266,18 +266,53 @@ def matmul( order="K", dtype=None, subok=True, + signature=None, + extobj=None, + axes=None, + axis=None, ): """ Matrix product of two arrays. For full documentation refer to :obj:`numpy.matmul`. + Parameters + ---------- + x1 : {dpnp_array, usm_ndarray} + First input array. + x2 : {dpnp_array, usm_ndarray} + Second input array. + out : {dpnp.ndarray, usm_ndarray}, optional + Alternative output array in which to place the result. It must have + a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type + (of the calculated values) will be cast if necessary. Default: ``None``. + dtype : dtype, optional + Type to use in computing the matrix product. By default, the returned + array will have data type that is determined by considering + Promotion Type Rule and device capabilities. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Default: ``"same_kind"``. + order : {"C", "F", "A", "K", None}, optional + Memory layout of the newly output array, if parameter `out` is ``None``. + Default: "K". + axes : list of tuples, optional + A list of tuples with indices of axes the matrix product should operate on. + For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements + are 2d matrices and these are taken to be stored in the two last axes of each + argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)]. + Default: ``None``. + + Returns + ------- + out : dpnp.ndarray + Returns the matrix product of the inputs. + This is a 0-d array only when both `x1`, `x2` are 1-d vectors. + Limitations ----------- - Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`. - Keyword argument `subok` is currently unsupported. - Input array data types are limited by supported DPNP :ref:`Data types`. + Keyword arguments `subok`, `signature`, `extobj`, and `axis` are + only supported with their default value. + Otherwise ``NotImplementedError`` exception will be raised. See Also -------- @@ -338,8 +373,52 @@ def matmul( raise NotImplementedError( "subok keyword argument is only supported by its default value." ) + elif signature is not None: + raise NotImplementedError( + "signature keyword argument is only supported by its default value." + ) + elif extobj is not None: + raise NotImplementedError( + "extobj keyword argument is only supported by its default value." + ) + elif axis is not None: + raise NotImplementedError( + "axis keyword argument is only supported by its default value." + ) else: - return dpnp_matmul( + if axes is not None: + if not isinstance(axes, list): + raise TypeError("Axes should be a list.") + else: + if len(axes) != 3: + raise ValueError( + "Axes should be a list of three tuples for inputs and output." + ) + + for i in range(3): + if not isinstance(axes[i], tuple): + raise TypeError(f"Axes item {i} should be a tuple.") + if len(axes[i]) != 2: + raise ValueError( + f"Axes item {i} should be a tuple with 2 elements." + ) + + for j in range(2): + if not isinstance(axes[i][j], int): + raise TypeError("Axes must be an integer.") + + axes_x1, axes_x2, axes_res = axes + # Move the axes that are going to be used in matrix product, + # to the end of "x1" and "x2" + x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1)) + x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) + out_orig = out + if out is not None: + dpnp.check_supported_arrays_type(x1, x2) + # out that is passed to the backend should have the correct shape + out = dpnp.moveaxis(out, axes_res, (-2, -1)) + + result = dpnp_matmul( x1, x2, out=out, @@ -347,6 +426,14 @@ def matmul( order=order, dtype=dtype, ) + if axes is not None: + if out is result: + # out and out_orig contain the same data but they have different shape + return out_orig + # Move the result to the appropriate axes of out array + result = dpnp.moveaxis(result, (-2, -1), axes_res) + + return result def outer(x1, x2, out=None): diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 3c36eda042d..2bb47832d09 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -116,21 +116,9 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): x2_strides = x2.strides res_strides = res.strides - # when shape along any particular dimension is 1, - # the stride along that dimension is not a - # meaningful number and is undefined. Here, we - # standardizing strides before continuing, - # setting stride to 0 if the shape along that axis is <=1 - if x1_is_2D: - x1_strides = tuple( - str_i if sh_i > 1 else 0 - for sh_i, str_i in zip(x1.shape, x1_strides) - ) - if x2_is_2D: - x2_strides = tuple( - str_i if sh_i > 1 else 0 - for sh_i, str_i in zip(x2.shape, x2_strides) - ) + # need to standardize to use in ti._contract_iter2 + x1_strides = _standardize_strides(x1_strides, x1_is_2D, x1.shape, x1.ndim) + x2_strides = _standardize_strides(x2_strides, x2_is_2D, x2.shape, x2.ndim) batch_size = res.shape[:-2][0] stridea = x1_strides[0] @@ -220,6 +208,37 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue): return op_dtype, res_dtype +def _standardize_strides(strides, inherently_2D, shape, ndim): + """ + Standardizing the strides. + + When shape of an array along any particular dimension is 1, the stride + along that dimension is undefined. This functions standardize the strides + in the following way: + For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them), + we use zero as the stride for dimensions equal one. + For other N-D arrays, the non-zero value of strides is calculated and used. + + """ + + if inherently_2D: + stndrd_strides = tuple( + str_i if sh_i > 1 else 0 for sh_i, str_i in zip(shape, strides) + ) + else: + stndrd_strides = [ + numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i] + for i in range(ndim - 1) + ] + # last dimension + stndrd_strides.append( + 1 if strides[ndim - 1] == 0 else strides[ndim - 1] + ) + stndrd_strides = tuple(stndrd_strides) + + return stndrd_strides + + def dpnp_dot(a, b, /, out=None, *, conjugate=False): """ Return the dot product of two arrays. diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 12115b5256c..5e7987a571e 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2634,6 +2634,57 @@ def test_matmul_dtype(self, dtype, shape_pair): expected = numpy.matmul(a1, a2, dtype=dtype) assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "axes", + [ + [(-3, -1), (0, 2), (-2, -3)], + [(3, 1), (2, 0), (3, 1)], + [(3, 1), (2, 0), (0, 1)], + ], + ) + def test_matmul_axes(self, dtype, axes): + a = numpy.array( + numpy.random.uniform(-10, 10, 120), dtype=dtype + ).reshape(2, 5, 3, 4) + b = numpy.array( + numpy.random.uniform(-10, 10, 120), dtype=dtype + ).reshape(4, 2, 5, 3) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.matmul(ia, ib, axes=axes) + print(result.shape) + expected = numpy.matmul(a, b, axes=axes) + # TODO: investigate the effect of factor, see SAT-6700 + assert_dtype_allclose(result, expected, factor=24) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "axes, out_shape", + [ + ([(-3, -1), (0, 2), (-2, -3)], (2, 5, 5, 3)), + ([(3, 1), (2, 0), (3, 1)], (2, 4, 3, 4)), + ([(3, 1), (2, 0), (1, 2)], (2, 4, 4, 3)), + ], + ) + def test_matmul_axes_out(self, dtype, axes, out_shape): + a = numpy.array( + numpy.random.uniform(-10, 10, 120), dtype=dtype + ).reshape(2, 5, 3, 4) + b = numpy.array( + numpy.random.uniform(-10, 10, 120), dtype=dtype + ).reshape(4, 2, 5, 3) + ia = dpnp.array(a) + ib = dpnp.array(b) + + out_dp = dpnp.empty(out_shape, dtype=dtype) + result = dpnp.matmul(ia, ib, axes=axes, out=out_dp) + assert result is out_dp + expected = numpy.matmul(a, b, axes=axes) + # TODO: investigate the effect of factor, see SAT-6700 + assert_dtype_allclose(result, expected, factor=24) + @pytest.mark.parametrize("dtype1", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize( "dtype2", get_all_dtypes(no_bool=True, no_none=True) @@ -2822,9 +2873,52 @@ def test_matmul_casting(self): with pytest.raises(TypeError): dpnp.matmul(a1, a2, out=res, casting="safe") - def test_matmul_subok(self): + def test_matmul_not_implemented(self): a1 = dpnp.arange(2 * 4).reshape(2, 4) a2 = dpnp.arange(4 * 3).reshape(4, 3) with pytest.raises(NotImplementedError): dpnp.matmul(a1, a2, subok=False) + + with pytest.raises(NotImplementedError): + dpnp.matmul( + a1, a2, signature=(dpnp.float32, dpnp.float32, dpnp.float32) + ) + + def custom_error_callback(err): + print("Custom error callback triggered with error:", err) + + with pytest.raises(NotImplementedError): + dpnp.matmul(a1, a2, extobj=[32, 1, custom_error_callback]) + + with pytest.raises(NotImplementedError): + dpnp.matmul(a1, a2, axis=2) + + def test_matmul_axes(self): + a1 = dpnp.arange(120).reshape(2, 5, 3, 4) + a2 = dpnp.arange(120).reshape(4, 2, 5, 3) + + # axes must be a list + axes = ((3, 1), (2, 0), (0, 1)) + with pytest.raises(TypeError): + dpnp.matmul(a1, a2, axes=axes) + + # axes must be be a list of three tuples + axes = [(3, 1), (2, 0)] + with pytest.raises(ValueError): + dpnp.matmul(a1, a2, axes=axes) + + # axes items should be a tuple + axes = [(3, 1), (2, 0), [0, 1]] + with pytest.raises(TypeError): + dpnp.matmul(a1, a2, axes=axes) + + # axes items should be a tuple with 2 elements + axes = [(3, 1), (2, 0), (0, 1, 2)] + with pytest.raises(ValueError): + dpnp.matmul(a1, a2, axes=axes) + + # axes must be an integer + axes = [(3, 1), (2, 0), (0.0, 1)] + with pytest.raises(TypeError): + dpnp.matmul(a1, a2, axes=axes) diff --git a/tests/third_party/cupy/math_tests/test_matmul.py b/tests/third_party/cupy/math_tests/test_matmul.py index 887ed9ae1b9..1298454c0cc 100644 --- a/tests/third_party/cupy/math_tests/test_matmul.py +++ b/tests/third_party/cupy/math_tests/test_matmul.py @@ -5,6 +5,7 @@ import pytest import dpnp +from tests.helper import has_support_aspect64 from tests.third_party.cupy import testing @@ -128,6 +129,15 @@ def test_overlap_both(self, xp, dtype, shape): return xp.matmul(a, a, out=a) +class TestMatmulStrides: + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + def test_relaxed_c_contiguous_input(self, xp, dtype): + x1 = testing.shaped_arange((2, 2, 3), xp, dtype)[:, None, :, :] + x2 = testing.shaped_arange((2, 1, 3, 1), xp, dtype) + return x1 @ x2 + + @testing.parameterize( *testing.product( { @@ -186,6 +196,42 @@ def test_cupy_matmul(self, xp, dtype1): return xp.matmul(x1, x2) +@pytest.mark.parametrize( + "shape1, shape2", + [ + # TODO: include it when issue #1540 in dpctl is resolved + # ((256, 256, 3, 2), (256, 256, 2, 4)), + ((256, 256, 3, 2), (2, 4)), + ((3, 2), (256, 256, 2, 4)), + ], +) +class TestMatmulIntegralLargeBatch: + @testing.for_int_dtypes(name="dtype") + @testing.numpy_cupy_array_equal() + def test_operator_matmul(self, xp, dtype, shape1, shape2): + x1 = testing.shaped_random(shape1, xp, dtype) + x2 = testing.shaped_random(shape2, xp, dtype) + return operator.matmul(x1, x2) + + @testing.for_int_dtypes(name="dtype") + @testing.numpy_cupy_array_equal() + def test_cupy_matmul(self, xp, dtype, shape1, shape2): + x1 = testing.shaped_random(shape1, xp, dtype) + x2 = testing.shaped_random(shape2, xp, dtype) + return xp.matmul(x1, x2) + + +@pytest.mark.skip("until issue #1540 in dpctl is resolved") +class TestMatmulOverflow(unittest.TestCase): + @testing.for_int_dtypes(name="dtype", no_bool=True) + @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + def test_overflow(self, xp, dtype): + value = numpy.iinfo(dtype).max + a = xp.array([value - 10]).astype(dtype) + b = xp.array([value - 10]).astype(dtype) + return xp.matmul(a, b) + + @testing.parameterize( *testing.product( { @@ -210,3 +256,54 @@ def test_invalid_shape(self): x2 = testing.shaped_arange(shape2, xp, numpy.float32) with pytest.raises(ValueError): xp.matmul(x1, x2) + + +@testing.parameterize( + *testing.product( + { + "shapes_axes": [ + ( + ( + (2, 5, 3, 2, 3, 4), + (3, 5, 1, 1, 1, 4), + (5, 5, 2, 2, 3, 4), + ), + [(1, 2), (0, 1), (0, 1)], + ), + ( + ( + (2, 5, 3, 2, 3, 4), + (2, 5, 3, 1, 4, 1), + (3, 1, 2, 5, 3, 2), + ), + [(-2, -1), (-2, -1), (0, 1)], + ), + ( + ((3, 2, 4, 4), (4, 4, 3, 2), (4, 4, 3, 3)), + [(0, 1), (-1, -2), (-2, -1)], + ), + ( + ((3, 2, 4, 4), (2, 3, 4, 4), (4, 3, 3, 4)), + [(0, 1), (0, 1), (1, 2)], + ), + ], + } + ) +) +class TestMatmulAxes(unittest.TestCase): + @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + def test_cupy_matmul_axes(self, xp): + x1 = testing.shaped_arange(self.shapes_axes[0][0], xp) + x2 = testing.shaped_arange(self.shapes_axes[0][1], xp) + return xp.matmul(x1, x2, axes=self.shapes_axes[1]) + + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() + ) # required for uint8 + def test_cupy_matmul_axes_out(self, xp): + x1 = testing.shaped_arange(self.shapes_axes[0][0], xp) + x2 = testing.shaped_arange(self.shapes_axes[0][1], xp) + out = xp.zeros(self.shapes_axes[0][2]) + result = xp.matmul(x1, x2, axes=self.shapes_axes[1], out=out) + assert out is result + return out From 80c867152fc3f87375e48356bfe8daa09c88151e Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 12 Feb 2024 14:57:34 -0600 Subject: [PATCH 2/5] address comments --- dpnp/dpnp_iface_linearalgebra.py | 43 +------- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 98 ++++++++++++++++++- tests/test_mathematical.py | 70 ++++++++++++- .../cupy/math_tests/test_matmul.py | 42 +++++--- 4 files changed, 190 insertions(+), 63 deletions(-) diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index de0d6ebabc4..883ce7140cc 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -386,54 +386,15 @@ def matmul( "axis keyword argument is only supported by its default value." ) else: - if axes is not None: - if not isinstance(axes, list): - raise TypeError("Axes should be a list.") - else: - if len(axes) != 3: - raise ValueError( - "Axes should be a list of three tuples for inputs and output." - ) - - for i in range(3): - if not isinstance(axes[i], tuple): - raise TypeError(f"Axes item {i} should be a tuple.") - if len(axes[i]) != 2: - raise ValueError( - f"Axes item {i} should be a tuple with 2 elements." - ) - - for j in range(2): - if not isinstance(axes[i][j], int): - raise TypeError("Axes must be an integer.") - - axes_x1, axes_x2, axes_res = axes - # Move the axes that are going to be used in matrix product, - # to the end of "x1" and "x2" - x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1)) - x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) - out_orig = out - if out is not None: - dpnp.check_supported_arrays_type(x1, x2) - # out that is passed to the backend should have the correct shape - out = dpnp.moveaxis(out, axes_res, (-2, -1)) - - result = dpnp_matmul( + return dpnp_matmul( x1, x2, out=out, casting=casting, order=order, dtype=dtype, + axes=axes, ) - if axes is not None: - if out is result: - # out and out_orig contain the same data but they have different shape - return out_orig - # Move the result to the appropriate axes of out array - result = dpnp.moveaxis(result, (-2, -1), axes_res) - - return result def outer(x1, x2, out=None): diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 2bb47832d09..68bb9da2f48 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -27,6 +27,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import numpy +from numpy.core.numeric import normalize_axis_tuple import dpnp import dpnp.backend.extensions.blas._blas_impl as bi @@ -43,7 +44,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): If `out` is not ``None`` and its features match the specified `shape`, `dtype, `usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and does not have any memory overlap with `x1` and `x2`, `out` itself is returned. - If these conditions are not statisfied, an empty array is returned with the + If these conditions are not satisfied, an empty array is returned with the specified `shape`, `dtype, `usm_type`, and `sycl_queue`. """ @@ -239,6 +240,61 @@ def _standardize_strides(strides, inherently_2D, shape, ndim): return stndrd_strides +def _validate_axes(x1, x2, axes): + """Check axes is valid for matmul function.""" + + def _validate_internal(axes, i, ndim): + if ndim == 1: + iter = 1 + if isinstance(axes, int): + axes = (axes,) + elif not isinstance(axes, tuple): + raise TypeError( + f"Axes item {i}: {type(axes)} object cannot be interpreted as an integer." + ) + + if len(axes) != 1: + raise ValueError( + f"Axes item {i} should be a tuple with a single element, or an integer." + ) + else: + iter = 2 + if not isinstance(axes, tuple): + raise TypeError(f"Axes item {i} should be a tuple.") + if len(axes) != 2: + raise ValueError( + f"Axes item {i} should be a tuple with 2 elements." + ) + + for j in range(iter): + if not isinstance(axes[j], int): + raise TypeError( + f"Axes item {i}: {type(axes[j])} object cannot be interpreted as an integer." + ) + return axes + + if not isinstance(axes, list): + raise TypeError("Axes should be a list.") + else: + if len(axes) != 3: + raise ValueError( + "Axes should be a list of three tuples for inputs and output." + ) + + axes[0] = _validate_internal(axes[0], 0, x1.ndim) + axes[1] = _validate_internal(axes[1], 1, x2.ndim) + + if x1.ndim == 1 and x2.ndim == 1: + if axes[2] != (): + raise TypeError("Axes item 2 should be an empty tuple.") + elif x1.ndim == 1 or x2.ndim == 1: + axes[2] = _validate_internal(axes[2], 2, 1) + else: + axes[2] = _validate_internal(axes[2], 2, 2) + + return axes + + def dpnp_dot(a, b, /, out=None, *, conjugate=False): """ Return the dot product of two arrays. @@ -321,6 +377,7 @@ def dpnp_matmul( casting="same_kind", order="K", dtype=None, + axes=None, ): """ Return the matrix product of two arrays. @@ -346,6 +403,22 @@ def dpnp_matmul( res_usm_type, exec_q = get_usm_allocations([x1, x2]) + if axes is not None: + axes = _validate_axes(x1, x2, axes) + + axes_x1, axes_x2, axes_res = axes + axes_x1 = normalize_axis_tuple(axes_x1, x1.ndim, "axis") + axes_x2 = normalize_axis_tuple(axes_x2, x2.ndim, "axis") + # Move the axes that are going to be used in matrix product, + # to the end of "x1" and "x2" + x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1)) if x1.ndim != 1 else x1 + x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) if x2.ndim != 1 else x2 + out_orig = out + if out is not None: + dpnp.check_supported_arrays_type(out) + # out that is passed to the backend should have the correct shape + out = dpnp.moveaxis(out, axes_res, (-2, -1)) + appended_axes = [] if x1_ndim == 1: x1 = x1[dpnp.newaxis, :] @@ -416,9 +489,15 @@ def dpnp_matmul( x2_shape = x2.shape res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) + # handling a special case to provide a similar result to NumPy + if out is not None and x1.shape == (1, 0) and x2.shape == (0, 1): + res_shape = (0,) + appended_axes = [] + result = _create_result_array( x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q ) + # calculate result if result.size == 0: pass @@ -490,12 +569,25 @@ def dpnp_matmul( if gemm_dtype != res_dtype: result = dpnp.astype(result, res_dtype, copy=False) + if out is None: + if axes is not None: + # Move the result to the appropriate axes of out array + if len(axes_res) == 2: + result = dpnp.moveaxis(result, (-2, -1), axes_res) + elif len(axes_res) == 1: + result = dpnp.moveaxis(result, (-1,), axes_res) + return result # If `order` was not passed as default # we need to update it to match the passed `order`. - if order not in ["k", "K"]: + elif order not in ["k", "K"]: return dpnp.array(result, copy=False, order=order) else: return result else: - return dpnp.get_result_array(result, out, casting=casting) + result = dpnp.get_result_array(result, out, casting=casting) + if axes is not None: + if out is result: + # out and out_orig contain the same data but they have different shape + return out_orig + return result diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 5e7987a571e..7c302dc6bc5 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2643,7 +2643,7 @@ def test_matmul_dtype(self, dtype, shape_pair): [(3, 1), (2, 0), (0, 1)], ], ) - def test_matmul_axes(self, dtype, axes): + def test_matmul_axes_ND_ND(self, dtype, axes): a = numpy.array( numpy.random.uniform(-10, 10, 120), dtype=dtype ).reshape(2, 5, 3, 4) @@ -2654,11 +2654,55 @@ def test_matmul_axes(self, dtype, axes): ib = dpnp.array(b) result = dpnp.matmul(ia, ib, axes=axes) - print(result.shape) expected = numpy.matmul(a, b, axes=axes) # TODO: investigate the effect of factor, see SAT-6700 assert_dtype_allclose(result, expected, factor=24) + @pytest.mark.parametrize( + "axes", + [ + [(1, 0), (0), (0)], + [(1, 0), 0, 0], + [(1, 0), (0,), (0,)], + ], + ) + def test_matmul_axes_ND_1D(self, axes): + a = numpy.arange(3 * 4 * 5).reshape(3, 4, 5) + b = numpy.arange(3) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.matmul(ia, ib, axes=axes) + expected = numpy.matmul(a, b, axes=axes) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "axes", + [ + [(0,), (0, 1), (0)], + [(0), (0, 1), 0], + [0, (0, 1), (0,)], + ], + ) + def test_matmul_axes_1D_ND(self, axes): + a = numpy.arange(3 * 4 * 5).reshape(3, 4, 5) + b = numpy.arange(3) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.matmul(ib, ia, axes=axes) + expected = numpy.matmul(b, a, axes=axes) + assert_dtype_allclose(result, expected) + + def test_matmul_axes_1D_1D(self): + a = numpy.arange(3) + ia = dpnp.array(a) + + axes = [0, 0, ()] + result = dpnp.matmul(ia, ia, axes=axes) + expected = numpy.matmul(a, a, axes=axes) + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize( "axes, out_shape", @@ -2908,12 +2952,12 @@ def test_matmul_axes(self): with pytest.raises(ValueError): dpnp.matmul(a1, a2, axes=axes) - # axes items should be a tuple + # axes item should be a tuple axes = [(3, 1), (2, 0), [0, 1]] with pytest.raises(TypeError): dpnp.matmul(a1, a2, axes=axes) - # axes items should be a tuple with 2 elements + # axes item should be a tuple with 2 elements axes = [(3, 1), (2, 0), (0, 1, 2)] with pytest.raises(ValueError): dpnp.matmul(a1, a2, axes=axes) @@ -2922,3 +2966,21 @@ def test_matmul_axes(self): axes = [(3, 1), (2, 0), (0.0, 1)] with pytest.raises(TypeError): dpnp.matmul(a1, a2, axes=axes) + + # axes item 2 should be an empty tuple + a = dpnp.arange(3) + axes = [0, 0, 0] + with pytest.raises(TypeError): + dpnp.matmul(a, a, axes=axes) + + a = dpnp.arange(3 * 4 * 5).reshape(3, 4, 5) + b = dpnp.arange(3) + # list object cannot be interpreted as an integer + axes = [(1, 0), (0), [0]] + with pytest.raises(TypeError): + dpnp.matmul(a, b, axes=axes) + + # axes item should be a tuple with a single element, or an integer + axes = [(1, 0), (0), (0, 1)] + with pytest.raises(ValueError): + dpnp.matmul(a, b, axes=axes) diff --git a/tests/third_party/cupy/math_tests/test_matmul.py b/tests/third_party/cupy/math_tests/test_matmul.py index 1298454c0cc..0a3455768a3 100644 --- a/tests/third_party/cupy/math_tests/test_matmul.py +++ b/tests/third_party/cupy/math_tests/test_matmul.py @@ -4,7 +4,7 @@ import numpy import pytest -import dpnp +import dpnp as cupy from tests.helper import has_support_aspect64 from tests.third_party.cupy import testing @@ -60,17 +60,23 @@ ) class TestMatmul(unittest.TestCase): @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 - def test_operator_matmul(self, xp, dtype1): + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() + ) # required for uint8 + def test_operator_matmul(self, xp, dtype1, dtype2): x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) - x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype1) + x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2) return operator.matmul(x1, x2) @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 - def test_cupy_matmul(self, xp, dtype1): + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() + ) # required for uint8 + def test_cupy_matmul(self, xp, dtype1, dtype2): x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) - x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype1) + x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2) return xp.matmul(x1, x2) @@ -80,7 +86,7 @@ def test_cupy_matmul(self, xp, dtype1): "shape_pair": [ # dot test ((2, 3), (3, 4), (2, 4)), - # ((0,), (0,), (0,)), + ((0,), (0,), (0,)), # matmul test ((5, 3, 2), (5, 2, 4), (5, 3, 4)), ((0, 3, 2), (0, 2, 4), (0, 3, 4)), @@ -171,20 +177,26 @@ class TestMatmulLarge(unittest.TestCase): } @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 - def test_operator_matmul(self, xp, dtype1): + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() + ) # required for uint8 + def test_operator_matmul(self, xp, dtype1, dtype2): if (dtype1, dtype1) in self.skip_dtypes or ( dtype1, dtype1, ) in self.skip_dtypes: return xp.array([]) x1 = testing.shaped_random(self.shape_pair[0], xp, dtype1) - x2 = testing.shaped_random(self.shape_pair[1], xp, dtype1) + x2 = testing.shaped_random(self.shape_pair[1], xp, dtype2) return operator.matmul(x1, x2) @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 - def test_cupy_matmul(self, xp, dtype1): + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() + ) # required for uint8 + def test_cupy_matmul(self, xp, dtype1, dtype2): if (dtype1, dtype1) in self.skip_dtypes or ( dtype1, dtype1, @@ -192,7 +204,7 @@ def test_cupy_matmul(self, xp, dtype1): return xp.array([]) shape1, shape2 = self.shape_pair x1 = testing.shaped_random(shape1, xp, dtype1) - x2 = testing.shaped_random(shape2, xp, dtype1) + x2 = testing.shaped_random(shape2, xp, dtype2) return xp.matmul(x1, x2) @@ -250,7 +262,7 @@ def test_overflow(self, xp, dtype): ) class TestMatmulInvalidShape(unittest.TestCase): def test_invalid_shape(self): - for xp in (numpy, dpnp): + for xp in (numpy, cupy): shape1, shape2 = self.shape_pair x1 = testing.shaped_arange(shape1, xp, numpy.float32) x2 = testing.shaped_arange(shape2, xp, numpy.float32) From d6003632ede990d75b0528bf98154fd0a5871a33 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 13 Feb 2024 08:42:11 -0600 Subject: [PATCH 3/5] address more comments --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 66 ++++++++++++++++--- tests/test_mathematical.py | 63 ++++++++++++++++++ .../cupy/math_tests/test_matmul.py | 16 ++--- 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 68bb9da2f48..5a4d5a36c1b 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -58,7 +58,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): and out.shape == shape and out.usm_type == usm_type and out.sycl_queue == sycl_queue - and (out.flags.c_contiguous or out.flags.f_contiguous) + and out.flags.c_contiguous and not ti._array_overlap(x1_usm, out_usm) and not ti._array_overlap(x2_usm, out_usm) ): @@ -417,7 +417,10 @@ def dpnp_matmul( if out is not None: dpnp.check_supported_arrays_type(out) # out that is passed to the backend should have the correct shape - out = dpnp.moveaxis(out, axes_res, (-2, -1)) + if len(axes_res) == 2: + out = dpnp.moveaxis(out, axes_res, (-2, -1)) + elif len(axes_res) == 1: + out = dpnp.moveaxis(out, axes_res, (-1,)) appended_axes = [] if x1_ndim == 1: @@ -439,6 +442,37 @@ def dpnp_matmul( f"(size {x1_shape[-1]} is different from {x2_shape[-2]})" ) + if out is not None: + out_shape = out.shape + if not appended_axes: + if out_shape[-2] != x1_shape[-2]: + raise ValueError( + "Output array has a mismatch in its core dimension 0. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {out_shape[-2]} is different from {x1_shape[-2]})" + ) + if out_shape[-1] != x2_shape[-1]: + raise ValueError( + "Output array has a mismatch in its core dimension 1. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {out_shape[-1]} is different from {x2_shape[-1]})" + ) + elif len(appended_axes) == 1: + if appended_axes[0] == -1: + if out_shape[-1] != x1_shape[-2]: + raise ValueError( + "Output array has a mismatch in its core dimension 0. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {out_shape[-1]} is different from {x1_shape[-2]})" + ) + elif appended_axes[0] == -2: + if out_shape[-1] != x2_shape[-1]: + raise ValueError( + "Output array has a mismatch in its core dimension 0. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {out_shape[-1]} is different from {x2_shape[-1]})" + ) + # Determine the appropriate data types gemm_dtype, res_dtype = _op_res_dtype( x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q @@ -483,10 +517,25 @@ def dpnp_matmul( x2 = dpnp.repeat(x2, x1_shape[i], axis=i) else: raise ValueError( - "arrays could not be broadcast together with remapped shapes." + "Input arrays could not be broadcast together with remapped shapes, " + f"{x1_shape[:-2]} is different from {x2_shape[:-2]}." ) + x1_shape = x1.shape x2_shape = x2.shape + if out is not None: + for i in range(x1_ndim - 2): + if x1_shape[i] != out_shape[i]: + if not appended_axes: + raise ValueError( + "Output array could not be broadcast together with remapped shapes, " + f"{x1_shape[:-2]} is different from {out_shape[:-2]}." + ) + elif len(appended_axes) == 1: + raise ValueError( + "Output array could not be broadcast together with remapped shapes, " + f"{x1_shape[:-2]} is different from {out_shape[:-1]}." + ) res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) # handling a special case to provide a similar result to NumPy @@ -559,6 +608,8 @@ def dpnp_matmul( if appended_axes: result = dpnp.squeeze(result, tuple(appended_axes)) + if len(appended_axes) == 2 and out is not None: + result = dpnp.tile(result, out.shape) if x1_is_2D and x2_is_2D: # add new axes only if one of the input arrays @@ -572,7 +623,7 @@ def dpnp_matmul( if out is None: if axes is not None: - # Move the result to the appropriate axes of out array + # Move the data to the appropriate axes of the result array if len(axes_res) == 2: result = dpnp.moveaxis(result, (-2, -1), axes_res) elif len(axes_res) == 1: @@ -586,8 +637,7 @@ def dpnp_matmul( return result else: result = dpnp.get_result_array(result, out, casting=casting) - if axes is not None: - if out is result: - # out and out_orig contain the same data but they have different shape - return out_orig + if axes is not None and out is result: + # out and out_orig contain the same data but they have different shape + return out_orig return result diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 7c302dc6bc5..d0fb833b6d7 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2729,6 +2729,29 @@ def test_matmul_axes_out(self, dtype, axes, out_shape): # TODO: investigate the effect of factor, see SAT-6700 assert_dtype_allclose(result, expected, factor=24) + @pytest.mark.parametrize( + "axes, b_shape, out_shape", + [ + ([(1, 0), 0, 0], (3,), (4, 5)), + ([(1, 0), 0, 1], (3,), (5, 4)), + ([(1, 0), (0, 1), (1, 2)], (3, 1), (5, 4, 1)), + ([(1, 0), (0, 1), (0, 2)], (3, 1), (4, 5, 1)), + ([(1, 0), (0, 1), (1, 0)], (3, 1), (1, 4, 5)), + ], + ) + def test_matmul_axes_out_1D(self, axes, b_shape, out_shape): + a = numpy.arange(3 * 4 * 5).reshape(3, 4, 5) + b = numpy.arange(3).reshape(b_shape) + ia = dpnp.array(a) + ib = dpnp.array(b) + + out_dp = dpnp.empty(out_shape) + out_np = numpy.empty(out_shape) + result = dpnp.matmul(ia, ib, axes=axes, out=out_dp) + assert result is out_dp + expected = numpy.matmul(a, b, axes=axes, out=out_np) + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dtype1", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize( "dtype2", get_all_dtypes(no_bool=True, no_none=True) @@ -2855,6 +2878,25 @@ def test_matmul_out(self, dtype): assert result is dpnp_out assert_dtype_allclose(result, expected) + @pytest.mark.parametrize( + "out_shape", + [ + ((4, 5)), + ((6,)), + ((4, 7, 2)), + ], + ) + def test_matmul_out_0D(self, out_shape): + a = numpy.arange(3) + b = dpnp.asarray(a) + + numpy_out = numpy.empty(out_shape) + dpnp_out = dpnp.empty(out_shape) + result = dpnp.matmul(b, b, out=dpnp_out) + expected = numpy.matmul(a, a, out=numpy_out) + assert result is dpnp_out + assert_dtype_allclose(result, expected) + class TestMatmulInvalidCases: @pytest.mark.parametrize( @@ -2892,6 +2934,27 @@ def test_invalid_shape(self, shape_pair): with pytest.raises(ValueError): xp.matmul(x1, x2) + @pytest.mark.parametrize( + "shape_pair", + [ + ((5, 4, 3), (3, 1), (3, 4, 1)), + ((5, 4, 3), (3, 1), (5, 6, 1)), + ((5, 4, 3), (3, 1), (5, 4, 2)), + ((5, 4, 3), (3,), (5, 3)), + ((5, 4, 3), (3,), (6, 4)), + ((3,), (3, 4, 5), (3, 5)), + ((3,), (3, 4, 5), (4, 6)), + ], + ) + def test_invalid_shape_out(self, shape_pair): + for xp in (numpy, dpnp): + shape1, shape2, out_shape = shape_pair + x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1) + x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2) + res = xp.empty(out_shape) + with pytest.raises(ValueError): + xp.matmul(x1, x2, out=res) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)[:-2]) def test_invalid_dtype(self, dtype): dpnp_dtype = get_all_dtypes(no_none=True)[-1] diff --git a/tests/third_party/cupy/math_tests/test_matmul.py b/tests/third_party/cupy/math_tests/test_matmul.py index 0a3455768a3..d1ebd718a61 100644 --- a/tests/third_party/cupy/math_tests/test_matmul.py +++ b/tests/third_party/cupy/math_tests/test_matmul.py @@ -182,11 +182,11 @@ class TestMatmulLarge(unittest.TestCase): rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() ) # required for uint8 def test_operator_matmul(self, xp, dtype1, dtype2): - if (dtype1, dtype1) in self.skip_dtypes or ( - dtype1, + if (dtype1, dtype2) in self.skip_dtypes or ( + dtype2, dtype1, ) in self.skip_dtypes: - return xp.array([]) + pytest.skip() x1 = testing.shaped_random(self.shape_pair[0], xp, dtype1) x2 = testing.shaped_random(self.shape_pair[1], xp, dtype2) return operator.matmul(x1, x2) @@ -197,11 +197,11 @@ def test_operator_matmul(self, xp, dtype1, dtype2): rtol=1e-3, atol=1e-3, type_check=has_support_aspect64() ) # required for uint8 def test_cupy_matmul(self, xp, dtype1, dtype2): - if (dtype1, dtype1) in self.skip_dtypes or ( - dtype1, + if (dtype1, dtype2) in self.skip_dtypes or ( + dtype2, dtype1, ) in self.skip_dtypes: - return xp.array([]) + pytest.skip() shape1, shape2 = self.shape_pair x1 = testing.shaped_random(shape1, xp, dtype1) x2 = testing.shaped_random(shape2, xp, dtype2) @@ -211,7 +211,7 @@ def test_cupy_matmul(self, xp, dtype1, dtype2): @pytest.mark.parametrize( "shape1, shape2", [ - # TODO: include it when issue #1540 in dpctl is resolved + # the first one causes overflow which is undefined behavior # ((256, 256, 3, 2), (256, 256, 2, 4)), ((256, 256, 3, 2), (2, 4)), ((3, 2), (256, 256, 2, 4)), @@ -233,7 +233,7 @@ def test_cupy_matmul(self, xp, dtype, shape1, shape2): return xp.matmul(x1, x2) -@pytest.mark.skip("until issue #1540 in dpctl is resolved") +@pytest.mark.skip("overflow is undefined behavior.") class TestMatmulOverflow(unittest.TestCase): @testing.for_int_dtypes(name="dtype", no_bool=True) @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 From ad0ba70d211b53513fd983702eb001a7d20c9328 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 13 Feb 2024 12:49:47 -0600 Subject: [PATCH 4/5] fix an error --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 5a4d5a36c1b..57bc7a52346 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -525,16 +525,16 @@ def dpnp_matmul( x2_shape = x2.shape if out is not None: for i in range(x1_ndim - 2): - if x1_shape[i] != out_shape[i]: + if tmp_shape[i] != out_shape[i]: if not appended_axes: raise ValueError( "Output array could not be broadcast together with remapped shapes, " - f"{x1_shape[:-2]} is different from {out_shape[:-2]}." + f"{tmp_shape[:-2]} is different from {out_shape[:-2]}." ) elif len(appended_axes) == 1: raise ValueError( "Output array could not be broadcast together with remapped shapes, " - f"{x1_shape[:-2]} is different from {out_shape[:-1]}." + f"{tmp_shape[:-2]} is different from {out_shape[:-1]}." ) res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) From 5060adc1ff7f3572d512b9df81a37c6beb0bf799 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 13 Feb 2024 15:36:46 -0600 Subject: [PATCH 5/5] use a function for error msg --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 71 ++++++++++----------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 57bc7a52346..bb960b7f4b4 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -209,6 +209,31 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue): return op_dtype, res_dtype +def _shape_error(a, b, core_dim, err_msg): + if err_msg == 0: + raise ValueError( + "Input arrays have a mismatch in their core dimensions. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {a} is different from {b})" + ) + elif err_msg == 1: + raise ValueError( + f"Output array has a mismatch in its core dimension {core_dim}. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {a} is different from {b})" + ) + elif err_msg == 2: + raise ValueError( + "Input arrays could not be broadcast together with remapped shapes, " + f"{a} is different from {b}." + ) + elif err_msg == 3: + raise ValueError( + "Output array could not be broadcast to input arrays with remapped shapes, " + f"{a} is different from {b}." + ) + + def _standardize_strides(strides, inherently_2D, shape, ndim): """ Standardizing the strides. @@ -436,42 +461,22 @@ def dpnp_matmul( x1_shape = x1.shape x2_shape = x2.shape if x1_shape[-1] != x2_shape[-2]: - raise ValueError( - "Input arrays have a mismatch in their core dimensions. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {x1_shape[-1]} is different from {x2_shape[-2]})" - ) + _shape_error(x1_shape[-1], x2_shape[-2], None, 0) if out is not None: out_shape = out.shape if not appended_axes: if out_shape[-2] != x1_shape[-2]: - raise ValueError( - "Output array has a mismatch in its core dimension 0. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {out_shape[-2]} is different from {x1_shape[-2]})" - ) + _shape_error(out_shape[-2], x1_shape[-2], 0, 1) if out_shape[-1] != x2_shape[-1]: - raise ValueError( - "Output array has a mismatch in its core dimension 1. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {out_shape[-1]} is different from {x2_shape[-1]})" - ) + _shape_error(out_shape[-1], x2_shape[-1], 1, 1) elif len(appended_axes) == 1: if appended_axes[0] == -1: if out_shape[-1] != x1_shape[-2]: - raise ValueError( - "Output array has a mismatch in its core dimension 0. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {out_shape[-1]} is different from {x1_shape[-2]})" - ) + _shape_error(out_shape[-1], x1_shape[-2], 0, 1) elif appended_axes[0] == -2: if out_shape[-1] != x2_shape[-1]: - raise ValueError( - "Output array has a mismatch in its core dimension 0. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {out_shape[-1]} is different from {x2_shape[-1]})" - ) + _shape_error(out_shape[-1], x2_shape[-1], 0, 1) # Determine the appropriate data types gemm_dtype, res_dtype = _op_res_dtype( @@ -516,10 +521,7 @@ def dpnp_matmul( if not x2_is_2D: x2 = dpnp.repeat(x2, x1_shape[i], axis=i) else: - raise ValueError( - "Input arrays could not be broadcast together with remapped shapes, " - f"{x1_shape[:-2]} is different from {x2_shape[:-2]}." - ) + _shape_error(x1_shape[:-2], x2_shape[:-2], None, 2) x1_shape = x1.shape x2_shape = x2.shape @@ -527,15 +529,10 @@ def dpnp_matmul( for i in range(x1_ndim - 2): if tmp_shape[i] != out_shape[i]: if not appended_axes: - raise ValueError( - "Output array could not be broadcast together with remapped shapes, " - f"{tmp_shape[:-2]} is different from {out_shape[:-2]}." - ) + _shape_error(tuple(tmp_shape), out_shape[:-2], None, 3) elif len(appended_axes) == 1: - raise ValueError( - "Output array could not be broadcast together with remapped shapes, " - f"{tmp_shape[:-2]} is different from {out_shape[:-1]}." - ) + _shape_error(tuple(tmp_shape), out_shape[:-1], None, 3) + res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) # handling a special case to provide a similar result to NumPy