From 7e1821dfa4bcb10b8f3183b7dccf00c01e7bccf1 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Thu, 3 Apr 2025 15:15:24 -0400 Subject: [PATCH 01/11] Very rough implementation of bcast for CuSparseVector --- lib/cusparse/broadcast.jl | 169 +++++++++++++++++++++++---- lib/cusparse/device.jl | 2 +- test/Project.toml | 1 + test/libraries/cusparse/broadcast.jl | 60 +++++++++- 4 files changed, 205 insertions(+), 27 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index 273b028c2f..ad6b356ace 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -103,7 +103,6 @@ end ## COV_EXCL_START ## iteration helpers - """ CSRIterator{Ti}(row, args...) @@ -295,9 +294,9 @@ function _getindex(arg::Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC}, @inbounds arg.nzVal[ptr] end end +_getindex(arg::CuDeviceArray, I, ptr) = @inbounds arg[I] _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I) - ## sparse broadcast implementation iter_type(::Type{<:CuSparseMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} @@ -305,6 +304,42 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} +function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args...) where Ti + row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + row = row_ix + first_row - 1i32 + + row > last_row && return + + # TODO load arg.iPtr slices into shared memory + row_is_nnz = 0i32 + for arg in args + if arg isa CuSparseDeviceVector + for arg_row in arg.iPtr + if arg_row == row + row_is_nnz = 1i32 + @inbounds offsets[row] = row + end + arg_row > row && break + end + else + @inbounds offsets[row] = row + row_is_nnz = 1i32 + end + sync_warp() + end + # count nnz in the warp + offset = 16 + while (offset >= 1) + row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset) + offset = offset >> 1 + end + if threadIdx().x == 1 + CUDA.@atomic offsets[last_row+1] += row_is_nnz + end + return +end +# TODO: unify CSC/CSR kernels + # kernel to count the number of non-zeros in a row, to determine the row offsets function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti}, args...) where Ti @@ -332,6 +367,30 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri end # broadcast kernels that iterate the elements of sparse arrays +function _getindex(A::CuSparseDeviceVector{Tv}, row, ptr) where {Tv} + for r_ix in 1:A.nnz + arg_row = @inbounds A.iPtr[r_ix] + arg_row == row && return @inbounds A.nzVal[r_ix] + arg_row > row && return zero(Tv) + end + return zero(Tv) +end + +function sparse_to_sparse_broadcast_kernel(f, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti}, + offsets::Union{AbstractVector,Nothing}, + args::Vararg{Any, N}) where {Tv, Ti, N} + row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + row_ix > output.nnz && return + row = @inbounds output.iPtr[row_ix + first_row - 1i32] + vals = ntuple(Val(N)) do i + arg = @inbounds args[i] + _getindex(arg, row, 0)::Tv + end + output_val = f(vals...) + @inbounds output.nzVal[row_ix] = output_val + return +end + function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}} # every thread processes an entire row leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x @@ -392,8 +451,24 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, return end -## COV_EXCL_STOP +function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f, + output::CuDeviceArray{Tv}, args::Vararg{Any, N}) where {Tv, N} + # every thread processes an entire row + row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + row > length(output) && return + + # set the values for this row + vals = ntuple(Val(N)) do i + arg = @inbounds args[i] + _getindex(arg, row, 0)::Tv + end + out_val = f(vals...) + @inbounds output[row] = out_val + return +end +## COV_EXCL_STOP +const N_VEC_THREADS = 512 function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}}) # find the sparse inputs bc = Broadcast.flatten(bc) @@ -405,12 +480,14 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported") end sparse_typ = typeof(bc.args[first(sparse_args)]) - sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC} || - error("broadcast with sparse arrays is currently only implemented for CSR and CSC matrices") + sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector} || + error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices") Ti = if sparse_typ <: CuSparseMatrixCSR reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args)) elseif sparse_typ <: CuSparseMatrixCSC reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args)) + elseif sparse_typ <: CuSparseVector + reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args)) end # determine the output type @@ -433,21 +510,29 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl # the kernels below parallelize across rows or cols, not elements, so it's unlikely # we'll launch many threads. to maximize utilization, parallelize across blocks first. - rows, cols = size(bc) + rows, cols = sparse_typ <: CuSparseVector ? (length(bc), 1) : size(bc) function compute_launch_config(kernel) config = launch_configuration(kernel.fun) if sparse_typ <: CuSparseMatrixCSR threads = min(rows, config.threads) - blocks = max(cld(rows, threads), config.blocks) + blocks = max(cld(rows, threads), config.blocks) threads = cld(rows, blocks) elseif sparse_typ <: CuSparseMatrixCSC threads = min(cols, config.threads) - blocks = max(cld(cols, threads), config.blocks) + blocks = max(cld(cols, threads), config.blocks) threads = cld(cols, blocks) + elseif sparse_typ <: CuSparseVector + threads = N_VEC_THREADS + blocks = max(cld(rows, threads), config.blocks) + threads = N_VEC_THREADS end (; threads, blocks) end - + # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20 + # but the only rows present in any sparse vector input are between 2 and 128, no need to + # launch massive threads. TODO: use the difference here to set the thread count + overall_first_row = one(Ti) + overall_last_row = Ti(rows) # allocate the output container if !fpreszeros # either we have dense inputs, or the function isn't preserving zeros, @@ -472,14 +557,18 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl sparse_arg = bc.args[first(sparse_args)] if sparse_typ <: CuSparseMatrixCSR offsets = rowPtr = sparse_arg.rowPtr - colVal = similar(sparse_arg.colVal) - nzVal = similar(sparse_arg.nzVal, Tv) - output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc)) + colVal = similar(sparse_arg.colVal) + nzVal = similar(sparse_arg.nzVal, Tv) + output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc)) elseif sparse_typ <: CuSparseMatrixCSC offsets = colPtr = sparse_arg.colPtr - rowVal = similar(sparse_arg.rowVal) - nzVal = similar(sparse_arg.nzVal, Tv) - output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc)) + rowVal = similar(sparse_arg.rowVal) + nzVal = similar(sparse_arg.nzVal, Tv) + output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc)) + elseif sparse_typ <: CuSparseVector + offsets = iPtr = sparse_arg.iPtr + nzVal = similar(sparse_arg.nzVal, Tv) + output = CuSparseVector(iPtr, nzVal, length(bc)) end # NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays, # while we do that in our kernel (for consistency with other code paths) @@ -490,9 +579,31 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl CuArray{Ti}(undef, rows+1) elseif sparse_typ <: CuSparseMatrixCSC CuArray{Ti}(undef, cols+1) + elseif sparse_typ <: CuSparseVector + CUDA.@allowscalar begin + arg_first_rows = ntuple(Val(length(bc.args))) do i + bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[1] + return one(Ti) + end + arg_last_rows = ntuple(Val(length(bc.args))) do i + bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[end] + return Ti(rows) + end + end + overall_first_row = min(arg_first_rows...) + overall_last_row = max(arg_last_rows...) + offsets_len = overall_last_row - overall_first_row + # last element is new NNZ + CUDA.zeros(Ti, offsets_len + 1) end + # TODO for sparse vectors, is it worth it to store a "workspace" array of where every row in + # output is in each sparse argument? let - args = (sparse_typ, offsets, bc.args...) + if sparse_typ <: CuSparseVector + args = (sparse_typ, overall_first_row, overall_last_row, offsets, bc.args...) + else + args = (sparse_typ, offsets, bc.args...) + end kernel = @cuda launch=false compute_offsets_kernel(args...) threads, blocks = compute_launch_config(kernel) kernel(args...; threads, blocks) @@ -501,22 +612,36 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl # accumulate these values so that we can use them directly as row pointer offsets, # as well as to get the total nnz count to allocate the sparse output array. # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort - accumulate!(Base.add_sum, offsets, offsets) - total_nnz = @allowscalar last(offsets[end]) - 1 - + if !(sparse_typ <: CuSparseVector) + accumulate!(Base.add_sum, offsets, offsets) + total_nnz = @allowscalar last(offsets[end]) - 1 + else + total_nnz = @allowscalar last(offsets[end]) + end + @assert total_nnz >= 0 output = if sparse_typ <: CuSparseMatrixCSR colVal = CuArray{Ti}(undef, total_nnz) - nzVal = CuArray{Tv}(undef, total_nnz) + nzVal = CuArray{Tv}(undef, total_nnz) CuSparseMatrixCSR(offsets, colVal, nzVal, size(bc)) elseif sparse_typ <: CuSparseMatrixCSC rowVal = CuArray{Ti}(undef, total_nnz) - nzVal = CuArray{Tv}(undef, total_nnz) + nzVal = CuArray{Tv}(undef, total_nnz) CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc)) + elseif sparse_typ <: CuSparseVector + nzVal = CuArray{Tv}(undef, total_nnz) + iPtr = CuArray{Ti}(undef, total_nnz) + copyto!(iPtr, 1, offsets, 1, total_nnz) + CuSparseVector(iPtr, nzVal, rows) end end # perform the actual broadcast - if output isa AbstractCuSparseArray + if output isa CuSparseVector + args = (bc.f, overall_first_row, overall_last_row, output, offsets, bc.args...) + kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) + threads, blocks = compute_launch_config(kernel) + kernel(args...; threads, blocks) + elseif output isa AbstractCuSparseArray args = (bc.f, output, offsets, bc.args...) kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) threads, blocks = compute_launch_config(kernel) diff --git a/lib/cusparse/device.jl b/lib/cusparse/device.jl index b81fb57a65..4e2925c3e1 100644 --- a/lib/cusparse/device.jl +++ b/lib/cusparse/device.jl @@ -19,7 +19,7 @@ struct CuSparseDeviceVector{Tv,Ti, A} <: AbstractSparseVector{Tv,Ti} nnz::Ti end -Base.length(g::CuSparseDeviceVector) = prod(g.dims) +Base.length(g::CuSparseDeviceVector) = g.len Base.size(g::CuSparseDeviceVector) = (g.len,) SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz diff --git a/test/Project.toml b/test/Project.toml index 5d6ea83e88..810a96e1cf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl index 37fef9ff2a..e73f40bc4b 100644 --- a/test/libraries/cusparse/broadcast.jl +++ b/test/libraries/cusparse/broadcast.jl @@ -1,10 +1,9 @@ using CUDA.CUSPARSE, SparseArrays -m,n = 5,6 -p = 0.5 - for elty in [Int32, Int64, Float32, Float64] - @testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC] + @testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC] + m,n = 5,6 + p = 0.5 x = sprand(elty, m, n, p) dx = typ(x) @@ -34,6 +33,59 @@ for elty in [Int32, Int64, Float32, Float64] @test dz isa typ{elty} @test z == SparseMatrixCSC(dz) end + @testset "$typ($elty)" for typ in [CuSparseVector,] + m = 64 + p = 0.5 + x = sprand(elty, m, p) + dx = typ(x) + + # zero-preserving + y = x .* elty(1) + dy = dx .* elty(1) + @test dy isa typ{elty} + @test y == SparseVector(dy) + + # not zero-preserving + y = x .+ elty(1) + dy = dx .+ elty(1) + @test dy isa CuArray{elty} + @test y == Array(dy) + + # involving something dense - broken for now + #=y = x .+ ones(elty, m) + dy = dx .+ CUDA.ones(elty, m) + @test dy isa CuArray{elty} + @test y == Array(dy)=# + + # sparse to sparse + y = sprand(elty, m, p) + dy = typ(y) + dx = typ(x) + z = x .* y + dz = dx .* dy + @test dz isa typ{elty} + @test z == SparseVector(dz) + + # multiple inputs + #=y = sprand(elty, m, p) + w = sprand(elty, m, p) + dy = typ(y) + dx = typ(x) + dw = typ(w) + z = @. x * y * w + dz = @. dx * dy * w + @test dz isa typ{elty} + @test z == SparseVector(dz)=# + + # broken due to llvm IR + #=y = sprand(elty, m, p) + dy = typ(y) + dx = typ(x) + z = x .* y .* elty(2) + dz = dx .* dy .* elty(2) + @test dz isa typ{elty} + @test z == SparseVector(dz)=# + end end @testset "bug: type conversions" begin From 5fb2163f8deeba4e8488c1e463a068041a3ab11e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Sat, 12 Apr 2025 15:01:01 -0400 Subject: [PATCH 02/11] Uncomment broken tests --- test/libraries/cusparse/broadcast.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl index e73f40bc4b..5e184d4404 100644 --- a/test/libraries/cusparse/broadcast.jl +++ b/test/libraries/cusparse/broadcast.jl @@ -52,10 +52,10 @@ for elty in [Int32, Int64, Float32, Float64] @test y == Array(dy) # involving something dense - broken for now - #=y = x .+ ones(elty, m) + y = x .+ ones(elty, m) dy = dx .+ CUDA.ones(elty, m) @test dy isa CuArray{elty} - @test y == Array(dy)=# + @test y == Array(dy) # sparse to sparse y = sprand(elty, m, p) @@ -78,13 +78,13 @@ for elty in [Int32, Int64, Float32, Float64] @test z == SparseVector(dz)=# # broken due to llvm IR - #=y = sprand(elty, m, p) + y = sprand(elty, m, p) dy = typ(y) dx = typ(x) z = x .* y .* elty(2) dz = dx .* dy .* elty(2) @test dz isa typ{elty} - @test z == SparseVector(dz)=# + @test z == SparseVector(dz) end end From d9d9f491d10c33b14024962b64ebe8474b83c913 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Mon, 14 Apr 2025 11:54:43 -0400 Subject: [PATCH 03/11] Use Vararg for compute_offsets_kernel --- lib/cusparse/broadcast.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index ad6b356ace..0b328393c8 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -304,7 +304,7 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} -function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args...) where Ti +function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args::Vararg{Any, N}) where {Ti, N} row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row = row_ix + first_row - 1i32 @@ -312,7 +312,8 @@ function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_ro # TODO load arg.iPtr slices into shared memory row_is_nnz = 0i32 - for arg in args + for i in 1:N + arg = @inbounds args[i] if arg isa CuSparseDeviceVector for arg_row in arg.iPtr if arg_row == row From dd799a6d15fa052d0c22c703376520b2782ce76a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Mon, 14 Apr 2025 13:28:25 -0400 Subject: [PATCH 04/11] Uncomment more tests and add a @. for matrix bcast --- test/libraries/cusparse/broadcast.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl index 5e184d4404..0918adc5cc 100644 --- a/test/libraries/cusparse/broadcast.jl +++ b/test/libraries/cusparse/broadcast.jl @@ -32,6 +32,14 @@ for elty in [Int32, Int64, Float32, Float64] dz = dx .* dy .* elty(2) @test dz isa typ{elty} @test z == SparseMatrixCSC(dz) + + # multiple inputs + w = sprand(elty, m, n, p) + dw = typ(w) + z = @. x * y * w + dz = @. dx * dy * w + @test dz isa typ{elty} + @test z == SparseMatrixCSC(dz) end @testset "$typ($elty)" for typ in [CuSparseVector,] m = 64 @@ -67,7 +75,7 @@ for elty in [Int32, Int64, Float32, Float64] @test z == SparseVector(dz) # multiple inputs - #=y = sprand(elty, m, p) + y = sprand(elty, m, p) w = sprand(elty, m, p) dy = typ(y) dx = typ(x) @@ -75,7 +83,7 @@ for elty in [Int32, Int64, Float32, Float64] z = @. x * y * w dz = @. dx * dy * w @test dz isa typ{elty} - @test z == SparseVector(dz)=# + @test z == SparseVector(dz) # broken due to llvm IR y = sprand(elty, m, p) From 19a50c28cce4910e9c9eac1561611d58b1a02539 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <kshyatt@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:01:02 -0400 Subject: [PATCH 05/11] Update lib/cusparse/broadcast.jl Co-authored-by: Valentin Churavy <v.churavy@gmail.com> --- lib/cusparse/broadcast.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index 0b328393c8..d919c448de 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -377,9 +377,9 @@ function _getindex(A::CuSparseDeviceVector{Tv}, row, ptr) where {Tv} return zero(Tv) end -function sparse_to_sparse_broadcast_kernel(f, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti}, +function sparse_to_sparse_broadcast_kernel(f::F, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti}, offsets::Union{AbstractVector,Nothing}, - args::Vararg{Any, N}) where {Tv, Ti, N} + args::Vararg{<:Any, N}) where {Tv, Ti, N, F} row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row_ix > output.nnz && return row = @inbounds output.iPtr[row_ix + first_row - 1i32] From 3af6ed66271725235218178b5e0d7c40d6ed5b7f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Mon, 14 Apr 2025 16:02:05 -0400 Subject: [PATCH 06/11] Fix test --- test/libraries/cusparse/broadcast.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl index 0918adc5cc..1709a2016d 100644 --- a/test/libraries/cusparse/broadcast.jl +++ b/test/libraries/cusparse/broadcast.jl @@ -36,8 +36,8 @@ for elty in [Int32, Int64, Float32, Float64] # multiple inputs w = sprand(elty, m, n, p) dw = typ(w) - z = @. x * y * w - dz = @. dx * dy * w + z = x .* y .* w + dz = dx .* dy .* dw @test dz isa typ{elty} @test z == SparseMatrixCSC(dz) end @@ -81,7 +81,7 @@ for elty in [Int32, Int64, Float32, Float64] dx = typ(x) dw = typ(w) z = @. x * y * w - dz = @. dx * dy * w + dz = @. dx * dy * dw @test dz isa typ{elty} @test z == SparseVector(dz) From 80165324a6618e5ff98026e7805f9c247a25b462 Mon Sep 17 00:00:00 2001 From: Tim Besard <tim.besard@gmail.com> Date: Wed, 16 Apr 2025 12:06:36 +0200 Subject: [PATCH 07/11] Work around bad codegen. --- lib/cusparse/broadcast.jl | 88 ++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 39 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index d919c448de..31d221034e 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -1,5 +1,7 @@ using Base.Broadcast: Broadcasted +using Base.Cartesian: @nany + using CUDA: CuArrayStyle # TODO: support more types (SparseVector, SparseMatrixCSC, COO, BSR) @@ -304,40 +306,47 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} -function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, offsets::AbstractVector{Ti}, args::Vararg{Any, N}) where {Ti, N} - row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - row = row_ix + first_row - 1i32 - - row > last_row && return +@eval @generated function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, + last_row::Ti, offsets::AbstractVector{Ti}, + args...) where {Ti} + N = length(args) + quote + row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + row = row_ix + first_row - 1i32 - # TODO load arg.iPtr slices into shared memory - row_is_nnz = 0i32 - for i in 1:N - arg = @inbounds args[i] - if arg isa CuSparseDeviceVector - for arg_row in arg.iPtr - if arg_row == row - row_is_nnz = 1i32 - @inbounds offsets[row] = row + row > last_row && return + + # TODO load arg.iPtr slices into shared memory + row_is_nnz = @nany $N i -> begin + is_nnz = false + arg = @inbounds args[i] + if arg isa CuSparseDeviceVector + for arg_row in arg.iPtr + if arg_row == row + is_nnz = true + @inbounds offsets[row] = row + end + arg_row > row && break end - arg_row > row && break + else + @inbounds offsets[row] = row + is_nnz = true end - else - @inbounds offsets[row] = row - row_is_nnz = 1i32 + is_nnz end sync_warp() + + # count nnz in the warp + offset = 16 + while (offset >= 1) + row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset) + offset = offset >> 1 + end + if threadIdx().x == 1 + CUDA.@atomic offsets[last_row+1] += row_is_nnz + end + return end - # count nnz in the warp - offset = 16 - while (offset >= 1) - row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset) - offset = offset >> 1 - end - if threadIdx().x == 1 - CUDA.@atomic offsets[last_row+1] += row_is_nnz - end - return end # TODO: unify CSC/CSR kernels @@ -379,7 +388,7 @@ end function sparse_to_sparse_broadcast_kernel(f::F, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti}, offsets::Union{AbstractVector,Nothing}, - args::Vararg{<:Any, N}) where {Tv, Ti, N, F} + args::Vararg{Any, N}) where {Tv, Ti, N, F} row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row_ix > output.nnz && return row = @inbounds output.iPtr[row_ix + first_row - 1i32] @@ -405,7 +414,7 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract # fetch the row offset, and write it to the output @inbounds begin output_ptr = output_ptrs[leading_dim] = offsets[leading_dim] - if leading_dim == leading_dim_size + if leading_dim == leading_dim_size output_ptrs[leading_dim+1i32] = offsets[leading_dim+1i32] end end @@ -454,18 +463,19 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, end function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f, - output::CuDeviceArray{Tv}, args::Vararg{Any, N}) where {Tv, N} + output::CuDeviceArray{Tv}, args...) where {Tv} # every thread processes an entire row row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row > length(output) && return # set the values for this row - vals = ntuple(Val(N)) do i - arg = @inbounds args[i] - _getindex(arg, row, 0)::Tv + vals = ntuple(Val(length(args))) do i + @inline # XXX: works around bad codegen + arg = @inbounds args[i] + _getindex(arg, row, 0) end - out_val = f(vals...) - @inbounds output[row] = out_val + out_val = f(vals...) + @inbounds output[row] = out_val return end ## COV_EXCL_STOP @@ -532,7 +542,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20 # but the only rows present in any sparse vector input are between 2 and 128, no need to # launch massive threads. TODO: use the difference here to set the thread count - overall_first_row = one(Ti) + overall_first_row = one(Ti) overall_last_row = Ti(rows) # allocate the output container if !fpreszeros @@ -593,7 +603,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl end overall_first_row = min(arg_first_rows...) overall_last_row = max(arg_last_rows...) - offsets_len = overall_last_row - overall_first_row + offsets_len = overall_last_row - overall_first_row # last element is new NNZ CUDA.zeros(Ti, offsets_len + 1) end @@ -637,7 +647,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl end # perform the actual broadcast - if output isa CuSparseVector + if output isa CuSparseVector args = (bc.f, overall_first_row, overall_last_row, output, offsets, bc.args...) kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) threads, blocks = compute_launch_config(kernel) From 9bc0f9819bb4ae0b8eb686a5da9cf16e1033e79e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Mon, 14 Apr 2025 16:02:05 -0400 Subject: [PATCH 08/11] Fix test Get tests passing with a hand unroll Use offsets as intermediate ptr storage Don't use atmoic --- lib/cusparse/broadcast.jl | 177 ++++++++++++--------------- test/libraries/cusparse/broadcast.jl | 38 ++++-- 2 files changed, 107 insertions(+), 108 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index 31d221034e..b7bdfb1fdc 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -289,15 +289,18 @@ end end # helpers to index a sparse or dense array -function _getindex(arg::Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC}, I, ptr) +@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},CuSparseDeviceMatrixCSC{Tv},CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv} if ptr == 0 - zero(eltype(arg)) + return zero(Tv) else - @inbounds arg.nzVal[ptr] + return @inbounds arg.nzVal[ptr]::Tv end end -_getindex(arg::CuDeviceArray, I, ptr) = @inbounds arg[I] -_getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I) + +@inline function _getindex(arg::CuDeviceArray{Tv}, I, ptr)::Tv where {Tv} + return @inbounds arg[I]::Tv +end +@inline _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I) ## sparse broadcast implementation @@ -306,49 +309,40 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} -@eval @generated function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, - last_row::Ti, offsets::AbstractVector{Ti}, - args...) where {Ti} - N = length(args) - quote - row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - row = row_ix + first_row - 1i32 +_has_row(A, offsets, row::Int32, fpreszeros::Bool) = fpreszeros ? 0i32 : row +_has_row(A::CuDeviceArray, offsets, row::Int32, ::Bool) = row +function _has_row(A::CuSparseDeviceVector, offsets, row::Int32, ::Bool)::Int32 + for row_ix in 1i32:length(A.iPtr) + arg_row = @inbounds A.iPtr[row_ix] + arg_row == row && return row_ix + arg_row > row && break + end + return 0i32 +end - row > last_row && return +function _get_my_row(first_row)::Int32 + row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + return row_ix + first_row - 1i32 +end - # TODO load arg.iPtr slices into shared memory - row_is_nnz = @nany $N i -> begin - is_nnz = false - arg = @inbounds args[i] - if arg isa CuSparseDeviceVector - for arg_row in arg.iPtr - if arg_row == row - is_nnz = true - @inbounds offsets[row] = row - end - arg_row > row && break - end - else - @inbounds offsets[row] = row - is_nnz = true - end - is_nnz - end - sync_warp() +function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Ti, N} + row = _get_my_row(first_row) + row > last_row && return - # count nnz in the warp - offset = 16 - while (offset >= 1) - row_is_nnz = shfl_down_sync(FULL_MASK, row_is_nnz, offset) - offset = offset >> 1 - end - if threadIdx().x == 1 - CUDA.@atomic offsets[last_row+1] += row_is_nnz - end - return + # TODO load arg.iPtr slices into shared memory + row_is_nnz = 0i32 + arg_row_is_nnz = ntuple(Val(N)) do i + arg = @inbounds args[i] + _has_row(arg, offsets, row, fpreszeros)::Int32 + end + row_is_nnz = 0i32 + for i in 1:N + row_is_nnz |= @inbounds arg_row_is_nnz[i] end + key = (row_is_nnz == 0i32) ? typemax(Ti) : row + @inbounds offsets[row - first_row + 1i32] = key => arg_row_is_nnz + return end -# TODO: unify CSC/CSR kernels # kernel to count the number of non-zeros in a row, to determine the row offsets function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti}, @@ -376,27 +370,19 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri return end -# broadcast kernels that iterate the elements of sparse arrays -function _getindex(A::CuSparseDeviceVector{Tv}, row, ptr) where {Tv} - for r_ix in 1:A.nnz - arg_row = @inbounds A.iPtr[r_ix] - arg_row == row && return @inbounds A.nzVal[r_ix] - arg_row > row && return zero(Tv) - end - return zero(Tv) -end - -function sparse_to_sparse_broadcast_kernel(f::F, first_row::Ti, last_row::Ti, output::CuSparseDeviceVector{Tv,Ti}, - offsets::Union{AbstractVector,Nothing}, - args::Vararg{Any, N}) where {Tv, Ti, N, F} +function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, Ti, N, F} row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row_ix > output.nnz && return - row = @inbounds output.iPtr[row_ix + first_row - 1i32] - vals = ntuple(Val(N)) do i + row_and_ptrs = @inbounds offsets[row_ix] + row = @inbounds row_and_ptrs[1] + args_are_nnz = @inbounds row_and_ptrs[2] + vals = ntuple(Val(N)) do i arg = @inbounds args[i] - _getindex(arg, row, 0)::Tv + arg_is_nnz = @inbounds args_are_nnz[i] + _getindex(arg, row, arg_is_nnz)::Tv end output_val = f(vals...) + @inbounds output.iPtr[row_ix] = row @inbounds output.nzVal[row_ix] = output_val return end @@ -462,20 +448,21 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, return end -function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f, - output::CuDeviceArray{Tv}, args...) where {Tv} +function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F, + output::CuDeviceArray{Tv}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, F, N, Ti} # every thread processes an entire row - row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x - row > length(output) && return - - # set the values for this row + row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + row_ix > length(output) && return + row_and_ptrs = @inbounds offsets[row_ix] + row = @inbounds row_and_ptrs[1] + args_are_nnz = @inbounds row_and_ptrs[2] vals = ntuple(Val(length(args))) do i - @inline # XXX: works around bad codegen arg = @inbounds args[i] - _getindex(arg, row, 0) + arg_is_nnz = @inbounds args_are_nnz[i] + _getindex(arg, row, arg_is_nnz)::Tv end out_val = f(vals...) - @inbounds output[row] = out_val + @inbounds output[row] = out_val return end ## COV_EXCL_STOP @@ -544,8 +531,9 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl # launch massive threads. TODO: use the difference here to set the thread count overall_first_row = one(Ti) overall_last_row = Ti(rows) + offsets = nothing # allocate the output container - if !fpreszeros + if !fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC} # either we have dense inputs, or the function isn't preserving zeros, # so use a dense output to broadcast into. output = CuArray{Tv}(undef, size(bc)) @@ -562,7 +550,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl end end broadcast!(bc.f, output, nonsparse_args...) - elseif length(sparse_args) == 1 + elseif length(sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC} # we only have a single sparse input, so we can reuse its structure for the output. # this avoids a kernel launch and costly synchronization. sparse_arg = bc.args[first(sparse_args)] @@ -576,10 +564,6 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl rowVal = similar(sparse_arg.rowVal) nzVal = similar(sparse_arg.nzVal, Tv) output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc)) - elseif sparse_typ <: CuSparseVector - offsets = iPtr = sparse_arg.iPtr - nzVal = similar(sparse_arg.nzVal, Tv) - output = CuSparseVector(iPtr, nzVal, length(bc)) end # NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays, # while we do that in our kernel (for consistency with other code paths) @@ -603,15 +587,11 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl end overall_first_row = min(arg_first_rows...) overall_last_row = max(arg_last_rows...) - offsets_len = overall_last_row - overall_first_row - # last element is new NNZ - CUDA.zeros(Ti, offsets_len + 1) + CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1) end - # TODO for sparse vectors, is it worth it to store a "workspace" array of where every row in - # output is in each sparse argument? let if sparse_typ <: CuSparseVector - args = (sparse_typ, overall_first_row, overall_last_row, offsets, bc.args...) + args = (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...) else args = (sparse_typ, offsets, bc.args...) end @@ -619,7 +599,6 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl threads, blocks = compute_launch_config(kernel) kernel(args...; threads, blocks) end - # accumulate these values so that we can use them directly as row pointer offsets, # as well as to get the total nnz count to allocate the sparse output array. # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort @@ -627,9 +606,9 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl accumulate!(Base.add_sum, offsets, offsets) total_nnz = @allowscalar last(offsets[end]) - 1 else - total_nnz = @allowscalar last(offsets[end]) + sort!(offsets; by=first) + total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets) end - @assert total_nnz >= 0 output = if sparse_typ <: CuSparseMatrixCSR colVal = CuArray{Ti}(undef, total_nnz) nzVal = CuArray{Tv}(undef, total_nnz) @@ -638,27 +617,33 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl rowVal = CuArray{Ti}(undef, total_nnz) nzVal = CuArray{Tv}(undef, total_nnz) CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc)) - elseif sparse_typ <: CuSparseVector - nzVal = CuArray{Tv}(undef, total_nnz) - iPtr = CuArray{Ti}(undef, total_nnz) - copyto!(iPtr, 1, offsets, 1, total_nnz) + elseif sparse_typ <: CuSparseVector && !fpreszeros + CuArray{Tv}(undef, size(bc)) + elseif sparse_typ <: CuSparseVector && fpreszeros + iPtr = CUDA.zeros(Ti, total_nnz) + nzVal = CUDA.zeros(Tv, total_nnz) CuSparseVector(iPtr, nzVal, rows) end + if sparse_typ <: CuSparseVector && !fpreszeros + nonsparse_args = map(bc.args) do arg + # NOTE: this assumes the broadcst is flattened, but not yet preprocessed + if arg isa AbstractCuSparseArray + zero(eltype(arg)) + else + arg + end + end + broadcast!(bc.f, output, nonsparse_args...) + end end - # perform the actual broadcast - if output isa CuSparseVector - args = (bc.f, overall_first_row, overall_last_row, output, offsets, bc.args...) - kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) - threads, blocks = compute_launch_config(kernel) - kernel(args...; threads, blocks) - elseif output isa AbstractCuSparseArray - args = (bc.f, output, offsets, bc.args...) + if output isa AbstractCuSparseArray + args = (bc.f, output, offsets, bc.args...) kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) threads, blocks = compute_launch_config(kernel) kernel(args...; threads, blocks) else - args = (sparse_typ, bc.f, output, bc.args...) + args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : (sparse_typ, bc.f, output, bc.args...) kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...) threads, blocks = compute_launch_config(kernel) kernel(args...; threads, blocks) diff --git a/test/libraries/cusparse/broadcast.jl b/test/libraries/cusparse/broadcast.jl index 1709a2016d..b1a89437b5 100644 --- a/test/libraries/cusparse/broadcast.jl +++ b/test/libraries/cusparse/broadcast.jl @@ -1,6 +1,6 @@ using CUDA.CUSPARSE, SparseArrays -for elty in [Int32, Int64, Float32, Float64] +@testset for elty in [Int32, Int64, Float32, Float64] @testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC] m,n = 5,6 p = 0.5 @@ -46,29 +46,33 @@ for elty in [Int32, Int64, Float32, Float64] p = 0.5 x = sprand(elty, m, p) dx = typ(x) - + # zero-preserving - y = x .* elty(1) + y = x .* elty(1) dy = dx .* elty(1) @test dy isa typ{elty} - @test y == SparseVector(dy) - + @test collect(dy.iPtr) == collect(dx.iPtr) + @test collect(dy.iPtr) == y.nzind + @test collect(dy.nzVal) == y.nzval + @test y == SparseVector(dy) + # not zero-preserving y = x .+ elty(1) dy = dx .+ elty(1) @test dy isa CuArray{elty} - @test y == Array(dy) + hy = Array(dy) + @test Array(y) == hy - # involving something dense - broken for now + # involving something dense y = x .+ ones(elty, m) dy = dx .+ CUDA.ones(elty, m) @test dy isa CuArray{elty} @test y == Array(dy) - + # sparse to sparse - y = sprand(elty, m, p) - dy = typ(y) dx = typ(x) + y = sprand(elty, m, p) + dy = typ(y) z = x .* y dz = dx .* dy @test dz isa typ{elty} @@ -84,8 +88,18 @@ for elty in [Int32, Int64, Float32, Float64] dz = @. dx * dy * dw @test dz isa typ{elty} @test z == SparseVector(dz) - - # broken due to llvm IR + + y = sprand(elty, m, p) + w = sprand(elty, m, p) + dense_arr = rand(elty, m) + d_dense_arr = CuArray(dense_arr) + dy = typ(y) + dw = typ(w) + z = @. x * y * w * dense_arr + dz = @. dx * dy * dw * d_dense_arr + @test dz isa CuArray{elty} + @test z == Array(dz) + y = sprand(elty, m, p) dy = typ(y) dx = typ(x) From 2de968bf03878e67c1edd96680abc91124c3ba6e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Fri, 18 Apr 2025 11:31:22 -0400 Subject: [PATCH 09/11] Cleanup --- lib/cusparse/broadcast.jl | 20 ++++++++++++-------- test/Project.toml | 1 - 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index b7bdfb1fdc..98cfd14e29 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -375,11 +375,13 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv row_ix > output.nnz && return row_and_ptrs = @inbounds offsets[row_ix] row = @inbounds row_and_ptrs[1] - args_are_nnz = @inbounds row_and_ptrs[2] + arg_ptrs = @inbounds row_and_ptrs[2] vals = ntuple(Val(N)) do i arg = @inbounds args[i] - arg_is_nnz = @inbounds args_are_nnz[i] - _getindex(arg, row, arg_is_nnz)::Tv + # ptr is 0 if the sparse vector doesn't have an element at this row + # ptr is 0 if the arg is a scalar AND f preserves zeros + ptr = @inbounds arg_ptrs[i] + _getindex(arg, row, ptr)::Tv end output_val = f(vals...) @inbounds output.iPtr[row_ix] = row @@ -455,14 +457,16 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F, row_ix > length(output) && return row_and_ptrs = @inbounds offsets[row_ix] row = @inbounds row_and_ptrs[1] - args_are_nnz = @inbounds row_and_ptrs[2] + arg_ptrs = @inbounds row_and_ptrs[2] vals = ntuple(Val(length(args))) do i arg = @inbounds args[i] - arg_is_nnz = @inbounds args_are_nnz[i] - _getindex(arg, row, arg_is_nnz)::Tv + # ptr is 0 if the sparse vector doesn't have an element at this row + # ptr is row if the arg is dense OR a scalar with non-zero-preserving f + # ptr is 0 if the arg is a scalar AND f preserves zeros + ptr = @inbounds arg_ptrs[i] + _getindex(arg, row, ptr)::Tv end - out_val = f(vals...) - @inbounds output[row] = out_val + @inbounds output[row] = f(vals...) return end ## COV_EXCL_STOP diff --git a/test/Project.toml b/test/Project.toml index 810a96e1cf..5d6ea83e88 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,7 +5,6 @@ BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" From e2f3d4b1d2a3a2960fa28bb1b9ea3b446501211b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <khyatt@flatironinstitute.org> Date: Mon, 21 Apr 2025 14:18:37 -0400 Subject: [PATCH 10/11] Remove unneeded using --- lib/cusparse/broadcast.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index 98cfd14e29..c52068c5a6 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -1,7 +1,5 @@ using Base.Broadcast: Broadcasted -using Base.Cartesian: @nany - using CUDA: CuArrayStyle # TODO: support more types (SparseVector, SparseMatrixCSC, COO, BSR) From 43dc26ffad87c573df95f5af64889c4087a0d531 Mon Sep 17 00:00:00 2001 From: Tim Besard <tim.besard@gmail.com> Date: Thu, 24 Apr 2025 09:17:31 +0200 Subject: [PATCH 11/11] Clean-ups. --- lib/cusparse/broadcast.jl | 56 ++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/lib/cusparse/broadcast.jl b/lib/cusparse/broadcast.jl index c52068c5a6..965c0ce478 100644 --- a/lib/cusparse/broadcast.jl +++ b/lib/cusparse/broadcast.jl @@ -84,7 +84,8 @@ end end end end -@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner) +@inline function _capturescalars(arg) + # this definition is just an optimization (to bottom out the recursion slightly sooner) if scalararg(arg) return (), () -> (arg,) # add scalararg elseif scalarwrappedarg(arg) @@ -287,7 +288,9 @@ end end # helpers to index a sparse or dense array -@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},CuSparseDeviceMatrixCSC{Tv},CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv} +@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv}, + CuSparseDeviceMatrixCSC{Tv}, + CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv} if ptr == 0 return zero(Tv) else @@ -323,7 +326,9 @@ function _get_my_row(first_row)::Int32 return row_ix + first_row - 1i32 end -function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Ti, N} +function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, + fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, + args...) where {Ti, N} row = _get_my_row(first_row) row > last_row && return @@ -343,7 +348,8 @@ function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_ro end # kernel to count the number of non-zeros in a row, to determine the row offsets -function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti}, +function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, + offsets::AbstractVector{Ti}, args...) where Ti # every thread processes an entire row leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x @@ -368,7 +374,9 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri return end -function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, Ti, N, F} +function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, + offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, + args...) where {Tv, Ti, N, F} row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row_ix > output.nnz && return row_and_ptrs = @inbounds offsets[row_ix] @@ -382,12 +390,14 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv _getindex(arg, row, ptr)::Tv end output_val = f(vals...) - @inbounds output.iPtr[row_ix] = row + @inbounds output.iPtr[row_ix] = row @inbounds output.nzVal[row_ix] = output_val return end -function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}} +function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, + args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti}, + CuSparseDeviceMatrixCSC{<:Any,Ti}}} # every thread processes an entire row leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2) @@ -423,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract return end -function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}}, f, +function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, + CuSparseMatrixCSC{Tv, Ti}}}, f, output::CuDeviceArray, args...) where {Tv, Ti} # every thread processes an entire row leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x @@ -449,7 +460,9 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, end function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F, - output::CuDeviceArray{Tv}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, F, N, Ti} + output::CuDeviceArray{Tv}, + offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, + args...) where {Tv, F, N, Ti} # every thread processes an entire row row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x row_ix > length(output) && return @@ -468,7 +481,7 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F, return end ## COV_EXCL_STOP -const N_VEC_THREADS = 512 + function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}}) # find the sparse inputs bc = Broadcast.flatten(bc) @@ -510,7 +523,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl # the kernels below parallelize across rows or cols, not elements, so it's unlikely # we'll launch many threads. to maximize utilization, parallelize across blocks first. - rows, cols = sparse_typ <: CuSparseVector ? (length(bc), 1) : size(bc) + rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) # `size(bc, ::Int)` is missing function compute_launch_config(kernel) config = launch_configuration(kernel.fun) if sparse_typ <: CuSparseMatrixCSR @@ -522,15 +535,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl blocks = max(cld(cols, threads), config.blocks) threads = cld(cols, blocks) elseif sparse_typ <: CuSparseVector - threads = N_VEC_THREADS + threads = 512 blocks = max(cld(rows, threads), config.blocks) - threads = N_VEC_THREADS end (; threads, blocks) end # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20 # but the only rows present in any sparse vector input are between 2 and 128, no need to - # launch massive threads. TODO: use the difference here to set the thread count + # launch massive threads. + # TODO: use the difference here to set the thread count overall_first_row = one(Ti) overall_last_row = Ti(rows) offsets = nothing @@ -592,10 +605,10 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1) end let - if sparse_typ <: CuSparseVector - args = (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...) + args = if sparse_typ <: CuSparseVector + (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...) else - args = (sparse_typ, offsets, bc.args...) + (sparse_typ, offsets, bc.args...) end kernel = @cuda launch=false compute_offsets_kernel(args...) threads, blocks = compute_launch_config(kernel) @@ -642,14 +655,13 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl if output isa AbstractCuSparseArray args = (bc.f, output, offsets, bc.args...) kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...) - threads, blocks = compute_launch_config(kernel) - kernel(args...; threads, blocks) else - args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : (sparse_typ, bc.f, output, bc.args...) + args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : + (sparse_typ, bc.f, output, bc.args...) kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...) - threads, blocks = compute_launch_config(kernel) - kernel(args...; threads, blocks) end + threads, blocks = compute_launch_config(kernel) + kernel(args...; threads, blocks) return output end