@@ -336,14 +336,55 @@ function SparseArrays.sparsevec(I::CuArray{Ti}, V::CuArray{Tv}, n::Integer) wher
336336 CuSparseVector (I, V, n)
337337end
338338
339- function SparseArrays. spdiagm (v:: CuVector{Tv} ) where {Tv}
340- nzVal = v
341- N = Int32 (length (nzVal))
342-
343- colPtr = CuArray (one (Int32): (N + one (Int32)))
344- rowVal = CuArray (one (Int32): N)
345- dims = (N, N)
346- CuSparseMatrixCSC (colPtr, rowVal, nzVal, dims)
339+ SparseArrays. spdiagm (kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_spdiagm (nothing , kv... )
340+ SparseArrays. spdiagm (m:: Integer , n:: Integer , kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_spdiagm ((Int (m),Int (n)), kv... )
341+ SparseArrays. spdiagm (v:: CuVector ) = _cuda_spdiagm (nothing , 0 => v)
342+ SparseArrays. spdiagm (m:: Integer , n:: Integer , v:: CuVector ) = _cuda_spdiagm ((Int (m), Int (n)), 0 => v)
343+
344+ function _cuda_spdiagm (size, kv:: Pair{<:Integer, <:CuVector} ...)
345+ I, J, V, mmax, nmax = _cuda_spdiagm_internal (kv... )
346+ mnmax = max (mmax, nmax)
347+ m, n = something (size, (mnmax,mnmax))
348+ (m ≥ mmax && n ≥ nmax) || throw (DimensionMismatch (" invalid size=$size " ))
349+ return sparse (CuVector (I), CuVector (J), V, m, n)
350+ end
351+
352+ function _cuda_spdiagm_internal (kv:: Pair{T,<:CuVector} ...) where {T<: Integer }
353+ ncoeffs = 0
354+ for p in kv
355+ ncoeffs += SparseArrays. _nnz (p. second)
356+ end
357+ I = Vector {T} (undef, ncoeffs)
358+ J = Vector {T} (undef, ncoeffs)
359+ V = CuArray {promote_type(map(x -> eltype(x.second), kv)...)} (undef, ncoeffs)
360+ i = 0
361+ m = 0
362+ n = 0
363+ for p in kv
364+ k = p. first
365+ v = p. second
366+ if k < 0
367+ row = - k
368+ col = 0
369+ elseif k > 0
370+ row = 0
371+ col = k
372+ else
373+ row = 0
374+ col = 0
375+ end
376+ numel = SparseArrays. _nnz (v)
377+ r = 1 + i: numel+ i
378+ I_r, J_r = SparseArrays. _indices (v, row, col)
379+ copyto! (view (I, r), I_r)
380+ copyto! (view (J, r), J_r)
381+ copyto! (view (V, r), v)
382+ veclen = length (v)
383+ m = max (m, row + veclen)
384+ n = max (n, col + veclen)
385+ i += numel
386+ end
387+ return I, J, V, m, n
347388end
348389
349390LinearAlgebra. issymmetric (M:: Union{CuSparseMatrixCSC,CuSparseMatrixCSR} ) = size (M, 1 ) == size (M, 2 ) ? norm (M - transpose (M), Inf ) == 0 : false
0 commit comments