-
Notifications
You must be signed in to change notification settings - Fork 259
Add diagm in CUBLAS #2786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add diagm in CUBLAS #2786
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl
index 193b7c92d..7c01ef730 100644
--- a/lib/cublas/linalg.jl
+++ b/lib/cublas/linalg.jl
@@ -387,12 +387,12 @@ end
# diagm
-LinearAlgebra.diagm(kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm(nothing, kv...)
-LinearAlgebra.diagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm((Int(m),Int(n)), kv...)
+LinearAlgebra.diagm(kv::Pair{<:Integer, <:CuVector}...) = _cuda_diagm(nothing, kv...)
+LinearAlgebra.diagm(m::Integer, n::Integer, kv::Pair{<:Integer, <:CuVector}...) = _cuda_diagm((Int(m), Int(n)), kv...)
LinearAlgebra.diagm(v::CuVector) = LinearAlgebra.diagm(0 => v)
LinearAlgebra.diagm(m::Integer, n::Integer, v::CuVector) = LinearAlgebra.diagm(m, n, 0 => v)
-function _cuda_diagm(size, kv::Pair{<:Integer,<:CuVector}...)
+function _cuda_diagm(size, kv::Pair{<:Integer, <:CuVector}...)
A = LinearAlgebra.diagm_container(size, kv...)
for p in kv
inds = LinearAlgebra.diagind(A, p.first)
@@ -401,19 +401,19 @@ function _cuda_diagm(size, kv::Pair{<:Integer,<:CuVector}...)
return A
end
-function LinearAlgebra.diagm_container(size, kv::Pair{<:Integer,<:CuVector}...)
+function LinearAlgebra.diagm_container(size, kv::Pair{<:Integer, <:CuVector}...)
T = promote_type(map(x -> eltype(x.second), kv)...)
U = promote_type(T, typeof(zero(T)))
return cu(zeros(U, LinearAlgebra.diagm_size(size, kv...)...))
end
-function LinearAlgebra.diagm_size(size::Nothing, kv::Pair{<:Integer,<:CuVector}...)
- mnmax = mapreduce(x -> length(x.second) + abs(Int(x.first)), max, kv; init=0)
+function LinearAlgebra.diagm_size(size::Nothing, kv::Pair{<:Integer, <:CuVector}...)
+ mnmax = mapreduce(x -> length(x.second) + abs(Int(x.first)), max, kv; init = 0)
return mnmax, mnmax
end
-function LinearAlgebra.diagm_size(size::Tuple{Int,Int}, kv::Pair{<:Integer,<:CuVector}...)
- mmax = mapreduce(x -> length(x.second) - min(0,Int(x.first)), max, kv; init=0)
- nmax = mapreduce(x -> length(x.second) + max(0,Int(x.first)), max, kv; init=0)
+function LinearAlgebra.diagm_size(size::Tuple{Int, Int}, kv::Pair{<:Integer, <:CuVector}...)
+ mmax = mapreduce(x -> length(x.second) - min(0, Int(x.first)), max, kv; init = 0)
+ nmax = mapreduce(x -> length(x.second) + max(0, Int(x.first)), max, kv; init = 0)
m, n = size
(m ≥ mmax && n ≥ nmax) || throw(DimensionMismatch(lazy"invalid size=$size"))
return m, n |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2786 +/- ##
==========================================
+ Coverage 89.45% 89.59% +0.13%
==========================================
Files 153 153
Lines 13274 13298 +24
==========================================
+ Hits 11874 11914 +40
+ Misses 1400 1384 -16 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Thanks! |
|
Great! Can we cut a release with recent changes? I need them in my package testing.@maleadt |
|
These are technically features, and we've only just cut a release... I'll add these to a patch version then. |
|
It seems there are some failures with JuliaRegister 8b6a2a0#commitcomment-157997774, the latest version is not registered. @maleadt |
|
Seems like it worked now: JuliaRegistries/General#131869 |
Fix: #2785