From 7e1821dfa4bcb10b8f3183b7dccf00c01e7bccf1 Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Thu, 3 Apr 2025 15:15:24 -0400
Subject: [PATCH 01/11] Very rough implementation of bcast for CuSparseVector

---
 lib/cusparse/broadcast.jl            | 169 +++++++++++++++++++++++----
 lib/cusparse/device.jl               |   2 +-
 test/Project.toml                    |   1 +
 test/libraries/cusparse/broadcast.jl |  60 +++++++++-
 4 files changed, 205 insertions(+), 27 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index 273b028c2f..ad6b356ace 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -103,7 +103,6 @@ end
 
 ## COV_EXCL_START
 ## iteration helpers
-
 """
     CSRIterator{Ti}(row, args...)
 
@@ -295,9 +294,9 @@ function _getindex(arg::Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC},
         @inbounds arg.nzVal[ptr]
     end
 end
+_getindex(arg::CuDeviceArray, I, ptr) = @inbounds arg[I]
 _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
 
-
 ## sparse broadcast implementation
 
 iter_type(::Type{<:CuSparseMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
@@ -305,6 +304,42 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 
+function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args...) where Ti
+    row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
+    row    = row_ix + first_row - 1i32
+    
+    row > last_row && return
+
+    # TODO load arg.iPtr slices into shared memory
+    row_is_nnz = 0i32
+    for arg in args
+        if arg isa CuSparseDeviceVector
+            for arg_row in arg.iPtr
+                if arg_row == row 
+                    row_is_nnz = 1i32
+                    @inbounds offsets[row] = row
+                end
+                arg_row > row && break
+            end
+        else
+            @inbounds offsets[row] = row
+            row_is_nnz = 1i32
+        end
+        sync_warp()
+    end
+    # count nnz in the warp
+    offset = 16
+    while (offset >= 1)
+        row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset)
+        offset = offset >> 1
+    end
+    if threadIdx().x == 1
+        CUDA.@atomic offsets[last_row+1] += row_is_nnz
+    end
+    return
+end
+# TODO: unify CSC/CSR kernels
+
 # kernel to count the number of non-zeros in a row, to determine the row offsets
 function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
                                 args...) where Ti
@@ -332,6 +367,30 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
 end
 
 # broadcast kernels that iterate the elements of sparse arrays
+function _getindex(A::CuSparseDeviceVector{Tv}, row, ptr) where {Tv}
+    for r_ix in 1:A.nnz
+        arg_row = @inbounds A.iPtr[r_ix]
+        arg_row == row && return @inbounds A.nzVal[r_ix]
+        arg_row > row && return zero(Tv)
+    end
+    return zero(Tv)
+end
+
+function sparse_to_sparse_broadcast_kernel(f, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti},
+                                           offsets::Union{AbstractVector,Nothing},
+                                           args::Vararg{Any, N}) where {Tv, Ti, N}
+    row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
+    row_ix > output.nnz && return
+    row    = @inbounds output.iPtr[row_ix + first_row - 1i32]
+    vals   = ntuple(Val(N)) do i
+        arg = @inbounds args[i]
+        _getindex(arg, row, 0)::Tv
+    end
+    output_val = f(vals...)
+    @inbounds output.nzVal[row_ix] = output_val
+    return
+end
+
 function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}}
     # every thread processes an entire row
     leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -392,8 +451,24 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
 
     return
 end
-## COV_EXCL_STOP
 
+function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f,
+                                          output::CuDeviceArray{Tv}, args::Vararg{Any, N}) where {Tv, N}
+    # every thread processes an entire row
+    row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
+    row > length(output) && return
+
+    # set the values for this row
+    vals = ntuple(Val(N)) do i
+        arg = @inbounds args[i] 
+        _getindex(arg, row, 0)::Tv
+    end
+    out_val =  f(vals...)
+    @inbounds output[row] = out_val 
+    return
+end
+## COV_EXCL_STOP
+const N_VEC_THREADS = 512
 function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}})
     # find the sparse inputs
     bc = Broadcast.flatten(bc)
@@ -405,12 +480,14 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
         error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
     end
     sparse_typ = typeof(bc.args[first(sparse_args)])
-    sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC} ||
-        error("broadcast with sparse arrays is currently only implemented for CSR and CSC matrices")
+    sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector} ||
+        error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
     Ti = if sparse_typ <: CuSparseMatrixCSR
         reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
     elseif sparse_typ <: CuSparseMatrixCSC
         reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
+    elseif sparse_typ <: CuSparseVector
+        reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
     end
 
     # determine the output type
@@ -433,21 +510,29 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
 
     # the kernels below parallelize across rows or cols, not elements, so it's unlikely
     # we'll launch many threads. to maximize utilization, parallelize across blocks first.
-    rows, cols = size(bc)
+    rows, cols = sparse_typ <: CuSparseVector ? (length(bc), 1) : size(bc)
     function compute_launch_config(kernel)
         config = launch_configuration(kernel.fun)
         if sparse_typ <: CuSparseMatrixCSR
             threads = min(rows, config.threads)
-            blocks = max(cld(rows, threads), config.blocks)
+            blocks  = max(cld(rows, threads), config.blocks)
             threads = cld(rows, blocks)
         elseif sparse_typ <: CuSparseMatrixCSC
             threads = min(cols, config.threads)
-            blocks = max(cld(cols, threads), config.blocks)
+            blocks  = max(cld(cols, threads), config.blocks)
             threads = cld(cols, blocks)
+        elseif sparse_typ <: CuSparseVector
+            threads = N_VEC_THREADS
+            blocks  = max(cld(rows, threads), config.blocks)
+            threads = N_VEC_THREADS
         end
         (; threads, blocks)
     end
-
+    # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
+    # but the only rows present in any sparse vector input are between 2 and 128, no need to
+    # launch massive threads. TODO: use the difference here to set the thread count
+    overall_first_row = one(Ti) 
+    overall_last_row = Ti(rows)
     # allocate the output container
     if !fpreszeros
         # either we have dense inputs, or the function isn't preserving zeros,
@@ -472,14 +557,18 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
         sparse_arg = bc.args[first(sparse_args)]
         if sparse_typ <: CuSparseMatrixCSR
             offsets = rowPtr = sparse_arg.rowPtr
-            colVal = similar(sparse_arg.colVal)
-            nzVal = similar(sparse_arg.nzVal, Tv)
-            output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc))
+            colVal  = similar(sparse_arg.colVal)
+            nzVal   = similar(sparse_arg.nzVal, Tv)
+            output  = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc))
         elseif sparse_typ <: CuSparseMatrixCSC
             offsets = colPtr = sparse_arg.colPtr
-            rowVal = similar(sparse_arg.rowVal)
-            nzVal = similar(sparse_arg.nzVal, Tv)
-            output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
+            rowVal  = similar(sparse_arg.rowVal)
+            nzVal   = similar(sparse_arg.nzVal, Tv)
+            output  = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
+        elseif sparse_typ <: CuSparseVector
+            offsets = iPtr = sparse_arg.iPtr
+            nzVal   = similar(sparse_arg.nzVal, Tv)
+            output  = CuSparseVector(iPtr, nzVal, length(bc))
         end
         # NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays,
         #       while we do that in our kernel (for consistency with other code paths)
@@ -490,9 +579,31 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             CuArray{Ti}(undef, rows+1)
         elseif sparse_typ <: CuSparseMatrixCSC
             CuArray{Ti}(undef, cols+1)
+        elseif sparse_typ <: CuSparseVector
+            CUDA.@allowscalar begin
+                arg_first_rows = ntuple(Val(length(bc.args))) do i
+                    bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[1]
+                    return one(Ti)
+                end
+                arg_last_rows = ntuple(Val(length(bc.args))) do i
+                    bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[end]
+                    return Ti(rows)
+                end
+            end
+            overall_first_row = min(arg_first_rows...)
+            overall_last_row  = max(arg_last_rows...)
+            offsets_len = overall_last_row - overall_first_row 
+            # last element is new NNZ
+            CUDA.zeros(Ti, offsets_len + 1)
         end
+        # TODO for sparse vectors, is it worth it to store a "workspace" array of where every row in
+        # output is in each sparse argument?
         let
-            args = (sparse_typ, offsets, bc.args...)
+            if sparse_typ <: CuSparseVector
+                args = (sparse_typ, overall_first_row, overall_last_row, offsets, bc.args...)
+            else
+                args = (sparse_typ, offsets, bc.args...)
+            end
             kernel = @cuda launch=false compute_offsets_kernel(args...)
             threads, blocks = compute_launch_config(kernel)
             kernel(args...; threads, blocks)
@@ -501,22 +612,36 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
         # accumulate these values so that we can use them directly as row pointer offsets,
         # as well as to get the total nnz count to allocate the sparse output array.
         # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort
-        accumulate!(Base.add_sum, offsets, offsets)
-        total_nnz = @allowscalar last(offsets[end]) - 1
-
+        if !(sparse_typ <: CuSparseVector)
+            accumulate!(Base.add_sum, offsets, offsets)
+            total_nnz = @allowscalar last(offsets[end]) - 1
+        else
+            total_nnz = @allowscalar last(offsets[end])
+        end
+        @assert total_nnz >= 0
         output = if sparse_typ <: CuSparseMatrixCSR
             colVal = CuArray{Ti}(undef, total_nnz)
-            nzVal = CuArray{Tv}(undef, total_nnz)
+            nzVal  = CuArray{Tv}(undef, total_nnz)
             CuSparseMatrixCSR(offsets, colVal, nzVal, size(bc))
         elseif sparse_typ <: CuSparseMatrixCSC
             rowVal = CuArray{Ti}(undef, total_nnz)
-            nzVal = CuArray{Tv}(undef, total_nnz)
+            nzVal  = CuArray{Tv}(undef, total_nnz)
             CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc))
+        elseif sparse_typ <: CuSparseVector
+            nzVal  = CuArray{Tv}(undef, total_nnz)
+            iPtr   = CuArray{Ti}(undef, total_nnz)
+            copyto!(iPtr, 1, offsets, 1, total_nnz)
+            CuSparseVector(iPtr, nzVal, rows)
         end
     end
 
     # perform the actual broadcast
-    if output isa AbstractCuSparseArray
+    if output isa CuSparseVector 
+        args = (bc.f, overall_first_row, overall_last_row, output, offsets, bc.args...)
+        kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
+        threads, blocks = compute_launch_config(kernel)
+        kernel(args...; threads, blocks)
+    elseif output isa AbstractCuSparseArray
         args = (bc.f, output, offsets, bc.args...)
         kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
         threads, blocks = compute_launch_config(kernel)
diff --git a/lib/cusparse/device.jl b/lib/cusparse/device.jl
index b81fb57a65..4e2925c3e1 100644
--- a/lib/cusparse/device.jl
+++ b/lib/cusparse/device.jl
@@ -19,7 +19,7 @@ struct CuSparseDeviceVector{Tv,Ti, A} <: AbstractSparseVector{Tv,Ti}
     nnz::Ti
 end
 
-Base.length(g::CuSparseDeviceVector) = prod(g.dims)
+Base.length(g::CuSparseDeviceVector) = g.len
 Base.size(g::CuSparseDeviceVector) = (g.len,)
 SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz
 
diff --git a/test/Project.toml b/test/Project.toml
index 5d6ea83e88..810a96e1cf 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -5,6 +5,7 @@ BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
 CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc"
 CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
 DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
 Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
 Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl
index 37fef9ff2a..e73f40bc4b 100644
--- a/test/libraries/cusparse/broadcast.jl
+++ b/test/libraries/cusparse/broadcast.jl
@@ -1,10 +1,9 @@
 using CUDA.CUSPARSE, SparseArrays
 
-m,n = 5,6
-p = 0.5
-
 for elty in [Int32, Int64, Float32, Float64]
-    @testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
+   @testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
+        m,n = 5,6
+        p = 0.5
         x = sprand(elty, m, n, p)
         dx = typ(x)
 
@@ -34,6 +33,59 @@ for elty in [Int32, Int64, Float32, Float64]
         @test dz isa typ{elty}
         @test z == SparseMatrixCSC(dz)
     end
+    @testset "$typ($elty)" for typ in [CuSparseVector,]
+        m = 64 
+        p = 0.5
+        x = sprand(elty, m, p)
+        dx = typ(x)
+
+        # zero-preserving
+        y = x .* elty(1)
+        dy = dx .* elty(1)
+        @test dy isa typ{elty}
+        @test y == SparseVector(dy)
+
+        # not zero-preserving
+        y = x .+ elty(1)
+        dy = dx .+ elty(1)
+        @test dy isa CuArray{elty}
+        @test y == Array(dy)
+
+        # involving something dense - broken for now
+        #=y = x .+ ones(elty, m)
+        dy = dx .+ CUDA.ones(elty, m)
+        @test dy isa CuArray{elty}
+        @test y == Array(dy)=#
+        
+        # sparse to sparse 
+        y = sprand(elty, m, p)
+        dy = typ(y)
+        dx = typ(x)
+        z  = x .* y
+        dz = dx .* dy
+        @test dz isa typ{elty}
+        @test z == SparseVector(dz)
+
+        # multiple inputs
+        #=y = sprand(elty, m, p)
+        w = sprand(elty, m, p)
+        dy = typ(y)
+        dx = typ(x)
+        dw = typ(w)
+        z  = @. x * y * w
+        dz = @. dx * dy * w
+        @test dz isa typ{elty}
+        @test z == SparseVector(dz)=#
+
+        # broken due to llvm IR
+        #=y = sprand(elty, m, p)
+        dy = typ(y)
+        dx = typ(x)
+        z  = x .* y .* elty(2)
+        dz = dx .* dy .* elty(2)
+        @test dz isa typ{elty}
+        @test z == SparseVector(dz)=#
+    end
 end
 
 @testset "bug: type conversions" begin

From 5fb2163f8deeba4e8488c1e463a068041a3ab11e Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Sat, 12 Apr 2025 15:01:01 -0400
Subject: [PATCH 02/11] Uncomment broken tests

---
 test/libraries/cusparse/broadcast.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl
index e73f40bc4b..5e184d4404 100644
--- a/test/libraries/cusparse/broadcast.jl
+++ b/test/libraries/cusparse/broadcast.jl
@@ -52,10 +52,10 @@ for elty in [Int32, Int64, Float32, Float64]
         @test y == Array(dy)
 
         # involving something dense - broken for now
-        #=y = x .+ ones(elty, m)
+        y = x .+ ones(elty, m)
         dy = dx .+ CUDA.ones(elty, m)
         @test dy isa CuArray{elty}
-        @test y == Array(dy)=#
+        @test y == Array(dy)
         
         # sparse to sparse 
         y = sprand(elty, m, p)
@@ -78,13 +78,13 @@ for elty in [Int32, Int64, Float32, Float64]
         @test z == SparseVector(dz)=#
 
         # broken due to llvm IR
-        #=y = sprand(elty, m, p)
+        y = sprand(elty, m, p)
         dy = typ(y)
         dx = typ(x)
         z  = x .* y .* elty(2)
         dz = dx .* dy .* elty(2)
         @test dz isa typ{elty}
-        @test z == SparseVector(dz)=#
+        @test z == SparseVector(dz)
     end
 end
 

From d9d9f491d10c33b14024962b64ebe8474b83c913 Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Mon, 14 Apr 2025 11:54:43 -0400
Subject: [PATCH 03/11] Use Vararg for compute_offsets_kernel

---
 lib/cusparse/broadcast.jl | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index ad6b356ace..0b328393c8 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -304,7 +304,7 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 
-function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args...) where Ti
+function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args::Vararg{Any, N}) where {Ti, N}
     row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row    = row_ix + first_row - 1i32
     
@@ -312,7 +312,8 @@ function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_ro
 
     # TODO load arg.iPtr slices into shared memory
     row_is_nnz = 0i32
-    for arg in args
+    for i in 1:N
+        arg = @inbounds args[i]
         if arg isa CuSparseDeviceVector
             for arg_row in arg.iPtr
                 if arg_row == row 

From dd799a6d15fa052d0c22c703376520b2782ce76a Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Mon, 14 Apr 2025 13:28:25 -0400
Subject: [PATCH 04/11] Uncomment more tests and add a @. for matrix bcast

---
 test/libraries/cusparse/broadcast.jl | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl
index 5e184d4404..0918adc5cc 100644
--- a/test/libraries/cusparse/broadcast.jl
+++ b/test/libraries/cusparse/broadcast.jl
@@ -32,6 +32,14 @@ for elty in [Int32, Int64, Float32, Float64]
         dz = dx .* dy .* elty(2)
         @test dz isa typ{elty}
         @test z == SparseMatrixCSC(dz)
+        
+        # multiple inputs
+        w = sprand(elty, m, n, p)
+        dw = typ(w)
+        z  = @. x * y * w
+        dz = @. dx * dy * w
+        @test dz isa typ{elty}
+        @test z == SparseMatrixCSC(dz)
     end
     @testset "$typ($elty)" for typ in [CuSparseVector,]
         m = 64 
@@ -67,7 +75,7 @@ for elty in [Int32, Int64, Float32, Float64]
         @test z == SparseVector(dz)
 
         # multiple inputs
-        #=y = sprand(elty, m, p)
+        y = sprand(elty, m, p)
         w = sprand(elty, m, p)
         dy = typ(y)
         dx = typ(x)
@@ -75,7 +83,7 @@ for elty in [Int32, Int64, Float32, Float64]
         z  = @. x * y * w
         dz = @. dx * dy * w
         @test dz isa typ{elty}
-        @test z == SparseVector(dz)=#
+        @test z == SparseVector(dz)
 
         # broken due to llvm IR
         y = sprand(elty, m, p)

From 19a50c28cce4910e9c9eac1561611d58b1a02539 Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <kshyatt@users.noreply.github.com>
Date: Mon, 14 Apr 2025 15:01:02 -0400
Subject: [PATCH 05/11] Update lib/cusparse/broadcast.jl

Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
---
 lib/cusparse/broadcast.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index 0b328393c8..d919c448de 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -377,9 +377,9 @@ function _getindex(A::CuSparseDeviceVector{Tv}, row, ptr) where {Tv}
     return zero(Tv)
 end
 
-function sparse_to_sparse_broadcast_kernel(f, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti},
+function sparse_to_sparse_broadcast_kernel(f::F, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti},
                                            offsets::Union{AbstractVector,Nothing},
-                                           args::Vararg{Any, N}) where {Tv, Ti, N}
+                                           args::Vararg{<:Any, N}) where {Tv, Ti, N, F}
     row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row_ix > output.nnz && return
     row    = @inbounds output.iPtr[row_ix + first_row - 1i32]

From 3af6ed66271725235218178b5e0d7c40d6ed5b7f Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Mon, 14 Apr 2025 16:02:05 -0400
Subject: [PATCH 06/11] Fix test

---
 test/libraries/cusparse/broadcast.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl
index 0918adc5cc..1709a2016d 100644
--- a/test/libraries/cusparse/broadcast.jl
+++ b/test/libraries/cusparse/broadcast.jl
@@ -36,8 +36,8 @@ for elty in [Int32, Int64, Float32, Float64]
         # multiple inputs
         w = sprand(elty, m, n, p)
         dw = typ(w)
-        z  = @. x * y * w
-        dz = @. dx * dy * w
+        z = x .* y .* w
+        dz = dx .* dy .* dw
         @test dz isa typ{elty}
         @test z == SparseMatrixCSC(dz)
     end
@@ -81,7 +81,7 @@ for elty in [Int32, Int64, Float32, Float64]
         dx = typ(x)
         dw = typ(w)
         z  = @. x * y * w
-        dz = @. dx * dy * w
+        dz = @. dx * dy * dw
         @test dz isa typ{elty}
         @test z == SparseVector(dz)
 

From 80165324a6618e5ff98026e7805f9c247a25b462 Mon Sep 17 00:00:00 2001
From: Tim Besard <tim.besard@gmail.com>
Date: Wed, 16 Apr 2025 12:06:36 +0200
Subject: [PATCH 07/11] Work around bad codegen.

---
 lib/cusparse/broadcast.jl | 88 ++++++++++++++++++++++-----------------
 1 file changed, 49 insertions(+), 39 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index d919c448de..31d221034e 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -1,5 +1,7 @@
 using Base.Broadcast: Broadcasted
 
+using Base.Cartesian: @nany
+
 using CUDA: CuArrayStyle
 
 # TODO: support more types (SparseVector, SparseMatrixCSC, COO, BSR)
@@ -304,40 +306,47 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 
-function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args::Vararg{Any, N}) where {Ti, N}
-    row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
-    row    = row_ix + first_row - 1i32
-    
-    row > last_row && return
+@eval @generated function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti,
+                                                 last_row::Ti, offsets::AbstractVector{Ti},
+                                                 args...) where {Ti}
+    N = length(args)
+    quote
+        row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
+        row    = row_ix + first_row - 1i32
 
-    # TODO load arg.iPtr slices into shared memory
-    row_is_nnz = 0i32
-    for i in 1:N
-        arg = @inbounds args[i]
-        if arg isa CuSparseDeviceVector
-            for arg_row in arg.iPtr
-                if arg_row == row 
-                    row_is_nnz = 1i32
-                    @inbounds offsets[row] = row
+        row > last_row && return
+
+        # TODO load arg.iPtr slices into shared memory
+        row_is_nnz = @nany $N i -> begin
+            is_nnz = false
+            arg = @inbounds args[i]
+            if arg isa CuSparseDeviceVector
+                for arg_row in arg.iPtr
+                    if arg_row == row
+                        is_nnz = true
+                        @inbounds offsets[row] = row
+                    end
+                    arg_row > row && break
                 end
-                arg_row > row && break
+            else
+                @inbounds offsets[row] = row
+                is_nnz = true
             end
-        else
-            @inbounds offsets[row] = row
-            row_is_nnz = 1i32
+            is_nnz
         end
         sync_warp()
+
+        # count nnz in the warp
+        offset = 16
+        while (offset >= 1)
+            row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset)
+            offset = offset >> 1
+        end
+        if threadIdx().x == 1
+            CUDA.@atomic offsets[last_row+1] += row_is_nnz
+        end
+        return
     end
-    # count nnz in the warp
-    offset = 16
-    while (offset >= 1)
-        row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset)
-        offset = offset >> 1
-    end
-    if threadIdx().x == 1
-        CUDA.@atomic offsets[last_row+1] += row_is_nnz
-    end
-    return
 end
 # TODO: unify CSC/CSR kernels
 
@@ -379,7 +388,7 @@ end
 
 function sparse_to_sparse_broadcast_kernel(f::F, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti},
                                            offsets::Union{AbstractVector,Nothing},
-                                           args::Vararg{<:Any, N}) where {Tv, Ti, N, F}
+                                           args::Vararg{Any, N}) where {Tv, Ti, N, F}
     row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row_ix > output.nnz && return
     row    = @inbounds output.iPtr[row_ix + first_row - 1i32]
@@ -405,7 +414,7 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
     # fetch the row offset, and write it to the output
     @inbounds begin
         output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
-        if leading_dim == leading_dim_size 
+        if leading_dim == leading_dim_size
             output_ptrs[leading_dim+1i32] = offsets[leading_dim+1i32]
         end
     end
@@ -454,18 +463,19 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
 end
 
 function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f,
-                                          output::CuDeviceArray{Tv}, args::Vararg{Any, N}) where {Tv, N}
+                                         output::CuDeviceArray{Tv}, args...) where {Tv}
     # every thread processes an entire row
     row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row > length(output) && return
 
     # set the values for this row
-    vals = ntuple(Val(N)) do i
-        arg = @inbounds args[i] 
-        _getindex(arg, row, 0)::Tv
+    vals = ntuple(Val(length(args))) do i
+        @inline # XXX: works around bad codegen
+        arg = @inbounds args[i]
+        _getindex(arg, row, 0)
     end
-    out_val =  f(vals...)
-    @inbounds output[row] = out_val 
+    out_val = f(vals...)
+    @inbounds output[row] = out_val
     return
 end
 ## COV_EXCL_STOP
@@ -532,7 +542,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
     # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
     # but the only rows present in any sparse vector input are between 2 and 128, no need to
     # launch massive threads. TODO: use the difference here to set the thread count
-    overall_first_row = one(Ti) 
+    overall_first_row = one(Ti)
     overall_last_row = Ti(rows)
     # allocate the output container
     if !fpreszeros
@@ -593,7 +603,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             end
             overall_first_row = min(arg_first_rows...)
             overall_last_row  = max(arg_last_rows...)
-            offsets_len = overall_last_row - overall_first_row 
+            offsets_len = overall_last_row - overall_first_row
             # last element is new NNZ
             CUDA.zeros(Ti, offsets_len + 1)
         end
@@ -637,7 +647,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
     end
 
     # perform the actual broadcast
-    if output isa CuSparseVector 
+    if output isa CuSparseVector
         args = (bc.f, overall_first_row, overall_last_row, output, offsets, bc.args...)
         kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
         threads, blocks = compute_launch_config(kernel)

From 9bc0f9819bb4ae0b8eb686a5da9cf16e1033e79e Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Mon, 14 Apr 2025 16:02:05 -0400
Subject: [PATCH 08/11] Fix test

Get tests passing with a hand unroll

Use offsets as intermediate ptr storage

Don't use atmoic
---
 lib/cusparse/broadcast.jl            | 177 ++++++++++++---------------
 test/libraries/cusparse/broadcast.jl |  38 ++++--
 2 files changed, 107 insertions(+), 108 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index 31d221034e..b7bdfb1fdc 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -289,15 +289,18 @@ end
 end
 
 # helpers to index a sparse or dense array
-function _getindex(arg::Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC}, I, ptr)
+@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},CuSparseDeviceMatrixCSC{Tv},CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
     if ptr == 0
-        zero(eltype(arg))
+        return zero(Tv)
     else
-        @inbounds arg.nzVal[ptr]
+        return @inbounds arg.nzVal[ptr]::Tv
     end
 end
-_getindex(arg::CuDeviceArray, I, ptr) = @inbounds arg[I]
-_getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
+
+@inline function _getindex(arg::CuDeviceArray{Tv}, I, ptr)::Tv where {Tv}
+    return @inbounds arg[I]::Tv
+end
+@inline _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
 
 ## sparse broadcast implementation
 
@@ -306,49 +309,40 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
 iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
 
-@eval @generated function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti,
-                                                 last_row::Ti, offsets::AbstractVector{Ti},
-                                                 args...) where {Ti}
-    N = length(args)
-    quote
-        row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
-        row    = row_ix + first_row - 1i32
+_has_row(A, offsets, row::Int32, fpreszeros::Bool) = fpreszeros ? 0i32 : row
+_has_row(A::CuDeviceArray, offsets, row::Int32, ::Bool) = row
+function _has_row(A::CuSparseDeviceVector, offsets, row::Int32, ::Bool)::Int32
+    for row_ix in 1i32:length(A.iPtr)
+        arg_row = @inbounds A.iPtr[row_ix]
+        arg_row == row && return row_ix
+        arg_row > row && break
+    end
+    return 0i32
+end
 
-        row > last_row && return
+function _get_my_row(first_row)::Int32
+    row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
+    return row_ix + first_row - 1i32
+end
 
-        # TODO load arg.iPtr slices into shared memory
-        row_is_nnz = @nany $N i -> begin
-            is_nnz = false
-            arg = @inbounds args[i]
-            if arg isa CuSparseDeviceVector
-                for arg_row in arg.iPtr
-                    if arg_row == row
-                        is_nnz = true
-                        @inbounds offsets[row] = row
-                    end
-                    arg_row > row && break
-                end
-            else
-                @inbounds offsets[row] = row
-                is_nnz = true
-            end
-            is_nnz
-        end
-        sync_warp()
+function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Ti, N}
+    row = _get_my_row(first_row)
+    row > last_row && return
 
-        # count nnz in the warp
-        offset = 16
-        while (offset >= 1)
-            row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset)
-            offset = offset >> 1
-        end
-        if threadIdx().x == 1
-            CUDA.@atomic offsets[last_row+1] += row_is_nnz
-        end
-        return
+    # TODO load arg.iPtr slices into shared memory
+    row_is_nnz = 0i32
+    arg_row_is_nnz = ntuple(Val(N)) do i
+        arg = @inbounds args[i]
+        _has_row(arg, offsets, row, fpreszeros)::Int32
+    end
+    row_is_nnz = 0i32
+    for i in 1:N
+        row_is_nnz |= @inbounds arg_row_is_nnz[i]
     end
+    key = (row_is_nnz == 0i32) ? typemax(Ti) : row
+    @inbounds offsets[row - first_row + 1i32] = key => arg_row_is_nnz
+    return
 end
-# TODO: unify CSC/CSR kernels
 
 # kernel to count the number of non-zeros in a row, to determine the row offsets
 function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
@@ -376,27 +370,19 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
     return
 end
 
-# broadcast kernels that iterate the elements of sparse arrays
-function _getindex(A::CuSparseDeviceVector{Tv}, row, ptr) where {Tv}
-    for r_ix in 1:A.nnz
-        arg_row = @inbounds A.iPtr[r_ix]
-        arg_row == row && return @inbounds A.nzVal[r_ix]
-        arg_row > row && return zero(Tv)
-    end
-    return zero(Tv)
-end
-
-function sparse_to_sparse_broadcast_kernel(f::F, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti},
-                                           offsets::Union{AbstractVector,Nothing},
-                                           args::Vararg{Any, N}) where {Tv, Ti, N, F}
+function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, Ti, N, F}
     row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row_ix > output.nnz && return
-    row    = @inbounds output.iPtr[row_ix + first_row - 1i32]
-    vals   = ntuple(Val(N)) do i
+    row_and_ptrs = @inbounds offsets[row_ix]
+    row          = @inbounds row_and_ptrs[1]
+    args_are_nnz = @inbounds row_and_ptrs[2]
+    vals = ntuple(Val(N)) do i
         arg = @inbounds args[i]
-        _getindex(arg, row, 0)::Tv
+        arg_is_nnz = @inbounds args_are_nnz[i]
+        _getindex(arg, row, arg_is_nnz)::Tv
     end
     output_val = f(vals...)
+    @inbounds output.iPtr[row_ix]  = row 
     @inbounds output.nzVal[row_ix] = output_val
     return
 end
@@ -462,20 +448,21 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
     return
 end
 
-function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f,
-                                         output::CuDeviceArray{Tv}, args...) where {Tv}
+function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
+                                          output::CuDeviceArray{Tv}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, F, N, Ti}
     # every thread processes an entire row
-    row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
-    row > length(output) && return
-
-    # set the values for this row
+    row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
+    row_ix > length(output) && return
+    row_and_ptrs = @inbounds offsets[row_ix]
+    row          = @inbounds row_and_ptrs[1]
+    args_are_nnz = @inbounds row_and_ptrs[2]
     vals = ntuple(Val(length(args))) do i
-        @inline # XXX: works around bad codegen
         arg = @inbounds args[i]
-        _getindex(arg, row, 0)
+        arg_is_nnz = @inbounds args_are_nnz[i]
+        _getindex(arg, row, arg_is_nnz)::Tv
     end
     out_val = f(vals...)
-    @inbounds output[row] = out_val
+    @inbounds output[row] = out_val 
     return
 end
 ## COV_EXCL_STOP
@@ -544,8 +531,9 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
     # launch massive threads. TODO: use the difference here to set the thread count
     overall_first_row = one(Ti)
     overall_last_row = Ti(rows)
+    offsets = nothing
     # allocate the output container
-    if !fpreszeros
+    if !fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
         # either we have dense inputs, or the function isn't preserving zeros,
         # so use a dense output to broadcast into.
         output = CuArray{Tv}(undef, size(bc))
@@ -562,7 +550,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             end
         end
         broadcast!(bc.f, output, nonsparse_args...)
-    elseif length(sparse_args) == 1
+    elseif length(sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
         # we only have a single sparse input, so we can reuse its structure for the output.
         # this avoids a kernel launch and costly synchronization.
         sparse_arg = bc.args[first(sparse_args)]
@@ -576,10 +564,6 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             rowVal  = similar(sparse_arg.rowVal)
             nzVal   = similar(sparse_arg.nzVal, Tv)
             output  = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
-        elseif sparse_typ <: CuSparseVector
-            offsets = iPtr = sparse_arg.iPtr
-            nzVal   = similar(sparse_arg.nzVal, Tv)
-            output  = CuSparseVector(iPtr, nzVal, length(bc))
         end
         # NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays,
         #       while we do that in our kernel (for consistency with other code paths)
@@ -603,15 +587,11 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             end
             overall_first_row = min(arg_first_rows...)
             overall_last_row  = max(arg_last_rows...)
-            offsets_len = overall_last_row - overall_first_row
-            # last element is new NNZ
-            CUDA.zeros(Ti, offsets_len + 1)
+            CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1)
         end
-        # TODO for sparse vectors, is it worth it to store a "workspace" array of where every row in
-        # output is in each sparse argument?
         let
             if sparse_typ <: CuSparseVector
-                args = (sparse_typ, overall_first_row, overall_last_row, offsets, bc.args...)
+                args = (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
             else
                 args = (sparse_typ, offsets, bc.args...)
             end
@@ -619,7 +599,6 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             threads, blocks = compute_launch_config(kernel)
             kernel(args...; threads, blocks)
         end
-
         # accumulate these values so that we can use them directly as row pointer offsets,
         # as well as to get the total nnz count to allocate the sparse output array.
         # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort
@@ -627,9 +606,9 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             accumulate!(Base.add_sum, offsets, offsets)
             total_nnz = @allowscalar last(offsets[end]) - 1
         else
-            total_nnz = @allowscalar last(offsets[end])
+            sort!(offsets; by=first)
+            total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets)
         end
-        @assert total_nnz >= 0
         output = if sparse_typ <: CuSparseMatrixCSR
             colVal = CuArray{Ti}(undef, total_nnz)
             nzVal  = CuArray{Tv}(undef, total_nnz)
@@ -638,27 +617,33 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             rowVal = CuArray{Ti}(undef, total_nnz)
             nzVal  = CuArray{Tv}(undef, total_nnz)
             CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc))
-        elseif sparse_typ <: CuSparseVector
-            nzVal  = CuArray{Tv}(undef, total_nnz)
-            iPtr   = CuArray{Ti}(undef, total_nnz)
-            copyto!(iPtr, 1, offsets, 1, total_nnz)
+        elseif sparse_typ <: CuSparseVector && !fpreszeros
+            CuArray{Tv}(undef, size(bc))
+        elseif sparse_typ <: CuSparseVector && fpreszeros
+            iPtr   = CUDA.zeros(Ti, total_nnz)
+            nzVal  = CUDA.zeros(Tv, total_nnz)
             CuSparseVector(iPtr, nzVal, rows)
         end
+        if sparse_typ <: CuSparseVector && !fpreszeros
+            nonsparse_args = map(bc.args) do arg
+                # NOTE: this assumes the broadcst is flattened, but not yet preprocessed
+                if arg isa AbstractCuSparseArray
+                    zero(eltype(arg))
+                else
+                    arg
+                end
+            end
+            broadcast!(bc.f, output, nonsparse_args...)
+        end
     end
-
     # perform the actual broadcast
-    if output isa CuSparseVector
-        args = (bc.f, overall_first_row, overall_last_row, output, offsets, bc.args...)
-        kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
-        threads, blocks = compute_launch_config(kernel)
-        kernel(args...; threads, blocks)
-    elseif output isa AbstractCuSparseArray
-        args = (bc.f, output, offsets, bc.args...)
+    if output isa AbstractCuSparseArray
+        args   = (bc.f, output, offsets, bc.args...)
         kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
         threads, blocks = compute_launch_config(kernel)
         kernel(args...; threads, blocks)
     else
-        args = (sparse_typ, bc.f, output, bc.args...)
+        args   = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : (sparse_typ, bc.f, output, bc.args...)
         kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...)
         threads, blocks = compute_launch_config(kernel)
         kernel(args...; threads, blocks)
diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl
index 1709a2016d..b1a89437b5 100644
--- a/test/libraries/cusparse/broadcast.jl
+++ b/test/libraries/cusparse/broadcast.jl
@@ -1,6 +1,6 @@
 using CUDA.CUSPARSE, SparseArrays
 
-for elty in [Int32, Int64, Float32, Float64]
+@testset for elty in [Int32, Int64, Float32, Float64]
    @testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
         m,n = 5,6
         p = 0.5
@@ -46,29 +46,33 @@ for elty in [Int32, Int64, Float32, Float64]
         p = 0.5
         x = sprand(elty, m, p)
         dx = typ(x)
-
+        
         # zero-preserving
-        y = x .* elty(1)
+        y  = x .* elty(1)
         dy = dx .* elty(1)
         @test dy isa typ{elty}
-        @test y == SparseVector(dy)
-
+        @test collect(dy.iPtr) == collect(dx.iPtr) 
+        @test collect(dy.iPtr) == y.nzind
+        @test collect(dy.nzVal) == y.nzval
+        @test y  == SparseVector(dy)
+        
         # not zero-preserving
         y = x .+ elty(1)
         dy = dx .+ elty(1)
         @test dy isa CuArray{elty}
-        @test y == Array(dy)
+        hy = Array(dy)
+        @test Array(y) == hy 
 
-        # involving something dense - broken for now
+        # involving something dense
         y = x .+ ones(elty, m)
         dy = dx .+ CUDA.ones(elty, m)
         @test dy isa CuArray{elty}
         @test y == Array(dy)
-        
+         
         # sparse to sparse 
-        y = sprand(elty, m, p)
-        dy = typ(y)
         dx = typ(x)
+        y  = sprand(elty, m, p)
+        dy = typ(y)
         z  = x .* y
         dz = dx .* dy
         @test dz isa typ{elty}
@@ -84,8 +88,18 @@ for elty in [Int32, Int64, Float32, Float64]
         dz = @. dx * dy * dw
         @test dz isa typ{elty}
         @test z == SparseVector(dz)
-
-        # broken due to llvm IR
+        
+        y = sprand(elty, m, p)
+        w = sprand(elty, m, p)
+        dense_arr   = rand(elty, m)
+        d_dense_arr = CuArray(dense_arr) 
+        dy = typ(y)
+        dw = typ(w)
+        z  = @. x * y * w * dense_arr 
+        dz = @. dx * dy * dw * d_dense_arr 
+        @test dz isa CuArray{elty}
+        @test z == Array(dz)
+        
         y = sprand(elty, m, p)
         dy = typ(y)
         dx = typ(x)

From 2de968bf03878e67c1edd96680abc91124c3ba6e Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Fri, 18 Apr 2025 11:31:22 -0400
Subject: [PATCH 09/11] Cleanup

---
 lib/cusparse/broadcast.jl | 20 ++++++++++++--------
 test/Project.toml         |  1 -
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index b7bdfb1fdc..98cfd14e29 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -375,11 +375,13 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
     row_ix > output.nnz && return
     row_and_ptrs = @inbounds offsets[row_ix]
     row          = @inbounds row_and_ptrs[1]
-    args_are_nnz = @inbounds row_and_ptrs[2]
+    arg_ptrs     = @inbounds row_and_ptrs[2]
     vals = ntuple(Val(N)) do i
         arg = @inbounds args[i]
-        arg_is_nnz = @inbounds args_are_nnz[i]
-        _getindex(arg, row, arg_is_nnz)::Tv
+        # ptr is 0 if the sparse vector doesn't have an element at this row
+        # ptr is 0 if the arg is a scalar AND f preserves zeros
+        ptr = @inbounds arg_ptrs[i]
+        _getindex(arg, row, ptr)::Tv
     end
     output_val = f(vals...)
     @inbounds output.iPtr[row_ix]  = row 
@@ -455,14 +457,16 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
     row_ix > length(output) && return
     row_and_ptrs = @inbounds offsets[row_ix]
     row          = @inbounds row_and_ptrs[1]
-    args_are_nnz = @inbounds row_and_ptrs[2]
+    arg_ptrs     = @inbounds row_and_ptrs[2]
     vals = ntuple(Val(length(args))) do i
         arg = @inbounds args[i]
-        arg_is_nnz = @inbounds args_are_nnz[i]
-        _getindex(arg, row, arg_is_nnz)::Tv
+        # ptr is 0 if the sparse vector doesn't have an element at this row
+        # ptr is row if the arg is dense OR a scalar with non-zero-preserving f
+        # ptr is 0 if the arg is a scalar AND f preserves zeros
+        ptr = @inbounds arg_ptrs[i]
+        _getindex(arg, row, ptr)::Tv
     end
-    out_val = f(vals...)
-    @inbounds output[row] = out_val 
+    @inbounds output[row] = f(vals...)
     return
 end
 ## COV_EXCL_STOP
diff --git a/test/Project.toml b/test/Project.toml
index 810a96e1cf..5d6ea83e88 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -5,7 +5,6 @@ BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
 CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc"
 CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
 DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
 Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
 Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

From e2f3d4b1d2a3a2960fa28bb1b9ea3b446501211b Mon Sep 17 00:00:00 2001
From: Katharine Hyatt <khyatt@flatironinstitute.org>
Date: Mon, 21 Apr 2025 14:18:37 -0400
Subject: [PATCH 10/11] Remove unneeded using

---
 lib/cusparse/broadcast.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index 98cfd14e29..c52068c5a6 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -1,7 +1,5 @@
 using Base.Broadcast: Broadcasted
 
-using Base.Cartesian: @nany
-
 using CUDA: CuArrayStyle
 
 # TODO: support more types (SparseVector, SparseMatrixCSC, COO, BSR)

From 43dc26ffad87c573df95f5af64889c4087a0d531 Mon Sep 17 00:00:00 2001
From: Tim Besard <tim.besard@gmail.com>
Date: Thu, 24 Apr 2025 09:17:31 +0200
Subject: [PATCH 11/11] Clean-ups.

---
 lib/cusparse/broadcast.jl | 56 ++++++++++++++++++++++++---------------
 1 file changed, 34 insertions(+), 22 deletions(-)

diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl
index c52068c5a6..965c0ce478 100644
--- a/lib/cusparse/broadcast.jl
+++ b/lib/cusparse/broadcast.jl
@@ -84,7 +84,8 @@ end
         end
     end
 end
-@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
+@inline function _capturescalars(arg)
+    # this definition is just an optimization (to bottom out the recursion slightly sooner)
     if scalararg(arg)
         return (), () -> (arg,) # add scalararg
     elseif scalarwrappedarg(arg)
@@ -287,7 +288,9 @@ end
 end
 
 # helpers to index a sparse or dense array
-@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},CuSparseDeviceMatrixCSC{Tv},CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
+@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},
+                                      CuSparseDeviceMatrixCSC{Tv},
+                                      CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
     if ptr == 0
         return zero(Tv)
     else
@@ -323,7 +326,9 @@ function _get_my_row(first_row)::Int32
     return row_ix + first_row - 1i32
 end
 
-function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Ti, N}
+function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti,
+                                fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+                                args...) where {Ti, N}
     row = _get_my_row(first_row)
     row > last_row && return
 
@@ -343,7 +348,8 @@ function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_ro
 end
 
 # kernel to count the number of non-zeros in a row, to determine the row offsets
-function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
+function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}},
+                                offsets::AbstractVector{Ti},
                                 args...) where Ti
     # every thread processes an entire row
     leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -368,7 +374,9 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
     return
 end
 
-function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, Ti, N, F}
+function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti},
+                                           offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+                                           args...) where {Tv, Ti, N, F}
     row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row_ix > output.nnz && return
     row_and_ptrs = @inbounds offsets[row_ix]
@@ -382,12 +390,14 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
         _getindex(arg, row, ptr)::Tv
     end
     output_val = f(vals...)
-    @inbounds output.iPtr[row_ix]  = row 
+    @inbounds output.iPtr[row_ix]  = row
     @inbounds output.nzVal[row_ix] = output_val
     return
 end
 
-function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}}
+function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing},
+                                           args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},
+                                                                        CuSparseDeviceMatrixCSC{<:Any,Ti}}}
     # every thread processes an entire row
     leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -423,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
 
     return
 end
-function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}}, f,
+function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti},
+                                                          CuSparseMatrixCSC{Tv, Ti}}}, f,
                                           output::CuDeviceArray, args...) where {Tv, Ti}
     # every thread processes an entire row
     leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -449,7 +460,9 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
 end
 
 function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
-                                          output::CuDeviceArray{Tv}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, F, N, Ti}
+                                          output::CuDeviceArray{Tv},
+                                          offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+                                          args...) where {Tv, F, N, Ti}
     # every thread processes an entire row
     row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
     row_ix > length(output) && return
@@ -468,7 +481,7 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
     return
 end
 ## COV_EXCL_STOP
-const N_VEC_THREADS = 512
+
 function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}})
     # find the sparse inputs
     bc = Broadcast.flatten(bc)
@@ -510,7 +523,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
 
     # the kernels below parallelize across rows or cols, not elements, so it's unlikely
     # we'll launch many threads. to maximize utilization, parallelize across blocks first.
-    rows, cols = sparse_typ <: CuSparseVector ? (length(bc), 1) : size(bc)
+    rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)   # `size(bc, ::Int)` is missing
     function compute_launch_config(kernel)
         config = launch_configuration(kernel.fun)
         if sparse_typ <: CuSparseMatrixCSR
@@ -522,15 +535,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             blocks  = max(cld(cols, threads), config.blocks)
             threads = cld(cols, blocks)
         elseif sparse_typ <: CuSparseVector
-            threads = N_VEC_THREADS
+            threads = 512
             blocks  = max(cld(rows, threads), config.blocks)
-            threads = N_VEC_THREADS
         end
         (; threads, blocks)
     end
     # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
     # but the only rows present in any sparse vector input are between 2 and 128, no need to
-    # launch massive threads. TODO: use the difference here to set the thread count
+    # launch massive threads.
+    # TODO: use the difference here to set the thread count
     overall_first_row = one(Ti)
     overall_last_row = Ti(rows)
     offsets = nothing
@@ -592,10 +605,10 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
             CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1)
         end
         let
-            if sparse_typ <: CuSparseVector
-                args = (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
+            args = if sparse_typ <: CuSparseVector
+                (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
             else
-                args = (sparse_typ, offsets, bc.args...)
+                (sparse_typ, offsets, bc.args...)
             end
             kernel = @cuda launch=false compute_offsets_kernel(args...)
             threads, blocks = compute_launch_config(kernel)
@@ -642,14 +655,13 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
     if output isa AbstractCuSparseArray
         args   = (bc.f, output, offsets, bc.args...)
         kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
-        threads, blocks = compute_launch_config(kernel)
-        kernel(args...; threads, blocks)
     else
-        args   = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : (sparse_typ, bc.f, output, bc.args...)
+        args   = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
+                                                (sparse_typ, bc.f, output, bc.args...)
         kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...)
-        threads, blocks = compute_launch_config(kernel)
-        kernel(args...; threads, blocks)
     end
+    threads, blocks = compute_launch_config(kernel)
+    kernel(args...; threads, blocks)
 
     return output
 end