@@ -385,6 +385,40 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoin
385385 return C
386386end
387387
388+ # diagm
389+
390+ LinearAlgebra. diagm (kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_diagm (nothing , kv... )
391+ LinearAlgebra. diagm (m:: Integer , n:: Integer , kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_diagm ((Int (m),Int (n)), kv... )
392+ LinearAlgebra. diagm (v:: CuVector ) = LinearAlgebra. diagm (0 => v)
393+ LinearAlgebra. diagm (m:: Integer , n:: Integer , v:: CuVector ) = LinearAlgebra. diagm (m, n, 0 => v)
394+
395+ function _cuda_diagm (size, kv:: Pair{<:Integer,<:CuVector} ...)
396+ A = LinearAlgebra. diagm_container (size, kv... )
397+ for p in kv
398+ inds = LinearAlgebra. diagind (A, p. first)
399+ copyto! (view (A, inds), p. second)
400+ end
401+ return A
402+ end
403+
404+ function LinearAlgebra. diagm_container (size, kv:: Pair{<:Integer,<:CuVector} ...)
405+ T = promote_type (map (x -> eltype (x. second), kv)... )
406+ U = promote_type (T, typeof (zero (T)))
407+ return cu (zeros (U, LinearAlgebra. diagm_size (size, kv... )... ))
408+ end
409+
410+ function LinearAlgebra. diagm_size (size:: Nothing , kv:: Pair{<:Integer,<:CuVector} ...)
411+ mnmax = mapreduce (x -> length (x. second) + abs (Int (x. first)), max, kv; init= 0 )
412+ return mnmax, mnmax
413+ end
414+ function LinearAlgebra. diagm_size (size:: Tuple{Int,Int} , kv:: Pair{<:Integer,<:CuVector} ...)
415+ mmax = mapreduce (x -> length (x. second) - min (0 ,Int (x. first)), max, kv; init= 0 )
416+ nmax = mapreduce (x -> length (x. second) + max (0 ,Int (x. first)), max, kv; init= 0 )
417+ m, n = size
418+ (m ≥ mmax && n ≥ nmax) || throw (DimensionMismatch (lazy " invalid size=$size" ))
419+ return m, n
420+ end
421+
388422# symmetric mul!
389423
390424op_wrappers = ((identity, T -> ' N' , identity),
0 commit comments