Skip to content

Commit

Permalink
Fix2130 (kokkos#2132)
Browse files Browse the repository at this point in the history
* Fix kokkos#2130

- Do not call BsrMatrix spmv impl if block size is 1
- Instead, convert it to unmanaged CrsMatrix and call spmv again
  - cuSPARSE returned an error code in this case
  - Better performance

* Formatting

* Remove redundant remove_pointer_t

Handle is already a non-pointer type
  • Loading branch information
brian-kelley committed Mar 14, 2024
1 parent bdbaeae commit 629c423
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
44 changes: 35 additions & 9 deletions sparse/src/KokkosSparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,20 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
"KokkosSparse::spmv: Output Vector must be non-const.");

// Check that A, X, Y types match that of the Handle
static_assert(
std::is_same_v<AMatrix, typename Handle::AMatrixType>,
"KokkosSparse::spmv: AMatrix must be identical to Handle::AMatrixType");
static_assert(
std::is_same_v<XVector, typename Handle::XVectorType>,
"KokkosSparse::spmv: XVector must be identical to Handle::XVectorType");
static_assert(
std::is_same_v<YVector, typename Handle::YVectorType>,
"KokkosSparse::spmv: YVector must be identical to Handle::YVectorType");
// But only check this if Handle is the user-facing type (SPMVHandle).
// We may internally call spmv with SPMVHandleImpl, which does not include
// the matrix and vector types.
if constexpr (KokkosSparse::Impl::is_spmv_handle_v<Handle>) {
static_assert(
std::is_same_v<AMatrix, typename Handle::AMatrixType>,
"KokkosSparse::spmv: AMatrix must be identical to Handle::AMatrixType");
static_assert(
std::is_same_v<XVector, typename Handle::XVectorType>,
"KokkosSparse::spmv: XVector must be identical to Handle::XVectorType");
static_assert(
std::is_same_v<YVector, typename Handle::YVectorType>,
"KokkosSparse::spmv: YVector must be identical to Handle::YVectorType");
}

constexpr bool isBSR = Experimental::is_bsr_matrix_v<AMatrix>;

Expand Down Expand Up @@ -167,6 +172,7 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
return;
}

// Get the "impl" parent class of Handle, if it's not already the impl
using HandleImpl = typename Handle::ImplType;

using ACrs_Internal = CrsMatrix<
Expand All @@ -181,6 +187,26 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
using AMatrix_Internal =
std::conditional_t<isBSR, ABsr_Internal, ACrs_Internal>;

// Intercept special case: A is a BsrMatrix with blockDim() == 1
// This is exactly equivalent to CrsMatrix (more performant)
// and cuSPARSE actually errors out in that case.
//
// This relies on the fact that this codepath will always be taken for
// this particular matrix (so internally, this handle is only ever used for
// Crs)
if constexpr (isBSR) {
if (A.blockDim() == 1) {
// Construct an ACrs_Internal (unmanaged memory) from A's views
typename ACrs_Internal::row_map_type rowmap(A.graph.row_map);
typename ACrs_Internal::index_type entries(A.graph.entries);
typename ACrs_Internal::values_type values(A.values);
ACrs_Internal ACrs(std::string{}, A.numRows(), A.numCols(), A.nnz(),
values, rowmap, entries);
spmv(space, handle->get_impl(), mode, alpha, ACrs, x, beta, y);
return;
}
}

AMatrix_Internal A_i(A);

// Note: data_type of a View includes both the scalar and rank
Expand Down
29 changes: 28 additions & 1 deletion sparse/src/KokkosSparse_spmv_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ template <class ExecutionSpace, class MemorySpace, class Scalar, class Offset,
class Ordinal>
struct SPMVHandleImpl {
using ExecutionSpaceType = ExecutionSpace;
// This is its own ImplType
using ImplType =
SPMVHandleImpl<ExecutionSpace, MemorySpace, Scalar, Offset, Ordinal>;
// Do not allow const qualifier on Scalar, Ordinal, Offset (otherwise this
// type won't match the ETI'd type). Users should not use SPMVHandleImpl
// directly and SPMVHandle explicitly removes const, so this should never
Expand All @@ -268,6 +271,10 @@ struct SPMVHandleImpl {
void set_exec_space(const ExecutionSpace& exec) {
if (tpl) tpl->set_exec_space(exec);
}

/// Get the SPMVAlgorithm used by this handle
SPMVAlgorithm get_algorithm() const { return this->algo; }

bool is_set_up = false;
const SPMVAlgorithm algo = SPMV_DEFAULT;
TPL_SpMV_Data<ExecutionSpace>* tpl = nullptr;
Expand Down Expand Up @@ -385,9 +392,29 @@ struct SPMVHandle
}

/// Get the SPMVAlgorithm used by this handle
SPMVAlgorithm get_algorithm() const { return this->algo; }
SPMVAlgorithm get_algorithm() const {
// Note: get_algorithm is also a method of parent ImplType, but for
// documentation purposes it should appear directly in the public interface
// of SPMVHandle
return this->algo;
}

/// Get pointer to this as the impl type
ImplType* get_impl() { return static_cast<ImplType*>(this); }
};

namespace Impl {
template <typename>
struct is_spmv_handle : public std::false_type {};
template <typename... P>
struct is_spmv_handle<SPMVHandle<P...>> : public std::true_type {};
template <typename... P>
struct is_spmv_handle<const SPMVHandle<P...>> : public std::true_type {};

template <typename T>
inline constexpr bool is_spmv_handle_v = is_spmv_handle<T>::value;
} // namespace Impl

} // namespace KokkosSparse

#endif

0 comments on commit 629c423

Please sign in to comment.