diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index ee19b5cd594..59105572137 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -185,6 +185,12 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_STEP_2_KERNEL); namespace lower_trs { +GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL() +GKO_NOT_COMPILED(GKO_HOOK_MODULE); + +GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL() +GKO_NOT_COMPILED(GKO_HOOK_MODULE); + template GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); diff --git a/core/solver/lower_trs.cpp b/core/solver/lower_trs.cpp index b0345cf586a..987e8609573 100644 --- a/core/solver/lower_trs.cpp +++ b/core/solver/lower_trs.cpp @@ -48,23 +48,32 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace gko { namespace solver { - - namespace lower_trs { GKO_REGISTER_OPERATION(generate, lower_trs::generate); +GKO_REGISTER_OPERATION(init_struct, lower_trs::init_struct); +GKO_REGISTER_OPERATION(should_perform_transpose, + lower_trs::should_perform_transpose); GKO_REGISTER_OPERATION(solve, lower_trs::solve); } // namespace lower_trs +template +void LowerTrs::init_trs_solve_struct() +{ + this->get_executor()->run(lower_trs::make_init_struct(this->solve_struct_)); +} + + template void LowerTrs::generate() { - this->get_executor()->run( - lower_trs::make_generate(gko::lend(system_matrix_), gko::lend(b_))); + this->get_executor()->run(lower_trs::make_generate( + gko::lend(system_matrix_), gko::lend(this->solve_struct_), + parameters_.num_rhs)); } @@ -77,8 +86,26 @@ void LowerTrs::apply_impl(const LinOp *b, LinOp *x) const auto dense_b = as(b); auto dense_x = as(x); - exec->run( - lower_trs::make_solve(gko::lend(system_matrix_), dense_b, dense_x)); + // This kernel checks if a transpose is needed for the multiple rhs case. + // Currently only the algorithm for CUDA version <=9.1 needs this + // transposition due to the limitation in the cusparse algorithm. The other + // executors (omp and reference) do not use the transpose (trans_x and + // trans_b) and hence are passed in empty pointers. + bool do_transpose = false; + std::shared_ptr trans_b; + std::shared_ptr trans_x; + this->get_executor()->run( + lower_trs::make_should_perform_transpose(do_transpose)); + if (do_transpose) { + trans_b = Vector::create(exec, gko::transpose(dense_b->get_size())); + trans_x = Vector::create(exec, gko::transpose(dense_x->get_size())); + } else { + trans_b = Vector::create(exec); + trans_x = Vector::create(exec); + } + exec->run(lower_trs::make_solve( + gko::lend(system_matrix_), gko::lend(this->solve_struct_), + gko::lend(trans_b), gko::lend(trans_x), dense_b, dense_x)); } diff --git a/core/solver/lower_trs_kernels.hpp b/core/solver/lower_trs_kernels.hpp index 611c898ae44..b2c931d76cf 100644 --- a/core/solver/lower_trs_kernels.hpp +++ b/core/solver/lower_trs_kernels.hpp @@ -40,6 +40,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include namespace gko { @@ -47,19 +48,34 @@ namespace kernels { namespace lower_trs { +#define GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL() \ + void should_perform_transpose(std::shared_ptr exec, \ + bool &do_transpose) + + +#define GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL() \ + void init_struct(std::shared_ptr exec, \ + std::shared_ptr &solve_struct) + + #define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ void generate(std::shared_ptr exec, \ const matrix::Csr<_vtype, _itype> *matrix, \ - const matrix::Dense<_vtype> *b) + solver::SolveStruct *solve_struct, \ + const gko::size_type num_rhs) -#define GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(_vtype, _itype) \ - void solve(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype> *matrix, \ +#define GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(_vtype, _itype) \ + void solve(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype> *matrix, \ + const solver::SolveStruct *solve_struct, \ + matrix::Dense<_vtype> *trans_b, matrix::Dense<_vtype> *trans_x, \ const matrix::Dense<_vtype> *b, matrix::Dense<_vtype> *x) #define GKO_DECLARE_ALL_AS_TEMPLATES \ + GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \ + GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL(); \ template \ GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(ValueType, IndexType); \ template \ diff --git a/core/test/solver/lower_trs.cpp b/core/test/solver/lower_trs.cpp index fca789448f9..f08ad46bf87 100644 --- a/core/test/solver/lower_trs.cpp +++ b/core/test/solver/lower_trs.cpp @@ -43,6 +43,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/test/utils/assertions.hpp" + + namespace { diff --git a/cuda/base/cusparse_bindings.hpp b/cuda/base/cusparse_bindings.hpp index b79cf00f193..e9da6b9952b 100644 --- a/cuda/base/cusparse_bindings.hpp +++ b/cuda/base/cusparse_bindings.hpp @@ -34,6 +34,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GKO_CUDA_BASE_CUSPARSE_BINDINGS_HPP_ +#include #include @@ -44,6 +45,87 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace gko { +namespace solver { + + +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + + +struct SolveStruct { + int algorithm; + csrsm2Info_t solve_info; + cusparseSolvePolicy_t policy; + cusparseMatDescr_t factor_descr; + size_t factor_work_size; + void *factor_work_vec; + SolveStruct() + { + factor_work_vec = nullptr; + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&factor_descr)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatIndexBase(factor_descr, CUSPARSE_INDEX_BASE_ZERO)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatType(factor_descr, CUSPARSE_MATRIX_TYPE_GENERAL)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatDiagType(factor_descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateCsrsm2Info(&solve_info)); + algorithm = 0; + policy = CUSPARSE_SOLVE_POLICY_USE_LEVEL; + } + SolveStruct(const SolveStruct &) : SolveStruct() {} + SolveStruct(SolveStruct &&) : SolveStruct() {} + SolveStruct &operator=(const SolveStruct &) { return *this; } + SolveStruct &operator=(SolveStruct &&) { return *this; } + ~SolveStruct() + { + cusparseDestroyMatDescr(factor_descr); + if (solve_info) { + cusparseDestroyCsrsm2Info(solve_info); + } + if (factor_work_vec != nullptr) { + cudaFree(factor_work_vec); + factor_work_vec = nullptr; + } + } +}; + + +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + + +struct SolveStruct { + cusparseSolveAnalysisInfo_t solve_info; + cusparseMatDescr_t factor_descr; + SolveStruct() + { + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseCreateSolveAnalysisInfo(&solve_info)); + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&factor_descr)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatIndexBase(factor_descr, CUSPARSE_INDEX_BASE_ZERO)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatType(factor_descr, CUSPARSE_MATRIX_TYPE_GENERAL)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatDiagType(factor_descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); + } + SolveStruct(const SolveStruct &) : SolveStruct() {} + SolveStruct(SolveStruct &&) : SolveStruct() {} + SolveStruct &operator=(const SolveStruct &) { return *this; } + SolveStruct &operator=(SolveStruct &&) { return *this; } + ~SolveStruct() + { + cusparseDestroyMatDescr(factor_descr); + cusparseDestroySolveAnalysisInfo(solve_info); + } +}; + + +#endif + + +} // namespace solver + + namespace kernels { namespace cuda { /** @@ -491,6 +573,264 @@ inline void destroy(cusparseMatDescr_t descr) } +// CUDA versions 9.2 and above have csrsm2. +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + + +#define GKO_BIND_CUSPARSE32_BUFFERSIZEEXT(ValueType, CusparseName) \ + inline void buffer_size_ext( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int32 *csrRowPtr, \ + const int32 *csrColInd, const ValueType *rhs, int32 sol_size, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + size_t *factor_work_size) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ + as_culibs_type(one), descr, as_culibs_type(csrVal), \ + csrRowPtr, csrColInd, as_culibs_type(rhs), sol_size, \ + factor_info, policy, factor_work_size)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, CusparseName) \ + inline void buffer_size_ext( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int64 *csrRowPtr, \ + const int64 *csrColInd, const ValueType *rhs, int64 sol_size, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + size_t *factor_work_size) GKO_NOT_IMPLEMENTED; \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSPARSE32_BUFFERSIZEEXT(float, cusparseScsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE32_BUFFERSIZEEXT(double, cusparseDcsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE32_BUFFERSIZEEXT(std::complex, + cusparseCcsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE32_BUFFERSIZEEXT(std::complex, + cusparseZcsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(float, cusparseScsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(double, cusparseDcsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(std::complex, + cusparseCcsrsm2_bufferSizeExt); +GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(std::complex, + cusparseZcsrsm2_bufferSizeExt); +template +GKO_BIND_CUSPARSE32_BUFFERSIZEEXT(ValueType, detail::not_implemented); +template +GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); +#undef GKO_BIND_CUSPARSE32_BUFFERSIZEEXT +#undef GKO_BIND_CUSPARSE64_BUFFERSIZEEXT + + +#define GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS(ValueType, CusparseName) \ + inline void csrsm2_analysis( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int32 *csrRowPtr, \ + const int32 *csrColInd, const ValueType *rhs, int32 sol_size, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + void *factor_work_vec) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ + as_culibs_type(one), descr, as_culibs_type(csrVal), \ + csrRowPtr, csrColInd, as_culibs_type(rhs), sol_size, \ + factor_info, policy, factor_work_vec)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(ValueType, CusparseName) \ + inline void csrsm2_analysis( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int64 *csrRowPtr, \ + const int64 *csrColInd, const ValueType *rhs, int64 sol_size, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + void *factor_work_vec) GKO_NOT_IMPLEMENTED; \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS(float, cusparseScsrsm2_analysis); +GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS(double, cusparseDcsrsm2_analysis); +GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS(std::complex, + cusparseCcsrsm2_analysis); +GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS(std::complex, + cusparseZcsrsm2_analysis); +GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(float, cusparseScsrsm2_analysis); +GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(double, cusparseDcsrsm2_analysis); +GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(std::complex, + cusparseCcsrsm2_analysis); +GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(std::complex, + cusparseZcsrsm2_analysis); +template +GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS(ValueType, detail::not_implemented); +template +GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(ValueType, detail::not_implemented); +#undef GKO_BIND_CUSPARSE32_CSRSM2_ANALYSIS +#undef GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS + + +#define GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(ValueType, CusparseName) \ + inline void csrsm2_solve( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int32 *csrRowPtr, \ + const int32 *csrColInd, ValueType *rhs, int32 sol_stride, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + void *factor_work_vec) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ + as_culibs_type(one), descr, as_culibs_type(csrVal), \ + csrRowPtr, csrColInd, as_culibs_type(rhs), \ + sol_stride, factor_info, policy, factor_work_vec)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(ValueType, CusparseName) \ + inline void csrsm2_solve( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int64 *csrRowPtr, \ + const int64 *csrColInd, ValueType *rhs, int64 sol_stride, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + void *factor_work_vec) GKO_NOT_IMPLEMENTED; \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(float, cusparseScsrsm2_solve); +GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(double, cusparseDcsrsm2_solve); +GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(std::complex, cusparseCcsrsm2_solve); +GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(std::complex, cusparseZcsrsm2_solve); +GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(float, cusparseScsrsm2_solve); +GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(double, cusparseDcsrsm2_solve); +GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(std::complex, cusparseCcsrsm2_solve); +GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(std::complex, cusparseZcsrsm2_solve); +template +GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(ValueType, detail::not_implemented); +template +GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(ValueType, detail::not_implemented); +#undef GKO_BIND_CUSPARSE32_CSRSM2_SOLVE +#undef GKO_BIND_CUSPARSE64_CSRSM2_SOLVE + + +// CUDA_VERSION<=9.1 do not support csrsm2. +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + + +#define GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS(ValueType, CusparseName) \ + inline void csrsm_analysis( \ + cusparseHandle_t handle, cusparseOperation_t trans, size_type m, \ + size_type nnz, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int32 *csrRowPtr, \ + const int32 *csrColInd, cusparseSolveAnalysisInfo_t factor_info) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, trans, m, nnz, descr, as_culibs_type(csrVal), \ + csrRowPtr, csrColInd, factor_info)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, CusparseName) \ + inline void csrsm_analysis( \ + cusparseHandle_t handle, cusparseOperation_t trans, size_type m, \ + size_type nnz, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int64 *csrRowPtr, \ + const int64 *csrColInd, cusparseSolveAnalysisInfo_t factor_info) \ + GKO_NOT_IMPLEMENTED; \ + static_assert(true, \ + "This assert is used to counter the " \ + "false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS(float, cusparseScsrsm_analysis); +GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS(double, cusparseDcsrsm_analysis); +GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS(std::complex, + cusparseCcsrsm_analysis); +GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS(std::complex, + cusparseZcsrsm_analysis); +GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(float, cusparseScsrsm_analysis); +GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(double, cusparseDcsrsm_analysis); +GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(std::complex, + cusparseCcsrsm_analysis); +GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(std::complex, + cusparseZcsrsm_analysis); +template +GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS(ValueType, detail::not_implemented); +template +GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented); +#undef GKO_BIND_CUSPARSE32_CSRSM_ANALYSIS +#undef GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS + +#define GKO_BIND_CUSPARSE32_CSRSM_SOLVE(ValueType, CusparseName) \ + inline void csrsm_solve( \ + cusparseHandle_t handle, cusparseOperation_t trans, size_type m, \ + size_type n, const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int32 *csrRowPtr, \ + const int32 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \ + ValueType *rhs, int32 rhs_stride, ValueType *sol, int32 sol_stride) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, trans, m, n, as_culibs_type(one), descr, \ + as_culibs_type(csrVal), csrRowPtr, csrColInd, \ + factor_info, as_culibs_type(rhs), rhs_stride, \ + as_culibs_type(sol), sol_stride)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_BIND_CUSPARSE64_CSRSM_SOLVE(ValueType, CusparseName) \ + inline void csrsm_solve( \ + cusparseHandle_t handle, cusparseOperation_t trans1, size_type m, \ + size_type n, const ValueType *one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int64 *csrRowPtr, \ + const int64 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \ + ValueType *rhs, int64 rhs_stride, ValueType *sol, int64 sol_stride) \ + GKO_NOT_IMPLEMENTED; \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +GKO_BIND_CUSPARSE32_CSRSM_SOLVE(float, cusparseScsrsm_solve); +GKO_BIND_CUSPARSE32_CSRSM_SOLVE(double, cusparseDcsrsm_solve); +GKO_BIND_CUSPARSE32_CSRSM_SOLVE(std::complex, cusparseCcsrsm_solve); +GKO_BIND_CUSPARSE32_CSRSM_SOLVE(std::complex, cusparseZcsrsm_solve); +GKO_BIND_CUSPARSE64_CSRSM_SOLVE(float, cusparseScsrsm_solve); +GKO_BIND_CUSPARSE64_CSRSM_SOLVE(double, cusparseDcsrsm_solve); +GKO_BIND_CUSPARSE64_CSRSM_SOLVE(std::complex, cusparseCcsrsm_solve); +GKO_BIND_CUSPARSE64_CSRSM_SOLVE(std::complex, cusparseZcsrsm_solve); +template +GKO_BIND_CUSPARSE32_CSRSM_SOLVE(ValueType, detail::not_implemented); +template +GKO_BIND_CUSPARSE64_CSRSM_SOLVE(ValueType, detail::not_implemented); +#undef GKO_BIND_CUSPARSE32_CSRSM_SOLVE +#undef GKO_BIND_CUSPARSE64_CSRSM_SOLVE + + +#endif + + } // namespace cusparse } // namespace cuda } // namespace kernels diff --git a/cuda/components/atomic.cuh b/cuda/components/atomic.cuh index 3dd24eeb72c..8031fe70b7d 100644 --- a/cuda/components/atomic.cuh +++ b/cuda/components/atomic.cuh @@ -96,7 +96,7 @@ GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned int); GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned short int); #endif -#undef GKO_BIND_ATOMIC_HELPER__STRUCTURE +#undef GKO_BIND_ATOMIC_HELPER_STRUCTURE } // namespace detail diff --git a/cuda/solver/lower_trs_kernels.cu b/cuda/solver/lower_trs_kernels.cu index f032d7880e4..051fd11be09 100644 --- a/cuda/solver/lower_trs_kernels.cu +++ b/cuda/solver/lower_trs_kernels.cu @@ -33,11 +33,24 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/lower_trs_kernels.hpp" +#include +#include + + +#include +#include + + #include #include +#include +#include "core/matrix/dense_kernels.hpp" +#include "core/solver/lower_trs_kernels.hpp" +#include "core/synthesizer/implementation_selection.hpp" #include "cuda/base/cusparse_bindings.hpp" +#include "cuda/base/device_guard.hpp" #include "cuda/base/math.hpp" #include "cuda/base/types.hpp" @@ -53,10 +66,98 @@ namespace cuda { namespace lower_trs { +void should_perform_transpose(std::shared_ptr exec, + bool &do_transpose) +{ +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + + + do_transpose = false; + + +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + + + do_transpose = true; + + +#endif +} + + +void init_struct(std::shared_ptr exec, + std::shared_ptr &solve_struct) +{ + solve_struct = + std::shared_ptr(new solver::SolveStruct()); +} + + template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - const matrix::Dense *b) GKO_NOT_IMPLEMENTED; + solver::SolveStruct *solve_struct, const gko::size_type num_rhs) +{ + if (cusparse::is_supported::value) { + auto handle = exec->get_cusparse_handle(); + + +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + + + ValueType one = 1.0; + + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST)); + cusparse::buffer_size_ext( + handle, solve_struct->algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), &one, solve_struct->factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, + solve_struct->solve_info, solve_struct->policy, + &solve_struct->factor_work_size); + + // allocate workspace + if (solve_struct->factor_work_vec != nullptr) { + GKO_ASSERT_NO_CUDA_ERRORS(cudaFree(solve_struct->factor_work_vec)); + } + solve_struct->factor_work_vec = + exec->alloc(solve_struct->factor_work_size); + + cusparse::csrsm2_analysis( + handle, solve_struct->algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), &one, solve_struct->factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, + solve_struct->solve_info, solve_struct->policy, + solve_struct->factor_work_vec); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE)); + + +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + + + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST)); + cusparse::csrsm_analysis( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0], + matrix->get_num_stored_elements(), solve_struct->factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), solve_struct->solve_info); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE)); + + +#endif + + + } else { + GKO_NOT_IMPLEMENTED; + } +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL); @@ -65,8 +166,70 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( template void solve(std::shared_ptr exec, const matrix::Csr *matrix, - const matrix::Dense *b, - matrix::Dense *x) GKO_NOT_IMPLEMENTED; + const solver::SolveStruct *solve_struct, + matrix::Dense *trans_b, matrix::Dense *trans_x, + const matrix::Dense *b, matrix::Dense *x) +{ + using vec = matrix::Dense; + if (cusparse::is_supported::value) { + ValueType one = 1.0; + auto handle = exec->get_cusparse_handle(); + + +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + + + x->copy_from(gko::lend(b)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST)); + cusparse::csrsm2_solve( + handle, solve_struct->algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], + b->get_stride(), matrix->get_num_stored_elements(), &one, + solve_struct->factor_descr, matrix->get_const_values(), + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + x->get_values(), b->get_stride(), solve_struct->solve_info, + solve_struct->policy, solve_struct->factor_work_vec); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE)); + + +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + + + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST)); + if (b->get_stride() == 1) { + auto temp_b = const_cast(b->get_const_values()); + cusparse::csrsm_solve( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0], + b->get_stride(), &one, solve_struct->factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), solve_struct->solve_info, temp_b, + b->get_size()[0], x->get_values(), x->get_size()[0]); + } else { + dense::transpose(exec, trans_b, b); + dense::transpose(exec, trans_x, x); + cusparse::csrsm_solve( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0], + trans_b->get_size()[0], &one, solve_struct->factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), solve_struct->solve_info, + trans_b->get_values(), trans_b->get_size()[1], + trans_x->get_values(), trans_x->get_size()[1]); + dense::transpose(exec, x, trans_x); + } + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE)); + + +#endif + + + } else { + GKO_NOT_IMPLEMENTED; + } +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL); diff --git a/cuda/test/solver/CMakeLists.txt b/cuda/test/solver/CMakeLists.txt index 2db21d26e4c..ede4999a510 100644 --- a/cuda/test/solver/CMakeLists.txt +++ b/cuda/test/solver/CMakeLists.txt @@ -4,3 +4,4 @@ ginkgo_create_test(cgs_kernels) ginkgo_create_test(fcg_kernels) ginkgo_create_test(gmres_kernels) ginkgo_create_test(ir_kernels) +ginkgo_create_test(lower_trs_kernels) diff --git a/cuda/test/solver/lower_trs_kernels.cpp b/cuda/test/solver/lower_trs_kernels.cpp new file mode 100644 index 00000000000..47648a8b0e3 --- /dev/null +++ b/cuda/test/solver/lower_trs_kernels.cpp @@ -0,0 +1,179 @@ +/************************************************************* +Copyright (c) 2017-2019, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include + + +#include +#include + + +#include +#include + + +#include +#include +#include +#include + + +#include "core/solver/lower_trs_kernels.hpp" +#include "core/test/utils.hpp" + + +namespace { + + +class LowerTrs : public ::testing::Test { +protected: + using CsrMtx = gko::matrix::Csr; + using Mtx = gko::matrix::Dense<>; + + LowerTrs() : rand_engine(30) {} + + void SetUp() + { + ASSERT_GT(gko::CudaExecutor::get_num_devices(), 0); + ref = gko::ReferenceExecutor::create(); + cuda = gko::CudaExecutor::create(0, ref); + } + + void TearDown() + { + if (cuda != nullptr) { + ASSERT_NO_THROW(cuda->synchronize()); + } + } + + std::unique_ptr gen_mtx(int num_rows, int num_cols) + { + return gko::test::generate_random_matrix( + num_rows, num_cols, + std::uniform_int_distribution<>(num_cols, num_cols), + std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); + } + + std::unique_ptr gen_l_mtx(int num_rows, int num_cols) + { + return gko::test::generate_random_lower_triangular_matrix( + num_rows, num_cols, false, + std::uniform_int_distribution<>(num_cols, num_cols), + std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); + } + + void initialize_data(int m, int n) + { + mtx = gen_l_mtx(m, m); + b = gen_mtx(m, n); + x = gen_mtx(m, n); + csr_mtx = CsrMtx::create(ref); + mtx->convert_to(csr_mtx.get()); + d_csr_mtx = CsrMtx::create(cuda); + d_x = Mtx::create(cuda); + d_x->copy_from(x.get()); + d_csr_mtx->copy_from(csr_mtx.get()); + b2 = Mtx::create(ref); + d_b2 = Mtx::create(cuda); + d_b2->copy_from(b.get()); + b2->copy_from(b.get()); + } + + std::shared_ptr b; + std::shared_ptr b2; + std::shared_ptr x; + std::shared_ptr mtx; + std::shared_ptr csr_mtx; + std::shared_ptr d_b; + std::shared_ptr d_b2; + std::shared_ptr d_x; + std::shared_ptr d_csr_mtx; + std::shared_ptr ref; + std::shared_ptr cuda; + std::ranlux48 rand_engine; +}; + + +TEST_F(LowerTrs, CudaLowerTrsFlagCheckIsCorrect) +{ + bool trans_flag = true; + bool expected_flag = false; + + +#if (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + + + expected_flag = true; + + +#endif + + + gko::kernels::cuda::lower_trs::should_perform_transpose(cuda, trans_flag); + + ASSERT_EQ(expected_flag, trans_flag); +} + + +TEST_F(LowerTrs, CudaSingleRhsApplyIsEquivalentToRef) +{ + initialize_data(50, 1); + auto lower_trs_factory = gko::solver::LowerTrs<>::build().on(ref); + auto d_lower_trs_factory = gko::solver::LowerTrs<>::build().on(cuda); + auto solver = lower_trs_factory->generate(csr_mtx); + auto d_solver = d_lower_trs_factory->generate(d_csr_mtx); + + solver->apply(b2.get(), x.get()); + d_solver->apply(d_b2.get(), d_x.get()); + + GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); +} + + +TEST_F(LowerTrs, CudaMultipleRhsApplyIsEquivalentToRef) +{ + initialize_data(50, 3); + auto lower_trs_factory = + gko::solver::LowerTrs<>::build().with_num_rhs(3u).on(ref); + auto d_lower_trs_factory = + gko::solver::LowerTrs<>::build().with_num_rhs(3u).on(cuda); + auto solver = lower_trs_factory->generate(csr_mtx); + auto d_solver = d_lower_trs_factory->generate(d_csr_mtx); + + solver->apply(b2.get(), x.get()); + d_solver->apply(d_b2.get(), d_x.get()); + + GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); +} + + +} // namespace diff --git a/include/ginkgo/core/solver/lower_trs.hpp b/include/ginkgo/core/solver/lower_trs.hpp index 1dc79a97450..7a41c670b2a 100644 --- a/include/ginkgo/core/solver/lower_trs.hpp +++ b/include/ginkgo/core/solver/lower_trs.hpp @@ -54,112 +54,7 @@ namespace gko { namespace solver { -template -class LowerTrs; - - -/** - * This struct is used to pass parameters to the - * EnableDefaultLowerTrsFactory::generate() method. It is the - * ComponentsType of LowerTrsFactory. - * - * @tparam ValueType precision of matrix elements - */ -template -struct LowerTrsArgs { - std::shared_ptr system_matrix; - std::shared_ptr> b; - - - LowerTrsArgs(std::shared_ptr system_matrix, - std::shared_ptr> b) - : system_matrix{system_matrix}, b{b} - {} -}; - - -/** - * Declares an Abstract Factory specialized for LowerTrs solver. - * - * @tparam ValueType precision of matrix elements - * @tparam IndexType precision of matrix indexes - */ -template -using LowerTrsFactory = - AbstractFactory, LowerTrsArgs>; - - -/** - * This is an alias for the EnableDefaultFactory mixin, which correctly sets the - * template parameters to enable a subclass of LowerTrsFactory. - * - * @tparam ConcreteFactory the concrete factory which is being implemented - * [CRTP parmeter] - * @tparam ConcreteLowerTrs the concrete LowerTrs type which this factory - * produces, needs to have a constructor which takes - * a const ConcreteFactory *, and a - * const LowerTrsArgs * as parameters. - * @tparam ParametersType a subclass of enable_parameters_type template which - * defines all of the parameters of the factory - * @tparam ValueType precision of matrix elements - * @tparam IndexType precision of matrix indexes - */ -template -using EnableDefaultLowerTrsFactory = - EnableDefaultFactory>; - - -/** - * This macro will generate a default implementation of a LowerTrsFactory for - * the LowerTrs subclass it is defined in. - * - * This macro is very similar to the macro #ENABLE_LIN_OP_FACTORY(). A more - * detailed description of the use of these type of macros can be found there. - * - * @param _lower_trs concrete operator for which the factory is to be created - * [CRTP parameter] - * @param _parameters_name name of the parameters member in the class - * (its type is `<_parameters_name>_type`, the - * protected member's name is `<_parameters_name>_`, - * and the public getter's name is - * `get_<_parameters_name>()`) - * @param _factory_name name of the generated factory type - * - * @ingroup solvers - */ -#define GKO_ENABLE_LOWER_TRS_FACTORY(_lower_trs, _parameters_name, \ - _factory_name) \ -public: \ - const _parameters_name##_type &get_##_parameters_name() const \ - { \ - return _parameters_name##_; \ - } \ - \ - class _factory_name : public ::gko::solver::EnableDefaultLowerTrsFactory< \ - _factory_name, _lower_trs, \ - _parameters_name##_type, ValueType, IndexType> { \ - friend class ::gko::EnablePolymorphicObject< \ - _factory_name, \ - ::gko::solver::LowerTrsFactory>; \ - friend class ::gko::enable_parameters_type<_parameters_name##_type, \ - _factory_name>; \ - using ::gko::solver::EnableDefaultLowerTrsFactory< \ - _factory_name, _lower_trs, _parameters_name##_type, ValueType, \ - IndexType>::EnableDefaultLowerTrsFactory; \ - }; \ - friend ::gko::solver::EnableDefaultLowerTrsFactory< \ - _factory_name, _lower_trs, _parameters_name##_type, ValueType, \ - IndexType>; \ - \ -private: \ - _parameters_name##_type _parameters_name##_; \ - \ -public: \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") +struct SolveStruct; /** @@ -168,6 +63,11 @@ public: \ * format. If the matrix is not in CSR, then the generate step converts it into * a CSR matrix. The generation fails if the matrix is not convertible to CSR. * + * @note As the constructor uses the copy and convert functionality, it is not + * possible to create a empty solver or a solver with a matrix in any + * other format other than CSR, if none of the executor modules are being + * compiled with. + * * @tparam ValueType precision of matrix elements * @tparam IndexType precision of matrix indices * @@ -195,16 +95,6 @@ class LowerTrs : public EnableLinOp>, return system_matrix_; } - /** - * Gets the right hand side of the linear system. - * - * @return the right hand side - */ - std::shared_ptr> get_rhs() const - { - return b_; - } - /** * Returns the preconditioner operator used by the solver. * @@ -222,11 +112,24 @@ class LowerTrs : public EnableLinOp>, */ std::shared_ptr GKO_FACTORY_PARAMETER( preconditioner, nullptr); + + /** + * Number of right hand sides. + * + * @note This value is currently a dummy value which is not used by the + * analysis step. It is possible that future algorithms (cusparse + * csrsm2) make use of the number of right hand sides for a more + * sophisticated implementation. Hence this parameter is left + * here. But currently, there is no need to use it. + */ + gko::size_type GKO_FACTORY_PARAMETER(num_rhs, 1u); }; - GKO_ENABLE_LOWER_TRS_FACTORY(LowerTrs, parameters, Factory); + GKO_ENABLE_LIN_OP_FACTORY(LowerTrs, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); protected: + void init_trs_solve_struct(); + void apply_impl(const LinOp *b, LinOp *x) const override; void apply_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta, @@ -243,24 +146,23 @@ class LowerTrs : public EnableLinOp>, {} explicit LowerTrs(const Factory *factory, - const LowerTrsArgs &args) - : parameters_{factory->get_parameters()}, - EnableLinOp(factory->get_executor(), - transpose(args.system_matrix->get_size())), - b_{std::move(args.b)}, + std::shared_ptr system_matrix) + : EnableLinOp(factory->get_executor(), + transpose(system_matrix->get_size())), + parameters_{factory->get_parameters()}, system_matrix_{} { using CsrMatrix = matrix::Csr; - GKO_ASSERT_IS_SQUARE_MATRIX(args.system_matrix); + GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix); // This is needed because it does not make sense to call the copy and // convert if the existing matrix is empty. const auto exec = this->get_executor(); - if (!args.system_matrix->get_size()) { + if (!system_matrix->get_size()) { system_matrix_ = CsrMatrix::create(exec); } else { system_matrix_ = - copy_and_convert_to(exec, args.system_matrix); + copy_and_convert_to(exec, system_matrix); } if (parameters_.preconditioner) { preconditioner_ = @@ -269,13 +171,14 @@ class LowerTrs : public EnableLinOp>, preconditioner_ = matrix::Identity::create( this->get_executor(), this->get_size()[0]); } + this->init_trs_solve_struct(); this->generate(); } private: std::shared_ptr> system_matrix_{}; - std::shared_ptr> b_{}; std::shared_ptr preconditioner_{}; + std::shared_ptr solve_struct_; }; diff --git a/include/ginkgo/ginkgo.hpp b/include/ginkgo/ginkgo.hpp index 50f47d95dc4..427c422b660 100644 --- a/include/ginkgo/ginkgo.hpp +++ b/include/ginkgo/ginkgo.hpp @@ -83,7 +83,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include -#include +#include #include #include diff --git a/omp/solver/lower_trs_kernels.cpp b/omp/solver/lower_trs_kernels.cpp index 55cd57a1840..bdfd73e94b1 100644 --- a/omp/solver/lower_trs_kernels.cpp +++ b/omp/solver/lower_trs_kernels.cpp @@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include namespace gko { @@ -58,10 +59,25 @@ namespace omp { namespace lower_trs { +void should_perform_transpose(std::shared_ptr exec, + bool &do_transpose) +{ + do_transpose = false; +} + + +void init_struct(std::shared_ptr exec, + std::shared_ptr &solve_struct) +{ + // This init kernel is here to allow initialization of the solve struct for + // a more sophisticated implementation as for other executors. +} + + template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - const matrix::Dense *b) + solver::SolveStruct *solve_struct, const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the @@ -72,9 +88,15 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL); +/** + * The parameters trans_x and trans_b are used only in the CUDA executor for + * versions <=9.1 due to a limitation in the cssrsm_solve algorithm + */ template void solve(std::shared_ptr exec, const matrix::Csr *matrix, + const solver::SolveStruct *solve_struct, + matrix::Dense *trans_b, matrix::Dense *trans_x, const matrix::Dense *b, matrix::Dense *x) { auto row_ptrs = matrix->get_const_row_ptrs(); diff --git a/omp/test/solver/lower_trs_kernels.cpp b/omp/test/solver/lower_trs_kernels.cpp index 7cf7cd2dd46..071ba5c0bec 100644 --- a/omp/test/solver/lower_trs_kernels.cpp +++ b/omp/test/solver/lower_trs_kernels.cpp @@ -74,6 +74,14 @@ class LowerTrs : public ::testing::Test { } std::shared_ptr gen_mtx(int num_rows, int num_cols) + { + return gko::test::generate_random_matrix( + num_rows, num_cols, + std::uniform_int_distribution<>(num_cols, num_cols), + std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); + } + + std::shared_ptr gen_l_mtx(int num_rows, int num_cols) { return gko::test::generate_random_lower_triangular_matrix( num_rows, num_cols, false, @@ -85,11 +93,19 @@ class LowerTrs : public ::testing::Test { { b = gen_mtx(m, n); x = gen_mtx(m, n); + t_b = Mtx::create(ref); + t_x = Mtx::create(ref); + t_b->copy_from(b.get()); + t_x->copy_from(x.get()); d_b = Mtx::create(omp); d_b->copy_from(b.get()); d_x = Mtx::create(omp); d_x->copy_from(x.get()); - mat = gen_mtx(m, m); + dt_b = Mtx::create(omp); + dt_b->copy_from(b.get()); + dt_x = Mtx::create(omp); + dt_x->copy_from(x.get()); + mat = gen_l_mtx(m, m); csr_mat = CsrMtx::create(ref); mat->convert_to(csr_mat.get()); d_mat = Mtx::create(omp); @@ -105,23 +121,41 @@ class LowerTrs : public ::testing::Test { std::shared_ptr b; std::shared_ptr x; + std::shared_ptr t_b; + std::shared_ptr t_x; std::shared_ptr mat; std::shared_ptr csr_mat; std::shared_ptr d_b; std::shared_ptr d_x; + std::shared_ptr dt_b; + std::shared_ptr dt_x; std::shared_ptr d_mat; std::shared_ptr d_csr_mat; + std::shared_ptr solve_struct; }; +TEST_F(LowerTrs, OmpLowerTrsFlagCheckIsCorrect) +{ + bool trans_flag = true; + bool expected_flag = false; + + gko::kernels::omp::lower_trs::should_perform_transpose(omp, trans_flag); + + ASSERT_EQ(expected_flag, trans_flag); +} + + TEST_F(LowerTrs, OmpLowerTrsSolveIsEquivalentToRef) { initialize_data(59, 43); - gko::kernels::reference::lower_trs::solve(ref, csr_mat.get(), b.get(), - x.get()); - gko::kernels::omp::lower_trs::solve(omp, d_csr_mat.get(), d_b.get(), - d_x.get()); + gko::kernels::reference::lower_trs::solve(ref, csr_mat.get(), + solve_struct.get(), t_b.get(), + t_x.get(), b.get(), x.get()); + gko::kernels::omp::lower_trs::solve(omp, d_csr_mat.get(), + solve_struct.get(), dt_b.get(), + dt_x.get(), d_b.get(), d_x.get()); GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); } @@ -132,8 +166,8 @@ TEST_F(LowerTrs, ApplyIsEquivalentToRef) initialize_data(59, 3); auto lower_trs_factory = gko::solver::LowerTrs<>::build().on(ref); auto d_lower_trs_factory = gko::solver::LowerTrs<>::build().on(omp); - auto solver = lower_trs_factory->generate(csr_mat, b); - auto d_solver = d_lower_trs_factory->generate(d_csr_mat, d_b); + auto solver = lower_trs_factory->generate(csr_mat); + auto d_solver = d_lower_trs_factory->generate(d_csr_mat); solver->apply(b.get(), x.get()); d_solver->apply(d_b.get(), d_x.get()); diff --git a/reference/solver/lower_trs_kernels.cpp b/reference/solver/lower_trs_kernels.cpp index 3bb46a3b91e..c8f698ed711 100644 --- a/reference/solver/lower_trs_kernels.cpp +++ b/reference/solver/lower_trs_kernels.cpp @@ -41,6 +41,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include namespace gko { @@ -54,10 +55,25 @@ namespace reference { namespace lower_trs { +void should_perform_transpose(std::shared_ptr exec, + bool &do_transpose) +{ + do_transpose = false; +} + + +void init_struct(std::shared_ptr exec, + std::shared_ptr &solve_struct) +{ + // This init kernel is here to allow initialization of the solve struct for + // a more sophisticated implementation as for other executors. +} + + template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - const matrix::Dense *b) + solver::SolveStruct *solve_struct, const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the @@ -68,9 +84,16 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL); +/** + * The parameters trans_x and trans_b are used only in the CUDA executor for + * versions <=9.1 due to a limitation in the cssrsm_solve algorithm and hence + * here essentially unused. + */ template void solve(std::shared_ptr exec, const matrix::Csr *matrix, + const solver::SolveStruct *solve_struct, + matrix::Dense *trans_b, matrix::Dense *trans_x, const matrix::Dense *b, matrix::Dense *x) { auto row_ptrs = matrix->get_const_row_ptrs(); diff --git a/reference/test/solver/lower_trs.cpp b/reference/test/solver/lower_trs.cpp index 114e57acc76..fbb58b546d4 100644 --- a/reference/test/solver/lower_trs.cpp +++ b/reference/test/solver/lower_trs.cpp @@ -39,7 +39,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include #include #include #include @@ -53,102 +52,100 @@ namespace { class LowerTrs : public ::testing::Test { protected: - using CsrMtx = gko::matrix::Csr; using Mtx = gko::matrix::Dense<>; + using CsrMtx = gko::matrix::Csr<>; using Solver = gko::solver::LowerTrs<>; LowerTrs() : exec(gko::ReferenceExecutor::create()), mtx(gko::initialize( - {{2, 0.0, 0.0}, {3.0, 1, 0.0}, {1.0, 2.0, 3}}, exec)), - b(gko::initialize({{2, 0.0, 0.0}}, exec)), - csr_mtx(gko::copy_and_convert_to(exec, gko::lend(mtx))), + {{2, -1.0, 0.0}, {-1.0, 2, -1.0}, {0.0, -1.0, 2}}, exec)), + csr_mtx(gko::initialize( + {{2, -1.0, 0.0}, {-1.0, 2, -1.0}, {0.0, -1.0, 2}}, exec)), lower_trs_factory(Solver::build().on(exec)), - lower_trs_solver(lower_trs_factory->generate(mtx, b)) + solver(lower_trs_factory->generate(mtx)) {} std::shared_ptr exec; std::shared_ptr mtx; - std::shared_ptr b; std::shared_ptr csr_mtx; std::unique_ptr lower_trs_factory; - std::unique_ptr lower_trs_solver; + std::unique_ptr solver; }; TEST_F(LowerTrs, LowerTrsFactoryCreatesCorrectSolver) { - auto sys_mtx = lower_trs_solver->get_system_matrix(); - auto d_sys_mtx = Mtx::create(exec); - sys_mtx->convert_to(gko::lend(d_sys_mtx)); + auto sys_mtx = solver->get_system_matrix(); - ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(3, 3)); + ASSERT_EQ(solver->get_size(), gko::dim<2>(3, 3)); ASSERT_NE(sys_mtx, nullptr); - ASSERT_NE(lower_trs_solver->get_rhs(), nullptr); - GKO_ASSERT_MTX_NEAR(d_sys_mtx, mtx, 0); - ASSERT_EQ(lower_trs_solver->get_rhs(), b); + GKO_ASSERT_MTX_NEAR(sys_mtx, csr_mtx, 0); } TEST_F(LowerTrs, CanBeCopied) { - auto copy = Solver::build().on(exec)->generate(Mtx::create(exec), - Mtx::create(exec)); + auto copy = Solver::build().on(exec)->generate(Mtx::create(exec)); - copy->copy_from(gko::lend(lower_trs_solver)); + copy->copy_from(gko::lend(solver)); auto copy_mtx = copy->get_system_matrix(); - auto d_copy_mtx = Mtx::create(exec); - copy_mtx->convert_to(gko::lend(d_copy_mtx)); - auto copy_b = copy->get_rhs(); ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3)); - GKO_ASSERT_MTX_NEAR(d_copy_mtx, mtx, 0); - GKO_ASSERT_MTX_NEAR(copy_b, b, 0); + GKO_ASSERT_MTX_NEAR(copy_mtx.get(), csr_mtx.get(), 0); } TEST_F(LowerTrs, CanBeMoved) { - auto copy = - lower_trs_factory->generate(Mtx::create(exec), Mtx::create(exec)); + auto copy = Solver::build().on(exec)->generate(Mtx::create(exec)); - copy->copy_from(std::move(lower_trs_solver)); + copy->copy_from(std::move(solver)); auto copy_mtx = copy->get_system_matrix(); - auto d_copy_mtx = Mtx::create(exec); - copy_mtx->convert_to(gko::lend(d_copy_mtx)); - auto copy_b = copy->get_rhs(); ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3)); - GKO_ASSERT_MTX_NEAR(d_copy_mtx, mtx, 0); - GKO_ASSERT_MTX_NEAR(copy_b, b, 0); + GKO_ASSERT_MTX_NEAR(copy_mtx.get(), csr_mtx.get(), 0); } TEST_F(LowerTrs, CanBeCloned) { - auto clone = lower_trs_solver->clone(); + auto clone = solver->clone(); auto clone_mtx = clone->get_system_matrix(); - auto d_clone_mtx = Mtx::create(exec); - clone_mtx->convert_to(gko::lend(d_clone_mtx)); - auto clone_b = clone->get_rhs(); ASSERT_EQ(clone->get_size(), gko::dim<2>(3, 3)); - GKO_ASSERT_MTX_NEAR(d_clone_mtx, mtx, 0); - GKO_ASSERT_MTX_NEAR(clone_b, b, 0); + GKO_ASSERT_MTX_NEAR(clone_mtx.get(), csr_mtx.get(), 0); } TEST_F(LowerTrs, CanBeCleared) { - lower_trs_solver->clear(); + solver->clear(); - auto solver_mtx = lower_trs_solver->get_system_matrix(); - auto solver_b = lower_trs_solver->get_rhs(); + auto solver_mtx = solver->get_system_matrix(); - ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0)); ASSERT_EQ(solver_mtx, nullptr); - ASSERT_EQ(solver_b, nullptr); + ASSERT_EQ(solver->get_size(), gko::dim<2>(0, 0)); +} + + +TEST_F(LowerTrs, CanSetPreconditionerGenerator) +{ + auto lower_trs_factory = + Solver::build().with_preconditioner(Solver::build().on(exec)).on(exec); + auto solver = lower_trs_factory->generate(mtx); + + auto precond = dynamic_cast *>( + static_cast *>(solver.get()) + ->get_preconditioner() + .get()); + + ASSERT_NE(precond, nullptr); + ASSERT_EQ(precond->get_size(), gko::dim<2>(3, 3)); + GKO_ASSERT_MTX_NEAR( + static_cast(precond->get_system_matrix().get()), + csr_mtx.get(), 0); } diff --git a/reference/test/solver/lower_trs_kernels.cpp b/reference/test/solver/lower_trs_kernels.cpp index a7e580febc7..22ba58a8912 100644 --- a/reference/test/solver/lower_trs_kernels.cpp +++ b/reference/test/solver/lower_trs_kernels.cpp @@ -33,6 +33,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include + + #include @@ -46,6 +49,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/solver/lower_trs_kernels.hpp" #include "core/test/utils/assertions.hpp" @@ -57,11 +61,14 @@ class LowerTrs : public ::testing::Test { using Mtx = gko::matrix::Dense<>; LowerTrs() : exec(gko::ReferenceExecutor::create()), + ref(gko::ReferenceExecutor::create()), mtx(gko::initialize( {{1, 0.0, 0.0}, {3.0, 1, 0.0}, {1.0, 2.0, 1}}, exec)), mtx2(gko::initialize( {{2, 0.0, 0.0}, {3.0, 3, 0.0}, {1.0, 2.0, 4}}, exec)), lower_trs_factory(gko::solver::LowerTrs<>::build().on(exec)), + lower_trs_factory_mrhs( + gko::solver::LowerTrs<>::build().with_num_rhs(2u).on(exec)), mtx_big(gko::initialize({{124.0, 0.0, 0.0, 0.0, 0.0}, {43.0, -789.0, 0.0, 0.0, 0.0}, {134.5, -651.0, 654.0, 0.0, 0.0}, @@ -72,19 +79,33 @@ class LowerTrs : public ::testing::Test { {} std::shared_ptr exec; + std::shared_ptr ref; std::shared_ptr mtx; std::shared_ptr mtx2; std::shared_ptr mtx_big; std::unique_ptr::Factory> lower_trs_factory; + std::unique_ptr::Factory> lower_trs_factory_mrhs; std::unique_ptr::Factory> lower_trs_factory_big; }; +TEST_F(LowerTrs, RefLowerTrsFlagCheckIsCorrect) +{ + bool trans_flag = true; + bool expected_flag = false; + + gko::kernels::reference::lower_trs::should_perform_transpose(ref, + trans_flag); + + ASSERT_EQ(expected_flag, trans_flag); +} + + TEST_F(LowerTrs, SolvesTriangularSystem) { std::shared_ptr b = gko::initialize({1.0, 2.0, 1.0}, exec); auto x = gko::initialize({0.0, 0.0, 0.0}, exec); - auto solver = lower_trs_factory->generate(mtx, b); + auto solver = lower_trs_factory->generate(mtx); solver->apply(b.get(), x.get()); @@ -97,7 +118,7 @@ TEST_F(LowerTrs, SolvesMultipleTriangularSystems) std::shared_ptr b = gko::initialize({{3.0, 4.0}, {1.0, 0.0}, {1.0, -1.0}}, exec); auto x = gko::initialize({{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, exec); - auto solver = lower_trs_factory->generate(mtx, b); + auto solver = lower_trs_factory_mrhs->generate(mtx); solver->apply(b.get(), x.get()); @@ -109,7 +130,7 @@ TEST_F(LowerTrs, SolvesNonUnitTriangularSystem) { std::shared_ptr b = gko::initialize({2.0, 12.0, 3.0}, exec); auto x = gko::initialize({0.0, 0.0, 0.0}, exec); - auto solver = lower_trs_factory->generate(mtx2, b); + auto solver = lower_trs_factory->generate(mtx2); solver->apply(b.get(), x.get()); @@ -122,7 +143,7 @@ TEST_F(LowerTrs, SolvesTriangularSystemUsingAdvancedApply) auto beta = gko::initialize({-1.0}, exec); std::shared_ptr b = gko::initialize({1.0, 2.0, 1.0}, exec); auto x = gko::initialize({1.0, -1.0, 1.0}, exec); - auto solver = lower_trs_factory->generate(mtx, b); + auto solver = lower_trs_factory->generate(mtx); solver->apply(alpha.get(), b.get(), beta.get(), x.get()); @@ -138,7 +159,7 @@ TEST_F(LowerTrs, SolvesMultipleTriangularSystemsUsingAdvancedApply) gko::initialize({{3.0, 4.0}, {1.0, 0.0}, {1.0, -1.0}}, exec); auto x = gko::initialize({{1.0, 2.0}, {-1.0, -1.0}, {0.0, -2.0}}, exec); - auto solver = lower_trs_factory->generate(mtx, b); + auto solver = lower_trs_factory_mrhs->generate(mtx); solver->apply(alpha.get(), b.get(), beta.get(), x.get()); @@ -152,7 +173,7 @@ TEST_F(LowerTrs, SolvesBigDenseSystem) std::shared_ptr b = gko::initialize({-124.0, -3199.0, 3147.5, 5151.0, -6021.0}, exec); auto x = gko::initialize({0.0, 0.0, 0.0, 0.0, 0.0}, exec); - auto solver = lower_trs_factory_big->generate(mtx_big, b); + auto solver = lower_trs_factory_big->generate(mtx_big); solver->apply(b.get(), x.get());