-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUSPARSE] Update the interface for triangular solves #2164
Conversation
18a3fad
to
37ab076
Compare
@dkarrasch
into
? If the right-hand side |
No, at least that it is not designed like that. The reason is that triangular solves are typically 2-arg only. But you could create your own. generic_trimatdiv!(C, uploc, isunitc, tfun, A, B::SomeCuMatrix) = # as usual
generic_trimatdiv!(C, uploc, isunitc, tfun, A, B::Transpose{<:Any,<:SomeCuMatrix}) =
_generic_trimatdiv!(C, uploc, isunitc, tfunA, A, tfunB, B) (with the other slots type-annotated as appropriate), which then calls whatever CUSPARSE offers. |
I already did that but in the case that the triangular solves are in-place, We can use 2-arg methods again with CUDA v12.x so it could simplify many things. |
The dispatch in LinearAlgebra is as follows: if you call |
Thanks @dkarrasch! I understand how the dispatch with function LinearAlgebra.generic_trimatdiv!(C::DenseCuMatrix{T}, uploc, isunitc, tfun, A, B::AdjOrTrans{T,<:DenseCuMatrix{T}}) where {T<:BlasFloat}
...
end
function LinearAlgebra.generic_trimatdiv!(C::Transpose{T,<:DenseCuMatrix{T}}, uploc, isunitc, tfun, A, B::Transpose{T,<:DenseCuMatrix{T}}) where {T<:BlasFloat}
...
end
function LinearAlgebra.generic_trimatdiv!(C::Adjoint{T,<:DenseCuMatrix{T}}, uploc, isunitc, tfun, A, B::Adjoint{T,<:DenseCuMatrix{T}}) where {T<:BlasFloat}
...
end |
3c62bc9
to
48ed919
Compare
@dkarrasch
It seems that NVIDIA reintroduced in-place triangular solves in CUSPARSE with recent CUDA toolkits.
I also found a few typos that you added in #1946. Can you review the PR?
Can you also confirm that
ldiv!(C, A, B)
,ldiv!(A, B)
andA \ B
and automatically defined for the four triangular wrappers if we definegeneric_trimatdiv!
?