Skip to content

Commit

Permalink
add fallback to dpcpp sparselib csr
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Apr 20, 2023
1 parent 4d2daf2 commit 8108f0f
Showing 1 changed file with 131 additions and 110 deletions.
241 changes: 131 additions & 110 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,80 @@ void load_balance_spmv(std::shared_ptr<const DpcppExecutor> exec,
}


template <typename ValueType, typename IndexType>
bool try_general_sparselib_spmv(std::shared_ptr<const DpcppExecutor> exec,
const ValueType host_alpha,
const matrix::Csr<ValueType, IndexType>* a,
const matrix::Dense<ValueType>* b,
const ValueType host_beta,
matrix::Dense<ValueType>* c)
{
bool try_sparselib = !is_complex<ValueType>();
if (try_sparselib) {
oneapi::mkl::sparse::matrix_handle_t mat_handle;
oneapi::mkl::sparse::init_matrix_handle(&mat_handle);
oneapi::mkl::sparse::set_csr_data(
mat_handle, IndexType(a->get_size()[0]),
IndexType(a->get_size()[1]), oneapi::mkl::index_base::zero,
const_cast<IndexType*>(a->get_const_row_ptrs()),
const_cast<IndexType*>(a->get_const_col_idxs()),
const_cast<ValueType*>(a->get_const_values()));
if (b->get_size()[1] == 1 && b->get_stride() == 1) {
oneapi::mkl::sparse::gemv(
*exec->get_queue(), oneapi::mkl::transpose::nontrans,
host_alpha, mat_handle,
const_cast<ValueType*>(b->get_const_values()), host_beta,
c->get_values());
} else {
oneapi::mkl::sparse::gemm(
*exec->get_queue(), oneapi::mkl::layout::row_major,
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, host_alpha, mat_handle,
const_cast<ValueType*>(b->get_const_values()), b->get_size()[1],
b->get_stride(), host_beta, c->get_values(), c->get_stride());
}
oneapi::mkl::sparse::release_matrix_handle(&mat_handle);
}
return try_sparselib;
}


template <typename MatrixValueType, typename InputValueType,
typename OutputValueType, typename IndexType,
typename = std::enable_if_t<
!std::is_same<MatrixValueType, InputValueType>::value ||
!std::is_same<MatrixValueType, OutputValueType>::value>>
bool try_sparselib_spmv(std::shared_ptr<const DpcppExecutor> exec,
const matrix::Csr<MatrixValueType, IndexType>* a,
const matrix::Dense<InputValueType>* b,
matrix::Dense<OutputValueType>* c,
const matrix::Dense<MatrixValueType>* alpha = nullptr,
const matrix::Dense<OutputValueType>* beta = nullptr)
{
// TODO: support sparselib mixed
return false;
}

template <typename ValueType, typename IndexType>
bool try_sparselib_spmv(std::shared_ptr<const DpcppExecutor> exec,
const matrix::Csr<ValueType, IndexType>* a,
const matrix::Dense<ValueType>* b,
matrix::Dense<ValueType>* c,
const matrix::Dense<ValueType>* alpha = nullptr,
const matrix::Dense<ValueType>* beta = nullptr)
{
// onemkl only supports host scalar
if (alpha) {
return try_general_sparselib_spmv(
exec, exec->copy_val_to_host(alpha->get_const_values()), a, b,
exec->copy_val_to_host(beta->get_const_values()), c);
} else {
return try_general_sparselib_spmv(exec, one<ValueType>(), a, b,
zero<ValueType>(), c);
}
}


} // namespace host_kernel


