Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented BLAS backend for work with oneMKL Interfaces #1981

Merged
merged 11 commits into from
Aug 26, 2024
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ if(_use_onemkl_interfaces)
set(BUILD_FUNCTIONAL_TESTS False)
set(BUILD_EXAMPLES False)
if(_use_onemkl_interfaces_cuda)
# set(ENABLE_CUBLAS_BACKEND True)
set(ENABLE_CUBLAS_BACKEND True)
set(ENABLE_CUSOLVER_BACKEND True)
set(ENABLE_CUFFT_BACKEND True)
# set(ENABLE_CURAND_BACKEND True)
Expand Down
24 changes: 23 additions & 1 deletion dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ set(_module_src
pybind11_add_module(${python_module_name} MODULE ${_module_src})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src})

if(_dpnp_sycl_targets)
# make fat binary
target_compile_options(
${python_module_name}
PRIVATE
-fsycl-targets=${_dpnp_sycl_targets}
)
target_link_options(
${python_module_name}
PRIVATE
-fsycl-targets=${_dpnp_sycl_targets}
)
endif()

if (WIN32)
if (${CMAKE_VERSION} VERSION_LESS "3.27")
# this is a work-around for target_link_options inserting option after -link option, cause
Expand Down Expand Up @@ -69,7 +83,15 @@ if (DPNP_GENERATE_COVERAGE)
target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping)
endif()

target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS)
if(_use_onemkl_interfaces)
target_link_libraries(${python_module_name} PUBLIC onemkl)
target_compile_options(${python_module_name} PRIVATE -DUSE_ONEMKL_INTERFACES)
if(_use_onemkl_interfaces_cuda)
target_compile_options(${python_module_name} PRIVATE -DUSE_ONEMKL_CUBLAS)
endif()
else()
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS)
endif()

