Skip to content

Commit

Permalink
interface changes in CUDA code
Browse files Browse the repository at this point in the history
  • Loading branch information
trahflow committed Mar 7, 2024
1 parent e6279d4 commit ba5c517
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 92 deletions.
98 changes: 56 additions & 42 deletions ext/DiffPointRasterisationCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,21 @@ using StaticArrays


function raster_pullback_kernel!(
::Val{N_in},
ds_dout::AbstractArray{T, N_out_p1},
points,
rotations,
translations,
::Type{T},
ds_dout,
points::AbstractVector{<:StaticVector{N_in}},
rotations::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
translations::AbstractVector{<:StaticVector{N_out, TT}},
weights,
shifts,
scale,
projection_idxs,
# outputs:
ds_dpoints,
ds_dprojection_rotation,
ds_drotation,
ds_dtranslation,
ds_dweight,

) where {T, N_in, N_out_p1}
N_out = N_out_p1 - 1 # dimensionality of output, without batch dimension

) where {T, TR, TT, N_in, N_out}
n_voxel = blockDim().z
points_per_workgroup = blockDim().x
batchsize_per_workgroup = blockDim().y
Expand All @@ -46,25 +43,24 @@ function raster_pullback_kernel!(
neighbor_voxel_id = (blockIdx().z - 1) * n_voxel + s
point_idx = (blockIdx().x - 1) * points_per_workgroup + threadIdx().x
batch_idx = (blockIdx().y - 1) * batchsize_per_workgroup + b
in_batch = batch_idx <= size(rotations, 3)
in_batch = batch_idx <= length(rotations)

dimension = (N_out, n_voxel, batchsize_per_workgroup)
ds_dpoint_rot = CuDynamicSharedArray(T, dimension)
ds_dpoint_local = CuDynamicSharedArray(T, (N_in, batchsize_per_workgroup), sizeof(T) * prod(dimension))

rotation = in_batch ? @inbounds(SMatrix{N_in, N_in, T}(@view rotations[:, :, batch_idx])) : @SMatrix zeros(T, N_in, N_in)
point = @inbounds SVector{N_in, T}(@view points[:, point_idx])
rotation = @inbounds in_batch ? rotations[batch_idx] : @SMatrix zeros(TR, N_in, N_in)
point = @inbounds points[point_idx]

if in_batch
translation = @inbounds SVector{N_out, T}(@view translations[:, batch_idx])
translation = @inbounds translations[batch_idx]
weight = @inbounds weights[batch_idx]
shift = @inbounds shifts[neighbor_voxel_id]
origin = (-@SVector ones(T, N_out)) - translation
origin = (-@SVector ones(TT, N_out)) - translation

coord_reference_voxel, deltas = DiffPointRasterisation.reference_coordinate_and_deltas(
point,
rotation,
projection_idxs,
origin,
scale,
)
Expand All @@ -76,12 +72,11 @@ function raster_pullback_kernel!(
@inbounds ds_dweight_local = DiffPointRasterisation.voxel_weight(
deltas,
shift,
projection_idxs,
ds_dout[voxel_idx],
)

factor = ds_dout[voxel_idx] * weight
ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), N_out))
ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), Val(N_out)))
@inbounds ds_dpoint_rot[:, s, b] .= ds_dcoord_part .* scale
else
@inbounds ds_dpoint_rot[:, s, b] .= zero(T)
Expand Down Expand Up @@ -123,7 +118,7 @@ function raster_pullback_kernel!(
j = 1
while j <= N_in
val = coef * point[j]
@inbounds CUDA.@atomic ds_dprojection_rotation[dim, j, batch_idx] += val
@inbounds CUDA.@atomic ds_drotation[dim, j, batch_idx] += val
j += 1
end
end
Expand Down Expand Up @@ -175,37 +170,53 @@ function raster_pullback_kernel!(
nothing
end


# single image
raster_pullback!(
ds_dout::CuArray{<:Number, N_out},
points::AbstractVector{<:StaticVector{N_in, <:Number}},
rotation::StaticMatrix{N_out, N_in, <:Number},
translation::StaticVector{N_out, <:Number},
background::Number,
weight::Number,
ds_dpoints::AbstractMatrix{<:Number};
kwargs...
) where {N_in, N_out} = error("Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays")

# batch of images
function DiffPointRasterisation.raster_pullback!(
::Val{N_in},
ds_dout::CuArray{T, N_out_p1},
points::CuMatrix{<:Number},
rotation::CuArray{<:Number, 3},
translation::CuMatrix{<:Number},
# TODO: for some reason type inference fails if the following
# two arrays are FillArrays...
background::CuVector{<:Number}=CUDA.zeros(T, size(rotation, 3)),
weight::CuVector{<:Number}=CUDA.ones(T, size(rotation, 3)),
) where {T<:Number, N_in, N_out_p1}
N_out = N_out_p1 - 1
out_batch_dim = ndims(ds_dout)
batch_axis = axes(ds_dout, out_batch_dim)
n_points = size(points, 2)
ds_dout::CuArray{<:Number, N_out_p1},
points::CuVector{<:StaticVector{N_in, <:Number}},
rotation::CuVector{<:StaticMatrix{N_out, N_in, <:Number}},
translation::CuVector{<:StaticVector{N_out, <:Number}},
background::CuVector{<:Number},
weight::CuVector{<:Number},
ds_dpoints::CuMatrix{TP},
ds_drotation::CuArray{TR, 3},
ds_dtranslation::CuMatrix{TT},
ds_dbackground::CuVector{<:Number},
ds_dweight::CuVector{TW},
) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, TW<:Number}
T = promote_type(eltype(ds_dout), TP, TR, TT, TW)
batch_axis = axes(ds_dout, N_out_p1)
@argcheck N_out == N_out_p1 - 1
@argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(weight, 1)
@argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dweight, 1)
@argcheck N_out == N_out_p1 - 1

