Skip to content

Commit

Permalink
update BLAS extension routines
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Jun 15, 2024
1 parent 38fd39d commit 77d387d
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 275 deletions.
57 changes: 30 additions & 27 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@
#include "gemm.hpp"
#include "gemv.hpp"

namespace blas_ext = dpnp::backend::ext::blas;
namespace blas_ns = dpnp::extensions::blas;
namespace py = pybind11;
namespace dot_ext = blas_ext::dot;
using dot_ext::dot_impl_fn_ptr_t;
namespace dot_ns = blas_ns::dot;
using dot_ns::dot_impl_fn_ptr_t;

// populate dispatch vectors and tables
void init_dispatch_vectors_tables(void)
{
blas_ext::init_gemm_batch_dispatch_table();
blas_ext::init_gemm_dispatch_table();
blas_ext::init_gemv_dispatch_vector();
blas_ns::init_gemm_batch_dispatch_table();
blas_ns::init_gemm_dispatch_table();
blas_ns::init_gemv_dispatch_vector();
}

static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
Expand All @@ -62,14 +62,15 @@ PYBIND11_MODULE(_blas_impl, m)
using event_vecT = std::vector<sycl::event>;

{
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ext::DotContigFactory>(
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ns::DotContigFactory>(
dot_dispatch_vector);

auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
auto dot_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
const arrayT &src2, const arrayT &dst,
const event_vecT &depends = {}) {
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
};

m.def("_dot", dot_pyapi,
Expand All @@ -80,14 +81,15 @@ PYBIND11_MODULE(_blas_impl, m)
}

{
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ext::DotcContigFactory>(
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ns::DotcContigFactory>(
dotc_dispatch_vector);

auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
auto dotc_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
const arrayT &src2, const arrayT &dst,
const event_vecT &depends = {}) {
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
};

m.def("_dotc", dotc_pyapi,
Expand All @@ -99,14 +101,15 @@ PYBIND11_MODULE(_blas_impl, m)
}

{
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ext::DotuContigFactory>(
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
blas_ns::DotuContigFactory>(
dotu_dispatch_vector);

auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
auto dotu_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
const arrayT &src2, const arrayT &dst,
const event_vecT &depends = {}) {
return dot_ns::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
};

m.def("_dotu", dotu_pyapi,
Expand All @@ -117,23 +120,23 @@ PYBIND11_MODULE(_blas_impl, m)
}

{
m.def("_gemm", &blas_ext::gemm,
m.def("_gemm", &blas_ns::gemm,
"Call `gemm` from OneMKL BLAS library to compute "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
m.def("_gemm_batch", &blas_ns::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to compute "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemv", &blas_ext::gemv,
m.def("_gemv", &blas_ns::gemv,
"Call `gemv` from OneMKL BLAS library to compute "
"the matrix-vector product with a general matrix.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
Expand Down
21 changes: 6 additions & 15 deletions dpnp/backend/extensions/blas/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@

#include "dot_common.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
namespace dpnp::extensions::blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace type_utils = dpctl::tensor::type_utils;

template <typename T>
static sycl::event dot_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorX,
const char *vectorX,
const std::int64_t incx,
char *vectorY,
const char *vectorY,
const std::int64_t incy,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *x = reinterpret_cast<T *>(vectorX);
T *y = reinterpret_cast<T *>(vectorY);
const T *x = reinterpret_cast<const T *>(vectorX);
const T *y = reinterpret_cast<const T *>(vectorY);
T *res = reinterpret_cast<T *>(result);

std::stringstream error_msg;
Expand Down Expand Up @@ -99,7 +93,4 @@ struct DotContigFactory
}
}
};
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas
44 changes: 16 additions & 28 deletions dpnp/backend/extensions/blas/dot_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,13 @@

#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
{
namespace dot
namespace dpnp::extensions::blas::dot
{
typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const char *,
const std::int64_t,
char *,
const char *,
const std::int64_t,
char *,
const std::vector<sycl::event> &);
Expand All @@ -61,9 +53,9 @@ namespace py = pybind11;
template <typename dispatchT>
std::pair<sycl::event, sycl::event>
dot_func(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray vectorX,
dpctl::tensor::usm_ndarray vectorY,
dpctl::tensor::usm_ndarray result,
const dpctl::tensor::usm_ndarray &vectorX,
const dpctl::tensor::usm_ndarray &vectorY,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends,
const dispatchT &dot_dispatch_vector)
{
Expand Down Expand Up @@ -109,30 +101,30 @@ std::pair<sycl::event, sycl::event>
"USM allocations are not compatible with the execution queue.");
}

size_t src_nelems = 1;
const int src_nelems = 1;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(result,
src_nelems);

py::ssize_t x_size = vectorX.get_size();
py::ssize_t y_size = vectorY.get_size();
const py::ssize_t x_size = vectorX.get_size();
const py::ssize_t y_size = vectorY.get_size();
const std::int64_t n = x_size;
if (x_size != y_size) {
throw py::value_error("The size of the first input array must be "
"equal to the size of the second input array.");
}

int vectorX_typenum = vectorX.get_typenum();
int vectorY_typenum = vectorY.get_typenum();
int result_typenum = result.get_typenum();
const int vectorX_typenum = vectorX.get_typenum();
const int vectorY_typenum = vectorY.get_typenum();
const int result_typenum = result.get_typenum();

if (result_typenum != vectorX_typenum || result_typenum != vectorY_typenum)
{
throw py::value_error("Given arrays must be of the same type.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);
const int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);

dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
if (dot_fn == nullptr) {
Expand All @@ -144,8 +136,8 @@ std::pair<sycl::event, sycl::event>
char *y_typeless_ptr = vectorY.get_data();
char *r_typeless_ptr = result.get_data();

std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
const std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
const std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
const int x_elemsize = vectorX.get_elemsize();
const int y_elemsize = vectorY.get_elemsize();

Expand Down Expand Up @@ -184,8 +176,4 @@ void init_dot_dispatch_vector(dispatchT dot_dispatch_vector[])
contig;
contig.populate_dispatch_vector(dot_dispatch_vector);
}
} // namespace dot
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas::dot
21 changes: 6 additions & 15 deletions dpnp/backend/extensions/blas/dotc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@

#include "dot_common.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
namespace dpnp::extensions::blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace type_utils = dpctl::tensor::type_utils;

template <typename T>
static sycl::event dotc_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorX,
const char *vectorX,
const std::int64_t incx,
char *vectorY,
const char *vectorY,
const std::int64_t incy,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *x = reinterpret_cast<T *>(vectorX);
T *y = reinterpret_cast<T *>(vectorY);
const T *x = reinterpret_cast<const T *>(vectorX);
const T *y = reinterpret_cast<const T *>(vectorY);
T *res = reinterpret_cast<T *>(result);

std::stringstream error_msg;
Expand Down Expand Up @@ -100,7 +94,4 @@ struct DotcContigFactory
}
};

} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas
21 changes: 6 additions & 15 deletions dpnp/backend/extensions/blas/dotu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@

#include "dot_common.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
namespace dpnp::extensions::blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace type_utils = dpctl::tensor::type_utils;

template <typename T>
static sycl::event dotu_impl(sycl::queue &exec_q,
const std::int64_t n,
char *vectorX,
const char *vectorX,
const std::int64_t incx,
char *vectorY,
const char *vectorY,
const std::int64_t incy,
char *result,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *x = reinterpret_cast<T *>(vectorX);
T *y = reinterpret_cast<T *>(vectorY);
const T *x = reinterpret_cast<const T *>(vectorX);
const T *y = reinterpret_cast<const T *>(vectorY);
T *res = reinterpret_cast<T *>(result);

std::stringstream error_msg;
Expand Down Expand Up @@ -99,7 +93,4 @@ struct DotuContigFactory
}
}
};
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::extensions::blas
Loading

0 comments on commit 77d387d

Please sign in to comment.