install(TARGETS ${python_module_name}
DESTINATION "dpnp/backend/extensions/blas"
Expand Down
10 changes: 10 additions & 0 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,15 @@ PYBIND11_MODULE(_blas_impl, m)
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
py::arg("vectorY"), py::arg("transpose"),
py::arg("depends") = py::list());
m.def(
"_row_major_is_available",
[](void) {
#if defined(USE_ONEMKL_CUBLAS)
return false;
#else
return true;
#endif // USE_ONEMKL_CUBLAS
},
"Check if the onemkl::blas::row_major can be used.");
}
}
16 changes: 8 additions & 8 deletions dpnp/backend/extensions/blas/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ static sycl::event dot_impl(sycl::queue &exec_q,

sycl::event dot_event;
try {
dot_event = mkl_blas::row_major::dot(exec_q,
n, // size of the input vectors
x, // Pointer to vector x.
incx, // Stride of vector x.
y, // Pointer to vector y.
incy, // Stride of vector y.
res, // Pointer to result.
depends);
dot_event = mkl_blas::column_major::dot(exec_q,
n, // size of the input vectors
x, // Pointer to vector x.
incx, // Stride of vector x.
y, // Pointer to vector y.
incy, // Stride of vector y.
res, // Pointer to result.
depends);
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during dot() call:\nreason: "
Expand Down
17 changes: 9 additions & 8 deletions dpnp/backend/extensions/blas/dotc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ static sycl::event dotc_impl(sycl::queue &exec_q,

sycl::event dotc_event;
try {
dotc_event = mkl_blas::row_major::dotc(exec_q,
n, // size of the input vectors
x, // Pointer to vector x.
incx, // Stride of vector x.
y, // Pointer to vector y.
incy, // Stride of vector y.
res, // Pointer to result.
depends);
dotc_event =
mkl_blas::column_major::dotc(exec_q,
n, // size of the input vectors
x, // Pointer to vector x.
incx, // Stride of vector x.
y, // Pointer to vector y.
incy, // Stride of vector y.
res, // Pointer to result.
depends);
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during dotc() call:\nreason: "
Expand Down
17 changes: 9 additions & 8 deletions dpnp/backend/extensions/blas/dotu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ static sycl::event dotu_impl(sycl::queue &exec_q,

sycl::event dotu_event;
try {
dotu_event = mkl_blas::row_major::dotu(exec_q,
n, // size of the input vectors
x, // Pointer to vector x.
incx, // Stride of vector x.
y, // Pointer to vector y.
incy, // Stride of vector y.
res, // Pointer to result.
depends);
dotu_event =
mkl_blas::column_major::dotu(exec_q,
n, // size of the input vectors
x, // Pointer to vector x.
incx, // Stride of vector x.
y, // Pointer to vector y.
incy, // Stride of vector y.
res, // Pointer to result.
depends);
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during dotu() call:\nreason: "
Expand Down
4 changes: 4 additions & 0 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc,
const std::vector<sycl::event> &deps) -> sycl::event {
if (is_row_major) {
#if defined(USE_ONEMKL_CUBLAS)
throw py::value_error("Input matrices are not f-contiguous");
#else
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
alpha, a, lda, b, ldb, beta, c,
ldc, deps);
#endif // USE_ONEMKL_CUBLAS
}
else {
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
Expand Down
13 changes: 10 additions & 3 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,14 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
const std::int64_t batch_size,
const std::vector<sycl::event> &deps) -> sycl::event {
if (is_row_major) {
#if defined(USE_ONEMKL_CUBLAS)
throw py::value_error(
"last 2-dimensions of input matrices are not f-contiguous");
#else
return mkl_blas::row_major::gemm_batch(
q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
strideb, beta, c, ldc, stridec, batch_size, deps);
#endif // USE_ONEMKL_CUBLAS
}
else {
return mkl_blas::column_major::gemm_batch(
Expand Down Expand Up @@ -273,11 +278,12 @@ std::tuple<sycl::event, sycl::event, bool>
standardize_strides_to_nonzero(b_stride, b_shape);
standardize_strides_to_nonzero(c_stride, c_shape);
const bool A_base_is_f_contig =
a_stride[1] == 1 && a_stride[2] == a_shape[1];
(a_stride[1] == 1 || a_stride[1] == a_shape[2]) &&
a_stride[2] == a_shape[1];
const bool A_base_is_c_contig =
a_stride[1] == a_shape[2] && a_stride[2] == 1;
const bool B_base_is_f_contig =
b_stride[1] == 1 && b_stride[2] == b_shape[1];
b_stride[1] == 1 && (b_stride[2] == b_shape[1] || b_stride[2] == 1);
const bool B_base_is_c_contig =
b_stride[1] == b_shape[2] && b_stride[2] == 1;
const bool C_base_is_f_contig =
Expand Down Expand Up @@ -380,7 +386,8 @@ struct GemmBatchContigFactory
fnT get()
{
if constexpr (types::GemmBatchTypePairSupportFactory<Tab,
Tc>::is_defined) {
Tc>::is_defined)
{
return gemm_batch_impl<Tab, Tc>;
}
else {
Expand Down
7 changes: 6 additions & 1 deletion dpnp/backend/extensions/blas/gemv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
T beta, T *y, const std::int64_t incy,
const std::vector<sycl::event> &deps) -> sycl::event {
if (is_row_major) {
#if defined(USE_ONEMKL_CUBLAS)
throw py::value_error("Input matrix is not f-contiguous");
#else
return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda,
x, incx, beta, y, incy, deps);
#endif // USE_ONEMKL_CUBLAS
}
else {
return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a,
Expand Down Expand Up @@ -223,7 +227,8 @@ std::pair<sycl::event, sycl::event>
const int vectorY_typenum = vectorY.get_typenum();

if (matrixA_typenum != vectorX_typenum ||
matrixA_typenum != vectorY_typenum) {
matrixA_typenum != vectorY_typenum)
{
throw py::value_error("Given arrays must be of the same type.");
}

Expand Down
4 changes: 4 additions & 0 deletions dpnp/backend/extensions/blas/types_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ template <typename Tab, typename Tc>
struct GemmTypePairSupportFactory
{
static constexpr bool is_defined = std::disjunction<
#if !defined(USE_ONEMKL_INTERFACES)
dpctl_td_ns::TypePairDefinedEntry<Tab, std::int8_t, Tc, std::int32_t>,
dpctl_td_ns::TypePairDefinedEntry<Tab, std::int8_t, Tc, float>,
#endif // USE_ONEMKL_INTERFACES
dpctl_td_ns::TypePairDefinedEntry<Tab, sycl::half, Tc, float>,
dpctl_td_ns::TypePairDefinedEntry<Tab, sycl::half, Tc, sycl::half>,
dpctl_td_ns::TypePairDefinedEntry<Tab, float, Tc, float>,
Expand Down Expand Up @@ -140,8 +142,10 @@ template <typename Tab, typename Tc>
struct GemmBatchTypePairSupportFactory
{
static constexpr bool is_defined = std::disjunction<
#if !defined(USE_ONEMKL_INTERFACES)
dpctl_td_ns::TypePairDefinedEntry<Tab, std::int8_t, Tc, std::int32_t>,
dpctl_td_ns::TypePairDefinedEntry<Tab, std::int8_t, Tc, float>,
#endif // USE_ONEMKL_INTERFACES
dpctl_td_ns::TypePairDefinedEntry<Tab, sycl::half, Tc, float>,
dpctl_td_ns::TypePairDefinedEntry<Tab, sycl::half, Tc, sycl::half>,
dpctl_td_ns::TypePairDefinedEntry<Tab, float, Tc, float>,
Expand Down
55 changes: 47 additions & 8 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,28 @@ def _get_result_shape(x1, x2, out, np_flag):

def _gemm_batch_matmul(exec_q, x1, x2, res):
# arrays here are already at least 3D, make them 3D
x1_shape = x1.shape
x2_shape = x2.shape
x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
orig_shape = res.shape
res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1]))
if not bi._row_major_is_available():
tmp = x1
x1 = x2
x2 = tmp
x1_shape = x1.shape
x2_shape = x2.shape
if not x1.flags.c_contiguous:
x1 = dpnp.asarray(x1, order="C")
if not x2.flags.c_contiguous:
x2 = dpnp.asarray(x2, order="C")
x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
res = dpnp.reshape(res, (-1, orig_shape[-1], orig_shape[-2]))
x1 = x1.transpose(0, 2, 1)
x2 = x2.transpose(0, 2, 1)
else:
x1_shape = x1.shape
x2_shape = x2.shape
x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1]))
res_shape = res.shape

# gemm_batch does not handle negative strides, make a copy if needed
Expand Down Expand Up @@ -383,7 +399,7 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
.reshape(res_shape[1], res_shape[2], batch_size)
.transpose(2, 0, 1)
)
else:
elif bi._row_major_is_available():
if res_is_c_contig:
# read data of each 2D array in the batch in "C" order and
# write it in "F" order
Expand Down Expand Up @@ -419,6 +435,12 @@ def _gemm_matmul(exec_q, x1, x2, res):
if res.flags.c_contiguous is True:
# read data in "C" order and write it in "F" order
res = dpnp.ravel(res, order="C").reshape(res.shape, order="F")
elif (
not bi._row_major_is_available() and res.flags.f_contiguous is True
):
# read data in "C" order and write it in "C" order
# make result similar for row_major call
res = dpnp.ravel(res, order="C").reshape(res.shape, order="C")

return res

Expand Down Expand Up @@ -794,16 +816,25 @@ def dpnp_matmul(
x1 = dpnp.reshape(x1, x1.size)
x2 = dpnp.reshape(x2, x2_shape[-2:])
res_shape = (x2_shape[-1],)
if not bi._row_major_is_available() and x2.flags.c_contiguous:
x2 = dpnp.asarray(x2, order="F")
elif x1_is_2D and x2_is_1D:
call_flag = "gemv"
x1 = dpnp.reshape(x1, x1_shape[-2:])
x2 = dpnp.reshape(x2, x2.size)
res_shape = (x1_shape[-2],)
if not bi._row_major_is_available() and x1.flags.c_contiguous:
x1 = dpnp.asarray(x1, order="F")
elif x1_is_2D and x2_is_2D:
call_flag = "gemm"
x1 = dpnp.reshape(x1, x1_shape[-2:])
x2 = dpnp.reshape(x2, x2_shape[-2:])
res_shape = (x1_shape[-2], x2_shape[-1])
if not bi._row_major_is_available():
if x1.flags.c_contiguous:
x1 = dpnp.asarray(x1, order="F")
if x2.flags.c_contiguous:
x2 = dpnp.asarray(x2, order="F")
npolina4 marked this conversation as resolved.
Show resolved Hide resolved
elif x1_base_is_1D:
# TODO: implement gemv_batch to use it here with transpose
call_flag = "gemm_batch"
Expand Down Expand Up @@ -839,6 +870,14 @@ def dpnp_matmul(
x2_contig_flag, _, x2_f = _define_contig_flag(x2)

res_order = "F" if (x1_f and x2_f and call_flag == "gemm") else "C"

if bi._row_major_is_available():
array_order = res_order
elif (call_flag == "gemm") or (call_flag == "gemv"):
array_order = "F"
else:
array_order = "C"

result = _create_result_array(
x1,
x2,
Expand All @@ -862,13 +901,13 @@ def dpnp_matmul(
x1,
copy_flag=not x1_contig_flag,
dtype=compute_dtype,
order=res_order,
order=array_order,
)
x2 = _copy_array(
x2,
copy_flag=not x2_contig_flag,
dtype=compute_dtype,
order=res_order,
order=array_order,
)

if call_flag == "gemv":
Expand Down
2 changes: 1 addition & 1 deletion scripts/build_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def run(
type=str,
)
driver.add_argument(
"--onemkl_interfaces",
"--onemkl-interfaces",
help="Build using oneMKL Interfaces",
dest="onemkl_interfaces",
action="store_true",
Expand Down
Loading