Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A working alternative triangular solver (faster) for rocsolverrf #56

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/r_KLU_rocSolverRf_FGMRES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ int main(int argc, char *argv[])
std::cout<<"KLU analysis status: "<<status<<std::endl;
status = KLU->factorize();
std::cout<<"KLU factorization status: "<<status<<std::endl;

status = KLU->solve(vec_rhs, vec_x);
std::cout<<"KLU solve status: "<<status<<std::endl;
vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
Expand All @@ -149,6 +150,7 @@ int main(int argc, char *argv[])
if (L == nullptr) {printf("ERROR");}
index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();
Rf->setSolveMode(1);
Rf->setup(A, L, U, P, Q, vec_rhs);
Rf->refactorize();
std::cout<<"about to set FGMRES" <<std::endl;
Expand All @@ -162,7 +164,6 @@ int main(int argc, char *argv[])
std::cout<<"ROCSOLVER RF refactorization status: "<<status<<std::endl;
status = Rf->solve(vec_rhs, vec_x);
std::cout<<"ROCSOLVER RF solve status: "<<status<<std::endl;

vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
norm_b = vector_handler->dot(vec_r, vec_r, "hip");
norm_b = sqrt(norm_b);
Expand Down
213 changes: 210 additions & 3 deletions resolve/LinSolverDirectRocSolverRf.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <resolve/vector/Vector.hpp>
#include <resolve/matrix/Csr.hpp>
#include "LinSolverDirectRocSolverRf.hpp"
#include <resolve/hip/hipKernels.h>

namespace ReSolve
{
Expand All @@ -15,6 +16,12 @@ namespace ReSolve
{
mem_.deleteOnDevice(d_P_);
mem_.deleteOnDevice(d_Q_);

mem_.deleteOnDevice(d_aux1_);
mem_.deleteOnDevice(d_aux2_);

delete L_csr_;
delete U_csr_;
}

int LinSolverDirectRocSolverRf::setup(matrix::Sparse* A, matrix::Sparse* L, matrix::Sparse* U, index_type* P, index_type* Q, vector_type* rhs)
Expand Down Expand Up @@ -56,7 +63,109 @@ namespace ReSolve
mem_.deviceSynchronize();
error_sum += status_rocblas_;

// tri solve setup
if (solve_mode_ == 1) { // fast mode
L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz());
U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz());

L_csr_->allocateMatrixData(ReSolve::memory::DEVICE);
U_csr_->allocateMatrixData(ReSolve::memory::DEVICE);

rocsparse_create_mat_descr(&(descr_L_));
rocsparse_set_mat_fill_mode(descr_L_, rocsparse_fill_mode_lower);
rocsparse_set_mat_index_base(descr_L_, rocsparse_index_base_zero);

rocsparse_create_mat_descr(&(descr_U_));
rocsparse_set_mat_index_base(descr_U_, rocsparse_index_base_zero);
rocsparse_set_mat_fill_mode(descr_U_, rocsparse_fill_mode_upper);

rocsparse_create_mat_info(&info_L_);
rocsparse_create_mat_info(&info_U_);

// local variables
size_t L_buffer_size;
size_t U_buffer_size;

status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(),
n,
M_->getNnzExpanded(),
M_->getRowData(ReSolve::memory::DEVICE),
M_->getColData(ReSolve::memory::DEVICE),
M_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
U_csr_->getValues(ReSolve::memory::DEVICE));

error_sum += status_rocblas_;

status_rocsparse_ = rocsparse_dcsrsv_buffer_size(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
L_csr_->getNnz(),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
&L_buffer_size);
error_sum += status_rocsparse_;

printf("buffer size for L %d status %d \n", L_buffer_size, status_rocsparse_);
// hipMalloc((void**)&(L_buffer), L_buffer_size);

mem_.allocateBufferOnDevice(&L_buffer_, L_buffer_size);
status_rocsparse_ = rocsparse_dcsrsv_buffer_size(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
U_csr_->getNnz(),
descr_U_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
&U_buffer_size);
error_sum += status_rocsparse_;
// hipMalloc((void**)&(U_buffer), U_buffer_size);
mem_.allocateBufferOnDevice(&U_buffer_, U_buffer_size);
printf("buffer size for U %d status %d \n", U_buffer_size, status_rocsparse_);

status_rocsparse_ = rocsparse_dcsrsv_analysis(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
L_csr_->getNnz(),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
rocsparse_analysis_policy_force,
rocsparse_solve_policy_auto,
L_buffer_);
error_sum += status_rocsparse_;
if (status_rocsparse_!=0)printf("status after analysis 1 %d \n", status_rocsparse_);
status_rocsparse_ = rocsparse_dcsrsv_analysis(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
n,
U_csr_->getNnz(),
descr_U_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
rocsparse_analysis_policy_force,
rocsparse_solve_policy_auto,
U_buffer_);
error_sum += status_rocsparse_;
if (status_rocsparse_!=0)printf("status after analysis 2 %d \n", status_rocsparse_);
//allocate aux data

mem_.allocateArrayOnDevice(&d_aux1_,n);
mem_.allocateArrayOnDevice(&d_aux2_,n);

}
return error_sum;
}

Expand All @@ -78,15 +187,38 @@ namespace ReSolve
d_Q_,
infoM_);


mem_.deviceSynchronize();
error_sum += status_rocblas_;

