diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt index d19f60c9792..fe3a92d2181 100644 --- a/dpnp/backend/extensions/blas/CMakeLists.txt +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -1,5 +1,5 @@ # ***************************************************************************** -# Copyright (c) 2016-2023, Intel Corporation +# Copyright (c) 2024, Intel Corporation # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -27,6 +27,8 @@ set(python_module_name _blas_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dot.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dotu.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp ) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index 524f16fcc7d..7d5237381b1 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -30,6 +30,7 @@ #include #include +#include "dot.hpp" #include "gemm.hpp" namespace blas_ext = dpnp::backend::ext::blas; @@ -38,6 +39,8 @@ namespace py = pybind11; // populate dispatch tables void init_dispatch_tables(void) { + blas_ext::init_dot_dispatch_table(); + blas_ext::init_dotu_dispatch_table(); blas_ext::init_gemm_batch_dispatch_table(); blas_ext::init_gemm_dispatch_table(); } @@ -46,6 +49,22 @@ PYBIND11_MODULE(_blas_impl, m) { init_dispatch_tables(); + { + m.def("_dot", &blas_ext::dot, + "Call `dot` from OneMKL LAPACK library to return " + "the dot product of two real-valued vectors.", + py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), + py::arg("result"), py::arg("depends") = py::list()); + } + + { + m.def("_dotu", &blas_ext::dotu, + "Call `dotu` from OneMKL LAPACK library to return " + "the dot product of two complex vectors.", + py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), + py::arg("result"), py::arg("depends") = py::list()); + } + { m.def("_gemm", &blas_ext::gemm, "Call `gemm` from OneMKL LAPACK library to return " diff --git a/dpnp/backend/extensions/blas/dot.cpp b/dpnp/backend/extensions/blas/dot.cpp new file mode 100644 index 00000000000..048738f57a9 --- /dev/null +++ b/dpnp/backend/extensions/blas/dot.cpp @@ -0,0 +1,238 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "dot.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &, + const std::int64_t, + char *, + const std::int64_t, + char *, + const std::int64_t, + char *, + const std::vector &); + +static dot_impl_fn_ptr_t dot_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + +template +static sycl::event dot_impl(sycl::queue &exec_q, + const std::int64_t n, + char *vectorA, + const std::int64_t stride_a, + char *vectorB, + const std::int64_t stride_b, + char *result, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + type_utils::validate_type_for_device(exec_q); + + Tab *a = reinterpret_cast(vectorA); + Tab *b = reinterpret_cast(vectorB); + Tc *res = reinterpret_cast(result); + + std::stringstream error_msg; + bool is_exception_caught = false; + + sycl::event dot_event; + try { + dot_event = mkl_blas::row_major::dot(exec_q, + n, // size of the input vectors + a, // Pointer to vector a. + stride_a, // Stride of vector a. + b, // Pointer to vector b. + stride_b, // Stride of vector b. + res, // Pointer to result. + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg + << "Unexpected MKL exception caught during dot() call:\nreason: " + << e.what(); + is_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during dot() call:\n" + << e.what(); + is_exception_caught = true; + } + + if (is_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + return dot_event; +} + +std::pair dot(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray vectorA, + dpctl::tensor::usm_ndarray vectorB, + dpctl::tensor::usm_ndarray result, + const std::vector &depends) +{ + const int vectorA_nd = vectorA.get_ndim(); + const int vectorB_nd = vectorB.get_ndim(); + const int result_nd = result.get_ndim(); + + if ((vectorA_nd != 1)) { + throw py::value_error( + "The first input array has ndim=" + std::to_string(vectorA_nd) + + ", but a 1-dimensional array is expected."); + } + + if ((vectorB_nd != 1)) { + throw py::value_error( + "The second input array has ndim=" + std::to_string(vectorB_nd) + + ", but a 1-dimensional array is expected."); + } + + if ((result_nd != 0)) { + throw py::value_error( + "The output array has ndim=" + std::to_string(result_nd) + + ", but a 0-dimensional array is expected."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(vectorA, result)) { + throw py::value_error( + "The first input array and output array are overlapping " + "segments of memory"); + } + if (overlap(vectorB, result)) { + throw py::value_error( + "The second input array and output array are overlapping " + "segments of memory"); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible( + exec_q, + {vectorA.get_queue(), vectorB.get_queue(), result.get_queue()})) + { + throw py::value_error( + "USM allocations are not compatible with the execution queue."); + } + + py::ssize_t a_size = vectorA.get_size(); + py::ssize_t b_size = vectorB.get_size(); + if (a_size != b_size) { + throw py::value_error("The size of the first input array must be " + "equal to the size of the second input array."); + } + + std::vector a_stride = vectorA.get_strides_vector(); + std::vector b_stride = vectorB.get_strides_vector(); + + const std::int64_t n = a_size; + const std::int64_t str_a = a_stride[0]; + const std::int64_t str_b = b_stride[0]; + + int vectorA_typenum = vectorA.get_typenum(); + int vectorB_typenum = vectorB.get_typenum(); + int result_typenum = result.get_typenum(); + + if (vectorA_typenum != vectorB_typenum) { + throw py::value_error("vectorA and vectorB must be of the same type."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int vectorAB_type_id = array_types.typenum_to_lookup_id(vectorA_typenum); + int result_type_id = array_types.typenum_to_lookup_id(result_typenum); + + dot_impl_fn_ptr_t dot_fn = + dot_dispatch_table[vectorAB_type_id][result_type_id]; + if (dot_fn == nullptr) { + throw py::value_error( + "Types of input vectors and result array are mismatched."); + } + + char *a_typeless_ptr = vectorA.get_data(); + char *b_typeless_ptr = vectorB.get_data(); + char *r_typeless_ptr = result.get_data(); + + const int a_elemsize = vectorA.get_elemsize(); + const int b_elemsize = vectorB.get_elemsize(); + if (str_a < 0) { + a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize; + } + if (str_b < 0) { + b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize; + } + + sycl::event dot_ev = dot_fn(exec_q, n, a_typeless_ptr, str_a, + b_typeless_ptr, str_b, r_typeless_ptr, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {vectorA, vectorB, result}, {dot_ev}); + + return std::make_pair(args_ev, dot_ev); +} + +template +struct DotContigFactory +{ + fnT get() + { + if constexpr (types::DotTypePairSupportFactory::is_defined) { + return dot_impl; + } + else { + return nullptr; + } + } +}; + +void init_dot_dispatch_table(void) +{ + dpctl_td_ns::DispatchTableBuilder + contig; + contig.populate_dispatch_table(dot_dispatch_table); +} +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/blas/dot.hpp b/dpnp/backend/extensions/blas/dot.hpp new file mode 100644 index 00000000000..3468196f760 --- /dev/null +++ b/dpnp/backend/extensions/blas/dot.hpp @@ -0,0 +1,60 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +extern std::pair + dot(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray vectorA, + dpctl::tensor::usm_ndarray vectorB, + dpctl::tensor::usm_ndarray result, + const std::vector &depends); + +extern std::pair + dotu(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray vectorA, + dpctl::tensor::usm_ndarray vectorB, + dpctl::tensor::usm_ndarray result, + const std::vector &depends); + +extern void init_dot_dispatch_table(void); +extern void init_dotu_dispatch_table(void); +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/blas/dotu.cpp b/dpnp/backend/extensions/blas/dotu.cpp new file mode 100644 index 00000000000..8c4b43f8034 --- /dev/null +++ b/dpnp/backend/extensions/blas/dotu.cpp @@ -0,0 +1,241 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "dot.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*dotu_impl_fn_ptr_t)(sycl::queue &, + const std::int64_t, + char *, + const std::int64_t, + char *, + const std::int64_t, + char *, + const std::vector &); + +static dotu_impl_fn_ptr_t dotu_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + +template +static sycl::event dotu_impl(sycl::queue &exec_q, + const std::int64_t n, + char *vectorA, + const std::int64_t stride_a, + char *vectorB, + const std::int64_t stride_b, + char *result, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + type_utils::validate_type_for_device(exec_q); + + Tab *a = reinterpret_cast(vectorA); + Tab *b = reinterpret_cast(vectorB); + Tc *res = reinterpret_cast(result); + + std::stringstream error_msg; + bool is_exception_caught = false; + + sycl::event dotu_event; + try { + dotu_event = mkl_blas::row_major::dotu(exec_q, + n, // size of the input vectors + a, // Pointer to vector a. + stride_a, // Stride of vector a. + b, // Pointer to vector b. + stride_b, // Stride of vector b. + res, // Pointer to result. + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg + << "Unexpected MKL exception caught during dotu() call:\nreason: " + << e.what(); + is_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during dotu() call:\n" + << e.what(); + is_exception_caught = true; + } + + if (is_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + return dotu_event; +} + +std::pair + dotu(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray vectorA, + dpctl::tensor::usm_ndarray vectorB, + dpctl::tensor::usm_ndarray result, + const std::vector &depends) +{ + const int vectorA_nd = vectorA.get_ndim(); + const int vectorB_nd = vectorB.get_ndim(); + const int result_nd = result.get_ndim(); + + if ((vectorA_nd != 1)) { + throw py::value_error( + "The first input array has ndim=" + std::to_string(vectorA_nd) + + ", but a 1-dimensional array is expected."); + } + + if ((vectorB_nd != 1)) { + throw py::value_error( + "The second input array has ndim=" + std::to_string(vectorB_nd) + + ", but a 1-dimensional array is expected."); + } + + if ((result_nd != 0)) { + throw py::value_error( + "The output array has ndim=" + std::to_string(result_nd) + + ", but a 0-dimensional array is expected."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(vectorA, result)) { + throw py::value_error( + "The first input array and output array are overlapping " + "segments of memory"); + } + if (overlap(vectorB, result)) { + throw py::value_error( + "The second input array and output array are overlapping " + "segments of memory"); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible( + exec_q, + {vectorA.get_queue(), vectorB.get_queue(), result.get_queue()})) + { + throw py::value_error( + "USM allocations are not compatible with the execution queue."); + } + + py::ssize_t a_size = vectorA.get_size(); + py::ssize_t b_size = vectorB.get_size(); + if (a_size != b_size) { + throw py::value_error("The size of the first input array must be " + "equal to the size of the second input array."); + } + + std::vector a_stride = vectorA.get_strides_vector(); + std::vector b_stride = vectorB.get_strides_vector(); + + const std::int64_t n = a_size; + const std::int64_t str_a = a_stride[0]; + const std::int64_t str_b = b_stride[0]; + + int vectorA_typenum = vectorA.get_typenum(); + int vectorB_typenum = vectorB.get_typenum(); + int result_typenum = result.get_typenum(); + + if (vectorA_typenum != vectorB_typenum) { + throw py::value_error( + "Input arrays must be of must be of the same type."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int vectorAB_type_id = array_types.typenum_to_lookup_id(vectorA_typenum); + int result_type_id = array_types.typenum_to_lookup_id(result_typenum); + + dotu_impl_fn_ptr_t dotu_fn = + dotu_dispatch_table[vectorAB_type_id][result_type_id]; + if (dotu_fn == nullptr) { + throw py::value_error( + "Types of input vectors and result array are mismatched."); + } + + char *a_typeless_ptr = vectorA.get_data(); + char *b_typeless_ptr = vectorB.get_data(); + char *r_typeless_ptr = result.get_data(); + + const int a_elemsize = vectorA.get_elemsize(); + const int b_elemsize = vectorB.get_elemsize(); + if (str_a < 0) { + a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize; + } + if (str_b < 0) { + b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize; + } + + sycl::event dotu_ev = + dotu_fn(exec_q, n, a_typeless_ptr, str_a, b_typeless_ptr, str_b, + r_typeless_ptr, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {vectorA, vectorB, result}, {dotu_ev}); + + return std::make_pair(args_ev, dotu_ev); +} + +template +struct DotuContigFactory +{ + fnT get() + { + if constexpr (types::DotuTypePairSupportFactory::is_defined) { + return dotu_impl; + } + else { + return nullptr; + } + } +}; + +void init_dotu_dispatch_table(void) +{ + dpctl_td_ns::DispatchTableBuilder + contig; + contig.populate_dispatch_table(dotu_dispatch_table); +} +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index 5526ecd3c1b..a26420f49b3 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/blas/gemm.hpp b/dpnp/backend/extensions/blas/gemm.hpp index 25f78b5b850..3f1ec6e745a 100644 --- a/dpnp/backend/extensions/blas/gemm.hpp +++ b/dpnp/backend/extensions/blas/gemm.hpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 32f592f6b8a..9359901edd8 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp index 49154df03c4..c36ae0e2045 100644 --- a/dpnp/backend/extensions/blas/types_matrix.hpp +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -43,6 +43,49 @@ namespace blas { namespace types { +/** + * @brief A factory to define pairs of supported types for which + * MKL BLAS library provides support in oneapi::mkl::blas::dot + * function. + * + * @tparam Tab Type of arrays containing input vectors A and B. + * @tparam Tc Type of array containing output. + */ +template +struct DotTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + +/** + * @brief A factory to define pairs of supported types for which + * MKL BLAS library provides support in oneapi::mkl::blas::dotu + * function. + * + * @tparam Tab Type of arrays containing input vectors A and B. + * @tparam Tc Type of array containing output. + */ +template +struct DotuTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + Tc, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + Tc, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL BLAS library provides support in oneapi::mkl::blas::gemm diff --git a/dpnp/backend/kernels/dpnp_krnl_common.cpp b/dpnp/backend/kernels/dpnp_krnl_common.cpp index e664c30b848..04eac54310d 100644 --- a/dpnp/backend/kernels/dpnp_krnl_common.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_common.cpp @@ -1040,6 +1040,7 @@ void func_map_init_linalg(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_dot_default_c}; + // needed for "dpnp_correlate_c" function in dpnp_krnl_statistics.cpp fmap[DPNPFuncName::DPNP_FN_DOT_EXT][eft_INT][eft_INT] = { eft_INT, (void *)dpnp_dot_ext_c}; fmap[DPNPFuncName::DPNP_FN_DOT_EXT][eft_INT][eft_LNG] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 28e21340647..2fc7e1b4a3b 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -54,8 +54,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_DIAG_INDICES_EXT DPNP_FN_DIAGONAL DPNP_FN_DIAGONAL_EXT - DPNP_FN_DOT - DPNP_FN_DOT_EXT DPNP_FN_EDIFF1D DPNP_FN_EDIFF1D_EXT DPNP_FN_EIG @@ -282,11 +280,6 @@ cpdef dpnp_descriptor dpnp_isclose(dpnp_descriptor input1, dpnp_descriptor input double rtol=*, double atol=*, cpp_bool equal_nan=*) -""" -Linear algebra -""" -cpdef dpnp_descriptor dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2) - """ Array creation routines """ diff --git a/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi b/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi index 9b4faf2a1b5..09336b5aaa3 100644 --- a/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi @@ -36,7 +36,6 @@ and the rest of the library # NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file __all__ += [ - "dpnp_dot", "dpnp_inner", "dpnp_kron", ] @@ -47,105 +46,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_shapes_t)(c_dpctl.DPCTLSyclQue void *, void * , void * , shape_elem_type * , shape_elem_type *, shape_elem_type * , size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_dot_t)(c_dpctl.DPCTLSyclQueueRef, - void * , const size_t, const size_t, - const shape_elem_type *, const shape_elem_type * , - void * , const size_t, const size_t, - const shape_elem_type *, const shape_elem_type * , - void * , const size_t, const size_t, - const shape_elem_type *, const shape_elem_type * , - const c_dpctl.DPCTLEventVectorRef) except + - -cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, - utils.dpnp_descriptor in_array2, - utils.dpnp_descriptor out=None): - cdef shape_type_c shape1, shape2 - - shape1 = in_array1.shape - shape2 = in_array2.shape - - # convert string type names (array.dtype) to C enum DPNPFuncType - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype) - cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype) - - # get the FPTR data structure - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DOT_EXT, param1_type, param2_type) - cdef utils.dpnp_descriptor result - - ndim1 = in_array1.ndim - ndim2 = in_array2.ndim - cdef shape_type_c result_shape - if ndim1 == 0: - result_shape = shape2 - elif ndim2 == 0: - result_shape = shape1 - elif ndim1 == 1 and ndim2 == 1: - result_shape = () - elif ndim1 == 1: # ndim2 > 1 - result_shape = shape2[::-2] if ndim2 == 2 else shape2[::2] - elif ndim2 == 1: # ndim1 > 1 - result_shape = shape1[:-1] - else: - if ndim1 == 1: - shape1 = (1, shape1[0]) - if ndim2 == 1: - shape2 = (shape1[0], 1) - result_shape = shape1[:-1] + shape2[:-2] + shape2[-1:] - - result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(in_array1, in_array2) - - if out is None: - # create result array with type given by FPTR data - result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, - None, - device=result_sycl_device, - usm_type=result_usm_type, - sycl_queue=result_sycl_queue) - else: - result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type) - if out.dtype != result_type: - utils.checker_throw_value_error('dot', 'out.dtype', out.dtype, result_type) - if out.shape != result_shape: - utils.checker_throw_value_error('dot', 'out.shape', out.shape, result_shape) - - result = out - - utils.get_common_usm_allocation(in_array1, result) # check USM allocation is common - - cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result.shape) - cdef shape_type_c in_array1_shape = in_array1.shape - cdef shape_type_c in_array1_strides = utils.strides_to_vector(in_array1.strides, in_array1.shape) - cdef shape_type_c in_array2_shape = in_array2.shape - cdef shape_type_c in_array2_strides = utils.strides_to_vector(in_array2.strides, in_array2.shape) - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef fptr_2in_1out_dot_t func = kernel_data.ptr - # call FPTR function - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - result.get_data(), - result.size, - result.ndim, - result_shape.data(), - result_strides.data(), - in_array1.get_data(), - in_array1.size, - in_array1.ndim, - in_array1_shape.data(), - in_array1_strides.data(), - in_array2.get_data(), - in_array2.size, - in_array2.ndim, - in_array2_shape.data(), - in_array2_strides.data(), - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result cpdef utils.dpnp_descriptor dpnp_inner(dpnp_descriptor array1, dpnp_descriptor array2): diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index cf848b50690..b5e75dde07c 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -704,8 +704,29 @@ def diagonal(input, offset=0, axis1=0, axis2=1): return dpnp.diagonal(input, offset, axis1, axis2) - def dot(self, other, out=None): - return dpnp.dot(self, other, out) + def dot(self, b, out=None): + """ + Dot product of two arrays. + + For full documentation refer to :obj:`dpnp.dot`. + + Examples + -------- + >>> import dpnp as np + >>> a = np.eye(2) + >>> b = np.ones((2, 2)) * 2 + >>> a.dot(b) + array([[2., 2.], + [2., 2.]]) + + This array method can be conveniently chained: + + >>> a.dot(b).dot(b) + array([[8., 8.], + [8., 8.]]) + """ + + return dpnp.dot(self, b, out) @property def dtype(self): diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index d39b84a50ec..9d63f7f8c3d 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -38,13 +38,12 @@ """ -import dpctl.tensor as dpt import numpy import dpnp from dpnp.dpnp_algo import * from dpnp.dpnp_utils import * -from dpnp.dpnp_utils.dpnp_utils_linearalgebra import dpnp_matmul +from dpnp.dpnp_utils.dpnp_utils_linearalgebra import dpnp_dot, dpnp_matmul __all__ = [ "dot", @@ -59,87 +58,99 @@ ] -def dot(x1, x2, out=None, **kwargs): +def dot(a, b, out=None): """ - Dot product of `x1` and `x2`. + Dot product of `a` and `b`. For full documentation refer to :obj:`numpy.dot`. + Parameters + ---------- + a : {dpnp_array, usm_ndarray, scalar} + First input array. Both inputs `a` and `b` can not be scalars at the same time. + b : {dpnp_array, usm_ndarray, scalar} + Second input array. Both inputs `a` and `b` can not be scalars at the same time. + out : {dpnp.ndarray, usm_ndarray}, optional + Alternative output array in which to place the result. It must have + the same shape and data type as the expected output and should be + C-contiguous. If these conditions are not met, an exception is + raised, instead of attempting to be flexible. + Returns ------- - y : dpnp.ndarray - Returns the dot product of `x1` and `x2`. + out : dpnp.ndarray + Returns the dot product of `a` and `b`. If `out` is given, then it is returned. - Limitations - ----------- - Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time. - Keyword argument ``kwargs`` is currently unsupported. - Otherwise the functions will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. - See Also -------- + :obj:`dpnp.ndarray.dot` : Equivalent method. :obj:`dpnp.tensordot` : Sum products over arbitrary axes. :obj:`dpnp.vdot` : Complex-conjugating dot product. + :obj:`dpnp.einsum` : Einstein summation convention. + :obj:`dpnp.matmul` : Matrix product of two arrays. + :obj:`dpnp.linalg.multi_dot` : Chained dot product. Examples -------- - >>> import dpnp as dp - >>> a = dp.array([1, 2, 3]) - >>> b = dp.array([1, 2, 3]) - >>> dp.dot(a, b) - 14 + >>> import dpnp as np + >>> a = np.array([1, 2, 3]) + >>> b = np.array([1, 2, 3]) + >>> np.dot(a, b) + array(14) + + Neither argument is complex-conjugated: + + >>> np.dot(np.array([2j, 3j]), np.array([2j, 3j])) + array(-13+0j) + + For 2-D arrays it is the matrix product: + + >>> a = np.array([[1, 0], [0, 1]]) + >>> b = np.array([[4, 1], [2, 2]]) + >>> np.dot(a, b) + array([[4, 1], + [2, 2]]) + + >>> a = np.arange(3*4*5*6).reshape((3,4,5,6)) + >>> b = np.arange(3*4*5*6)[::-1].reshape((5,4,6,3)) + >>> np.dot(a, b)[2,3,2,1,2,2] + array(499128) + >>> sum(a[2,3,2,:] * b[1,2,:,2]) + array(499128) """ - if kwargs: - pass - elif dpnp.isscalar(x1) and dpnp.isscalar(x2): - # at least either x1 or x2 has to be an array - pass + dpnp.check_supported_arrays_type(a, scalar_type=True) + dpnp.check_supported_arrays_type(b, scalar_type=True) + + if out is not None: + dpnp.check_supported_arrays_type(out) + if not out.flags.c_contiguous: + raise ValueError("Only C-contiguous array is acceptable.") + + if dpnp.isscalar(a) or dpnp.isscalar(b): + # TODO: investigate usage of axpy (axpy_batch) or scal + # functions from BLAS here instead of dpnp.multiply + return dpnp.multiply(a, b, out=out) + elif a.ndim == 0 or b.ndim == 0: + # TODO: investigate usage of axpy (axpy_batch) or scal + # functions from BLAS here instead of dpnp.multiply + return dpnp.multiply(a, b, out=out) + elif a.ndim == 1 and b.ndim == 1: + return dpnp_dot(a, b, out=out) + elif a.ndim == 2 and b.ndim == 2: + # NumPy does not allow casting even if it is safe + return dpnp.matmul(a, b, out=out, casting="no") + elif a.ndim == 1 or b.ndim == 1: + # NumPy does not allow casting even if it is safe + return dpnp.matmul(a, b, out=out, casting="no") else: - # get USM type and queue to copy scalar from the host memory into a USM allocation - usm_type, queue = ( - get_usm_allocations([x1, x2]) - if dpnp.isscalar(x1) or dpnp.isscalar(x2) - else (None, None) - ) - - x1_desc = dpnp.get_dpnp_descriptor( - x1, - copy_when_strides=False, - copy_when_nondefault_queue=False, - alloc_usm_type=usm_type, - alloc_queue=queue, - ) - x2_desc = dpnp.get_dpnp_descriptor( - x2, - copy_when_strides=False, - copy_when_nondefault_queue=False, - alloc_usm_type=usm_type, - alloc_queue=queue, - ) - if x1_desc and x2_desc: - if out is not None: - if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)): - raise TypeError( - "return array must be of supported array type" - ) - out_desc = ( - dpnp.get_dpnp_descriptor( - out, - copy_when_strides=False, - copy_when_nondefault_queue=False, - ) - or None - ) - else: - out_desc = None - return dpnp_dot(x1_desc, x2_desc, out=out_desc).get_pyobj() - - return call_origin(numpy.dot, x1, x2, out=out, **kwargs) + # TODO: investigate usage of matmul for some possible + # use cases instead of dpnp.tensordot + result = dpnp.tensordot(a, b, axes=(-1, -2)) + # NumPy does not allow casting even if it is safe + return dpnp.get_result_array(result, out, casting="no") def einsum(*args, **kwargs): diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index d0add55eee3..65d97befa98 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -24,69 +24,46 @@ # ***************************************************************************** import dpctl +import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import numpy import dpnp import dpnp.backend.extensions.blas._blas_impl as bi +from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import get_usm_allocations -__all__ = ["dpnp_matmul"] +__all__ = ["dpnp_dot", "dpnp_matmul"] -def _gemm_res_dtype(*arrays, dtype, casting, sycl_queue): +def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): """ - Determines the output array data type and the intermediate data type. - - If dtype is ``None``, the output array data type is determined based on - the Promotion Type Rule and device capabilities. Otherwise, `dtype` is - used as output array dtype if input arrays can cast to it according to - the casting rule determined. If casting cannot be done, a ``TypeError`` - is raised. - The intermediate data type is the data type used for performing matmul - operation calculations. If output array dtype is a floating-point data type, - it is also used for the intermediate data type. If output array dtype is an - integral data type, the default floating point data type of the device where - input arrays are allocated on are used for intermediate data type. + Creating a copy of input array if needed. - Parameters - ---------- - arrays : {dpnp.ndarray, usm_ndarray} - Input arrays. - dtype : dtype - If not ``None``, data type of the output array. - casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional - Controls what kind of data casting may occur. - sycl_queue : {SyclQueue} - A SYCL queue to use for determining default floating point datat type. - - Returns - ------- - gemm_dtype, res_dtype : - `gemm_dtype` is the data type used in performing matmul calculations. - The input arrays of matmul function are cast to `gemm_dtype` and then - the calculations are performed. - `res_dtype` is the output data type. When the result is obtained, it is cast - to `res_dtype`. + If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. + In this case, the copy array has the input array data type unless `dtype` is + determined. + If `contig_copy` is ``False`` and input array data type is different than `dtype`, + a C-contiguous copy of input array with specified `dtype` is returned. """ - res_dtype = dpnp.result_type(*arrays) - default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) - - if dtype is not None: - if dpnp.can_cast(res_dtype, dtype, casting=casting): - res_dtype = dtype - else: - raise TypeError( - f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" - ) - - gemm_dtype = ( - res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype - ) + if contig_copy: + copy = contig_copy + else: + copy = x.dtype != dtype if dtype is not None else False - return gemm_dtype, res_dtype + if copy: + x_copy = dpnp.empty_like(x, dtype=dtype, order="C") + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=dpnp.get_usm_ndarray(x), + dst=x_copy.get_array(), + sycl_queue=x.sycl_queue, + ) + dep_events.append(copy_ev) + host_events.append(ht_copy_ev) + return x_copy + return x def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): @@ -95,8 +72,10 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): # when the input array is F-contiguous, the data of 2D array # that needs to be called in mkl::gemm_batch are not contiguous. ht_tasks_list = [] - x1 = _get_gemm_contig_array(x1, dev_tasks_list, ht_tasks_list) - x2 = _get_gemm_contig_array(x2, dev_tasks_list, ht_tasks_list) + contig_copy = not x1.flags.c_contiguous + x1 = _copy_array(x1, dev_tasks_list, ht_tasks_list, contig_copy=contig_copy) + contig_copy = not x2.flags.c_contiguous + x2 = _copy_array(x2, dev_tasks_list, ht_tasks_list, contig_copy=contig_copy) x1_strides = x1.strides x2_strides = x2.strides @@ -149,41 +128,133 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): return ht_blas_ev, ht_tasks_list, res -def _get_gemm_contig_array(x, dep_events, host_events, dtype=None): +def _op_res_dtype(*arrays, dtype, casting, sycl_queue): + """ + _op_res_dtype(*arrays, dtype, casting, sycl_queue) + + Determines the output array data type and an intermediate data type + used in performing calculations related to a specific math function. + If dtype is ``None``, the output array data type of the operation is + determined based on the Promotion Type Rule and device capabilities. + Otherwise, `dtype` is used as output array dtype, if input arrays + can cast to it according to the casting rule determined. If casting + cannot be done, a ``TypeError`` is raised. + The intermediate data type is the data type used for performing the math + function calculations. If output array dtype is a floating-point data type, + it is also used for the intermediate data type. If output array dtype is an + integral data type, the default floating point data type of the device where + input arrays are allocated on are used for intermediate data type. + + Parameters + ---------- + arrays : {dpnp.ndarray, usm_ndarray} + Input arrays. + dtype : dtype + If not ``None``, data type of the output array. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + sycl_queue : {SyclQueue} + A SYCL queue to use for determining default floating point datat type. + + Returns + ------- + op_dtype, res_dtype : + `op_dtype` is the data type used in performing math function calculations. + The input arrays of the math function are cast to `op_dtype` and then + the calculations are performed. + `res_dtype` is the output data type. When the result is obtained, it is cast + to `res_dtype`. + """ - Creating a copy of input array if needed. - This function has two use cases. In the first use case, which is more general, - if the input array is not c-contiguous or f-contiguous, we ensure it becomes - c-contiguous. Additionally, if the input array has an integral dtype, we - convert it to an appropriate floating-point data type specified by `dtype`. - In the second use case, which is for N-dimensional arrays with N>2, we need - to ensure c-contiguity. This is crucial because the implementation of the - `gemm_batch` function in dpnp only works for C-contiguous arrays. This use case - is essential when the input array is f-contiguous with floating point dtype for - which the array is not modified in the first use case. + res_dtype = dpnp.result_type(*arrays) + default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) + if dtype is not None: + if dpnp.can_cast(res_dtype, dtype, casting=casting): + res_dtype = dtype + else: + raise TypeError( + f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" + ) + + op_dtype = ( + res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype + ) + + return op_dtype, res_dtype + + +def dpnp_dot(a, b, /, out=None): """ + Return the dot product of two arrays. - if dtype is None: - copy = not x.flags.c_contiguous - else: - copy = ( - not (x.flags.c_contiguous or x.flags.f_contiguous) - or x.dtype != dtype - ) + The routine that is used to perform the main calculation + depends on input array data types: 1) For integer and boolean data types, + `dpctl.tensor.vecdot` form the Data Parallel Control library is used, + 2) For floating point real-valued data types, `dot` routines from + BLAS library of OneMKL is used, and 3) For complex data types, + `dotu` routines from BLAS library of OneMKL is used. - if copy: - x_copy = dpnp.empty_like(x, dtype=dtype, order="C") - ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=dpnp.get_usm_ndarray(x), - dst=x_copy.get_array(), - sycl_queue=x.sycl_queue, + """ + + if a.size != b.size: + raise ValueError( + "Input arrays have a mismatch in their size. " + f"(size {a.size} is different from {b.size})" ) - dep_events.append(copy_ev) - host_events.append(ht_copy_ev) - return x_copy - return x + + res_usm_type, exec_q = get_usm_allocations([a, b]) + + # Determine the appropriate data types + # casting is irrelevant here since dtype is `None` + dot_dtype, res_dtype = _op_res_dtype( + a, b, dtype=None, casting="no", sycl_queue=exec_q + ) + + # create result array + result = dpnp.empty( + (), + dtype=dot_dtype, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + + # input arrays should have the proper data type + dep_events_list = [] + host_tasks_list = [] + if dpnp.issubdtype(res_dtype, dpnp.inexact): + # copying is needed if dtypes of input arrays are different + a = _copy_array(a, dep_events_list, host_tasks_list, dtype=dot_dtype) + b = _copy_array(b, dep_events_list, host_tasks_list, dtype=dot_dtype) + if dpnp.issubdtype(res_dtype, dpnp.complexfloating): + ht_ev, _ = bi._dotu( + exec_q, + dpnp.get_usm_ndarray(a), + dpnp.get_usm_ndarray(b), + dpnp.get_usm_ndarray(result), + dep_events_list, + ) + else: + ht_ev, _ = bi._dot( + exec_q, + dpnp.get_usm_ndarray(a), + dpnp.get_usm_ndarray(b), + dpnp.get_usm_ndarray(result), + dep_events_list, + ) + host_tasks_list.append(ht_ev) + dpctl.SyclEvent.wait_for(host_tasks_list) + else: + dpt_a = dpnp.get_usm_ndarray(a) + dpt_b = dpnp.get_usm_ndarray(b) + result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b)) + + if dot_dtype != res_dtype: + result = result.astype(res_dtype, copy=False) + + # NumPy does not allow casting even if it is safe + return dpnp.get_result_array(result, out, casting="no") def dpnp_matmul( @@ -197,8 +268,6 @@ def dpnp_matmul( dtype=None, ): """ - dpnp_matmul(x1, x2, out=None, casting="same_kind", order="K", dtype=None) - Return the matrix product of two arrays. The main calculation is done by calling an extension function @@ -222,14 +291,16 @@ def dpnp_matmul( res_usm_type, exec_q = get_usm_allocations([x1, x2]) - squeeze_flag = x1_ndim == 1 or x2_ndim == 1 + appended_axes = [] if x1_ndim == 1: x1 = x1[dpnp.newaxis, :] x1_ndim = x1.ndim + appended_axes.append(-2) if x2_ndim == 1: x2 = x2[:, dpnp.newaxis] x2_ndim = x2.ndim + appended_axes.append(-1) x1_shape = x1.shape x2_shape = x2.shape @@ -241,7 +312,7 @@ def dpnp_matmul( ) # Determine the appropriate data types - gemm_dtype, res_dtype = _gemm_res_dtype( + gemm_dtype, res_dtype = _op_res_dtype( x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q ) @@ -306,13 +377,28 @@ def dpnp_matmul( # and be C_CONTIGUOUS or F_CONTIGUOUS dep_events_list = [] host_tasks_list = [] - x1 = _get_gemm_contig_array( - x1, dep_events_list, host_tasks_list, gemm_dtype + contig_copy = not (x1.flags.c_contiguous or x1.flags.f_contiguous) + x1 = _copy_array( + x1, + dep_events_list, + host_tasks_list, + contig_copy=contig_copy, + dtype=gemm_dtype, ) - x2 = _get_gemm_contig_array( - x2, dep_events_list, host_tasks_list, gemm_dtype + contig_copy = not (x2.flags.c_contiguous or x2.flags.f_contiguous) + x2 = _copy_array( + x2, + dep_events_list, + host_tasks_list, + contig_copy=contig_copy, + dtype=gemm_dtype, ) + # TODO: investigate usage of gemv (gemv_batch) function + # from BLAS when one of the inputs is a vector to + # gain performance. + # TODO: investigate usage of syrk function from BLAS in + # case of a.T @ a and a @ a.T to gain performance. if x1_is_2D and x2_is_2D: ht_blas_ev, _ = bi._gemm( exec_q, @@ -340,8 +426,8 @@ def dpnp_matmul( host_tasks_list.append(ht_blas_ev) dpctl.SyclEvent.wait_for(host_tasks_list) - if squeeze_flag: - result = dpnp.squeeze(result) + if appended_axes: + result = dpnp.squeeze(result, tuple(appended_axes)) if x1_is_2D and x2_is_2D: # add new axes only if one of the input arrays diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 018255c1e40..f91a4f23289 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -331,13 +331,12 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWith tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_dim_mismatch3 tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1 tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_too_many_dims3 + tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_dot_vec2 tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_vdot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index fe3671ecf7f..c3464096085 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -151,8 +151,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumError::test_too_ma tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_dim_mismatch3 tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_too_many_dims3 -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot - tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsMultivariateNormal_param_0_{d=2, shape=(4, 3, 2)}::test_normal tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsMultivariateNormal_param_1_{d=2, shape=(3, 2)}::test_normal tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsMultivariateNormal_param_2_{d=4, shape=(4, 3, 2)}::test_normal @@ -435,17 +433,17 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumLarge_param_9_{opt tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_float tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1 + tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_dot_vec2 tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_vdot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot +tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal diff --git a/tests/skipped_tests_gpu_no_fp64.tbl b/tests/skipped_tests_gpu_no_fp64.tbl index 26e11a70062..d724a6043e5 100644 --- a/tests/skipped_tests_gpu_no_fp64.tbl +++ b/tests/skipped_tests_gpu_no_fp64.tbl @@ -30,91 +30,6 @@ tests/test_umath.py::test_umaths[('floor_divide', 'ff')] tests/third_party/cupy/linalg_tests/test_eigenvalue.py::TestEigenvalue_param_0_{UPLO='U'}::test_eigh_batched tests/third_party/cupy/linalg_tests/test_eigenvalue.py::TestEigenvalue_param_1_{UPLO='L'}::test_eigh_batched -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_0_{shape=((2, 3, 4), (3, 4, 2)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_1_{shape=((2, 3, 4), (3, 4, 2)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_2_{shape=((2, 3, 4), (3, 4, 2)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_3_{shape=((2, 3, 4), (3, 4, 2)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_4_{shape=((1, 1), (1, 1)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_5_{shape=((1, 1), (1, 1)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_6_{shape=((1, 1), (1, 1)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_7_{shape=((1, 1), (1, 1)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_8_{shape=((1, 1), (1, 2)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_9_{shape=((1, 1), (1, 2)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_10_{shape=((1, 1), (1, 2)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_11_{shape=((1, 1), (1, 2)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_12_{shape=((1, 2), (2, 1)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_13_{shape=((1, 2), (2, 1)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_14_{shape=((1, 2), (2, 1)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_15_{shape=((1, 2), (2, 1)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_16_{shape=((2, 1), (1, 1)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_17_{shape=((2, 1), (1, 1)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_18_{shape=((2, 1), (1, 1)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_19_{shape=((2, 1), (1, 1)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_20_{shape=((1, 2), (2, 3)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_21_{shape=((1, 2), (2, 3)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_22_{shape=((1, 2), (2, 3)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_23_{shape=((1, 2), (2, 3)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_24_{shape=((2, 1), (1, 3)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_25_{shape=((2, 1), (1, 3)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_26_{shape=((2, 1), (1, 3)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_27_{shape=((2, 1), (1, 3)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_28_{shape=((2, 3), (3, 1)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_29_{shape=((2, 3), (3, 1)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_30_{shape=((2, 3), (3, 1)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_31_{shape=((2, 3), (3, 1)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_32_{shape=((2, 3), (3, 4)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_33_{shape=((2, 3), (3, 4)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_34_{shape=((2, 3), (3, 4)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_35_{shape=((2, 3), (3, 4)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_36_{shape=((0, 3), (3, 4)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_37_{shape=((0, 3), (3, 4)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_38_{shape=((0, 3), (3, 4)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_39_{shape=((0, 3), (3, 4)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_40_{shape=((2, 3), (3, 0)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_41_{shape=((2, 3), (3, 0)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_42_{shape=((2, 3), (3, 0)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_43_{shape=((2, 3), (3, 0)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_44_{shape=((0, 3), (3, 0)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_45_{shape=((0, 3), (3, 0)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_46_{shape=((0, 3), (3, 0)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_47_{shape=((0, 3), (3, 0)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_48_{shape=((3, 0), (0, 4)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_49_{shape=((3, 0), (0, 4)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_50_{shape=((3, 0), (0, 4)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_51_{shape=((3, 0), (0, 4)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_52_{shape=((2, 3, 0), (3, 0, 2)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_53_{shape=((2, 3, 0), (3, 0, 2)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_54_{shape=((2, 3, 0), (3, 0, 2)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_55_{shape=((2, 3, 0), (3, 0, 2)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_56_{shape=((0, 0), (0, 0)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_57_{shape=((0, 0), (0, 0)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_58_{shape=((0, 0), (0, 0)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_59_{shape=((0, 0), (0, 0)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_60_{shape=((3,), (3,)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_61_{shape=((3,), (3,)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_62_{shape=((3,), (3,)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_63_{shape=((3,), (3,)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_64_{shape=((2,), (2, 4)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_65_{shape=((2,), (2, 4)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_66_{shape=((2,), (2, 4)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_67_{shape=((2,), (2, 4)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_68_{shape=((4, 2), (2,)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_69_{shape=((4, 2), (2,)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_70_{shape=((4, 2), (2,)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_71_{shape=((4, 2), (2,)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_0_{shape=((), ()), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_1_{shape=((), ()), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_2_{shape=((), ()), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_3_{shape=((), ()), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_4_{shape=((), (2, 4)), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_5_{shape=((), (2, 4)), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_6_{shape=((), (2, 4)), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_7_{shape=((), (2, 4)), trans_a=False, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_8_{shape=((4, 2), ()), trans_a=True, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_9_{shape=((4, 2), ()), trans_a=True, trans_b=False}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_10_{shape=((4, 2), ()), trans_a=False, trans_b=True}::test_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_11_{shape=((4, 2), ()), trans_a=False, trans_b=False}::test_dot - tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_6_{a_shape=(3, 2), b_shape=(3, 2), shape=(4, 3, 2)}::test_beta tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_7_{a_shape=(3, 2), b_shape=(3, 2), shape=(3, 2)}::test_beta tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsChisquare_param_0_{df_shape=(), shape=(4, 3, 2)}::test_chisquare diff --git a/tests/test_dot.py b/tests/test_dot.py index 80da5090e1b..55884b00cd3 100644 --- a/tests/test_dot.py +++ b/tests/test_dot.py @@ -1,52 +1,373 @@ +import dpctl import numpy import pytest from numpy.testing import assert_allclose, assert_array_equal -import dpnp as inp +import dpnp -from .helper import get_all_dtypes +from .helper import assert_dtype_allclose, get_all_dtypes, get_complex_dtypes -@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) -def test_dot_ones(type): - n = 10**5 - a = numpy.ones(n, dtype=type) - b = numpy.ones(n, dtype=type) - ia = inp.array(a) - ib = inp.array(b) - - result = inp.dot(ia, ib) - expected = numpy.dot(a, b) - assert_array_equal(expected, result) +class Testdot: + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_dot_ones(self, dtype): + n = 10**5 + a = numpy.ones(n, dtype=dtype) + b = numpy.ones(n, dtype=dtype) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.dot(ia, ib) + expected = numpy.dot(a, b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_dot_arange(self, dtype): + n = 10**2 + m = 10**3 if dtype is not dpnp.float32 else 10**2 + a = numpy.hstack((numpy.arange(n, dtype=dtype),) * m) + b = numpy.flipud(a) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.dot(ia, ib) + expected = numpy.dot(a, b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_dot_scalar(self, dtype): + a = 2 + b = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype) + ib = dpnp.array(b) + + result = dpnp.dot(a, ib) + expected = numpy.dot(a, b) + assert_allclose(result, expected) + + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") + @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize( + "array_info", + [ + (1, 10, (), (10,)), + (10, 1, (10,), ()), + (1, 1, (), ()), + (10, 10, (10,), (10,)), + (12, 6, (4, 3), (3, 2)), + (12, 3, (4, 3), (3,)), + (60, 3, (5, 4, 3), (3,)), + (4, 8, (4,), (4, 2)), + (60, 48, (5, 3, 4), (6, 4, 2)), + ], + ids=[ + "0d_1d", + "1d_0d", + "0d_0d", + "1d_1d", + "2d_2d", + "2d_1d", + "3d_1d", + "1d_2d", + "3d_3d", + ], + ) + def test_dot(self, dtype, array_info): + size1, size2, shape1, shape2 = array_info + a = numpy.array( + numpy.random.uniform(-5, 5, size1), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, size2), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.dot(ia, ib) + expected = numpy.dot(a, b) + assert_dtype_allclose(result, expected) + + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize( + "array_info", + [ + (1, 10, (), (10,)), + (10, 1, (10,), ()), + (1, 1, (), ()), + (10, 10, (10,), (10,)), + (12, 6, (4, 3), (3, 2)), + (12, 3, (4, 3), (3,)), + (60, 3, (5, 4, 3), (3,)), + (4, 8, (4,), (4, 2)), + (60, 48, (5, 3, 4), (6, 4, 2)), + ], + ids=[ + "0d_1d", + "1d_0d", + "0d_0d", + "1d_1d", + "2d_2d", + "2d_1d", + "3d_1d", + "1d_2d", + "3d_3d", + ], + ) + def test_dot_complex(self, dtype, array_info): + size1, size2, shape1, shape2 = array_info + x11 = numpy.random.uniform(-5, 5, size1) + x12 = numpy.random.uniform(-5, 5, size1) + x21 = numpy.random.uniform(-5, 5, size2) + x22 = numpy.random.uniform(-5, 5, size2) + a = numpy.array(x11 + 1j * x12, dtype=dtype).reshape(shape1) + b = numpy.array(x21 + 1j * x22, dtype=dtype).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.dot(ia, ib) + expected = numpy.dot(a, b) + assert_dtype_allclose(result, expected) + + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") + @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize( + "array_info", + [ + (1, 10, (), (10,)), + (10, 1, (10,), ()), + (1, 1, (), ()), + (10, 10, (10,), (10,)), + (12, 6, (4, 3), (3, 2)), + (12, 3, (4, 3), (3,)), + (60, 3, (5, 4, 3), (3,)), + (4, 8, (4,), (4, 2)), + (60, 48, (5, 3, 4), (6, 4, 2)), + ], + ids=[ + "0d_1d", + "1d_0d", + "0d_0d", + "1d_1d", + "2d_2d", + "2d_1d", + "3d_1d", + "1d_2d", + "3d_3d", + ], + ) + def test_dot_ndarray(self, dtype, array_info): + size1, size2, shape1, shape2 = array_info + a = numpy.array( + numpy.random.uniform(-5, 5, size1), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, size2), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = ia.dot(ib) + expected = a.dot(b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_dot_strided(self, dtype): + a = numpy.arange(25, dtype=dtype) + b = numpy.arange(25, dtype=dtype) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.dot(ia[::3], ib[::3]) + expected = numpy.dot(a[::3], b[::3]) + assert_dtype_allclose(result, expected) + + result = dpnp.dot(ia, ib[::-1]) + expected = numpy.dot(a, b[::-1]) + assert_dtype_allclose(result, expected) + + result = dpnp.dot(ia[::-2], ib[::-2]) + expected = numpy.dot(a[::-2], b[::-2]) + assert_dtype_allclose(result, expected) + + result = dpnp.dot(ia[::-5], ib[::-5]) + expected = numpy.dot(a[::-5], b[::-5]) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_dot_out_scalar(self, dtype): + size = 10 + a = 2 + b = numpy.array(numpy.random.uniform(-5, 5, size), dtype=dtype) + ia = 2 + ib = dpnp.array(b) + + dp_out = dpnp.empty((size,), dtype=dtype) + result = dpnp.dot(ia, ib, out=dp_out) + expected = numpy.dot(a, b) + + assert result is dp_out + assert_allclose(result, expected) + + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") + @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize( + "array_info", + [ + (1, 10, (), (10,), (10,)), + (10, 1, (10,), (), (10,)), + (1, 1, (), (), ()), + (10, 10, (10,), (10,), ()), + (12, 6, (4, 3), (3, 2), (4, 2)), + (12, 3, (4, 3), (3,), (4,)), + (60, 3, (5, 4, 3), (3,), (5, 4)), + (4, 8, (4,), (4, 2), (2,)), + (60, 48, (5, 3, 4), (6, 4, 2), (5, 3, 6, 2)), + ], + ids=[ + "0d_1d", + "1d_0d", + "0d_0d", + "1d_1d", + "2d_2d", + "2d_1d", + "3d_1d", + "1d_2d", + "3d_3d", + ], + ) + def test_dot_out(self, dtype, array_info): + size1, size2, shape1, shape2, out_shape = array_info + a = numpy.array( + numpy.random.uniform(-5, 5, size1), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, size2), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + dp_out = dpnp.empty(out_shape, dtype=dtype) + result = dpnp.dot(ia, ib, out=dp_out) + expected = numpy.dot(a, b) + + assert result is dp_out + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype1", get_all_dtypes()) + @pytest.mark.parametrize("dtype2", get_all_dtypes()) + def test_dot_input_dtype_matrix(self, dtype1, dtype2): + a = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype1) + b = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.dot(ia, ib) + expected = numpy.dot(a, b) + assert_dtype_allclose(result, expected) + + def test_dot_1d_error(self): + a = dpnp.ones(25) + b = dpnp.ones(24) + # size of input arrays differ + with pytest.raises(ValueError): + dpnp.dot(a, b) + + def test_dot_sycl_queue_error(self): + a = dpnp.ones((5,), sycl_queue=dpctl.SyclQueue()) + b = dpnp.ones((5,), sycl_queue=dpctl.SyclQueue()) + with pytest.raises(ValueError): + dpnp.dot(a, b) + + # NumPy does not raise an error for the following test. + # it just does not update the out keyword if it as not properly defined + @pytest.mark.parametrize("ia", [1, dpnp.ones((), dtype=dpnp.int32)]) + def test_dot_out_error_scalar(self, ia): + ib = dpnp.ones(10, dtype=dpnp.int32) + + # output data type is incorrect + dp_out = dpnp.empty((10,), dtype=dpnp.int64) + # TODO: change it to ValueError, when updated + # dpctl is being used in internal CI + with pytest.raises((ValueError, TypeError)): + dpnp.dot(ia, ib, out=dp_out) + + # output shape is incorrect + dp_out = dpnp.empty((2,), dtype=dpnp.int32) + # TODO: change it to ValueError, when updated + # dpctl is being used in internal CI + with pytest.raises((ValueError, TypeError)): + dpnp.dot(ia, ib, out=dp_out) + + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") + @pytest.mark.parametrize( + "shape_pair", + [ + ((10,), (10,), ()), + ((3, 4), (4, 2), (3, 2)), + ((3, 4), (4,), (3,)), + ((5, 4, 3), (3,), (5, 4)), + ((4,), (3, 4, 2), (3, 2)), + ((5, 3, 4), (6, 4, 2), (5, 3, 6, 2)), + ], + ids=["1d_1d", "2d_2d", "2d_1d", "3d_1d", "1d_3d", "3d_3d"], + ) + def test_dot_out_error(self, shape_pair): + shape1, shape2, shape_out = shape_pair + a = numpy.ones(shape1, dtype=numpy.int32) + b = numpy.ones(shape2, dtype=numpy.int32) + ia = dpnp.array(a) + ib = dpnp.array(b) + # output data type is incorrect + np_out = numpy.empty(shape_out, dtype=numpy.int64) + dp_out = dpnp.empty(shape_out, dtype=dpnp.int64) + with pytest.raises(TypeError): + dpnp.dot(ia, ib, out=dp_out) + with pytest.raises(ValueError): + numpy.dot(a, b, out=np_out) -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) -def test_dot_arange(dtype): - n = 10**2 - m = 10**3 if dtype is not inp.float32 else 10**2 - a = numpy.hstack((numpy.arange(n, dtype=dtype),) * m) - b = numpy.flipud(a) - ia = inp.array(a) - ib = inp.array(b) + # output shape is incorrect + np_out = numpy.empty((2, 3), dtype=numpy.int32) + dp_out = dpnp.empty((2, 3), dtype=dpnp.int32) + with pytest.raises(ValueError): + dpnp.dot(ia, ib, out=dp_out) + with pytest.raises(ValueError): + numpy.dot(a, b, out=np_out) - result = inp.dot(ia, ib) - expected = numpy.dot(a, b) - assert_allclose(expected, result) + # "F" or "C" is irrelevant for 0d or 1d arrays + if not (len(shape_out) in [0, 1]): + # output should be C-contiguous + np_out = numpy.empty(shape_out, dtype=numpy.int32, order="F") + dp_out = dpnp.empty(shape_out, dtype=dpnp.int32, order="F") + with pytest.raises(ValueError): + dpnp.dot(ia, ib, out=dp_out) + with pytest.raises(ValueError): + numpy.dot(a, b, out=np_out) @pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) def test_multi_dot(type): n = 16 - a = inp.reshape(inp.arange(n, dtype=type), (4, 4)) - b = inp.reshape(inp.arange(n, dtype=type), (4, 4)) - c = inp.reshape(inp.arange(n, dtype=type), (4, 4)) - d = inp.reshape(inp.arange(n, dtype=type), (4, 4)) + a = dpnp.reshape(dpnp.arange(n, dtype=type), (4, 4)) + b = dpnp.reshape(dpnp.arange(n, dtype=type), (4, 4)) + c = dpnp.reshape(dpnp.arange(n, dtype=type), (4, 4)) + d = dpnp.reshape(dpnp.arange(n, dtype=type), (4, 4)) a1 = numpy.arange(n, dtype=type).reshape((4, 4)) b1 = numpy.arange(n, dtype=type).reshape((4, 4)) c1 = numpy.arange(n, dtype=type).reshape((4, 4)) d1 = numpy.arange(n, dtype=type).reshape((4, 4)) - result = inp.linalg.multi_dot([a, b, c, d]) + result = dpnp.linalg.multi_dot([a, b, c, d]) expected = numpy.linalg.multi_dot([a1, b1, c1, d1]) assert_array_equal(expected, result) diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 1faa0620f7d..56be3db6d92 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2517,6 +2517,7 @@ class TestMatmul: ((4,), (4,)), ((4,), (4, 2)), ((2, 4), (4,)), + ((1, 4), (4,)), # output should be 1-d not 0-d ((2, 4), (4, 3)), ((1, 2, 3), (1, 3, 5)), ((4, 2, 3), (4, 3, 5)), @@ -2672,7 +2673,7 @@ def test_matmul_dtype(self, dtype, shape_pair): "((6, 7, 4, 3), (6, 7, 3, 5))", ], ) - def test_matmul_dtype_matrix_inputs(self, dtype1, dtype2, shape_pair): + def test_matmul_dtype_matrix_inout(self, dtype1, dtype2, shape_pair): shape1, shape2 = shape_pair a1 = numpy.arange(numpy.prod(shape1), dtype=dtype1).reshape(shape1) a2 = numpy.arange(numpy.prod(shape2), dtype=dtype1).reshape(shape2) @@ -2703,7 +2704,7 @@ def test_matmul_dtype_matrix_inputs(self, dtype1, dtype2, shape_pair): "((6, 7, 4, 3), (6, 7, 3, 5))", ], ) - def test_matmul_dtype_matrix_inout(self, dtype1, dtype2, shape_pair): + def test_matmul_dtype_matrix_inputs(self, dtype1, dtype2, shape_pair): shape1, shape2 = shape_pair a1 = numpy.arange(numpy.prod(shape1), dtype=dtype1).reshape(shape1) a2 = numpy.arange(numpy.prod(shape2), dtype=dtype2).reshape(shape2) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 78a869fac9d..a8b8be52009 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -534,8 +534,8 @@ def test_reduce_hypot(device): ), pytest.param( "dot", - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[4.0, 4.0], [4.0, 4.0], [4.0, 4.0]], + [3.0, 4.0, 5.0], + [1.0, 2.0, 3.0], ), pytest.param( "floor_divide", [1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5] @@ -842,8 +842,8 @@ def test_out_1in_1out(func, data, device): ), pytest.param( "dot", - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[4.0, 4.0], [4.0, 4.0], [4.0, 4.0]], + [3.0, 4.0, 5.0], + [1.0, 2.0, 3.0], ), pytest.param( "floor_divide", [1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5] diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 5a29e677747..171e979facf 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -494,8 +494,8 @@ def test_1in_1out(func, data, usm_type): pytest.param("copysign", [0.0, 1.0, 2.0], [-1.0, 0.0, 1.0]), pytest.param( "dot", - [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], - [[4.0, 4.0], [4.0, 4.0], [4.0, 4.0]], + [3.0, 4.0, 5.0], + [1.0, 2.0, 3.0], ), pytest.param("fmax", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), pytest.param("fmin", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), diff --git a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py index 99dcfb2127c..b620bd39e98 100644 --- a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py +++ b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py @@ -15,12 +15,6 @@ def _get_hermitian(xp, a, UPLO): return xp.tril(a) + xp.tril(a, k=-1).swapaxes(-2, -1).conj() -# TODO: -# remove once dpnp.dot and dpnp.matmul support complex types -def _wrap_as_numpy_array(xp, a): - return a.asnumpy() if xp is cupy else a - - @testing.parameterize( *testing.product( { @@ -57,20 +51,12 @@ def test_eigh(self, xp, dtype): else: tol = 1e-5 - # TODO: remove _wrap_as_numpy_array() once @ support complex types - testing.assert_allclose( - _wrap_as_numpy_array(xp, A) @ _wrap_as_numpy_array(xp, v), - _wrap_as_numpy_array(xp, v) - @ numpy.diag(_wrap_as_numpy_array(xp, w)), - atol=tol, - rtol=tol, - ) + testing.assert_allclose(A @ v, v @ xp.diag(w), atol=tol, rtol=tol) # Check if v @ vt is an identity matrix testing.assert_allclose( - _wrap_as_numpy_array(xp, v) - @ _wrap_as_numpy_array(xp, v).swapaxes(-2, -1).conj(), - numpy.identity(_wrap_as_numpy_array(xp, A).shape[-1], _dtype), + v @ v.swapaxes(-2, -1).conj(), + xp.identity(A.shape[-1], _dtype), atol=tol, rtol=tol, ) @@ -121,11 +107,6 @@ def test_eigh_complex_batched(self, xp, dtype): # them through the eigen equation A*v=w*v. A = _get_hermitian(xp, a, self.UPLO) - # TODO: remove _wrap_as_numpy_array() once dpnp.dot() support complex types - A = _wrap_as_numpy_array(xp, A) - v = _wrap_as_numpy_array(xp, v) - w = _wrap_as_numpy_array(xp, w) - for i in range(a.shape[0]): testing.assert_allclose( A[i].dot(v[i]), w[i] * v[i], rtol=1e-5, atol=1e-5 diff --git a/tests/third_party/cupy/linalg_tests/test_product.py b/tests/third_party/cupy/linalg_tests/test_product.py index 93b13c93e87..1fd048356b4 100644 --- a/tests/third_party/cupy/linalg_tests/test_product.py +++ b/tests/third_party/cupy/linalg_tests/test_product.py @@ -36,10 +36,12 @@ } ) ) -@testing.gpu +# TODO: get rid of falls back on NumPy when tensordot +# is implemented using OneMKL +@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDot(unittest.TestCase): @testing.for_all_dtypes_combination(["dtype_a", "dtype_b"]) - @testing.numpy_cupy_allclose() + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_dot(self, xp, dtype_a, dtype_b): shape_a, shape_b = self.shape if self.trans_a: @@ -71,8 +73,13 @@ def test_dot_with_out(self, xp, dtype_a, dtype_b, dtype_c): else: shape_c = shape_a[:-1] + shape_b[:-2] + shape_b[-1:] c = xp.empty(shape_c, dtype=dtype_c) - out = xp.dot(a, b, out=c) - self.assertIs(out, c) + try: + out = xp.dot(a, b, out=c) + except TypeError: + # When output dtype is incorrect, NumPy raises ValueError + # While DPNP raises TypeError, so we change it to ValueError + raise ValueError + assert out is c return c @@ -128,10 +135,11 @@ def test_cross(self, xp, dtype_a, dtype_b): } ) ) -@testing.gpu class TestDotFor0Dim(unittest.TestCase): @testing.for_all_dtypes_combination(["dtype_a", "dtype_b"]) - @testing.numpy_cupy_allclose(contiguous_check=False) + @testing.numpy_cupy_allclose( + type_check=has_support_aspect64(), contiguous_check=False + ) def test_dot(self, xp, dtype_a, dtype_b): shape_a, shape_b = self.shape if self.trans_a: @@ -145,8 +153,7 @@ def test_dot(self, xp, dtype_a, dtype_b): return xp.dot(a, b) -@testing.gpu -class TestProduct(unittest.TestCase): +class TestProduct: @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_dot_vec1(self, xp, dtype): @@ -154,6 +161,9 @@ def test_dot_vec1(self, xp, dtype): b = testing.shaped_arange((2,), xp, dtype) return xp.dot(a, b) + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_dot_vec2(self, xp, dtype): @@ -168,6 +178,9 @@ def test_dot_vec3(self, xp, dtype): b = testing.shaped_arange((2,), xp, dtype) return xp.dot(a, b) + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_transposed_dot(self, xp, dtype): @@ -175,6 +188,9 @@ def test_transposed_dot(self, xp, dtype): b = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(0, 2, 1) return xp.dot(a, b) + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_transposed_dot_with_out(self, xp, dtype): @@ -184,6 +200,9 @@ def test_transposed_dot_with_out(self, xp, dtype): xp.dot(a, b, out=c) return c + # TODO: get rid of falls back on NumPy when tensordot + # is implemented using OneMKL + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() def test_transposed_dot_with_out_f_contiguous(self, dtype): for xp in (numpy, cupy): diff --git a/tests/third_party/cupy/math_tests/test_matmul.py b/tests/third_party/cupy/math_tests/test_matmul.py index d21ec7a2d68..887ed9ae1b9 100644 --- a/tests/third_party/cupy/math_tests/test_matmul.py +++ b/tests/third_party/cupy/math_tests/test_matmul.py @@ -73,6 +73,61 @@ def test_cupy_matmul(self, xp, dtype1): return xp.matmul(x1, x2) +@testing.parameterize( + *testing.product( + { + "shape_pair": [ + # dot test + ((2, 3), (3, 4), (2, 4)), + # ((0,), (0,), (0,)), + # matmul test + ((5, 3, 2), (5, 2, 4), (5, 3, 4)), + ((0, 3, 2), (0, 2, 4), (0, 3, 4)), + ], + } + ) +) +class TestMatmulOut(unittest.TestCase): + @testing.for_all_dtypes(name="dtype1") + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, accept_error=TypeError # required for uint8 + ) + def test_cupy_matmul_noncontiguous(self, xp, dtype1, dtype2): + x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) + x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2) + out = xp.zeros(self.shape_pair[2], dtype=dtype1)[::-1] + ret = xp.matmul(x1, x2, out=out) + assert ret is out + return ret + + @testing.for_all_dtypes(name="dtype1") + @testing.for_all_dtypes(name="dtype2") + @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + def test_cupy_matmul_out_cast(self, xp, dtype1, dtype2): + x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) + x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2) + out = xp.zeros(self.shape_pair[2], dtype=bool) + ret = xp.matmul(x1, x2, out=out, casting="unsafe") + assert ret is out + return ret + + +class TestMatmulOutOverlap: + @pytest.mark.parametrize( + "shape", + [ + (900, 900), + (2, 600, 600), + ], + ) + @testing.for_dtypes([numpy.int32, numpy.float64]) + @testing.numpy_cupy_allclose(rtol=1e-5, atol=1e-5) + def test_overlap_both(self, xp, dtype, shape): + a = xp.ones(shape, dtype=dtype) + return xp.matmul(a, a, out=a) + + @testing.parameterize( *testing.product( {