From 9f92fa75e63eaa525b1d006106933497a04d8884 Mon Sep 17 00:00:00 2001 From: kswirydo Date: Wed, 1 Nov 2023 23:54:26 -0400 Subject: [PATCH 1/2] a WORKING alternative triangular solver (faster) for rocsolverrf --- examples/r_KLU_rocSolverRf_FGMRES.cpp | 3 +- resolve/LinSolverDirectRocSolverRf.cpp | 211 ++++++++++++++++++++++++- resolve/LinSolverDirectRocSolverRf.hpp | 22 ++- resolve/hip/hipKernels.h | 11 ++ resolve/hip/hipKernels.hip | 44 ++++++ resolve/hip/hipVectorKernels.hip | 1 + 6 files changed, 285 insertions(+), 7 deletions(-) diff --git a/examples/r_KLU_rocSolverRf_FGMRES.cpp b/examples/r_KLU_rocSolverRf_FGMRES.cpp index d2e5f7a6..45fe4681 100644 --- a/examples/r_KLU_rocSolverRf_FGMRES.cpp +++ b/examples/r_KLU_rocSolverRf_FGMRES.cpp @@ -131,6 +131,7 @@ int main(int argc, char *argv[]) std::cout<<"KLU analysis status: "<factorize(); std::cout<<"KLU factorization status: "<solve(vec_rhs, vec_x); std::cout<<"KLU solve status: "<update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE); @@ -149,6 +150,7 @@ int main(int argc, char *argv[]) if (L == nullptr) {printf("ERROR");} index_type* P = KLU->getPOrdering(); index_type* Q = KLU->getQOrdering(); + Rf->setSolveMode(1); Rf->setup(A, L, U, P, Q, vec_rhs); Rf->refactorize(); std::cout<<"about to set FGMRES" <solve(vec_rhs, vec_x); std::cout<<"ROCSOLVER RF solve status: "<update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE); norm_b = vector_handler->dot(vec_r, vec_r, "hip"); norm_b = sqrt(norm_b); diff --git a/resolve/LinSolverDirectRocSolverRf.cpp b/resolve/LinSolverDirectRocSolverRf.cpp index 5869756d..bae9d088 100644 --- a/resolve/LinSolverDirectRocSolverRf.cpp +++ b/resolve/LinSolverDirectRocSolverRf.cpp @@ -1,6 +1,7 @@ #include #include #include "LinSolverDirectRocSolverRf.hpp" +#include namespace ReSolve { @@ -15,6 +16,12 @@ namespace ReSolve { mem_.deleteOnDevice(d_P_); mem_.deleteOnDevice(d_Q_); + + mem_.deleteOnDevice(d_aux1_); + mem_.deleteOnDevice(d_aux2_); + + delete L_csr_; + delete U_csr_; } int LinSolverDirectRocSolverRf::setup(matrix::Sparse* A, matrix::Sparse* L, matrix::Sparse* U, index_type* P, index_type* Q, vector_type* rhs) @@ -56,7 +63,109 @@ namespace ReSolve mem_.deviceSynchronize(); error_sum += status_rocblas_; + // tri solve setup + if (solve_mode_ == 1) { // fast mode + L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz()); + U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz()); + + L_csr_->allocateMatrixData(ReSolve::memory::DEVICE); + U_csr_->allocateMatrixData(ReSolve::memory::DEVICE); + + rocsparse_create_mat_descr(&(descr_L_)); + rocsparse_set_mat_fill_mode(descr_L_, rocsparse_fill_mode_lower); + rocsparse_set_mat_index_base(descr_L_, rocsparse_index_base_zero); + + rocsparse_create_mat_descr(&(descr_U_)); + rocsparse_set_mat_index_base(descr_U_, rocsparse_index_base_zero); + rocsparse_set_mat_fill_mode(descr_U_, rocsparse_fill_mode_upper); + + rocsparse_create_mat_info(&info_L_); + rocsparse_create_mat_info(&info_U_); + + // local variables + size_t L_buffer_size; + size_t U_buffer_size; + + status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(), + n, + M_->getNnzExpanded(), + M_->getRowData(ReSolve::memory::DEVICE), + M_->getColData(ReSolve::memory::DEVICE), + M_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getRowData(ReSolve::memory::DEVICE), + L_csr_->getColData(ReSolve::memory::DEVICE), + L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + U_csr_->getRowData(ReSolve::memory::DEVICE), + U_csr_->getColData(ReSolve::memory::DEVICE), + U_csr_->getValues(ReSolve::memory::DEVICE)); + error_sum += status_rocblas_; + + status_rocsparse_ = rocsparse_dcsrsv_buffer_size(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + n, + L_csr_->getNnz(), + descr_L_, + L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getRowData(ReSolve::memory::DEVICE), + L_csr_->getColData(ReSolve::memory::DEVICE), + info_L_, + &L_buffer_size); + error_sum += status_rocsparse_; + + printf("buffer size for L %d status %d \n", L_buffer_size, status_rocsparse_); + // hipMalloc((void**)&(L_buffer), L_buffer_size); + + mem_.allocateBufferOnDevice(&L_buffer_, L_buffer_size); + status_rocsparse_ = rocsparse_dcsrsv_buffer_size(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + n, + U_csr_->getNnz(), + descr_U_, + U_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + U_csr_->getRowData(ReSolve::memory::DEVICE), + U_csr_->getColData(ReSolve::memory::DEVICE), + info_U_, + &U_buffer_size); + error_sum += status_rocsparse_; + // hipMalloc((void**)&(U_buffer), U_buffer_size); + mem_.allocateBufferOnDevice(&U_buffer_, U_buffer_size); + printf("buffer size for U %d status %d \n", U_buffer_size, status_rocsparse_); + + status_rocsparse_ = rocsparse_dcsrsv_analysis(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + n, + L_csr_->getNnz(), + descr_L_, + L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getRowData(ReSolve::memory::DEVICE), + L_csr_->getColData(ReSolve::memory::DEVICE), + info_L_, + rocsparse_analysis_policy_force, + rocsparse_solve_policy_auto, + L_buffer_); + error_sum += status_rocsparse_; + if (status_rocsparse_!=0)printf("status after analysis 1 %d \n", status_rocsparse_); + status_rocsparse_ = rocsparse_dcsrsv_analysis(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + n, + U_csr_->getNnz(), + descr_U_, + U_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + U_csr_->getRowData(ReSolve::memory::DEVICE), + U_csr_->getColData(ReSolve::memory::DEVICE), + info_U_, + rocsparse_analysis_policy_force, + rocsparse_solve_policy_auto, + U_buffer_); + error_sum += status_rocsparse_; + if (status_rocsparse_!=0)printf("status after analysis 2 %d \n", status_rocsparse_); + //allocate aux data + + mem_.allocateArrayOnDevice(&d_aux1_,n); + mem_.allocateArrayOnDevice(&d_aux2_,n); + + } return error_sum; } @@ -78,15 +187,38 @@ namespace ReSolve d_Q_, infoM_); + mem_.deviceSynchronize(); error_sum += status_rocblas_; + if (solve_mode_ == 1) { + //split M, fill L and U with correct values +printf("solve mode 1, splitting the factors again \n"); + status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(), + A_->getNumRows(), + M_->getNnzExpanded(), + M_->getRowData(ReSolve::memory::DEVICE), + M_->getColData(ReSolve::memory::DEVICE), + M_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getRowData(ReSolve::memory::DEVICE), + L_csr_->getColData(ReSolve::memory::DEVICE), + L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + U_csr_->getRowData(ReSolve::memory::DEVICE), + U_csr_->getColData(ReSolve::memory::DEVICE), + U_csr_->getValues(ReSolve::memory::DEVICE)); + + mem_.deviceSynchronize(); + error_sum += status_rocblas_; + + } + return error_sum; } // solution is returned in RHS int LinSolverDirectRocSolverRf::solve(vector_type* rhs) { + int error_sum = 0; if (solve_mode_ == 0) { mem_.deviceSynchronize(); status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(), @@ -104,15 +236,49 @@ namespace ReSolve mem_.deviceSynchronize(); } else { // not implemented yet + permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_); + rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + A_->getNumRows(), + L_csr_->getNnz(), + &(constants::ONE), + descr_L_, + L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getRowData(ReSolve::memory::DEVICE), + L_csr_->getColData(ReSolve::memory::DEVICE), + info_L_, + d_aux1_, + d_aux2_, //result + rocsparse_solve_policy_auto, + L_buffer_); + error_sum += status_rocsparse_; + + rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + A_->getNumRows(), + U_csr_->getNnz(), + &(constants::ONE), + descr_L_, + U_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + U_csr_->getRowData(ReSolve::memory::DEVICE), + U_csr_->getColData(ReSolve::memory::DEVICE), + info_U_, + d_aux2_, //input + d_aux1_,//result + rocsparse_solve_policy_auto, + U_buffer_); + error_sum += status_rocsparse_; + + permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,rhs->getData(ReSolve::memory::DEVICE)); } - return status_rocblas_; + return error_sum; } int LinSolverDirectRocSolverRf::solve(vector_type* rhs, vector_type* x) { x->update(rhs->getData(ReSolve::memory::DEVICE), ReSolve::memory::DEVICE, ReSolve::memory::DEVICE); x->setDataUpdated(ReSolve::memory::DEVICE); - + int error_sum = 0; if (solve_mode_ == 0) { mem_.deviceSynchronize(); status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(), @@ -127,11 +293,50 @@ namespace ReSolve x->getData(ReSolve::memory::DEVICE), A_->getNumRows(), infoM_); + error_sum += status_rocblas_; mem_.deviceSynchronize(); } else { // not implemented yet + + permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_); + mem_.deviceSynchronize(); + + rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + A_->getNumRows(), + L_csr_->getNnz(), + &(constants::ONE), + descr_L_, + L_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + L_csr_->getRowData(ReSolve::memory::DEVICE), + L_csr_->getColData(ReSolve::memory::DEVICE), + info_L_, + d_aux1_, + d_aux2_, //result + rocsparse_solve_policy_auto, + L_buffer_); + error_sum += status_rocsparse_; + + rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(), + rocsparse_operation_none, + A_->getNumRows(), + U_csr_->getNnz(), + &(constants::ONE), + descr_U_, + U_csr_->getValues(ReSolve::memory::DEVICE), //vals_, + U_csr_->getRowData(ReSolve::memory::DEVICE), + U_csr_->getColData(ReSolve::memory::DEVICE), + info_U_, + d_aux2_, //input + d_aux1_,//result + rocsparse_solve_policy_auto, + U_buffer_); + error_sum += status_rocsparse_; + + permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,x->getData(ReSolve::memory::DEVICE)); + mem_.deviceSynchronize(); } - return status_rocblas_; + return error_sum; } int LinSolverDirectRocSolverRf::setSolveMode(int mode) diff --git a/resolve/LinSolverDirectRocSolverRf.hpp b/resolve/LinSolverDirectRocSolverRf.hpp index 5804393f..eb3a11a6 100644 --- a/resolve/LinSolverDirectRocSolverRf.hpp +++ b/resolve/LinSolverDirectRocSolverRf.hpp @@ -42,8 +42,8 @@ namespace ReSolve int getSolveMode(); //should be enum too private: - rocblas_status status_rocblas_; - + rocblas_status status_rocblas_; + rocsparse_status status_rocsparse_; index_type* d_P_; index_type* d_Q_; @@ -54,6 +54,22 @@ namespace ReSolve void addFactors(matrix::Sparse* L, matrix::Sparse* U); //create L+U from sepeate L, U factors rocsolver_rfinfo infoM_; matrix::Sparse* M_;//the matrix that contains added factors - int solve_mode_; + int solve_mode_; // 0 is default and 1 is fast + + // not used by default - for fast solve + rocsparse_mat_descr descr_L_{nullptr}; + rocsparse_mat_descr descr_U_{nullptr}; + + rocsparse_mat_info info_L_{nullptr}; + rocsparse_mat_info info_U_{nullptr}; + + void* L_buffer_{nullptr}; + void* U_buffer_{nullptr}; + + ReSolve::matrix::Csr* L_csr_; + ReSolve::matrix::Csr* U_csr_; + + real_type* d_aux1_{nullptr}; + real_type* d_aux2_{nullptr}; }; } diff --git a/resolve/hip/hipKernels.h b/resolve/hip/hipKernels.h index 9c48783a..986efc84 100644 --- a/resolve/hip/hipKernels.h +++ b/resolve/hip/hipKernels.h @@ -12,3 +12,14 @@ void matrix_row_sums(int n, int* a_ia, double* a_val, double* result); + +// needed for triangular solve + +void permuteVectorP(int n, + int* perm_vector, + double* vec_in, + double* vec_out); +void permuteVectorQ(int n, + int* perm_vector, + double* vec_in, + double* vec_out); diff --git a/resolve/hip/hipKernels.hip b/resolve/hip/hipKernels.hip index 13f53d85..abad5b39 100644 --- a/resolve/hip/hipKernels.hip +++ b/resolve/hip/hipKernels.hip @@ -143,6 +143,34 @@ __global__ void matrixInfNormPart1(const int n, } +__global__ void permuteVectorP_kernel(const int n, + const int* perm_vector, + const double* vec_in, + double* vec_out){ + + //one thread per vector entry, pass through rows + + int idx = blockIdx.x*blockDim.x + threadIdx.x; + while (idx Date: Thu, 2 Nov 2023 11:23:09 -0400 Subject: [PATCH 2/2] fixing synchronizations --- resolve/LinSolverDirectRocSolverRf.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/resolve/LinSolverDirectRocSolverRf.cpp b/resolve/LinSolverDirectRocSolverRf.cpp index bae9d088..f9f73b4a 100644 --- a/resolve/LinSolverDirectRocSolverRf.cpp +++ b/resolve/LinSolverDirectRocSolverRf.cpp @@ -237,6 +237,7 @@ printf("solve mode 1, splitting the factors again \n"); } else { // not implemented yet permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_); + mem_.deviceSynchronize(); rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(), rocsparse_operation_none, A_->getNumRows(), @@ -270,6 +271,7 @@ printf("solve mode 1, splitting the factors again \n"); error_sum += status_rocsparse_; permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,rhs->getData(ReSolve::memory::DEVICE)); + mem_.deviceSynchronize(); } return error_sum; }