From d9b184473e26602d0cedd2cee0ddccd1d6be3588 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Mon, 12 Aug 2024 13:30:19 -0700 Subject: [PATCH 1/8] Implemented BLAS backend for work with oneMKL Interfaces --- CMakeLists.txt | 2 +- dpnp/backend/extensions/blas/CMakeLists.txt | 24 +++++++- dpnp/backend/extensions/blas/blas_py.cpp | 10 ++++ dpnp/backend/extensions/blas/dot.hpp | 16 +++--- dpnp/backend/extensions/blas/dotc.hpp | 17 +++--- dpnp/backend/extensions/blas/dotu.hpp | 17 +++--- dpnp/backend/extensions/blas/gemm.cpp | 4 ++ dpnp/backend/extensions/blas/gemm_batch.cpp | 13 ++++- dpnp/backend/extensions/blas/gemv.cpp | 7 ++- dpnp/backend/extensions/blas/types_matrix.hpp | 4 ++ dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 55 ++++++++++++++++--- scripts/build_locally.py | 2 +- 12 files changed, 132 insertions(+), 39 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 18b543ceb83..aa1cbabf8ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,7 +108,7 @@ if(_use_onemkl_interfaces) set(BUILD_FUNCTIONAL_TESTS False) set(BUILD_EXAMPLES False) if(_use_onemkl_interfaces_cuda) - # set(ENABLE_CUBLAS_BACKEND True) + set(ENABLE_CUBLAS_BACKEND True) set(ENABLE_CUSOLVER_BACKEND True) set(ENABLE_CUFFT_BACKEND True) # set(ENABLE_CURAND_BACKEND True) diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt index 7e2ce831870..1695785e794 100644 --- a/dpnp/backend/extensions/blas/CMakeLists.txt +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -35,6 +35,20 @@ set(_module_src pybind11_add_module(${python_module_name} MODULE ${_module_src}) add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src}) +if(_dpnp_sycl_targets) + # make fat binary + target_compile_options( + ${python_module_name} + PRIVATE + -fsycl-targets=${_dpnp_sycl_targets} + ) + target_link_options( + ${python_module_name} + PRIVATE + -fsycl-targets=${_dpnp_sycl_targets} + ) +endif() + if (WIN32) if (${CMAKE_VERSION} VERSION_LESS "3.27") # this is a work-around for target_link_options inserting option after -link option, cause @@ -69,7 +83,15 @@ if (DPNP_GENERATE_COVERAGE) target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping) endif() -target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS) +if(_use_onemkl_interfaces) + target_link_libraries(${python_module_name} PUBLIC onemkl) + target_compile_options(${python_module_name} PRIVATE -DUSE_ONEMKL_INTERFACES) + if(_use_onemkl_interfaces_cuda) + target_compile_options(${python_module_name} PRIVATE -DUSE_ONEMKL_CUBLAS) + endif() +else() + target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS) +endif() install(TARGETS ${python_module_name} DESTINATION "dpnp/backend/extensions/blas" diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index 54fde4f4fea..aa5ef52be9e 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -142,5 +142,15 @@ PYBIND11_MODULE(_blas_impl, m) py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"), py::arg("vectorY"), py::arg("transpose"), py::arg("depends") = py::list()); + m.def( + "_row_major_is_available", + [](void) { +#if defined(USE_ONEMKL_CUBLAS) + return false; +#else + return true; +#endif // USE_ONEMKL_CUBLAS + }, + "Check if the onemkl::blas::row_major can be used."); } } diff --git a/dpnp/backend/extensions/blas/dot.hpp b/dpnp/backend/extensions/blas/dot.hpp index e700f983097..5e8f1e304e9 100644 --- a/dpnp/backend/extensions/blas/dot.hpp +++ b/dpnp/backend/extensions/blas/dot.hpp @@ -53,14 +53,14 @@ static sycl::event dot_impl(sycl::queue &exec_q, sycl::event dot_event; try { - dot_event = mkl_blas::row_major::dot(exec_q, - n, // size of the input vectors - x, // Pointer to vector x. - incx, // Stride of vector x. - y, // Pointer to vector y. - incy, // Stride of vector y. - res, // Pointer to result. - depends); + dot_event = mkl_blas::column_major::dot(exec_q, + n, // size of the input vectors + x, // Pointer to vector x. + incx, // Stride of vector x. + y, // Pointer to vector y. + incy, // Stride of vector y. + res, // Pointer to result. + depends); } catch (oneapi::mkl::exception const &e) { error_msg << "Unexpected MKL exception caught during dot() call:\nreason: " diff --git a/dpnp/backend/extensions/blas/dotc.hpp b/dpnp/backend/extensions/blas/dotc.hpp index 417c832bf06..24371b849a0 100644 --- a/dpnp/backend/extensions/blas/dotc.hpp +++ b/dpnp/backend/extensions/blas/dotc.hpp @@ -53,14 +53,15 @@ static sycl::event dotc_impl(sycl::queue &exec_q, sycl::event dotc_event; try { - dotc_event = mkl_blas::row_major::dotc(exec_q, - n, // size of the input vectors - x, // Pointer to vector x. - incx, // Stride of vector x. - y, // Pointer to vector y. - incy, // Stride of vector y. - res, // Pointer to result. - depends); + dotc_event = + mkl_blas::column_major::dotc(exec_q, + n, // size of the input vectors + x, // Pointer to vector x. + incx, // Stride of vector x. + y, // Pointer to vector y. + incy, // Stride of vector y. + res, // Pointer to result. + depends); } catch (oneapi::mkl::exception const &e) { error_msg << "Unexpected MKL exception caught during dotc() call:\nreason: " diff --git a/dpnp/backend/extensions/blas/dotu.hpp b/dpnp/backend/extensions/blas/dotu.hpp index 51c30735d22..57b89a508d0 100644 --- a/dpnp/backend/extensions/blas/dotu.hpp +++ b/dpnp/backend/extensions/blas/dotu.hpp @@ -53,14 +53,15 @@ static sycl::event dotu_impl(sycl::queue &exec_q, sycl::event dotu_event; try { - dotu_event = mkl_blas::row_major::dotu(exec_q, - n, // size of the input vectors - x, // Pointer to vector x. - incx, // Stride of vector x. - y, // Pointer to vector y. - incy, // Stride of vector y. - res, // Pointer to result. - depends); + dotu_event = + mkl_blas::column_major::dotu(exec_q, + n, // size of the input vectors + x, // Pointer to vector x. + incx, // Stride of vector x. + y, // Pointer to vector y. + incy, // Stride of vector y. + res, // Pointer to result. + depends); } catch (oneapi::mkl::exception const &e) { error_msg << "Unexpected MKL exception caught during dotu() call:\nreason: " diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index f47f8ebe7ae..86191ac67be 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -95,9 +95,13 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc, const std::vector &deps) -> sycl::event { if (is_row_major) { +#if defined(USE_ONEMKL_CUBLAS) + throw py::value_error("Input matrices are not f-contiguous"); +#else return mkl_blas::row_major::gemm(q, transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, deps); +#endif // USE_ONEMKL_CUBLAS } else { return mkl_blas::column_major::gemm(q, transA, transB, m, n, k, diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 74aa6d9b1dd..670cf3f6bfa 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -107,9 +107,14 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, const std::int64_t batch_size, const std::vector &deps) -> sycl::event { if (is_row_major) { +#if defined(USE_ONEMKL_CUBLAS) + throw py::value_error( + "last 2-dimensions of input matrices are not f-contiguous"); +#else return mkl_blas::row_major::gemm_batch( q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, batch_size, deps); +#endif // USE_ONEMKL_CUBLAS } else { return mkl_blas::column_major::gemm_batch( @@ -273,11 +278,12 @@ std::tuple standardize_strides_to_nonzero(b_stride, b_shape); standardize_strides_to_nonzero(c_stride, c_shape); const bool A_base_is_f_contig = - a_stride[1] == 1 && a_stride[2] == a_shape[1]; + (a_stride[1] == 1 || a_stride[1] == a_shape[2]) && + a_stride[2] == a_shape[1]; const bool A_base_is_c_contig = a_stride[1] == a_shape[2] && a_stride[2] == 1; const bool B_base_is_f_contig = - b_stride[1] == 1 && b_stride[2] == b_shape[1]; + b_stride[1] == 1 && (b_stride[2] == b_shape[1] || b_stride[2] == 1); const bool B_base_is_c_contig = b_stride[1] == b_shape[2] && b_stride[2] == 1; const bool C_base_is_f_contig = @@ -380,7 +386,8 @@ struct GemmBatchContigFactory fnT get() { if constexpr (types::GemmBatchTypePairSupportFactory::is_defined) { + Tc>::is_defined) + { return gemm_batch_impl; } else { diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 7104c9023f8..56049e1f30a 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -88,8 +88,12 @@ static sycl::event gemv_impl(sycl::queue &exec_q, T beta, T *y, const std::int64_t incy, const std::vector &deps) -> sycl::event { if (is_row_major) { +#if defined(USE_ONEMKL_CUBLAS) + throw py::value_error("Input matrix is not f-contiguous"); +#else return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda, x, incx, beta, y, incy, deps); +#endif // USE_ONEMKL_CUBLAS } else { return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, @@ -223,7 +227,8 @@ std::pair const int vectorY_typenum = vectorY.get_typenum(); if (matrixA_typenum != vectorX_typenum || - matrixA_typenum != vectorY_typenum) { + matrixA_typenum != vectorY_typenum) + { throw py::value_error("Given arrays must be of the same type."); } diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp index 1d9bf637780..0e1afd37d35 100644 --- a/dpnp/backend/extensions/blas/types_matrix.hpp +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -110,8 +110,10 @@ template struct GemmTypePairSupportFactory { static constexpr bool is_defined = std::disjunction< +#if !defined(USE_ONEMKL_INTERFACES) dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, +#endif // USE_ONEMKL_INTERFACES dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, @@ -140,8 +142,10 @@ template struct GemmBatchTypePairSupportFactory { static constexpr bool is_defined = std::disjunction< +#if !defined(USE_ONEMKL_INTERFACES) dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, +#endif // USE_ONEMKL_INTERFACES dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index dbbbada4e1f..77c12049a58 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -325,12 +325,28 @@ def _get_result_shape(x1, x2, out, np_flag): def _gemm_batch_matmul(exec_q, x1, x2, res): # arrays here are already at least 3D, make them 3D - x1_shape = x1.shape - x2_shape = x2.shape - x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1])) - x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1])) orig_shape = res.shape - res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1])) + if not bi._row_major_is_available(): + tmp = x1 + x1 = x2 + x2 = tmp + x1_shape = x1.shape + x2_shape = x2.shape + if not x1.flags.c_contiguous: + x1 = dpnp.asarray(x1, order="C") + if not x2.flags.c_contiguous: + x2 = dpnp.asarray(x2, order="C") + x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1])) + x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1])) + res = dpnp.reshape(res, (-1, orig_shape[-1], orig_shape[-2])) + x1 = x1.transpose(0, 2, 1) + x2 = x2.transpose(0, 2, 1) + else: + x1_shape = x1.shape + x2_shape = x2.shape + x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1])) + x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1])) + res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1])) res_shape = res.shape # gemm_batch does not handle negative strides, make a copy if needed @@ -383,7 +399,7 @@ def _gemm_batch_matmul(exec_q, x1, x2, res): .reshape(res_shape[1], res_shape[2], batch_size) .transpose(2, 0, 1) ) - else: + elif bi._row_major_is_available(): if res_is_c_contig: # read data of each 2D array in the batch in "C" order and # write it in "F" order @@ -419,6 +435,12 @@ def _gemm_matmul(exec_q, x1, x2, res): if res.flags.c_contiguous is True: # read data in "C" order and write it in "F" order res = dpnp.ravel(res, order="C").reshape(res.shape, order="F") + elif ( + not bi._row_major_is_available() and res.flags.f_contiguous is True + ): + # read data in "C" order and write it in "C" order + # make result similar for row_major call + res = dpnp.ravel(res, order="C").reshape(res.shape, order="C") return res @@ -794,16 +816,25 @@ def dpnp_matmul( x1 = dpnp.reshape(x1, x1.size) x2 = dpnp.reshape(x2, x2_shape[-2:]) res_shape = (x2_shape[-1],) + if not bi._row_major_is_available() and x2.flags.c_contiguous: + x2 = dpnp.asarray(x2, order="F") elif x1_is_2D and x2_is_1D: call_flag = "gemv" x1 = dpnp.reshape(x1, x1_shape[-2:]) x2 = dpnp.reshape(x2, x2.size) res_shape = (x1_shape[-2],) + if not bi._row_major_is_available() and x1.flags.c_contiguous: + x1 = dpnp.asarray(x1, order="F") elif x1_is_2D and x2_is_2D: call_flag = "gemm" x1 = dpnp.reshape(x1, x1_shape[-2:]) x2 = dpnp.reshape(x2, x2_shape[-2:]) res_shape = (x1_shape[-2], x2_shape[-1]) + if not bi._row_major_is_available(): + if x1.flags.c_contiguous: + x1 = dpnp.asarray(x1, order="F") + if x2.flags.c_contiguous: + x2 = dpnp.asarray(x2, order="F") elif x1_base_is_1D: # TODO: implement gemv_batch to use it here with transpose call_flag = "gemm_batch" @@ -839,6 +870,14 @@ def dpnp_matmul( x2_contig_flag, _, x2_f = _define_contig_flag(x2) res_order = "F" if (x1_f and x2_f and call_flag == "gemm") else "C" + + if bi._row_major_is_available(): + array_order = res_order + elif (call_flag == "gemm") or (call_flag == "gemv"): + array_order = "F" + else: + array_order = "C" + result = _create_result_array( x1, x2, @@ -862,13 +901,13 @@ def dpnp_matmul( x1, copy_flag=not x1_contig_flag, dtype=compute_dtype, - order=res_order, + order=array_order, ) x2 = _copy_array( x2, copy_flag=not x2_contig_flag, dtype=compute_dtype, - order=res_order, + order=array_order, ) if call_flag == "gemv": diff --git a/scripts/build_locally.py b/scripts/build_locally.py index d5f14102837..0c0b29efd3b 100644 --- a/scripts/build_locally.py +++ b/scripts/build_locally.py @@ -170,7 +170,7 @@ def run( type=str, ) driver.add_argument( - "--onemkl_interfaces", + "--onemkl-interfaces", help="Build using oneMKL Interfaces", dest="onemkl_interfaces", action="store_true", From e00f6c06c401bb9ae44cbb70dee069ec7e24936a Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Tue, 20 Aug 2024 14:19:52 -0700 Subject: [PATCH 2/8] update gemm --- dpnp/backend/extensions/blas/gemm.cpp | 51 +++++++++++++++------ dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 15 +----- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index 86191ac67be..578466361da 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -94,20 +94,22 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const Tab *a, const std::int64_t lda, const Tab *b, const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc, const std::vector &deps) -> sycl::event { - if (is_row_major) { #if defined(USE_ONEMKL_CUBLAS) - throw py::value_error("Input matrices are not f-contiguous"); + return mkl_blas::column_major::gemm(q, transA, transB, m, n, k, + alpha, a, lda, b, ldb, beta, c, + ldc, deps); #else + if (is_row_major) { return mkl_blas::row_major::gemm(q, transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, deps); -#endif // USE_ONEMKL_CUBLAS } else { return mkl_blas::column_major::gemm(q, transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, deps); } +#endif // USE_ONEMKL_CUBLAS }; gemm_event = gemm_func( exec_q, @@ -226,26 +228,43 @@ std::tuple throw py::value_error( "Result array is not c-contiguous nor f-contiguous."); } + + oneapi::mkl::transpose transA; + oneapi::mkl::transpose transB; + std::int64_t lda; + std::int64_t ldb; + +#if defined(USE_ONEMKL_CUBLAS) + bool is_row_major = false; + + transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + transB = is_matrixB_c_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + + if (transA == oneapi::mkl::transpose::N) { + lda = m; + } + else { + lda = k; + } + if (transB == oneapi::mkl::transpose::N) { + ldb = k; + } + else { + ldb = n; + } +#else bool is_row_major = true; if (is_matrixA_f_contig && is_matrixB_f_contig) { is_row_major = false; } - oneapi::mkl::transpose transA; - oneapi::mkl::transpose transB; + if (is_row_major) { transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; - } - else { - transA = oneapi::mkl::transpose::N; - transB = oneapi::mkl::transpose::N; - } - - std::int64_t lda; - std::int64_t ldb; - if (is_row_major) { if (transA == oneapi::mkl::transpose::N) { lda = k; } @@ -260,9 +279,13 @@ std::tuple } } else { + transA = oneapi::mkl::transpose::N; + transB = oneapi::mkl::transpose::N; lda = m; ldb = k; } +#endif // USE_ONEMKL_CUBLAS + const std::int64_t ldc = is_row_major ? n : m; const int matrixA_typenum = matrixA.get_typenum(); diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 77c12049a58..f6ab4cf638a 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -435,12 +435,6 @@ def _gemm_matmul(exec_q, x1, x2, res): if res.flags.c_contiguous is True: # read data in "C" order and write it in "F" order res = dpnp.ravel(res, order="C").reshape(res.shape, order="F") - elif ( - not bi._row_major_is_available() and res.flags.f_contiguous is True - ): - # read data in "C" order and write it in "C" order - # make result similar for row_major call - res = dpnp.ravel(res, order="C").reshape(res.shape, order="C") return res @@ -830,11 +824,6 @@ def dpnp_matmul( x1 = dpnp.reshape(x1, x1_shape[-2:]) x2 = dpnp.reshape(x2, x2_shape[-2:]) res_shape = (x1_shape[-2], x2_shape[-1]) - if not bi._row_major_is_available(): - if x1.flags.c_contiguous: - x1 = dpnp.asarray(x1, order="F") - if x2.flags.c_contiguous: - x2 = dpnp.asarray(x2, order="F") elif x1_base_is_1D: # TODO: implement gemv_batch to use it here with transpose call_flag = "gemm_batch" @@ -873,10 +862,10 @@ def dpnp_matmul( if bi._row_major_is_available(): array_order = res_order - elif (call_flag == "gemm") or (call_flag == "gemv"): + elif call_flag == "gemv": array_order = "F" else: - array_order = "C" + array_order = res_order result = _create_result_array( x1, From fe1c091cc54c95514b3167a5b488b4d0b203a7e1 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Thu, 22 Aug 2024 15:00:45 -0700 Subject: [PATCH 3/8] Update gemv --- dpnp/backend/extensions/blas/gemv.cpp | 87 +++++++++++++++++++-- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 15 +--- 2 files changed, 83 insertions(+), 19 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 56049e1f30a..7a5facbb037 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -180,20 +180,91 @@ std::pair "Input matrix is not c-contiguous nor f-contiguous."); } + const py::ssize_t *a_shape = matrixA.get_shape_raw(); + const py::ssize_t *x_shape = vectorX.get_shape_raw(); + const py::ssize_t *y_shape = vectorY.get_shape_raw(); + + oneapi::mkl::transpose transA; + std::size_t src_nelems; + +#if defined(USE_ONEMKL_CUBLAS) + bool is_row_major = false; + std::int64_t m; + std::int64_t n; + + if (is_matrixA_f_contig) { + m = a_shape[0]; + n = a_shape[1]; + if (transpose) { + transA = oneapi::mkl::transpose::T; + src_nelems = n; + if (m != x_shape[0]) { + throw py::value_error( + "The number of rows in A must be equal to " + "the number of elements in X."); + } + if (n != y_shape[0]) { + throw py::value_error( + "The number of columns in A must be equal to " + "the number of elements in Y."); + } + } + else { + transA = oneapi::mkl::transpose::N; + src_nelems = m; + if (n != x_shape[0]) { + throw py::value_error( + "The number of columns in A must be equal to " + "the number of elements in X."); + } + if (m != y_shape[0]) { + throw py::value_error( + "The number of rows in A must be equal to " + "the number of elements in Y."); + } + } + } + else { + m = a_shape[1]; + n = a_shape[0]; + if (transpose) { + transA = oneapi::mkl::transpose::N; + src_nelems = m; + if (n != x_shape[0]) { + throw py::value_error( + "The number of rows in A must be equal to " + "the number of elements in X."); + } + if (m != y_shape[0]) { + throw py::value_error( + "The number of columns in A must be equal to " + "the number of elements in Y."); + } + } + else { + transA = oneapi::mkl::transpose::T; + src_nelems = n; + if (m != x_shape[0]) { + throw py::value_error( + "The number of columns in A must be equal to " + "the number of elements in X."); + } + if (n != y_shape[0]) { + throw py::value_error( + "The number of rows in A must be equal to " + "the number of elements in Y."); + } + } + } +#else bool is_row_major = true; if (is_matrixA_f_contig) { is_row_major = false; } - const py::ssize_t *a_shape = matrixA.get_shape_raw(); - const py::ssize_t *x_shape = vectorX.get_shape_raw(); - const py::ssize_t *y_shape = vectorY.get_shape_raw(); const std::int64_t m = a_shape[0]; const std::int64_t n = a_shape[1]; - const std::int64_t lda = is_row_major ? n : m; - oneapi::mkl::transpose transA; - std::size_t src_nelems; if (transpose) { transA = oneapi::mkl::transpose::T; src_nelems = n; @@ -218,6 +289,10 @@ std::pair "the number of elements in Y."); } } +#endif // USE_ONEMKL_CUBLAS + + const std::int64_t lda = is_row_major ? n : m; + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY); dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY, src_nelems); diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index f6ab4cf638a..6f9fd478477 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -810,15 +810,11 @@ def dpnp_matmul( x1 = dpnp.reshape(x1, x1.size) x2 = dpnp.reshape(x2, x2_shape[-2:]) res_shape = (x2_shape[-1],) - if not bi._row_major_is_available() and x2.flags.c_contiguous: - x2 = dpnp.asarray(x2, order="F") elif x1_is_2D and x2_is_1D: call_flag = "gemv" x1 = dpnp.reshape(x1, x1_shape[-2:]) x2 = dpnp.reshape(x2, x2.size) res_shape = (x1_shape[-2],) - if not bi._row_major_is_available() and x1.flags.c_contiguous: - x1 = dpnp.asarray(x1, order="F") elif x1_is_2D and x2_is_2D: call_flag = "gemm" x1 = dpnp.reshape(x1, x1_shape[-2:]) @@ -860,13 +856,6 @@ def dpnp_matmul( res_order = "F" if (x1_f and x2_f and call_flag == "gemm") else "C" - if bi._row_major_is_available(): - array_order = res_order - elif call_flag == "gemv": - array_order = "F" - else: - array_order = res_order - result = _create_result_array( x1, x2, @@ -890,13 +879,13 @@ def dpnp_matmul( x1, copy_flag=not x1_contig_flag, dtype=compute_dtype, - order=array_order, + order=res_order, ) x2 = _copy_array( x2, copy_flag=not x2_contig_flag, dtype=compute_dtype, - order=array_order, + order=res_order, ) if call_flag == "gemv": From a3ed4d719a719a4d53c11d9fb1be8197132de28b Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Fri, 23 Aug 2024 09:52:23 -0700 Subject: [PATCH 4/8] Update gemv_impl --- dpnp/backend/extensions/blas/gemv.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 7a5facbb037..3726ffb4b8f 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -87,19 +87,20 @@ static sycl::event gemv_impl(sycl::queue &exec_q, const std::int64_t lda, const T *x, const std::int64_t incx, T beta, T *y, const std::int64_t incy, const std::vector &deps) -> sycl::event { - if (is_row_major) { #if defined(USE_ONEMKL_CUBLAS) - throw py::value_error("Input matrix is not f-contiguous"); + return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, lda, + x, incx, beta, y, incy, deps); #else + if (is_row_major) { return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda, x, incx, beta, y, incy, deps); -#endif // USE_ONEMKL_CUBLAS } else { return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, lda, x, incx, beta, y, incy, deps); } +#endif // USE_ONEMKL_CUBLAS }; gemv_event = gemv_func( exec_q, From 29d3694512e23aba868d27e154740f2e0a857633 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Fri, 23 Aug 2024 10:02:43 -0700 Subject: [PATCH 5/8] Update gemm_batch --- dpnp/backend/extensions/blas/gemm_batch.cpp | 45 ++++++++++++++++----- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 28 +++---------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 670cf3f6bfa..18b71d35816 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -106,21 +106,22 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, Tc *c, const std::int64_t ldc, const std::int64_t stridec, const std::int64_t batch_size, const std::vector &deps) -> sycl::event { - if (is_row_major) { #if defined(USE_ONEMKL_CUBLAS) - throw py::value_error( - "last 2-dimensions of input matrices are not f-contiguous"); + return mkl_blas::column_major::gemm_batch( + q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, + strideb, beta, c, ldc, stridec, batch_size, deps); #else + if (is_row_major) { return mkl_blas::row_major::gemm_batch( q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, batch_size, deps); -#endif // USE_ONEMKL_CUBLAS } else { return mkl_blas::column_major::gemm_batch( q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, batch_size, deps); } +#endif // USE_ONEMKL_CUBLAS }; gemm_batch_event = gemm_batch_func( exec_q, @@ -291,11 +292,6 @@ std::tuple const bool C_base_is_c_contig = c_stride[1] == c_shape[2] && c_stride[2] == 1; - bool is_row_major = true; - if (A_base_is_f_contig && B_base_is_f_contig) { - is_row_major = false; - } - if (!A_base_is_f_contig and !A_base_is_c_contig) { throw py::value_error("The 2D base of the first input array is not " "c-contiguous nor f-contiguous."); @@ -311,6 +307,34 @@ std::tuple oneapi::mkl::transpose transA; oneapi::mkl::transpose transB; + std::int64_t lda; + std::int64_t ldb; + +#if defined(USE_ONEMKL_CUBLAS) + bool is_row_major = false; + transA = A_base_is_c_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + transB = B_base_is_c_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + + if (transA == oneapi::mkl::transpose::N) { + lda = m; + } + else { + lda = k; + } + if (transB == oneapi::mkl::transpose::N) { + ldb = k; + } + else { + ldb = n; + } +#else + bool is_row_major = true; + if (A_base_is_f_contig && B_base_is_f_contig) { + is_row_major = false; + } + if (is_row_major) { transA = A_base_is_f_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; @@ -322,8 +346,6 @@ std::tuple transB = oneapi::mkl::transpose::N; } - std::int64_t lda; - std::int64_t ldb; if (is_row_major) { if (transA == oneapi::mkl::transpose::N) { lda = k; @@ -342,6 +364,7 @@ std::tuple lda = m; ldb = k; } +#endif // USE_ONEMKL_CUBLAS const std::int64_t ldc = is_row_major ? n : m; const int matrixA_typenum = matrixA.get_typenum(); diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 6f9fd478477..da27409819c 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -325,28 +325,12 @@ def _get_result_shape(x1, x2, out, np_flag): def _gemm_batch_matmul(exec_q, x1, x2, res): # arrays here are already at least 3D, make them 3D + x1_shape = x1.shape + x2_shape = x2.shape + x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1])) + x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1])) orig_shape = res.shape - if not bi._row_major_is_available(): - tmp = x1 - x1 = x2 - x2 = tmp - x1_shape = x1.shape - x2_shape = x2.shape - if not x1.flags.c_contiguous: - x1 = dpnp.asarray(x1, order="C") - if not x2.flags.c_contiguous: - x2 = dpnp.asarray(x2, order="C") - x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1])) - x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1])) - res = dpnp.reshape(res, (-1, orig_shape[-1], orig_shape[-2])) - x1 = x1.transpose(0, 2, 1) - x2 = x2.transpose(0, 2, 1) - else: - x1_shape = x1.shape - x2_shape = x2.shape - x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1])) - x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1])) - res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1])) + res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1])) res_shape = res.shape # gemm_batch does not handle negative strides, make a copy if needed @@ -399,7 +383,7 @@ def _gemm_batch_matmul(exec_q, x1, x2, res): .reshape(res_shape[1], res_shape[2], batch_size) .transpose(2, 0, 1) ) - elif bi._row_major_is_available(): + else: if res_is_c_contig: # read data of each 2D array in the batch in "C" order and # write it in "F" order From a3f92cbc4165833c0c97bcced6dc8b8c55b16241 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Fri, 23 Aug 2024 10:18:47 -0700 Subject: [PATCH 6/8] Fix pre-commit --- dpnp/backend/extensions/blas/gemm_batch.cpp | 8 +++----- dpnp/backend/extensions/blas/gemv.cpp | 3 +-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 18b71d35816..25e06ba1539 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -279,12 +279,11 @@ std::tuple standardize_strides_to_nonzero(b_stride, b_shape); standardize_strides_to_nonzero(c_stride, c_shape); const bool A_base_is_f_contig = - (a_stride[1] == 1 || a_stride[1] == a_shape[2]) && - a_stride[2] == a_shape[1]; + a_stride[1] == 1 && a_stride[2] == a_shape[1]; const bool A_base_is_c_contig = a_stride[1] == a_shape[2] && a_stride[2] == 1; const bool B_base_is_f_contig = - b_stride[1] == 1 && (b_stride[2] == b_shape[1] || b_stride[2] == 1); + b_stride[1] == 1 && b_stride[2] == b_shape[1]; const bool B_base_is_c_contig = b_stride[1] == b_shape[2] && b_stride[2] == 1; const bool C_base_is_f_contig = @@ -409,8 +408,7 @@ struct GemmBatchContigFactory fnT get() { if constexpr (types::GemmBatchTypePairSupportFactory::is_defined) - { + Tc>::is_defined) { return gemm_batch_impl; } else { diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 3726ffb4b8f..09d69f97a40 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -303,8 +303,7 @@ std::pair const int vectorY_typenum = vectorY.get_typenum(); if (matrixA_typenum != vectorX_typenum || - matrixA_typenum != vectorY_typenum) - { + matrixA_typenum != vectorY_typenum) { throw py::value_error("Given arrays must be of the same type."); } From 83bbb17d93e004c7551117bc255c4faaf523fbcf Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 26 Aug 2024 11:06:06 -0700 Subject: [PATCH 7/8] updates to remove duplication --- dpnp/backend/extensions/blas/gemm.cpp | 2 +- dpnp/backend/extensions/blas/gemm_batch.cpp | 12 ++-- dpnp/backend/extensions/blas/gemv.cpp | 62 +++++---------------- 3 files changed, 19 insertions(+), 57 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index 578466361da..e7043af59d2 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -235,7 +235,7 @@ std::tuple std::int64_t ldb; #if defined(USE_ONEMKL_CUBLAS) - bool is_row_major = false; + const bool is_row_major = false; transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 25e06ba1539..2d10bff3773 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -310,7 +310,8 @@ std::tuple std::int64_t ldb; #if defined(USE_ONEMKL_CUBLAS) - bool is_row_major = false; + const bool is_row_major = false; + transA = A_base_is_c_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; transB = B_base_is_c_contig ? oneapi::mkl::transpose::T @@ -339,13 +340,7 @@ std::tuple : oneapi::mkl::transpose::N; transB = B_base_is_f_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; - } - else { - transA = oneapi::mkl::transpose::N; - transB = oneapi::mkl::transpose::N; - } - if (is_row_major) { if (transA == oneapi::mkl::transpose::N) { lda = k; } @@ -360,10 +355,13 @@ std::tuple } } else { + transA = oneapi::mkl::transpose::N; + transB = oneapi::mkl::transpose::N; lda = m; ldb = k; } #endif // USE_ONEMKL_CUBLAS + const std::int64_t ldc = is_row_major ? n : m; const int matrixA_typenum = matrixA.get_typenum(); diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 09d69f97a40..a9408935bf6 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -189,7 +189,7 @@ std::pair std::size_t src_nelems; #if defined(USE_ONEMKL_CUBLAS) - bool is_row_major = false; + const bool is_row_major = false; std::int64_t m; std::int64_t n; @@ -199,30 +199,10 @@ std::pair if (transpose) { transA = oneapi::mkl::transpose::T; src_nelems = n; - if (m != x_shape[0]) { - throw py::value_error( - "The number of rows in A must be equal to " - "the number of elements in X."); - } - if (n != y_shape[0]) { - throw py::value_error( - "The number of columns in A must be equal to " - "the number of elements in Y."); - } } else { transA = oneapi::mkl::transpose::N; src_nelems = m; - if (n != x_shape[0]) { - throw py::value_error( - "The number of columns in A must be equal to " - "the number of elements in X."); - } - if (m != y_shape[0]) { - throw py::value_error( - "The number of rows in A must be equal to " - "the number of elements in Y."); - } } } else { @@ -231,30 +211,10 @@ std::pair if (transpose) { transA = oneapi::mkl::transpose::N; src_nelems = m; - if (n != x_shape[0]) { - throw py::value_error( - "The number of rows in A must be equal to " - "the number of elements in X."); - } - if (m != y_shape[0]) { - throw py::value_error( - "The number of columns in A must be equal to " - "the number of elements in Y."); - } } else { transA = oneapi::mkl::transpose::T; src_nelems = n; - if (m != x_shape[0]) { - throw py::value_error( - "The number of columns in A must be equal to " - "the number of elements in X."); - } - if (n != y_shape[0]) { - throw py::value_error( - "The number of rows in A must be equal to " - "the number of elements in Y."); - } } } #else @@ -269,31 +229,35 @@ std::pair if (transpose) { transA = oneapi::mkl::transpose::T; src_nelems = n; - if (m != x_shape[0]) { + } + else { + transA = oneapi::mkl::transpose::N; + src_nelems = m; + } +#endif // USE_ONEMKL_CUBLAS + + if (transpose) { + if (a_shape[0] != x_shape[0]) { throw py::value_error("The number of rows in A must be equal to " "the number of elements in X."); } - if (n != y_shape[0]) { + if (a_shape[1] != y_shape[0]) { throw py::value_error("The number of columns in A must be equal to " "the number of elements in Y."); } } else { - transA = oneapi::mkl::transpose::N; - src_nelems = m; - if (n != x_shape[0]) { + if (a_shape[1] != x_shape[0]) { throw py::value_error("The number of columns in A must be equal to " "the number of elements in X."); } - if (m != y_shape[0]) { + if (a_shape[0] != y_shape[0]) { throw py::value_error("The number of rows in A must be equal to " "the number of elements in Y."); } } -#endif // USE_ONEMKL_CUBLAS const std::int64_t lda = is_row_major ? n : m; - dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY); dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY, src_nelems); From 14ffa846e56f571f31756042a9cdf83cae922bab Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 26 Aug 2024 11:10:59 -0700 Subject: [PATCH 8/8] fix two issues: 1) when order is given as "A" 2) when axes is given and column_major is called --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index da27409819c..0b13db32f7a 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -412,11 +412,11 @@ def _gemm_matmul(exec_q, x1, x2, res): _manager.add_event_pair(ht_ev, gemm_ev) if row_major: - if res.flags.f_contiguous is True: + if res.flags.f_contiguous: # read data in "F" order and write it in "C" order res = dpnp.ravel(res, order="F").reshape(res.shape, order="C") else: - if res.flags.c_contiguous is True: + if res.flags.c_contiguous: # read data in "C" order and write it in "F" order res = dpnp.ravel(res, order="C").reshape(res.shape, order="F") @@ -729,6 +729,12 @@ def dpnp_matmul( "Input and output allocation queues are not compatible" ) + if order in ["a", "A"]: + if x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + else: + order = "C" + x1_ndim = x1.ndim x2_ndim = x2.ndim if axes is not None: @@ -921,7 +927,7 @@ def dpnp_matmul( result = dpnp.moveaxis(result, (-2, -1), axes_res) elif len(axes_res) == 1: result = dpnp.moveaxis(result, (-1,), axes_res) - return result + return dpnp.ascontiguousarray(result) # If `order` was not passed as default # we need to update it to match the passed `order`.