n_points = length(points)
batch_size = length(batch_axis)
@argcheck axes(ds_dout, out_batch_dim) == axes(rotation, 3) == axes(translation, 2) == axes(background, 1) == axes(weight, 1)

ds_dbackground = dropdims(sum(ds_dout; dims=1:N_out); dims=ntuple(identity, Val(N_out)))
ds_dbackground = vec(sum!(reshape(ds_dbackground, ntuple(_ -> 1, Val(N_out))..., batch_size), ds_dout))

scale = SVector{N_out, T}(size(ds_dout)[1:end-1]) / T(2)
projection_idxs = SVector{N_out}(ntuple(identity, N_out))
shifts=DiffPointRasterisation.voxel_shifts(Val(N_out))

ds_dpoints = fill!(similar(points), zero(T))
ds_drotation = fill!(similar(rotation), zero(T))
ds_dtranslation = fill!(similar(translation), zero(T))
ds_dweight = fill!(similar(weight), zero(T))
ds_dpoints = fill!(ds_dpoints, zero(TP))
ds_drotation = fill!(ds_drotation, zero(TR))
ds_dtranslation = fill!(ds_dtranslation, zero(TT))
ds_dweight = fill!(ds_dweight, zero(TW))

args = (Val(N_in), ds_dout, points, rotation, translation, weight, shifts, scale, projection_idxs, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight)
args = (T, ds_dout, points, rotation, translation, weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight)

ndrange = (n_points, batch_size, 2^N_out)

Expand All @@ -232,4 +243,7 @@ function DiffPointRasterisation.raster_pullback!(
)
end


DiffPointRasterisation.default_ds_dpoints_batched(points::CuVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points)))

end # module
21 changes: 13 additions & 8 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ function raster(
)
eltypes = deep_eltype.(args)
T = promote_type(eltypes...)
points = args[1]
rotation = args[2]
if isa(rotation, AbstractMatrix)
# non-batched
out = similar(rotation, T, grid_size)
out = similar(points, T, grid_size)
else
# batched
@assert rotation isa AbstractVector{<:AbstractMatrix}
batch_size = length(rotation)
out = similar(rotation, T, (grid_size..., batch_size))
out = similar(points, T, (grid_size..., batch_size))
end
raster!(out, args...)
end
Expand Down Expand Up @@ -196,17 +197,17 @@ raster_pullback!(
inp_translation::StaticVector{N_out, <:Number},
inp_background::Number,
inp_weight::Number;
accumulate_ds_dpoints=false,
points::AbstractMatrix{T} = similar(inp_points, T, (N_in, length(inp_points)))
) where {N_in, N_out, T<:Number} =raster_pullback!(
points::AbstractMatrix{T} = default_ds_dpoints_single(inp_points, N_in),
kwargs...
) where {N_in, N_out, T<:Number} = raster_pullback!(
ds_dout,
inp_points,
inp_rotation,
inp_translation,
inp_background,
inp_weight,
points;
accumulate_ds_dpoints,
kwargs...
)

