Skip to content

Commit

Permalink
[CUSPARSE] Support in-place triangular solves with CUDA v12.x (#2164)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Nov 13, 2023
1 parent 98c171a commit 3e59794
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 132 deletions.
8 changes: 6 additions & 2 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,17 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
transa = T <: Real && transa == 'C' ? 'T' : transa
transb = T <: Real && transb == 'C' ? 'T' : transb

# Check if we solve a triangular system in-place with transb != 'N'.
# In that case we need to update the descriptor of C such that it represents Bᵀ.
is_C_transposed = (B === C) && (transb != 'N')

if isa(A, CuSparseMatrixCSC) && transa == 'C' && T <: Complex
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR or COO matrix instead."))
end

mA,nA = size(A)
mB,nB = size(B)
mC,nC = size(C)
mC,nC = !is_C_transposed ? size(C) : reverse(size(C))

(mA != nA) && throw(DimensionMismatch("A must be square, but has dimensions ($mA,$nA)!"))
(mC != mA) && throw(DimensionMismatch("C must have $mA rows, but has $mC rows"))
Expand All @@ -566,7 +570,7 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
cusparseSpMatSetAttribute(descA, 'D', cusparse_diag, Csize_t(sizeof(cusparse_diag)))

descB = CuDenseMatrixDescriptor(B)
descC = CuDenseMatrixDescriptor(C)
descC = CuDenseMatrixDescriptor(C, transposed=is_C_transposed)

spsm_desc = CuSparseSpSMDescriptor()
function bufferSize()
Expand Down
Loading

0 comments on commit 3e59794

Please sign in to comment.