Expand Down Expand Up @@ -1267,61 +1341,36 @@ void spmv(std::shared_ptr<const DpcppExecutor> exec,
return items_per_thread == compiled_info;
},
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c);
} else if (a->get_strategy()->get_name() == "classical") {
IndexType max_length_per_row = 0;
using Tcsr = matrix::Csr<MatrixValueType, IndexType>;
if (auto strategy =
std::dynamic_pointer_cast<const typename Tcsr::classical>(
a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else if (auto strategy = std::dynamic_pointer_cast<
const typename Tcsr::automatical>(a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else {
GKO_NOT_SUPPORTED(a->get_strategy());
} else {
bool use_classical = true;
if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
use_classical = !host_kernel::try_sparselib_spmv(exec, a, b, c);
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
[&max_length_per_row](int compiled_info) {
return max_length_per_row >= compiled_info;
},
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c);
} else if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
if constexpr (!is_complex<MatrixValueType>() &&
std::is_same<MatrixValueType, InputValueType>::value &&
std::is_same<MatrixValueType, OutputValueType>::value) {
oneapi::mkl::sparse::matrix_handle_t mat_handle;
oneapi::mkl::sparse::init_matrix_handle(&mat_handle);
oneapi::mkl::sparse::set_csr_data(
mat_handle, IndexType(a->get_size()[0]),
IndexType(a->get_size()[1]), oneapi::mkl::index_base::zero,
const_cast<IndexType*>(a->get_const_row_ptrs()),
const_cast<IndexType*>(a->get_const_col_idxs()),
const_cast<MatrixValueType*>(a->get_const_values()));
if (b->get_size()[1] == 1 && b->get_stride() == 1) {
oneapi::mkl::sparse::gemv(
*exec->get_queue(), oneapi::mkl::transpose::nontrans,
one<MatrixValueType>(), mat_handle,
const_cast<MatrixValueType*>(b->get_const_values()),
zero<MatrixValueType>(), c->get_values());
if (use_classical) {
IndexType max_length_per_row = 0;
using Tcsr = matrix::Csr<MatrixValueType, IndexType>;
if (auto strategy =
std::dynamic_pointer_cast<const typename Tcsr::classical>(
a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else if (auto strategy = std::dynamic_pointer_cast<
const typename Tcsr::automatical>(
a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else {
oneapi::mkl::sparse::gemm(
*exec->get_queue(), oneapi::mkl::layout::row_major,
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, one<MatrixValueType>(),
mat_handle,
const_cast<MatrixValueType*>(b->get_const_values()),
b->get_size()[1], b->get_stride(), zero<MatrixValueType>(),
c->get_values(), c->get_stride());
// as a fall-back: use average row length, at least 1
max_length_per_row = a->get_num_stored_elements() /
std::max<size_type>(a->get_size()[0], 1);
}
oneapi::mkl::sparse::release_matrix_handle(&mat_handle);
} else {
GKO_NOT_IMPLEMENTED;
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
[&max_length_per_row](int compiled_info) {
return max_length_per_row >= compiled_info;
},
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c);
}
} else {
GKO_NOT_IMPLEMENTED;
}
}

Expand Down Expand Up @@ -1349,64 +1398,6 @@ void advanced_spmv(std::shared_ptr<const DpcppExecutor> exec,
}
if (a->get_strategy()->get_name() == "load_balance") {
host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta);
} else if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
if constexpr (!is_complex<MatrixValueType>() &&
std::is_same<MatrixValueType, InputValueType>::value &&
std::is_same<MatrixValueType, OutputValueType>::value) {
oneapi::mkl::sparse::matrix_handle_t mat_handle;
oneapi::mkl::sparse::init_matrix_handle(&mat_handle);
oneapi::mkl::sparse::set_csr_data(
mat_handle, IndexType(a->get_size()[0]),
IndexType(a->get_size()[1]), oneapi::mkl::index_base::zero,
const_cast<IndexType*>(a->get_const_row_ptrs()),
const_cast<IndexType*>(a->get_const_col_idxs()),
const_cast<MatrixValueType*>(a->get_const_values()));
if (b->get_size()[1] == 1 && b->get_stride() == 1) {
oneapi::mkl::sparse::gemv(
*exec->get_queue(), oneapi::mkl::transpose::nontrans,
exec->copy_val_to_host(alpha->get_const_values()),
mat_handle,
const_cast<MatrixValueType*>(b->get_const_values()),
exec->copy_val_to_host(beta->get_const_values()),
c->get_values());
} else {
oneapi::mkl::sparse::gemm(
*exec->get_queue(), oneapi::mkl::layout::row_major,
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
exec->copy_val_to_host(alpha->get_const_values()),
mat_handle,
const_cast<MatrixValueType*>(b->get_const_values()),
b->get_size()[1], b->get_stride(),
exec->copy_val_to_host(beta->get_const_values()),
c->get_values(), c->get_stride());
}
oneapi::mkl::sparse::release_matrix_handle(&mat_handle);
} else {
GKO_NOT_IMPLEMENTED;
}
} else if (a->get_strategy()->get_name() == "classical") {
IndexType max_length_per_row = 0;
using Tcsr = matrix::Csr<MatrixValueType, IndexType>;
if (auto strategy =
std::dynamic_pointer_cast<const typename Tcsr::classical>(
a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else if (auto strategy = std::dynamic_pointer_cast<
const typename Tcsr::automatical>(a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else {
GKO_NOT_SUPPORTED(a->get_strategy());
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
[&max_length_per_row](int compiled_info) {
return max_length_per_row >= compiled_info;
},
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c, alpha,
beta);
} else if (a->get_strategy()->get_name() == "merge_path") {
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
Expand All @@ -1421,7 +1412,37 @@ void advanced_spmv(std::shared_ptr<const DpcppExecutor> exec,
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c, alpha,
beta);
} else {
GKO_NOT_IMPLEMENTED;
bool use_classical = true;
if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
use_classical =
!host_kernel::try_sparselib_spmv(exec, a, b, c, alpha, beta);
}
if (use_classical) {
IndexType max_length_per_row = 0;
using Tcsr = matrix::Csr<MatrixValueType, IndexType>;
if (auto strategy =
std::dynamic_pointer_cast<const typename Tcsr::classical>(
a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else if (auto strategy = std::dynamic_pointer_cast<
const typename Tcsr::automatical>(
a->get_strategy())) {
max_length_per_row = strategy->get_max_length_per_row();
} else {
// as a fall-back: use average row length, at least 1
max_length_per_row = a->get_num_stored_elements() /
std::max<size_type>(a->get_size()[0], 1);
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
[&max_length_per_row](int compiled_info) {
return max_length_per_row >= compiled_info;
},
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c,
alpha, beta);
}
}
}

Expand Down

0 comments on commit 8108f0f

Please sign in to comment.