if (solve_mode_ == 1) {
//split M, fill L and U with correct values
printf("solve mode 1, splitting the factors again \n");
status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(),
A_->getNumRows(),
M_->getNnzExpanded(),
M_->getRowData(ReSolve::memory::DEVICE),
M_->getColData(ReSolve::memory::DEVICE),
M_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
U_csr_->getValues(ReSolve::memory::DEVICE));

mem_.deviceSynchronize();
error_sum += status_rocblas_;

}

return error_sum;
}

// solution is returned in RHS
int LinSolverDirectRocSolverRf::solve(vector_type* rhs)
{
int error_sum = 0;
if (solve_mode_ == 0) {
mem_.deviceSynchronize();
status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(),
Expand All @@ -104,15 +236,51 @@ namespace ReSolve
mem_.deviceSynchronize();
} else {
// not implemented yet
permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_);
mem_.deviceSynchronize();
rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
L_csr_->getNnz(),
&(constants::ONE),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
d_aux1_,
d_aux2_, //result
rocsparse_solve_policy_auto,
L_buffer_);
error_sum += status_rocsparse_;

rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
U_csr_->getNnz(),
&(constants::ONE),
descr_L_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
d_aux2_, //input
d_aux1_,//result
rocsparse_solve_policy_auto,
U_buffer_);
error_sum += status_rocsparse_;

permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,rhs->getData(ReSolve::memory::DEVICE));
mem_.deviceSynchronize();
}
return status_rocblas_;
return error_sum;
}

int LinSolverDirectRocSolverRf::solve(vector_type* rhs, vector_type* x)
{
x->update(rhs->getData(ReSolve::memory::DEVICE), ReSolve::memory::DEVICE, ReSolve::memory::DEVICE);
x->setDataUpdated(ReSolve::memory::DEVICE);

int error_sum = 0;
if (solve_mode_ == 0) {
mem_.deviceSynchronize();
status_rocblas_ = rocsolver_dcsrrf_solve(workspace_->getRocblasHandle(),
Expand All @@ -127,11 +295,50 @@ namespace ReSolve
x->getData(ReSolve::memory::DEVICE),
A_->getNumRows(),
infoM_);
error_sum += status_rocblas_;
mem_.deviceSynchronize();
} else {
// not implemented yet

permuteVectorP(A_->getNumRows(), d_P_, rhs->getData(ReSolve::memory::DEVICE), d_aux1_);
mem_.deviceSynchronize();

rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
L_csr_->getNnz(),
&(constants::ONE),
descr_L_,
L_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
L_csr_->getRowData(ReSolve::memory::DEVICE),
L_csr_->getColData(ReSolve::memory::DEVICE),
info_L_,
d_aux1_,
d_aux2_, //result
rocsparse_solve_policy_auto,
L_buffer_);
error_sum += status_rocsparse_;

rocsparse_dcsrsv_solve(workspace_->getRocsparseHandle(),
rocsparse_operation_none,
A_->getNumRows(),
U_csr_->getNnz(),
&(constants::ONE),
descr_U_,
U_csr_->getValues(ReSolve::memory::DEVICE), //vals_,
U_csr_->getRowData(ReSolve::memory::DEVICE),
U_csr_->getColData(ReSolve::memory::DEVICE),
info_U_,
d_aux2_, //input
d_aux1_,//result
rocsparse_solve_policy_auto,
U_buffer_);
error_sum += status_rocsparse_;

permuteVectorQ(A_->getNumRows(), d_Q_,d_aux1_,x->getData(ReSolve::memory::DEVICE));
mem_.deviceSynchronize();
}
return status_rocblas_;
return error_sum;
}

int LinSolverDirectRocSolverRf::setSolveMode(int mode)
Expand Down
22 changes: 19 additions & 3 deletions resolve/LinSolverDirectRocSolverRf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ namespace ReSolve
int getSolveMode(); //should be enum too

private:
rocblas_status status_rocblas_;

rocblas_status status_rocblas_;
rocsparse_status status_rocsparse_;
index_type* d_P_;
index_type* d_Q_;

Expand All @@ -54,6 +54,22 @@ namespace ReSolve
void addFactors(matrix::Sparse* L, matrix::Sparse* U); //create L+U from sepeate L, U factors
rocsolver_rfinfo infoM_;
matrix::Sparse* M_;//the matrix that contains added factors
int solve_mode_;
int solve_mode_; // 0 is default and 1 is fast

// not used by default - for fast solve
rocsparse_mat_descr descr_L_{nullptr};
rocsparse_mat_descr descr_U_{nullptr};

rocsparse_mat_info info_L_{nullptr};
rocsparse_mat_info info_U_{nullptr};

void* L_buffer_{nullptr};
void* U_buffer_{nullptr};

ReSolve::matrix::Csr* L_csr_;
ReSolve::matrix::Csr* U_csr_;

real_type* d_aux1_{nullptr};
real_type* d_aux2_{nullptr};
};
}
11 changes: 11 additions & 0 deletions resolve/hip/hipKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@ void matrix_row_sums(int n,
int* a_ia,
double* a_val,
double* result);

// needed for triangular solve

void permuteVectorP(int n,
int* perm_vector,
double* vec_in,
double* vec_out);
void permuteVectorQ(int n,
int* perm_vector,
double* vec_in,
double* vec_out);
Loading