Skip to content

Commit

Permalink
update dpnp.dot implementation (#1669)
Browse files Browse the repository at this point in the history
* dot_func

* using mkl::dotu instead mkl::dotc for complex

* fix a test

* fix negative strides

* add a temporary workaround

* address comments

* add a TODO comment

* call dpt.vecdot for integer data types

* update doc string

* pass argument by reference

* update doc to add boolean dtype

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
vtavana and antonwolfy authored Feb 6, 2024
1 parent 554bcdd commit ac1fca7
Show file tree
Hide file tree
Showing 25 changed files with 1,329 additions and 425 deletions.
4 changes: 3 additions & 1 deletion dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# *****************************************************************************
# Copyright (c) 2016-2023, Intel Corporation
# Copyright (c) 2024, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -27,6 +27,8 @@
set(python_module_name _blas_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dot.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dotu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
)
Expand Down
21 changes: 20 additions & 1 deletion dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -30,6 +30,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "dot.hpp"
#include "gemm.hpp"

namespace blas_ext = dpnp::backend::ext::blas;
Expand All @@ -38,6 +39,8 @@ namespace py = pybind11;
// populate dispatch tables
void init_dispatch_tables(void)
{
blas_ext::init_dot_dispatch_table();
blas_ext::init_dotu_dispatch_table();
blas_ext::init_gemm_batch_dispatch_table();
blas_ext::init_gemm_dispatch_table();
}
Expand All @@ -46,6 +49,22 @@ PYBIND11_MODULE(_blas_impl, m)
{
init_dispatch_tables();

{
m.def("_dot", &blas_ext::dot,
"Call `dot` from OneMKL LAPACK library to return "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
py::arg("result"), py::arg("depends") = py::list());
}

{
m.def("_dotu", &blas_ext::dotu,
"Call `dotu` from OneMKL LAPACK library to return "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
py::arg("result"), py::arg("depends") = py::list());
}

{
m.def("_gemm", &blas_ext::gemm,
"Call `gemm` from OneMKL LAPACK library to return "
Expand Down
238 changes: 238 additions & 0 deletions dpnp/backend/extensions/blas/dot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "dot.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const std::int64_t,
char *,
const std::int64_t,
char *,
const std::vector<sycl::event> &);

static dot_impl_fn_ptr_t dot_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event dot_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorA,
const std::int64_t stride_a,
char *vectorB,
const std::int64_t stride_b,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
type_utils::validate_type_for_device<Tc>(exec_q);

Tab *a = reinterpret_cast<Tab *>(vectorA);
Tab *b = reinterpret_cast<Tab *>(vectorB);
Tc *res = reinterpret_cast<Tc *>(result);

std::stringstream error_msg;
bool is_exception_caught = false;

sycl::event dot_event;
try {
dot_event = mkl_blas::row_major::dot(exec_q,
n, // size of the input vectors
a, // Pointer to vector a.
stride_a, // Stride of vector a.
b, // Pointer to vector b.
stride_b, // Stride of vector b.
res, // Pointer to result.
depends);
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during dot() call:\nreason: "
<< e.what();
is_exception_caught = true;
} catch (sycl::exception const &e) {
error_msg << "Unexpected SYCL exception caught during dot() call:\n"
<< e.what();
is_exception_caught = true;
}

if (is_exception_caught) // an unexpected error occurs
{
throw std::runtime_error(error_msg.str());
}

return dot_event;
}

std::pair<sycl::event, sycl::event> dot(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray vectorA,
dpctl::tensor::usm_ndarray vectorB,
dpctl::tensor::usm_ndarray result,
const std::vector<sycl::event> &depends)
{
const int vectorA_nd = vectorA.get_ndim();
const int vectorB_nd = vectorB.get_ndim();
const int result_nd = result.get_ndim();

if ((vectorA_nd != 1)) {
throw py::value_error(
"The first input array has ndim=" + std::to_string(vectorA_nd) +
", but a 1-dimensional array is expected.");
}

if ((vectorB_nd != 1)) {
throw py::value_error(
"The second input array has ndim=" + std::to_string(vectorB_nd) +
", but a 1-dimensional array is expected.");
}

if ((result_nd != 0)) {
throw py::value_error(
"The output array has ndim=" + std::to_string(result_nd) +
", but a 0-dimensional array is expected.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(vectorA, result)) {
throw py::value_error(
"The first input array and output array are overlapping "
"segments of memory");
}
if (overlap(vectorB, result)) {
throw py::value_error(
"The second input array and output array are overlapping "
"segments of memory");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(
exec_q,
{vectorA.get_queue(), vectorB.get_queue(), result.get_queue()}))
{
throw py::value_error(
"USM allocations are not compatible with the execution queue.");
}

py::ssize_t a_size = vectorA.get_size();
py::ssize_t b_size = vectorB.get_size();
if (a_size != b_size) {
throw py::value_error("The size of the first input array must be "
"equal to the size of the second input array.");
}

std::vector<py::ssize_t> a_stride = vectorA.get_strides_vector();
std::vector<py::ssize_t> b_stride = vectorB.get_strides_vector();

const std::int64_t n = a_size;
const std::int64_t str_a = a_stride[0];
const std::int64_t str_b = b_stride[0];

int vectorA_typenum = vectorA.get_typenum();
int vectorB_typenum = vectorB.get_typenum();
int result_typenum = result.get_typenum();

if (vectorA_typenum != vectorB_typenum) {
throw py::value_error("vectorA and vectorB must be of the same type.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int vectorAB_type_id = array_types.typenum_to_lookup_id(vectorA_typenum);
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);

dot_impl_fn_ptr_t dot_fn =
dot_dispatch_table[vectorAB_type_id][result_type_id];
if (dot_fn == nullptr) {
throw py::value_error(
"Types of input vectors and result array are mismatched.");
}

char *a_typeless_ptr = vectorA.get_data();
char *b_typeless_ptr = vectorB.get_data();
char *r_typeless_ptr = result.get_data();

const int a_elemsize = vectorA.get_elemsize();
const int b_elemsize = vectorB.get_elemsize();
if (str_a < 0) {
a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize;
}
if (str_b < 0) {
b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize;
}

sycl::event dot_ev = dot_fn(exec_q, n, a_typeless_ptr, str_a,
b_typeless_ptr, str_b, r_typeless_ptr, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {vectorA, vectorB, result}, {dot_ev});

return std::make_pair(args_ev, dot_ev);
}

template <typename fnT, typename Tab, typename Tc>
struct DotContigFactory
{
fnT get()
{
if constexpr (types::DotTypePairSupportFactory<Tab, Tc>::is_defined) {
return dot_impl<Tab, Tc>;
}
else {
return nullptr;
}
}
};

void init_dot_dispatch_table(void)
{
dpctl_td_ns::DispatchTableBuilder<dot_impl_fn_ptr_t, DotContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_table(dot_dispatch_table);
}
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
60 changes: 60 additions & 0 deletions dpnp/backend/extensions/blas/dot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>
#include <oneapi/mkl.hpp>

#include <dpctl4pybind11.hpp>

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
{
extern std::pair<sycl::event, sycl::event>
dot(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray vectorA,
dpctl::tensor::usm_ndarray vectorB,
dpctl::tensor::usm_ndarray result,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
dotu(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray vectorA,
dpctl::tensor::usm_ndarray vectorB,
dpctl::tensor::usm_ndarray result,
const std::vector<sycl::event> &depends);

extern void init_dot_dispatch_table(void);
extern void init_dotu_dispatch_table(void);
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading

0 comments on commit ac1fca7

Please sign in to comment.