-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA Coo spmm implementaion #345
Conversation
Looks good! Do you have performance result of a comparison between our csr spmv with this spmm and one nrhs? |
Codecov Report
@@ Coverage Diff @@
## develop #345 +/- ##
===========================================
+ Coverage 98.23% 98.23% +<.01%
===========================================
Files 238 238
Lines 18076 18086 +10
===========================================
+ Hits 17757 17767 +10
Misses 319 319
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
* @param num_elems the maximum number of nonzeros in each warp | ||
* @param val the value array of the matrix | ||
* @param col the column index array of the matrix | ||
* @param row the row index array of the matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@param num_cols
is missing.
@tcojean |
So it looks like this version pays off if we have 8 and more RHS - otherwise the scalar versions are preferred. Correct? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -83,7 +84,9 @@ namespace { | |||
* @param col the column index array of the matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the second parameter is num_lines
instead of num_line
cuda/matrix/coo_kernels.cu
Outdated
a->get_const_col_idxs(), as_cuda_type(a->get_const_row_idxs()), | ||
as_cuda_type(b->get_const_values()), b->get_stride(), | ||
as_cuda_type(c->get_values()), c->get_stride()); | ||
if (b->get_size()[1] == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing the significant performance difference from your last results, I would also include at least the value 2
here. There seems to be something in particular which makes 2 nrhs particularly slow (even slower than one).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I choose 4 for the condition
@hartwiganzt Yes, but I prefer to use 4 as the condition because this PR is faster when #nonzeros is large. |
cuda/matrix/coo_kernels.cu
Outdated
__global__ __launch_bounds__(default_block_size) void set_zero( | ||
const size_type nnz, ValueType *__restrict__ val) | ||
/** | ||
* The device function of COO spmv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/spmv/spmm/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Some minor comment to improve the code duplication number.
cuda/matrix/coo_kernels.cu
Outdated
as_cuda_type(c->get_values()), c->get_stride()); | ||
if (b->get_size()[1] < 4) { | ||
int num_lines = ceildiv(nnz, nwarps * cuda_config::warp_size); | ||
const dim3 coo_block(cuda_config::warp_size, warps_in_block, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This variable is the same as the else { }
part. You can move it outside of the if {} else{}
block.
cuda/matrix/coo_kernels.cu
Outdated
as_cuda_type(c->get_values()), c->get_stride()); | ||
if (b->get_size()[1] < 4) { | ||
int num_lines = ceildiv(nnz, nwarps * cuda_config::warp_size); | ||
const dim3 coo_block(cuda_config::warp_size, warps_in_block, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
add subwarp_size template parameter and the advanced apply
The Ginkgo team is proud to announce the new minor release of Ginkgo version 1.1.0. This release brings several performance improvements, adds Windows support, adds support for factorizations inside Ginkgo and a new ILU preconditioner based on ParILU algorithm, among other things. For detailed information, check the respective issue. Supported systems and requirements: + For all platforms, cmake 3.9+ + Linux and MacOS + gcc: 5.3+, 6.3+, 7.3+, 8.1+ + clang: 3.9+ + Intel compiler: 2017+ + Apple LLVM: 8.0+ + CUDA module: CUDA 9.0+ + Windows + MinGW and CygWin: gcc 5.3+, 6.3+, 7.3+, 8.1+ + Microsoft Visual Studio: VS 2017 15.7+ + CUDA module: CUDA 9.0+, Microsoft Visual Studio + OpenMP module: MinGW or CygWin. The current known issues can be found in the [known issues page](https://github.com/ginkgo-project/ginkgo/wiki/Known-Issues). Additions: + Upper and lower triangular solvers ([#327](#327), [#336](#336), [#341](#341), [#342](#342)) + New factorization support in Ginkgo, and addition of the ParILU algorithm ([#305](#305), [#315](#315), [#319](#319), [#324](#324)) + New ILU preconditioner ([#348](#348), [#353](#353)) + Windows MinGW and Cygwin support ([#347](#347)) + Windows Visual studio support ([#351](#351)) + New example showing how to use ParILU as a preconditioner ([#358](#358)) + New example on using loggers for debugging ([#360](#360)) + Add two new 9pt and 27pt stencil examples ([#300](#300), [#306](#306)) + Allow benchmarking CuSPARSE spmv formats through Ginkgo's benchmarks ([#303](#303)) + New benchmark for sparse matrix format conversions ([#312](https://github.com/ginkgo-project/ginkgo/issues/312)[#317](https://github.com/ginkgo-project/ginkgo/issues/317)) + Add conversions between CSR and Hybrid formats ([#302](#302), [#310](#310)) + Support for sorting rows in the CSR format by column idices ([#322](#322)) + Addition of a CUDA COO SpMM kernel for improved performance ([#345](#345)) + Addition of a LinOp to handle perturbations of the form (identity + scalar * basis * projector) ([#334](#334)) + New sparsity matrix representation format with Reference and OpenMP kernels ([#349](#349), [#350](#350)) Fixes: + Accelerate GMRES solver for CUDA executor ([#363](#363)) + Fix BiCGSTAB solver convergence ([#359](#359)) + Fix CGS logging by reporting the residual for every sub iteration ([#328](#328)) + Fix CSR,Dense->Sellp conversion's memory access violation ([#295](#295)) + Accelerate CSR->Ell,Hybrid conversions on CUDA ([#313](#313), [#318](#318)) + Fixed slowdown of COO SpMV on OpenMP ([#340](#340)) + Fix gcc 6.4.0 internal compiler error ([#316](#316)) + Fix compilation issue on Apple clang++ 10 ([#322](#322)) + Make Ginkgo able to compile on Intel 2017 and above ([#337](#337)) + Make the benchmarks spmv/solver use the same matrix formats ([#366](#366)) + Fix self-written isfinite function ([#348](#348)) + Fix Jacobi issues shown by cuda-memcheck Tools and ecosystem: + Multiple improvements to the CI system and tools ([#296](#296), [#311](#311), [#365](#365)) + Multiple improvements to the Ginkgo containers ([#328](#328), [#361](#361)) + Add sonarqube analysis to Ginkgo ([#304](#304), [#308](#308), [#309](#309)) + Add clang-tidy and iwyu support to Ginkgo ([#298](#298)) + Improve Ginkgo's support of xSDK M12 policy by adding the `TPL_` arguments to CMake ([#300](#300)) + Add support for the xSDK R7 policy ([#325](#325)) + Fix examples in html documentation ([#367](#367))
The Ginkgo team is proud to announce the new minor release of Ginkgo version 1.1.0. This release brings several performance improvements, adds Windows support, adds support for factorizations inside Ginkgo and a new ILU preconditioner based on ParILU algorithm, among other things. For detailed information, check the respective issue. Supported systems and requirements: + For all platforms, cmake 3.9+ + Linux and MacOS + gcc: 5.3+, 6.3+, 7.3+, 8.1+ + clang: 3.9+ + Intel compiler: 2017+ + Apple LLVM: 8.0+ + CUDA module: CUDA 9.0+ + Windows + MinGW and Cygwin: gcc 5.3+, 6.3+, 7.3+, 8.1+ + Microsoft Visual Studio: VS 2017 15.7+ + CUDA module: CUDA 9.0+, Microsoft Visual Studio + OpenMP module: MinGW or Cygwin. The current known issues can be found in the [known issues page](https://github.com/ginkgo-project/ginkgo/wiki/Known-Issues). ### Additions + Upper and lower triangular solvers ([#327](#327), [#336](#336), [#341](#341), [#342](#342)) + New factorization support in Ginkgo, and addition of the ParILU algorithm ([#305](#305), [#315](#315), [#319](#319), [#324](#324)) + New ILU preconditioner ([#348](#348), [#353](#353)) + Windows MinGW and Cygwin support ([#347](#347)) + Windows Visual Studio support ([#351](#351)) + New example showing how to use ParILU as a preconditioner ([#358](#358)) + New example on using loggers for debugging ([#360](#360)) + Add two new 9pt and 27pt stencil examples ([#300](#300), [#306](#306)) + Allow benchmarking CuSPARSE spmv formats through Ginkgo's benchmarks ([#303](#303)) + New benchmark for sparse matrix format conversions ([#312](https://github.com/ginkgo-project/ginkgo/issues/312)[#317](https://github.com/ginkgo-project/ginkgo/issues/317)) + Add conversions between CSR and Hybrid formats ([#302](#302), [#310](#310)) + Support for sorting rows in the CSR format by column idices ([#322](#322)) + Addition of a CUDA COO SpMM kernel for improved performance ([#345](#345)) + Addition of a LinOp to handle perturbations of the form (identity + scalar * basis * projector) ([#334](#334)) + New sparsity matrix representation format with Reference and OpenMP kernels ([#349](#349), [#350](#350)) ### Fixes + Accelerate GMRES solver for CUDA executor ([#363](#363)) + Fix BiCGSTAB solver convergence ([#359](#359)) + Fix CGS logging by reporting the residual for every sub iteration ([#328](#328)) + Fix CSR,Dense->Sellp conversion's memory access violation ([#295](#295)) + Accelerate CSR->Ell,Hybrid conversions on CUDA ([#313](#313), [#318](#318)) + Fixed slowdown of COO SpMV on OpenMP ([#340](#340)) + Fix gcc 6.4.0 internal compiler error ([#316](#316)) + Fix compilation issue on Apple clang++ 10 ([#322](#322)) + Make Ginkgo able to compile on Intel 2017 and above ([#337](#337)) + Make the benchmarks spmv/solver use the same matrix formats ([#366](#366)) + Fix self-written isfinite function ([#348](#348)) + Fix Jacobi issues shown by cuda-memcheck ### Tools and ecosystem improvements + Multiple improvements to the CI system and tools ([#296](#296), [#311](#311), [#365](#365)) + Multiple improvements to the Ginkgo containers ([#328](#328), [#361](#361)) + Add sonarqube analysis to Ginkgo ([#304](#304), [#308](#308), [#309](#309)) + Add clang-tidy and iwyu support to Ginkgo ([#298](#298)) + Improve Ginkgo's support of xSDK M12 policy by adding the `TPL_` arguments to CMake ([#300](#300)) + Add support for the xSDK R7 policy ([#325](#325)) + Fix examples in html documentation ([#367](#367)) Related PR: #370
This PR implements the COO spmm.
The implementation is simple. Read the needed elements, compute the result, and then atomicAdd the result to the output vectors.
I have tried that using warp to read a line of nonzeros and shuffle the data in previous commit warp_coo However, in my testing cases, this warp implementation is not better than the simple one because the warp implementation needs more condition and index calculation.
The matrices set contain 100 matrices which obtained by using K-means on the whole real matrices set of SuiteSparse.
nrhs = 32