# batch of images
Expand All @@ -217,7 +218,7 @@ raster_pullback!(
inp_translation::AbstractVector{<:StaticVector{N_out, TT}},
inp_background::AbstractVector{TB},
inp_weight::AbstractVector{TW};
points::AbstractArray{TP, 3} = similar(inp_points, TP, (N_in, length(inp_points), min(length(inp_rotation), Threads.nthreads()))),
points::AbstractArray{TP} = default_ds_dpoints_batched(inp_points, N_in, length(inp_rotation)),
rotation::AbstractArray{TR, 3} = similar(inp_rotation, TR, (N_out, N_in, length(inp_rotation))),
translation::AbstractMatrix{TT} = similar(inp_translation, TT, (N_out, length(inp_translation))),
background::AbstractVector{TB} = similar(inp_background),
Expand Down Expand Up @@ -267,7 +268,7 @@ raster_pullback!(
::AbstractVector{<:StaticVector{N_out_trans, <:Number}},
::AbstractVector{<:Number},
::AbstractVector{<:Number},
::AbstractArray{<:Number, 3},
::AbstractArray{<:Number},
::AbstractArray{<:Number, 3},
::AbstractMatrix{<:Number},
::AbstractVector{<:Number},
Expand Down Expand Up @@ -305,6 +306,10 @@ default_weight(rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotat

default_weight(rotation::AbstractArray{_T, 3} where _T, T=eltype(rotation)) = Ones(T, size(rotation, 3))

default_ds_dpoints_single(points::AbstractVector{<:AbstractVector{TP}}, N_in) where {TP<:Number} = similar(points, TP, (N_in, length(points)))

default_ds_dpoints_batched(points::AbstractVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points), min(batch_size, Threads.nthreads())))


@testitem "raster interface" begin
include("../test/data.jl")
Expand Down
60 changes: 33 additions & 27 deletions src/raster_pullback.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
# single image
function raster_pullback!(
ds_dout::AbstractArray{T, N_out},
ds_dout::AbstractArray{<:Number, N_out},
points::AbstractVector{<:StaticVector{N_in, <:Number}},
rotation::StaticMatrix{N_out, N_in, <:Number},
translation::StaticVector{N_out, <:Number},
rotation::StaticMatrix{N_out, N_in, TR},
translation::StaticVector{N_out, TT},
background::Number,
weight::Number,
ds_dpoints::AbstractMatrix{<:Number};
weight::TW,
ds_dpoints::AbstractMatrix{TP};
accumulate_ds_dpoints=false,
) where {N_in, N_out, T<:Number}
) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number, TW<:Number}
T = promote_type(eltype(ds_dout), TP, TR, TT, TW)
@argcheck size(ds_dpoints, 1) == N_in
# The strategy followed here is to redo some of the calculations
# made in the forward pass instead of caching them in the forward
# pass and reusing them here.
accumulate_ds_dpoints || fill!(ds_dpoints, zero(T))
accumulate_ds_dpoints || fill!(ds_dpoints, zero(TP))

origin = (-@SVector ones(T, N_out)) - translation
origin = (-@SVector ones(TT, N_out)) - translation
scale = SVector{N_out, T}(size(ds_dout)) / 2
shifts=voxel_shifts(Val(N_out))
all_density_idxs = CartesianIndices(ds_dout)

# initialize some output for accumulation
ds_dtranslation = @SVector zeros(T, N_out)
ds_drotation = @SMatrix zeros(T, N_out, N_in)
ds_dweight = zero(T)
ds_dtranslation = @SVector zeros(TT, N_out)
ds_drotation = @SMatrix zeros(TR, N_out, N_in)
ds_dweight = zero(TW)

# loop over points
for (pt_idx, point) in enumerate(points)
point = SVector{N_in, T}(point)
point = SVector{N_in, TP}(point)
coord_reference_voxel, deltas = reference_coordinate_and_deltas(
point,
rotation,
Expand Down Expand Up @@ -66,22 +67,22 @@ end
# batch of images
function raster_pullback!(
ds_dout::AbstractArray{<:Number, N_out_p1},
points::AbstractVector{<:StaticVector{N_in, TP}},
rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
translation::AbstractVector{<:StaticVector{N_out, TT}},
background::AbstractVector{TB},
weight::AbstractVector{TW},
ds_dpoints::AbstractArray{TP, 3},
ds_drotation::AbstractArray{TR, 3},
ds_dtranslation::AbstractMatrix{TT},
ds_dbackground::AbstractVector{TB},
ds_dweight::AbstractVector{TW},
) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, TB<:Number, TW<:Number}
points::AbstractVector{<:StaticVector{N_in, <:Number}},
rotation::AbstractVector{<:StaticMatrix{N_out, N_in, <:Number}},
translation::AbstractVector{<:StaticVector{N_out, <:Number}},
background::AbstractVector{<:Number},
weight::AbstractVector{<:Number},
ds_dpoints::AbstractArray{<:Number, 3},
ds_drotation::AbstractArray{<:Number, 3},
ds_dtranslation::AbstractMatrix{<:Number},
ds_dbackground::AbstractVector{<:Number},
ds_dweight::AbstractVector{<:Number},
) where {N_in, N_out, N_out_p1}
batch_axis = axes(ds_dout, N_out_p1)
@argcheck N_out == N_out_p1 - 1
@argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(weight, 1)
@argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dweight, 1)
fill!(ds_dpoints, zero(TP))
fill!(ds_dpoints, zero(eltype(ds_dpoints)))

n_threads = size(ds_dpoints, 3)

Expand Down Expand Up @@ -146,7 +147,7 @@ function interpolation_weight(n, N, deltas, shift)
end

@testitem "raster_pullback! inference and allocations" begin
using BenchmarkTools, CUDA
using BenchmarkTools, CUDA, Adapt
include("../test/data.jl")
ds_dout_3d = randn(D.grid_size_3d)
ds_dout_3d_batched = randn(D.grid_size_3d..., D.batch_size)
Expand Down Expand Up @@ -189,6 +190,11 @@ end
ds_dweights
)

function to_cuda(args)
args_cu = adapt(CuArray, args)
Base.setindex(args_cu, args_cu[7][:, :, 1], 7) # ds_dpoint without batch dim
end

# check type stability
# single image
@inferred DiffPointRasterisation.raster_pullback!(ds_dout_3d, D.points_static, D.rotation, D.translation_3d, D.background, D.weight, ds_dpoints)
Expand All @@ -197,9 +203,9 @@ end
@inferred DiffPointRasterisation.raster_pullback!(args_batched_3d...)
@inferred DiffPointRasterisation.raster_pullback!(args_batched_2d...)
if CUDA.functional()
cu_args_3d = cu.(args_batched_3d)
cu_args_3d = to_cuda(args_batched_3d)
@inferred DiffPointRasterisation.raster_pullback!(cu_args_3d...)
cu_args_2d = cu.(args_batched_2d)
cu_args_2d = to_cuda(args_batched_2d)
@inferred DiffPointRasterisation.raster_pullback!(cu_args_2d...)
end

Expand Down
14 changes: 7 additions & 7 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
include("util.jl")
cuda_available = CUDA.functional()

args = (D.grid_size_3d, D.more_points, D.rotations, D.translations_3d, D.backgrounds, D.weights)
args = (D.grid_size_3d, D.more_points, D.rotations_static, D.translations_3d_static, D.backgrounds, D.weights)
@test cuda_cpu_agree(raster, args...) skip=!cuda_available

args = (D.grid_size_2d, D.more_points, D.rotations, D.translations_2d, D.backgrounds, D.weights)
@test cuda_cpu_agree(raster_project, args...) skip=!cuda_available
args = (D.grid_size_2d, D.more_points, D.projections_static, D.translations_2d_static, D.backgrounds, D.weights)
@test cuda_cpu_agree(raster, args...) skip=!cuda_available
end

@testitem "CUDA backward" begin
Expand All @@ -21,12 +21,12 @@ end
cuda_available = CUDA.functional()

ds_dout_3d = randn(D.grid_size_3d..., D.batch_size)
args = (ds_dout_3d, D.more_points, D.rotations, D.translations_3d, D.backgrounds, D.weights)
args = (ds_dout_3d, D.more_points, D.rotations_static, D.translations_3d_static, D.backgrounds, D.weights)
@test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available

ds_dout_2d = randn(D.grid_size_2d..., D.batch_size)
args = (ds_dout_2d, D.more_points, D.rotations, D.translations_2d, D.backgrounds, D.weights)
@test cuda_cpu_agree(raster_project_pullback!, args...) skip=!cuda_available
args = (ds_dout_2d, D.more_points, D.projections_static, D.translations_2d_static, D.backgrounds, D.weights)
@test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available
end

# The follwing currently fails.
Expand All @@ -44,6 +44,6 @@ end
#
# ds_dout_2d = CUDA.randn(Float64, D.grid_size_2d..., D.batch_size)
# args = (D.grid_size_2d, c(D.points), c(D.rotations) ⊢ c(D.rotation_tangents), c(D.translations_2d), c(D.backgrounds), c(D.weights))
# test_rrule(raster_project, args...; output_tangent=ds_dout_2d)
# test_rrule(raster, args...; output_tangent=ds_dout_2d)
# end
# end
Loading

0 comments on commit ba5c517

Please sign in to comment.