diff --git a/examples/Project.toml b/examples/Project.toml index 4c1560e..26ee241 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,5 @@ DiffPointRasterisation = "f984992d-3c45-4382-99a1-cf20f5c47c61" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/logo.jl b/examples/logo.jl index 7e33e58..9115ba7 100644 --- a/examples/logo.jl +++ b/examples/logo.jl @@ -1,21 +1,24 @@ using DiffPointRasterisation using FFTW using Images +using LinearAlgebra +using StaticArrays using Zygote load_image(path) = load(path) .|> Gray |> channelview -init_points(n) = (rand(Float32, 2, n) .- 0.5f0) .* Float32[2;2;;] +init_points(n) = [2 * rand(Float32, 2) .- 1f0 for _ in 1:n] target_image = load_image("data/julia.png") points = init_points(5_000) -rotation = Float32[1;0;;0;1;;] +rotation = I(2) translation = zeros(Float32, 2) function model(points, log_bandwidth, log_weight) + # raster points to 2d-image rough_image = raster(size(target_image), points, rotation, translation, 0f0, exp(log_weight)) - # smooth with gaussian kernel + # smooth image with gaussian kernel kernel = gaussian_kernel(log_bandwidth, size(target_image)...) image = convolve_image(rough_image, kernel) image @@ -36,18 +39,20 @@ convolve_image(image, kernel) = irfft(rfft(image) .* rfft(kernel), size(image, 1 function loss(points, log_bandwidth, log_weight) model_image = model(points, log_bandwidth, log_weight) - sum((model_image .- target_image).^2) + sum(points.^2) + # squared error plus regularization term for points + sum((model_image .- target_image).^2) + sum(stack(points).^2) end logrange(s, e, n) = round.(Int, exp.(range(log(s), log(e), n))) function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=true, update_weight=true, eps_after_init=eps; n_init=n, n_logs_init=15) + # Langevin sampling for points and optionally log_bandwidth and log_weight. logs_init = logrange(1, n_init, n_logs_init) log_every = false logstep = 1 for i in 1:n l, grads = Zygote.withgradient(loss, points, log_bandwidth, log_weight) - points .+= sqrt(eps) .* randn(Float32, size(points)) .- eps .* 0.5f0 .* grads[1] + points .+= sqrt(eps) .* reinterpret(reshape, SVector{2, Float32}, randn(Float32, 2, length(points))) .- eps .* 0.5f0 .* grads[1] if update_bandwidth log_bandwidth += sqrt(eps) * randn(Float32) - eps * 0.5f0 * grads[2] end @@ -57,6 +62,7 @@ function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=t if i == n_init log_every = true + eps = eps_after_init end if log_every || (i in logs_init) println("iteration $logstep, $i: loss = $l, bandwidth = $(exp(log_bandwidth)), weight = $(exp(log_weight))") @@ -68,4 +74,4 @@ function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=t points, log_bandwidth end -isinteractive() || langevin!(points, log(0.5f0), 0f0, 5f-6, 6_030, false, true, 5f-4; n_init=6_000) +isinteractive() || langevin!(points, log(0.5f0), 0f0, 5f-6, 6_030, false, true, 4f-5; n_init=6_000) diff --git a/ext/DiffPointRasterisationCUDAExt.jl b/ext/DiffPointRasterisationCUDAExt.jl index b8cde78..61c830a 100644 --- a/ext/DiffPointRasterisationCUDAExt.jl +++ b/ext/DiffPointRasterisationCUDAExt.jl @@ -13,24 +13,21 @@ using StaticArrays function raster_pullback_kernel!( - ::Val{N_in}, - ds_dout::AbstractArray{T, N_out_p1}, - points, - rotations, - translations, + ::Type{T}, + ds_dout, + points::AbstractVector{<:StaticVector{N_in}}, + rotations::AbstractVector{<:StaticMatrix{N_out, N_in, TR}}, + translations::AbstractVector{<:StaticVector{N_out, TT}}, weights, shifts, scale, - projection_idxs, # outputs: ds_dpoints, - ds_dprojection_rotation, + ds_drotation, ds_dtranslation, ds_dweight, -) where {T, N_in, N_out_p1} - N_out = N_out_p1 - 1 # dimensionality of output, without batch dimension - +) where {T, TR, TT, N_in, N_out} n_voxel = blockDim().z points_per_workgroup = blockDim().x batchsize_per_workgroup = blockDim().y @@ -46,25 +43,24 @@ function raster_pullback_kernel!( neighbor_voxel_id = (blockIdx().z - 1) * n_voxel + s point_idx = (blockIdx().x - 1) * points_per_workgroup + threadIdx().x batch_idx = (blockIdx().y - 1) * batchsize_per_workgroup + b - in_batch = batch_idx <= size(rotations, 3) + in_batch = batch_idx <= length(rotations) dimension = (N_out, n_voxel, batchsize_per_workgroup) ds_dpoint_rot = CuDynamicSharedArray(T, dimension) ds_dpoint_local = CuDynamicSharedArray(T, (N_in, batchsize_per_workgroup), sizeof(T) * prod(dimension)) - rotation = in_batch ? @inbounds(SMatrix{N_in, N_in, T}(@view rotations[:, :, batch_idx])) : @SMatrix zeros(T, N_in, N_in) - point = @inbounds SVector{N_in, T}(@view points[:, point_idx]) + rotation = @inbounds in_batch ? rotations[batch_idx] : @SMatrix zeros(TR, N_in, N_in) + point = @inbounds points[point_idx] if in_batch - translation = @inbounds SVector{N_out, T}(@view translations[:, batch_idx]) + translation = @inbounds translations[batch_idx] weight = @inbounds weights[batch_idx] shift = @inbounds shifts[neighbor_voxel_id] - origin = (-@SVector ones(T, N_out)) - translation + origin = (-@SVector ones(TT, N_out)) - translation coord_reference_voxel, deltas = DiffPointRasterisation.reference_coordinate_and_deltas( point, rotation, - projection_idxs, origin, scale, ) @@ -76,12 +72,11 @@ function raster_pullback_kernel!( @inbounds ds_dweight_local = DiffPointRasterisation.voxel_weight( deltas, shift, - projection_idxs, ds_dout[voxel_idx], ) factor = ds_dout[voxel_idx] * weight - ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), N_out)) + ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), Val(N_out))) @inbounds ds_dpoint_rot[:, s, b] .= ds_dcoord_part .* scale else @inbounds ds_dpoint_rot[:, s, b] .= zero(T) @@ -123,7 +118,7 @@ function raster_pullback_kernel!( j = 1 while j <= N_in val = coef * point[j] - @inbounds CUDA.@atomic ds_dprojection_rotation[dim, j, batch_idx] += val + @inbounds CUDA.@atomic ds_drotation[dim, j, batch_idx] += val j += 1 end end @@ -175,37 +170,53 @@ function raster_pullback_kernel!( nothing end - -function DiffPointRasterisation._raster_pullback!( - ::Val{N_in}, - ds_dout::CuArray{T, N_out_p1}, - points::CuMatrix{<:Number}, - rotation::CuArray{<:Number, 3}, - translation::CuMatrix{<:Number}, - # TODO: for some reason type inference fails if the following - # two arrays are FillArrays... - background::CuVector{<:Number}=CUDA.zeros(T, size(rotation, 3)), - weight::CuVector{<:Number}=CUDA.ones(T, size(rotation, 3)), -) where {T<:Number, N_in, N_out_p1} - N_out = N_out_p1 - 1 - out_batch_dim = ndims(ds_dout) - batch_axis = axes(ds_dout, out_batch_dim) - n_points = size(points, 2) +# single image +raster_pullback!( + ds_dout::CuArray{<:Number, N_out}, + points::AbstractVector{<:StaticVector{N_in, <:Number}}, + rotation::StaticMatrix{N_out, N_in, <:Number}, + translation::StaticVector{N_out, <:Number}, + background::Number, + weight::Number, + ds_dpoints::AbstractMatrix{<:Number}; + kwargs... +) where {N_in, N_out} = error("Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays") + +# batch of images +function DiffPointRasterisation.raster_pullback!( + ds_dout::CuArray{<:Number, N_out_p1}, + points::CuVector{<:StaticVector{N_in, <:Number}}, + rotation::CuVector{<:StaticMatrix{N_out, N_in, <:Number}}, + translation::CuVector{<:StaticVector{N_out, <:Number}}, + background::CuVector{<:Number}, + weight::CuVector{<:Number}, + ds_dpoints::CuMatrix{TP}, + ds_drotation::CuArray{TR, 3}, + ds_dtranslation::CuMatrix{TT}, + ds_dbackground::CuVector{<:Number}, + ds_dweight::CuVector{TW}, +) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, TW<:Number} + T = promote_type(eltype(ds_dout), TP, TR, TT, TW) + batch_axis = axes(ds_dout, N_out_p1) + @argcheck N_out == N_out_p1 - 1 + @argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(weight, 1) + @argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dweight, 1) + @argcheck N_out == N_out_p1 - 1 + + n_points = length(points) batch_size = length(batch_axis) - @argcheck axes(ds_dout, out_batch_dim) == axes(rotation, 3) == axes(translation, 2) == axes(background, 1) == axes(weight, 1) - ds_dbackground = dropdims(sum(ds_dout; dims=1:N_out); dims=ntuple(identity, Val(N_out))) + ds_dbackground = vec(sum!(reshape(ds_dbackground, ntuple(_ -> 1, Val(N_out))..., batch_size), ds_dout)) scale = SVector{N_out, T}(size(ds_dout)[1:end-1]) / T(2) - projection_idxs = SVector{N_out}(ntuple(identity, N_out)) shifts=DiffPointRasterisation.voxel_shifts(Val(N_out)) - ds_dpoints = fill!(similar(points), zero(T)) - ds_drotation = fill!(similar(rotation), zero(T)) - ds_dtranslation = fill!(similar(translation), zero(T)) - ds_dweight = fill!(similar(weight), zero(T)) + ds_dpoints = fill!(ds_dpoints, zero(TP)) + ds_drotation = fill!(ds_drotation, zero(TR)) + ds_dtranslation = fill!(ds_dtranslation, zero(TT)) + ds_dweight = fill!(ds_dweight, zero(TW)) - args = (Val(N_in), ds_dout, points, rotation, translation, weight, shifts, scale, projection_idxs, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight) + args = (T, ds_dout, points, rotation, translation, weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight) ndrange = (n_points, batch_size, 2^N_out) @@ -232,4 +243,7 @@ function DiffPointRasterisation._raster_pullback!( ) end + +DiffPointRasterisation.default_ds_dpoints_batched(points::CuVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points))) + end # module \ No newline at end of file diff --git a/ext/DiffPointRasterisationChainRulesCoreExt.jl b/ext/DiffPointRasterisationChainRulesCoreExt.jl index f5a5c13..3c94989 100644 --- a/ext/DiffPointRasterisationChainRulesCoreExt.jl +++ b/ext/DiffPointRasterisationChainRulesCoreExt.jl @@ -1,40 +1,92 @@ module DiffPointRasterisationChainRulesCoreExt -using DiffPointRasterisation, ChainRulesCore +using DiffPointRasterisation, ChainRulesCore, StaticArrays +# single image function ChainRulesCore.rrule( - ::typeof(raster), + ::typeof(DiffPointRasterisation.raster), grid_size, - args..., -) - out = raster(grid_size, args...) + points::AbstractVector{<:StaticVector{N_in, T}}, + rotation::AbstractMatrix{<:Number}, + translation::AbstractVector{<:Number}, + optional_args... +) where {N_in, T<:Number} + out = raster(grid_size, points, rotation, translation, optional_args...) - raster_pullback(ds_dout) = NoTangent(), NoTangent(), values( - raster_pullback!( + function raster_pullback(ds_dout) + out_pb = raster_pullback!( unthunk(ds_dout), - args..., + points, + rotation, + translation, + optional_args..., ) - )[1:length(args)]... + ds_dpoints = reinterpret(reshape, SVector{N_in, T}, out_pb.points) + return NoTangent(), NoTangent(), ds_dpoints, values(out_pb)[2:3+length(optional_args)]... + end return out, raster_pullback end +ChainRulesCore.rrule( + f::typeof(DiffPointRasterisation.raster), + grid_size, + points::AbstractVector{<:AbstractVector{<:Number}}, + rotation::AbstractMatrix{<:Number}, + translation::AbstractVector{<:Number}, + optional_args... +) = ChainRulesCore.rrule( + f, + grid_size, + DiffPointRasterisation.inner_to_sized(points), + rotation, + translation, + optional_args... +) +# batch of images function ChainRulesCore.rrule( - ::typeof(raster_project), + ::typeof(DiffPointRasterisation.raster), grid_size, - args..., -) - out = raster_project(grid_size, args...) + points::AbstractVector{<:StaticVector{N_in, TP}}, + rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}}, + translation::AbstractVector{<:StaticVector{N_out, TT}}, + optional_args... +) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number} + out = raster(grid_size, points, rotation, translation, optional_args...) - raster_project_pullback(ds_dout) = NoTangent(), NoTangent(), values( - raster_project_pullback!( + function raster_pullback(ds_dout) + out_pb = raster_pullback!( unthunk(ds_dout), - args..., + points, + rotation, + translation, + optional_args..., ) - )[1:length(args)]... + ds_dpoints = reinterpret(reshape, SVector{N_in, TP}, out_pb.points) + L = N_out * N_in + ds_drotation = reinterpret(reshape, SMatrix{N_out, N_in, TR, L}, reshape(out_pb.rotation, L, :)) + ds_dtranslation = reinterpret(reshape, SVector{N_out, TT}, out_pb.translation) + return NoTangent(), NoTangent(), ds_dpoints, ds_drotation, ds_dtranslation, values(out_pb)[4:3+length(optional_args)]... + end - return out, raster_project_pullback + return out, raster_pullback end +ChainRulesCore.rrule( + f::typeof(DiffPointRasterisation.raster), + grid_size, + points::AbstractVector{<:AbstractVector{<:Number}}, + rotation::AbstractVector{<:AbstractMatrix{<:Number}}, + translation::AbstractVector{<:AbstractVector{<:Number}}, + optional_args... +) = ChainRulesCore.rrule( + f, + grid_size, + DiffPointRasterisation.inner_to_sized(points), + DiffPointRasterisation.inner_to_sized(rotation), + DiffPointRasterisation.inner_to_sized(translation), + optional_args... +) + end # module DiffPointRasterisationChainRulesCoreExt \ No newline at end of file diff --git a/logo.gif b/logo.gif index f402e4f..8566630 100644 Binary files a/logo.gif and b/logo.gif differ diff --git a/src/DiffPointRasterisation.jl b/src/DiffPointRasterisation.jl index a4cf651..8c517ee 100644 --- a/src/DiffPointRasterisation.jl +++ b/src/DiffPointRasterisation.jl @@ -12,7 +12,8 @@ using TestItems include("util.jl") include("raster.jl") include("raster_pullback.jl") +include("interface.jl") -export raster, raster!, raster_project, raster_project!, raster_pullback!, raster_project_pullback! +export raster, raster!, raster_pullback! end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..1a3dc55 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,492 @@ +""" + raster(grid_size, points, rotation, translation, [background, weight]) + +Interpolate points (multi-) linearly into an Nd-array of size `grid_size`. + +Before `points` are interpolated into the array, each point ``p`` is first +transformed according to +```math +\\hat{p} = R p + t +``` +with `rotation` ``R`` and `translation` ``t``. + +Points ``\\hat{p}`` that fall into the N-dimensional hypercube +with edges spanning from (-1, 1) in each dimension, are interpolated +into the output array. + +The total `weight` of each point is distributed onto the 2^N nearest +pixels/voxels of the output array (according to the closeness of the +voxel center to the coordinates of point ``\\hat{p}``) via +N-linear interpolation. + +`rotation`, `translation`, `background` and `weight` can have an +additional "batch" dimension (as last dimension, and the axis +along this dimension must agree across the four arguments). +In this case, the output will also have that additional dimension. +This is useful if the same scene/points should be rastered from +different perspectives. +""" +function raster end + +############################################### +# Step 1: Allocate output +############################################### + +function raster( + grid_size::Tuple, + args..., +) + eltypes = deep_eltype.(args) + T = promote_type(eltypes...) + points = args[1] + rotation = args[2] + if isa(rotation, AbstractMatrix) + # non-batched + out = similar(points, T, grid_size) + else + # batched + @assert rotation isa AbstractVector{<:AbstractMatrix} + batch_size = length(rotation) + out = similar(points, T, (grid_size..., batch_size)) + end + raster!(out, args...) +end + +deep_eltype(el) = deep_eltype(typeof(el)) +deep_eltype(t::Type) = t +deep_eltype(t::Type{<:AbstractArray}) = deep_eltype(eltype(t)) + + +############################################### +# Step 2: Fill default arguments if necessary +############################################### + +@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any, 3}) = raster!(out, args..., default_background(args[2])) +@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any, 4}) = raster!(out, args..., default_weight(args[2])) + +############################################### +# Step 3: Convenience interface for single image: +# Convert arguments for single image to +# length-1 vec of arguments +############################################### + +raster!( + out::AbstractArray{<:Number}, + points::AbstractVector{<:AbstractVector{<:Number}}, + rotation::AbstractMatrix{<:Number}, + translation::AbstractVector{<:Number}, + background::Number, + weight::Number, +) = drop_last_dim( + raster!( + append_singleton_dim(out), + points, + @SVector([rotation]), + @SVector([translation]), + @SVector([background]), + @SVector([weight]), + ) +) + +############################################### +# Step 4: Convert arguments to canonical form, +# i.e. vectors of statically sized arrays +############################################### + +raster!(out::AbstractArray{<:Number}, args::Vararg{AbstractVector, 5}) = raster!(out, inner_to_sized.(args)...) + +############################################### +# Step 5: Error on inconsistent dimensions +############################################### + +# if N_out_rot == N_out_trans this should not be called +# because the actual implementation specializes on N_out +function raster!( + ::AbstractArray{<:Number, N_out}, + ::AbstractVector{<:StaticVector{N_in, <:Number}}, + ::AbstractVector{<:StaticMatrix{N_out_rot, N_in_rot, <:Number}}, + ::AbstractVector{<:StaticVector{N_out_trans, <:Number}}, + ::AbstractVector{<:Number}, + ::AbstractVector{<:Number}, +) where {N_in, N_out, N_in_rot, N_out_rot, N_out_trans} + if N_out_trans != N_out + error("Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!") + end + if N_out_rot != N_out + error("Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!") + end + if N_in_rot != N_in + error("Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!") + end + error("Dispatch error. Should not arrive here. Please file a bug.") +end + +# now similar for pullback +""" + raster_pullback!( + ds_dout, args...; + [points, rotation, translation, background, weight] + ) + +Pullback for `raster(grid_size, args...)`/`raster!(out, args...)`. + +Take as input `ds_dout` the sensitivity of some quantity (`s` for "scalar") +to the *output* `out` of the function `out = raster(grid_size, args...)` +(or `out = raster!(out, args...)`), as well as +the exact same arguments `args` that were passed to `raster`/`raster!`, and +return the sensitivities of `s` to the *inputs* `args` of the function +`raster`/`raster!`. + +Optionally, pre-allocated output arrays for each input sensitivity can be +specified as keyword arguments with the name of the original argument to +`raster` as key, and a nd-array as value, where the n-th dimension is the +batch dimension. +For example to provide a pre-allocated array for the sensitivity of `s` to +the `translation` argument of `raster`, do: +`sensitivities = raster_pullback!(ds_dout, args...; translation = [zeros(2) for _ in 1:8])` +for 2-dimensional points and a batch size of 8. +""" +function raster_pullback! end + + +############################################### +# Step 1: Fill default arguments if necessary +############################################### + +@inline raster_pullback!(ds_out::AbstractArray{<:Number}, args::Vararg{Any, 3}; kwargs...) = raster_pullback!(ds_out, args..., default_background(args[2]); kwargs...) +@inline raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{Any, 4}; kwargs...) = raster_pullback!(ds_dout, args..., default_weight(args[2]); kwargs...) + + +############################################### +# Step 2: Convert arguments to canonical form, +# i.e. vectors of statically sized arrays +############################################### + +# single image +raster_pullback!( + ds_dout::AbstractArray{<:Number}, + points::AbstractVector{<:AbstractVector{<:Number}}, + rotation::AbstractMatrix{<:Number}, + translation::AbstractVector{<:Number}, + background::Number, + weight::Number; + kwargs... +) = raster_pullback!( + ds_dout, + inner_to_sized(points), + to_sized(rotation), + to_sized(translation), + background, + weight; + kwargs... +) + +# batch of images +raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{AbstractVector, 5}; kwargs...) = raster_pullback!(ds_dout, inner_to_sized.(args)...; kwargs...) + + +############################################### +# Step 3: Allocate output +############################################### + +# single image +raster_pullback!( + ds_dout::AbstractArray{<:Number, N_out}, + inp_points::AbstractVector{<:StaticVector{N_in, T}}, + inp_rotation::StaticMatrix{N_out, N_in, <:Number}, + inp_translation::StaticVector{N_out, <:Number}, + inp_background::Number, + inp_weight::Number; + points::AbstractMatrix{T} = default_ds_dpoints_single(inp_points, N_in), + kwargs... +) where {N_in, N_out, T<:Number} = raster_pullback!( + ds_dout, + inp_points, + inp_rotation, + inp_translation, + inp_background, + inp_weight, + points; + kwargs... +) + +# batch of images +raster_pullback!( + ds_dout::AbstractArray{<:Number}, + inp_points::AbstractVector{<:StaticVector{N_in, TP}}, + inp_rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}}, + inp_translation::AbstractVector{<:StaticVector{N_out, TT}}, + inp_background::AbstractVector{TB}, + inp_weight::AbstractVector{TW}; + points::AbstractArray{TP} = default_ds_dpoints_batched(inp_points, N_in, length(inp_rotation)), + rotation::AbstractArray{TR, 3} = similar(inp_rotation, TR, (N_out, N_in, length(inp_rotation))), + translation::AbstractMatrix{TT} = similar(inp_translation, TT, (N_out, length(inp_translation))), + background::AbstractVector{TB} = similar(inp_background), + weight::AbstractVector{TW} = similar(inp_weight), +) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number, TB<:Number, TW<:Number} = raster_pullback!( + ds_dout, + inp_points, + inp_rotation, + inp_translation, + inp_background, + inp_weight, + points, + rotation, + translation, + background, + weight, +) + + +############################################### +# Step 4: Error on inconsistent dimensions +############################################### + +# single image +raster_pullback!( + ::AbstractArray{<:Number, N_out}, + ::AbstractVector{<:StaticVector{N_in, <:Number}}, + ::StaticMatrix{N_out_rot, N_in_rot, <:Number}, + ::StaticVector{N_out_trans, <:Number}, + ::Number, + ::Number, + ::AbstractMatrix{<:Number}; + kwargs... +) where {N_in, N_out, N_in_rot, N_out_rot, N_out_trans} = error_dimensions( + N_in, + N_out, + N_in_rot, + N_out_rot, + N_out_trans +) + +# batch of images +raster_pullback!( + ::AbstractArray{<:Number, N_out_p1}, + ::AbstractVector{<:StaticVector{N_in, <:Number}}, + ::AbstractVector{<:StaticMatrix{N_out_rot, N_in_rot, <:Number}}, + ::AbstractVector{<:StaticVector{N_out_trans, <:Number}}, + ::AbstractVector{<:Number}, + ::AbstractVector{<:Number}, + ::AbstractArray{<:Number}, + ::AbstractArray{<:Number, 3}, + ::AbstractMatrix{<:Number}, + ::AbstractVector{<:Number}, + ::AbstractVector{<:Number}, +) where {N_in, N_out_p1, N_in_rot, N_out_rot, N_out_trans} = error_dimensions( + N_in, + N_out_p1 - 1, + N_in_rot, + N_out_rot, + N_out_trans +) + +function error_dimensions(N_in, N_out, N_in_rot, N_out_rot, N_out_trans) + if N_out_trans != N_out + error("Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!") + end + if N_out_rot != N_out + error("Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!") + end + if N_in_rot != N_in + error("Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!") + end + error("Dispatch error. Should not arrive here. Please file a bug.") +end + +default_background(rotation::AbstractMatrix, T=eltype(rotation)) = zero(T) + +default_background(rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotation))) = Zeros(T, length(rotation)) + +default_background(rotation::AbstractArray{_T, 3} where _T, T=eltype(rotation)) = Zeros(T, size(rotation, 3)) + +default_weight(rotation::AbstractMatrix, T=eltype(rotation)) = one(T) + +default_weight(rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotation))) = Ones(T, length(rotation)) + +default_weight(rotation::AbstractArray{_T, 3} where _T, T=eltype(rotation)) = Ones(T, size(rotation, 3)) + +default_ds_dpoints_single(points::AbstractVector{<:AbstractVector{TP}}, N_in) where {TP<:Number} = similar(points, TP, (N_in, length(points))) + +default_ds_dpoints_batched(points::AbstractVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points), min(batch_size, Threads.nthreads()))) + + +@testitem "raster interface" begin + include("../test/data.jl") + + @testset "no projection" begin + local out + @testset "canonical arguments (vec of staticarray)" begin + out = raster( + D.grid_size_3d, + D.points_static, + D.rotations_static, + D.translations_3d_static, + D.backgrounds, + D.weights + ) + end + @testset "reinterpret nd-array as vec-of-array" begin + @test out ≈ raster( + D.grid_size_3d, + D.points_reinterp, + D.rotations_reinterp, + D.translations_3d_reinterp, + D.backgrounds, + D.weights + ) + end + @testset "point as non-static vector" begin + @test out ≈ raster( + D.grid_size_3d, + D.points, + D.rotations_static, + D.translations_3d_static, + D.backgrounds, + D.weights + ) + end + @testset "rotation as non-static matrix" begin + @test out ≈ raster( + D.grid_size_3d, + D.points_static, + D.rotations, + D.translations_3d_static, + D.backgrounds, + D.weights + ) + end + @testset "translation as non-static vector" begin + @test out ≈ raster( + D.grid_size_3d, + D.points_static, + D.rotations_static, + D.translations_3d, + D.backgrounds, + D.weights + ) + end + @testset "all as non-static array" begin + @test out ≈ raster( + D.grid_size_3d, + D.points, + D.rotations, + D.translations_3d, + D.backgrounds, + D.weights + ) + end + out = raster( + D.grid_size_3d, + D.points_static, + D.rotations_static, + D.translations_3d_static, + zeros(D.batch_size), + ones(D.batch_size), + ) + @testset "default argmuments canonical" begin + @test out ≈ raster( + D.grid_size_3d, + D.points_static, + D.rotations_static, + D.translations_3d_static, + ) + end + @testset "default arguments all as non-static array" begin + @test out ≈ raster( + D.grid_size_3d, + D.points, + D.rotations, + D.translations_3d, + ) + end + end + + @testset "projection" begin + local out + @testset "canonical arguments (vec of staticarray)" begin + out = raster( + D.grid_size_2d, + D.points_static, + D.projections_static, + D.translations_2d_static, + D.backgrounds, + D.weights + ) + end + @testset "reinterpret nd-array as vec-of-array" begin + @test out ≈ raster( + D.grid_size_2d, + D.points_reinterp, + D.projections_reinterp, + D.translations_2d_reinterp, + D.backgrounds, + D.weights + ) + end + @testset "point as non-static vector" begin + @test out ≈ raster( + D.grid_size_2d, + D.points, + D.projections_static, + D.translations_2d_static, + D.backgrounds, + D.weights + ) + end + @testset "projection as non-static matrix" begin + @test out ≈ raster( + D.grid_size_2d, + D.points_static, + D.projections, + D.translations_2d_static, + D.backgrounds, + D.weights + ) + end + @testset "translation as non-static vector" begin + @test out ≈ raster( + D.grid_size_2d, + D.points_static, + D.projections_static, + D.translations_2d, + D.backgrounds, + D.weights + ) + end + @testset "all as non-static array" begin + @test out ≈ raster( + D.grid_size_2d, + D.points_static, + D.projections, + D.translations_2d, + D.backgrounds, + D.weights + ) + end + out = raster( + D.grid_size_2d, + D.points_static, + D.projections_static, + D.translations_2d_static, + zeros(D.batch_size), + ones(D.batch_size), + ) + @testset "default argmuments canonical" begin + @test out ≈ raster( + D.grid_size_2d, + D.points_static, + D.projections_static, + D.translations_2d_static, + ) + end + @testset "default arguments all as non-static array" begin + @test out ≈ raster( + D.grid_size_2d, + D.points, + D.projections, + D.translations_2d, + ) + end + end +end \ No newline at end of file diff --git a/src/raster.jl b/src/raster.jl index 57fcc24..d4ea71a 100644 --- a/src/raster.jl +++ b/src/raster.jl @@ -1,211 +1,25 @@ -""" - raster(grid_size, points, rotation, translation, [background, weight]) - -Interpolate points (multi-) linearly into an Nd-array of size `grid_size`. - -Before `points` are interpolated into the array, each point ``p`` is first -transformed according to -```math -\\hat{p} = R p + t -``` -with `rotation` ``R`` and `translation` ``t``. - -Points ``\\hat{p}`` that fall into the N-dimensional hypercube -with edges spanning from (-1, 1) in each dimension, are interpolated -into the output array. - -The total `weight` of each point is distributed onto the 2^N nearest -pixels/voxels of the output array (according to the closeness of the -voxel center to the coordinates of point ``\\hat{p}``) via -N-linear interpolation. - -`rotation`, `translation`, `background` and `weight` can have an -additional "batch" dimension (as last dimension, and the axis -along this dimension must agree across the four arguments). -In this case, the output will also have that additional dimension. -This is useful if the same scene/points should be rastered from -different perspectives. -""" -function raster end - -# single image -raster( - grid_size, - points::AbstractMatrix{T}, - rotation::AbstractMatrix{<:Number}, - translation, - background=zero(T), - weight=one(T), -) where {T} = raster!( - similar(points, grid_size), - points, - rotation, - translation, - background, - weight, -) - -# batched version -raster( - grid_size, - points::AbstractMatrix{T}, - rotation::AbstractArray{<:Number, 3}, - translation, - background=Zeros(T, size(rotation, 3)), - weight=Ones(T, size(rotation, 3)), -) where {T} = raster!( - similar(points, (grid_size..., size(rotation, 3))), - points, - rotation, - translation, - background, - weight, -) - -""" - raster_project(grid_size, points, rotation, translation, [background, weight]) - -Interpolate N-dimensional points (multi-) linearly into an N-1 dimensional-array -of size `grid_size`. - -Before `points` are interpolated into the array, each point ``p`` is first -transformed according to -```math -\\hat{p} = P R p + t -``` - -The remaining behaviour is the same as for `raster` -""" -function raster_project end - -# single image -raster_project( - grid_size, - points::AbstractMatrix{T}, - rotation::AbstractMatrix{<:Number}, - translation, - background=zero(T), - weight=one(T), -) where {T} = raster_project!( - similar(points, grid_size), - points, - rotation, - translation, - background, - weight, -) - -# batched version -raster_project( - grid_size, - points::AbstractMatrix{T}, - rotation::AbstractArray{<:Number, 3}, - translation, - background=Zeros(T, size(rotation, 3)), - weight=Ones(T, size(rotation, 3)), -) where {T} = raster_project!( - similar(points, (grid_size..., size(rotation, 3))), - points, - rotation, - translation, - background, - weight, -) - - -""" - raster!(out, points, rotation, translation, [background, weight]) - -Inplace version of `raster`. +############################################### +# Step 6: Actual implementation +############################################### -Write output into `out` and return `out`. -""" -raster!( - out::AbstractArray{T, N_out}, - points, - rotation::AbstractArray{<:Number, N_rotation}, - translation, - background=N_rotation == 2 ? zero(T) : Zeros(T, size(rotation, 3)), - weight=N_rotation == 2 ? one(T) : Ones(T, size(rotation, 3)), -) where {N_out, N_rotation, T<:Number} = _raster!( - Val(N_out - (N_rotation - 2)), - out, - points, - rotation, - translation, - background, - weight, -) - - -""" - raster_project!(out, points, rotation, translation, [background, weight]) - -Inplace version of `raster_project`. - -Write output into `out` and return `out`. -""" -raster_project!( - out::AbstractArray{T, N_out}, - points, - rotation::AbstractArray{<:Number, N_rotation}, - translation, - background=N_rotation == 2 ? zero(T) : Zeros(T, size(rotation, 3)), - weight=N_rotation == 2 ? one(T) : Ones(T, size(rotation, 3)), -) where {N_out, N_rotation, T<:Number} = _raster!( - Val(N_out + 1- (N_rotation - 2)), - out, - points, - rotation, - translation, - background, - weight, -) - - -_raster!( - ::Val{N_in}, - out::AbstractArray{T, N_out}, - points::AbstractMatrix{<:Number}, - rotation::AbstractMatrix{<:Number}, - translation::AbstractVector{<:Number}, - background::Number, - weight::Number, -) where {N_in, N_out, T<:Number} = drop_last_dim( - _raster!( - Val(N_in), - append_singleton_dim(out), - points, - append_singleton_dim(rotation), - append_singleton_dim(translation), - fill!(similar(out, 1), background), - fill!(similar(out, 1), weight), - ) -) - - -function _raster!( - ::Val{N_in}, +function raster!( out::AbstractArray{T, N_out_p1}, - points::AbstractMatrix{<:Number}, - rotation::AbstractArray{<:Number, 3}, - translation::AbstractMatrix{<:Number}, + points::AbstractVector{<:StaticVector{N_in, <:Number}}, + rotation::AbstractVector{<:StaticMatrix{N_out, N_in, <:Number}}, + translation::AbstractVector{<:StaticVector{N_out, <:Number}}, background::AbstractVector{<:Number}, weight::AbstractVector{<:Number}, -) where {T<:Number, N_in, N_out_p1} - N_out = N_out_p1 - 1 +) where {T<:Number, N_in, N_out, N_out_p1} + @argcheck N_out == N_out_p1 - 1 DimensionMismatch out_batch_dim = ndims(out) batch_size = size(out, out_batch_dim) - @argcheck size(points, 1) == size(rotation, 1) == size(rotation, 2) == N_in - @argcheck batch_size == size(rotation, 3) == size(translation, 2) == size(background, 1) == size(weight, 1) - @argcheck size(translation, 1) == N_out - n_points = size(points, 2) + @argcheck batch_size == length(rotation) == length(translation) == length(background) == length(weight) DimensionMismatch + n_points = length(points) scale = SVector{N_out, T}(size(out)[1:end-1]) / T(2) - projection_idxs = SVector{N_out}(ntuple(identity, N_out)) shifts=voxel_shifts(Val(N_out)) - out .= reshape(background, ntuple(_ -> 1, N_out)..., length(background)) - args = (Val(N_in), out, points, rotation, translation, weight, shifts, scale, projection_idxs) + out .= reshape(background, ntuple(_ -> 1, Val(N_out))..., length(background)) + args = (out, points, rotation, translation, weight, shifts, scale) backend = get_backend(out) ndrange = (2^N_out, n_points, batch_size) workgroup_size = 1024 @@ -213,13 +27,12 @@ function _raster!( out end -@kernel function raster_kernel!(::Val{N_in}, out::AbstractArray{T, N_out_p1}, points, rotations, translations, weights, shifts, scale, projection_idxs) where {T, N_in, N_out_p1} - N_out = N_out_p1 - 1 # dimensionality of output, without batch dimension +@kernel function raster_kernel!(out::AbstractArray{T}, points, rotations, translations::AbstractVector{<:StaticVector{N_out}}, weights, shifts, scale) where {T, N_out} neighbor_voxel_id, point_idx, batch_idx = @index(Global, NTuple) - point = @inbounds SVector{N_in, T}(@view points[:, point_idx]) - rotation = @inbounds SMatrix{N_in, N_in, T}(@view rotations[:, :, batch_idx]) - translation = @inbounds SVector{N_out, T}(@view translations[:, batch_idx]) + point = @inbounds points[point_idx] + rotation = @inbounds rotations[batch_idx] + translation = @inbounds translations[batch_idx] weight = @inbounds weights[batch_idx] shift = @inbounds shifts[neighbor_voxel_id] origin = (-@SVector ones(T, N_out)) - translation @@ -227,21 +40,20 @@ end coord_reference_voxel, deltas = reference_coordinate_and_deltas( point, rotation, - projection_idxs, origin, scale, ) voxel_idx = CartesianIndex(CartesianIndex(Tuple(coord_reference_voxel)) + CartesianIndex(shift), batch_idx) if voxel_idx in CartesianIndices(out) - val = voxel_weight(deltas, shift, projection_idxs, weight) + val = voxel_weight(deltas, shift, weight) @inbounds Atomix.@atomic out[voxel_idx] += val end nothing end """ - reference_coordinate_and_deltas(point, rotation, projection_idxs, origin, scale) + reference_coordinate_and_deltas(point, rotation, origin, scale) Return - The cartesian coordinate of the voxel of an N-dimensional rectangular @@ -260,12 +72,10 @@ Before `point` is discretized into this grid, it is first translated by @inline function reference_coordinate_and_deltas( point::AbstractVector{T}, rotation, - projection_idxs, origin, scale, ) where {T} - rotated_point = rotation * point - projected_point = @inbounds rotated_point[projection_idxs] + projected_point = rotation * point # coordinate of transformed point in output coordinate system # which is defined by the (integer) coordinates of the pixels/voxels # in the output array. @@ -280,9 +90,9 @@ Before `point` is discretized into this grid, it is first translated by coord_reference_voxel, deltas end -@inline function voxel_weight(deltas, shift, projection_idxs, point_weight) +@inline function voxel_weight(deltas, shift::NTuple{N, Int}, point_weight) where {N} lower_upper = mod1.(shift, 2) - delta_idxs = CartesianIndex.(projection_idxs, lower_upper) + delta_idxs = SVector{N}(CartesianIndex.(ntuple(identity, Val(N)), lower_upper)) val = prod(@inbounds @view deltas[delta_idxs]) * point_weight val end @@ -291,15 +101,15 @@ end using Rotations grid_size = (5, 5) - points_single_center = zeros(2, 1) - points_single_1pix_right = [0.0;0.4;;] - points_single_1pix_up = [-0.4;0.0;;] - points_single_1pix_left = [0.0;-0.4;;] - points_single_1pix_down = [0.4;0.0;;] - points_single_halfpix_down = [0.2;0.0;;] - points_single_halfpix_down_and_right = [0.2;0.2;;] + points_single_center = [zeros(2)] + points_single_1pix_right = [[0.0, 0.4]] + points_single_1pix_up = [[-0.4, 0.0]] + points_single_1pix_left = [[0.0, -0.4]] + points_single_1pix_down = [[0.4, 0.0]] + points_single_halfpix_down = [[0.2, 0.0]] + points_single_halfpix_down_and_right = [[0.2, 0.2]] points_four_cross = reduce( - hcat, + vcat, [ points_single_1pix_right, points_single_1pix_up, points_single_1pix_left, points_single_1pix_down ] @@ -396,22 +206,30 @@ end @testitem "raster inference and allocations" begin - using BenchmarkTools, CUDA + using BenchmarkTools, CUDA, StaticArrays include("../test/data.jl") # check type stability + # single image - @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points, D.rotation, D.translation_3d) - @inferred DiffPointRasterisation.raster_project(D.grid_size_2d, D.points, D.rotation, D.translation_2d) - # batched - @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points, D.rotations, D.translations_3d) - @inferred DiffPointRasterisation.raster_project(D.grid_size_2d, D.points, D.rotations, D.translations_2d) + @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points_static, D.rotation, D.translation_3d) + @inferred DiffPointRasterisation.raster(D.grid_size_2d, D.points_static, D.projection, D.translation_2d) + + # batched canonical + @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points_static, D.rotations_static, D.translations_3d_static) + @inferred DiffPointRasterisation.raster(D.grid_size_2d, D.points_static, D.projections_static, D.translations_2d_static) + + # batched reinterpret reshape + @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points_reinterp, D.rotations_reinterp, D.translations_3d_reinterp) + @inferred DiffPointRasterisation.raster(D.grid_size_2d, D.points_reinterp, D.projections_reinterp, D.translations_2d_reinterp) if CUDA.functional() - @inferred DiffPointRasterisation.raster(D.grid_size_3d, cu(D.points), cu(D.rotation), cu(D.translation_3d)) - @inferred DiffPointRasterisation.raster_project(D.grid_size_2d, cu(D.points), cu(D.rotation), cu(D.translation_2d)) + # single image + @inferred DiffPointRasterisation.raster(D.grid_size_3d, cu(D.points_static), cu(D.rotation), cu(D.translation_3d)) + @inferred DiffPointRasterisation.raster(D.grid_size_2d, cu(D.points_static), cu(D.projection), cu(D.translation_2d)) + # batched - @inferred DiffPointRasterisation.raster(D.grid_size_3d, cu(D.points), cu(D.rotations), cu(D.translations_3d)) - @inferred DiffPointRasterisation.raster_project(D.grid_size_2d, cu(D.points), cu(D.rotations), cu(D.translations_2d)) + @inferred DiffPointRasterisation.raster(D.grid_size_3d, cu(D.points_static), cu(D.rotations_static), cu(D.translations_3d_static)) + @inferred DiffPointRasterisation.raster(D.grid_size_2d, cu(D.points_static), cu(D.projections_static), cu(D.translations_2d_static)) end # Ideally the sinlge image (non batched) case would be allocation-free. @@ -419,9 +237,9 @@ end # set test to broken for now. out_3d = Array{Float64, 3}(undef, D.grid_size_3d...) out_2d = Array{Float64, 2}(undef, D.grid_size_2d...) - allocations = @ballocated DiffPointRasterisation.raster!($out_3d, $D.points, $D.rotation, $D.translation_3d) evals=1 samples=1 + allocations = @ballocated DiffPointRasterisation.raster!($out_3d, $D.points_static, $D.rotation, $D.translation_3d) evals=1 samples=1 @test allocations == 0 broken=true - allocations = @ballocated DiffPointRasterisation.raster_project!($out_2d, $D.points, $D.rotation, $D.translation_2d) evals=1 samples=1 + allocations = @ballocated DiffPointRasterisation.raster!($out_2d, $D.points_static, $D.projection, $D.translation_2d) evals=1 samples=1 @test allocations == 0 broken=true end @@ -433,7 +251,7 @@ end out_3d = zeros(D.grid_size_3d..., D.batch_size) out_3d_batched = zeros(D.grid_size_3d..., D.batch_size) - for (out_i, args...) in zip(eachslice(out_3d, dims=4), eachslice(D.rotations, dims=3), eachcol(D.translations_3d), D.backgrounds, D.weights) + for (out_i, args...) in zip(eachslice(out_3d, dims=4), D.rotations, D.translations_3d, D.backgrounds, D.weights) raster!(out_i, D.more_points, args...) end @@ -443,11 +261,11 @@ end out_2d = zeros(D.grid_size_2d..., D.batch_size) out_2d_batched = zeros(D.grid_size_2d..., D.batch_size) - for (out_i, args...) in zip(eachslice(out_2d, dims=3), eachslice(D.rotations, dims=3), eachcol(D.translations_2d), D.backgrounds, D.weights) - DiffPointRasterisation.raster_project!(out_i, D.more_points, args...) + for (out_i, args...) in zip(eachslice(out_2d, dims=3), D.projections, D.translations_2d, D.backgrounds, D.weights) + DiffPointRasterisation.raster!(out_i, D.more_points, args...) end - DiffPointRasterisation.raster_project!(out_2d_batched, D.more_points, D.rotations, D.translations_2d, D.backgrounds, D.weights) + DiffPointRasterisation.raster!(out_2d_batched, D.more_points, D.projections, D.translations_2d, D.backgrounds, D.weights) @test out_2d_batched ≈ out_2d end \ No newline at end of file diff --git a/src/raster_pullback.jl b/src/raster_pullback.jl index 7dcf462..316fd05 100644 --- a/src/raster_pullback.jl +++ b/src/raster_pullback.jl @@ -1,109 +1,37 @@ -""" - raster_pullback!( - ds_dout, points, rotation, translation, [background, weight]; - [ds_dpoints, ds_drotation, ds_dtranslation, ds_dbackground, ds_dweight] - ) - -Pullback for `raster(...)`/`raster!(...)`. - -Take as input `ds_dout` the sensitivity of some quantity (`s` for "scalar") -to the *output* `out` of the function `raster(args...)`, as well as -the exact same arguments `args` that were passed to `raster`, and -return the sensitivities of `s` to the *inputs* `args` of the function -`raster()`/`raster!()`. - -Optionally, pre-allocated output arrays for each input sensitivity can be -specified as `ds_d\$INPUT_NAME`, e.g. `ds_dtranslation = [zeros(2) for _ in 1:8]` -for 2-dimensional points and a batch size of 8. -""" -raster_pullback!( - ds_dout::AbstractArray{<:Number, N_out}, - points, - rotation::AbstractArray{<:Number, N_rotation}, - args...; - prealloc... -) where {N_out, N_rotation} = _raster_pullback!( - Val(N_out - (N_rotation - 2)), - ds_dout, - points, - rotation, - args...; - prealloc... -) - - - -""" - raster_project_pullback!( - ds_dout, points, rotation, translation, [background, weight]; - [ds_dpoints, ds_drotation, ds_dtranslation, ds_dbackground, ds_dweight] - ) - -Pullback for `raster_project(...)`/`raster_project!(...)`. - -Take as input `ds_dout` the sensitivity of some quantity (`s` for "scalar") -to the *output* `out` of the function `raster_project(args...)`, as well as -the exact same arguments `args` that were passed to `raster_project`, and -return the sensitivities of `s` to the *inputs* `args` of the function -`raster_project()`/`raster_project!()`. - -Optionally, pre-allocated output arrays for each input sensitivity can be -specified as `ds_d\$INPUT_NAME`, e.g. `ds_dtranslation = [zeros(2) for _ in 1:8]` -for 3-dimensional points and a batch size of 8. -""" -raster_project_pullback!( +# single image +function raster_pullback!( ds_dout::AbstractArray{<:Number, N_out}, - points, - rotation::AbstractArray{<:Number, N_rotation}, - args...; - prealloc... -) where {N_out, N_rotation} = _raster_pullback!( - Val(N_out + 1 - (N_rotation - 2)), - ds_dout, - points, - rotation, - args...; - prealloc... -) - - -function _raster_pullback!( - ::Val{N_in}, - ds_dout::AbstractArray{T, N_out}, - points::AbstractMatrix{<:Number}, - rotation::AbstractMatrix{<:Number}, - translation::AbstractVector{<:Number}, - background::Number=zero(T), - weight::Number=one(T); - accumulate_prealloc=false, - prealloc..., -) where {N_in, N_out, T<:Number} + points::AbstractVector{<:StaticVector{N_in, <:Number}}, + rotation::StaticMatrix{N_out, N_in, TR}, + translation::StaticVector{N_out, TT}, + background::Number, + weight::TW, + ds_dpoints::AbstractMatrix{TP}; + accumulate_ds_dpoints=false, +) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number, TW<:Number} + T = promote_type(eltype(ds_dout), TP, TR, TT, TW) + @argcheck size(ds_dpoints, 1) == N_in # The strategy followed here is to redo some of the calculations # made in the forward pass instead of caching them in the forward # pass and reusing them here. - args = (;points,) - @unpack ds_dpoints = _pullback_alloc_serial(args, NamedTuple(prealloc)) - accumulate_prealloc || fill!(ds_dpoints, zero(T)) + accumulate_ds_dpoints || fill!(ds_dpoints, zero(TP)) - rotation = convert(SMatrix{N_in, N_in}, rotation) - origin = (-@SVector ones(T, N_out)) - translation - projection_idxs = SVector(ntuple(identity, N_out)) + origin = (-@SVector ones(TT, N_out)) - translation scale = SVector{N_out, T}(size(ds_dout)) / 2 shifts=voxel_shifts(Val(N_out)) all_density_idxs = CartesianIndices(ds_dout) # initialize some output for accumulation - ds_dtranslation = @SVector zeros(T, N_out) - ds_dprojection_rotation = @SMatrix zeros(T, N_out, N_in) - ds_dweight = zero(T) + ds_dtranslation = @SVector zeros(TT, N_out) + ds_drotation = @SMatrix zeros(TR, N_out, N_in) + ds_dweight = zero(TW) # loop over points - for (pt_idx, point) in enumerate(eachcol(points)) - point = SVector{N_in, T}(point) + for (pt_idx, point) in enumerate(points) + point = SVector{N_in, TP}(point) coord_reference_voxel, deltas = reference_coordinate_and_deltas( point, rotation, - projection_idxs, origin, scale, ) @@ -114,93 +42,186 @@ function _raster_pullback!( voxel_idx = CartesianIndex(Tuple(coord_reference_voxel)) + CartesianIndex(shift) (voxel_idx in all_density_idxs) || continue + ds_dout_i = ds_dout[voxel_idx] + ds_dweight += voxel_weight( deltas, shift, - projection_idxs, - ds_dout[voxel_idx], + ds_dout_i, ) - factor = ds_dout[voxel_idx] * weight + factor = ds_dout_i * weight # loop over dimensions of point ds_dcoord += SVector(factor .* ntuple(n -> interpolation_weight(n, N_out, deltas, shift), N_out)) end scaled = ds_dcoord .* scale ds_dtranslation += scaled - ds_dprojection_rotation += scaled * point' - ds_dpoint = @view(rotation[projection_idxs, :])' * scaled + ds_drotation += scaled * point' + ds_dpoint = rotation' * scaled @view(ds_dpoints[:, pt_idx]) .+= ds_dpoint + end - ds_drotation = N_out == N_in ? ds_dprojection_rotation : vcat(ds_dprojection_rotation, @SMatrix zeros(T, 1, N_in)) return (; points=ds_dpoints, rotation=ds_drotation, translation=ds_dtranslation, background=sum(ds_dout), weight=ds_dweight) end +# batch of images +function raster_pullback!( + ds_dout::AbstractArray{<:Number, N_out_p1}, + points::AbstractVector{<:StaticVector{N_in, <:Number}}, + rotation::AbstractVector{<:StaticMatrix{N_out, N_in, <:Number}}, + translation::AbstractVector{<:StaticVector{N_out, <:Number}}, + background::AbstractVector{<:Number}, + weight::AbstractVector{<:Number}, + ds_dpoints::AbstractArray{<:Number, 3}, + ds_drotation::AbstractArray{<:Number, 3}, + ds_dtranslation::AbstractMatrix{<:Number}, + ds_dbackground::AbstractVector{<:Number}, + ds_dweight::AbstractVector{<:Number}, +) where {N_in, N_out, N_out_p1} + batch_axis = axes(ds_dout, N_out_p1) + @argcheck N_out == N_out_p1 - 1 + @argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(weight, 1) + @argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dweight, 1) + fill!(ds_dpoints, zero(eltype(ds_dpoints))) + + n_threads = size(ds_dpoints, 3) + + Threads.@threads for (idxs, ichunk) in chunks(batch_axis, n_threads) + for i in idxs + args_i = (selectdim(ds_dout, N_out_p1, i), points, rotation[i], translation[i], background[i], weight[i]) + result_i = raster_pullback!(args_i..., view(ds_dpoints, :, :, ichunk); accumulate_ds_dpoints=true) + ds_drotation[:, :, i] .= result_i.rotation + ds_dtranslation[:, i] = result_i.translation + ds_dbackground[i] = result_i.background + ds_dweight[i] = result_i.weight + end + end + return (; points=dropdims(sum(ds_dpoints; dims=3); dims=3), rotation=ds_drotation, translation=ds_dtranslation, background=ds_dbackground, weight=ds_dweight) +end + + +prefix(s::Symbol) = Symbol("ds_d" * string(s)) + + +_pullback_alloc_serial(args, prealloc) = _pullback_alloc_points_serial(args, prealloc) + +function _pullback_alloc_threaded(args, prealloc, n) + points = _pullback_alloc_points_threaded(args, prealloc, n) + other_args = Base.structdiff(args, NamedTuple{(:points,)}) + others = _pullback_alloc_others_threaded(other_args, prealloc) + merge(points, others) +end + +function _pullback_alloc_others_threaded(need_allocation, ::NamedTuple{}) + keys_alloc = prefix.(keys(need_allocation)) + vals = similar.(values(need_allocation)) + NamedTuple{keys_alloc}(vals) +end + +function _pullback_alloc_others_threaded(args, prealloc) + # it's a bit tricky to get this type-stable, but the following does the trick + need_allocation = Base.structdiff(args, prealloc) + keys_alloc = prefix.(keys(need_allocation)) + vals = similar.(values(need_allocation)) + alloc = NamedTuple{keys_alloc}(vals) + keys_prealloc = prefix.(keys(prealloc)) + prefixed_prealloc = NamedTuple{keys_prealloc}(values(prealloc)) + merge(prefixed_prealloc, alloc) +end + +_pullback_alloc_points_serial(args, prealloc) = (;ds_dpoints = get(() -> similar(args.points), prealloc, :points)) + +_pullback_alloc_points_threaded(args, prealloc, n) = (;ds_dpoints = get(() -> similar(args.points, (size(args.points)..., n)), prealloc, :points)) + + +function interpolation_weight(n, N, deltas, shift) + val = @inbounds shift[n] == 1 ? one(eltype(deltas)) : -one(eltype(deltas)) + # loop over other dimensions + @inbounds for other_n in 1:N + if n == other_n + continue + end + val *= deltas[other_n, mod1(shift[other_n], 2)] + end + val +end + @testitem "raster_pullback! inference and allocations" begin - using BenchmarkTools, CUDA + using BenchmarkTools, CUDA, Adapt include("../test/data.jl") ds_dout_3d = randn(D.grid_size_3d) ds_dout_3d_batched = randn(D.grid_size_3d..., D.batch_size) ds_dout_2d = randn(D.grid_size_2d) ds_dout_2d_batched = randn(D.grid_size_2d..., D.batch_size) - ds_dpoints = similar(D.points) + + ds_dpoints = similar(D.points_array) + ds_dpoints_batched = similar(D.points_array, (size(D.points_array)..., Threads.nthreads())) + ds_drotations = similar(D.rotations_array) + ds_dprojections = similar(D.projections_array) + ds_dtranslations_3d = similar(D.translations_3d_array) + ds_dtranslations_2d = similar(D.translations_2d_array) + ds_dbackgrounds = similar(D.backgrounds) + ds_dweights = similar(D.weights) + + args_batched_3d = ( + ds_dout_3d_batched, + D.points_static, + D.rotations_static, + D.translations_3d_static, + D.backgrounds, + D.weights, + ds_dpoints_batched, + ds_drotations, + ds_dtranslations_3d, + ds_dbackgrounds, + ds_dweights + ) + args_batched_2d = ( + ds_dout_2d_batched, + D.points_static, + D.projections_static, + D.translations_2d_static, + D.backgrounds, + D.weights, + ds_dpoints_batched, + ds_dprojections, + ds_dtranslations_2d, + ds_dbackgrounds, + ds_dweights + ) + + function to_cuda(args) + args_cu = adapt(CuArray, args) + Base.setindex(args_cu, args_cu[7][:, :, 1], 7) # ds_dpoint without batch dim + end # check type stability # single image - @inferred DiffPointRasterisation.raster_pullback!(ds_dout_3d, D.points, D.rotation, D.translation_3d) - @inferred DiffPointRasterisation.raster_project_pullback!(ds_dout_2d, D.points, D.rotation, D.translation_2d) + @inferred DiffPointRasterisation.raster_pullback!(ds_dout_3d, D.points_static, D.rotation, D.translation_3d, D.background, D.weight, ds_dpoints) + @inferred DiffPointRasterisation.raster_pullback!(ds_dout_2d, D.points_static, D.projection, D.translation_2d, D.background, D.weight, ds_dpoints) # batched - @inferred DiffPointRasterisation.raster_pullback!(ds_dout_3d_batched, D.points, D.rotations, D.translations_3d) - @inferred DiffPointRasterisation.raster_project_pullback!(ds_dout_2d_batched, D.points, D.rotations, D.translations_2d) + @inferred DiffPointRasterisation.raster_pullback!(args_batched_3d...) + @inferred DiffPointRasterisation.raster_pullback!(args_batched_2d...) if CUDA.functional() - @inferred DiffPointRasterisation.raster_pullback!(cu(ds_dout_3d_batched), cu(D.points), cu(D.rotations), cu(D.translations_3d)) - @inferred DiffPointRasterisation.raster_project_pullback!(cu(ds_dout_2d_batched), cu(D.points), cu(D.rotations), cu(D.translations_2d)) + cu_args_3d = to_cuda(args_batched_3d) + @inferred DiffPointRasterisation.raster_pullback!(cu_args_3d...) + cu_args_2d = to_cuda(args_batched_2d) + @inferred DiffPointRasterisation.raster_pullback!(cu_args_2d...) end # check that single-imge pullback is allocation-free allocations = @ballocated DiffPointRasterisation.raster_pullback!( $ds_dout_3d, - $(D.points), + $(D.points_static), $(D.rotation), $(D.translation_3d), $(D.background), - $(D.weight); - points=$ds_dpoints, + $(D.weight), + $ds_dpoints, ) evals=1 samples=1 @test allocations == 0 end -function _raster_pullback!( - ::Val{N_in}, - ds_dout::AbstractArray{T}, - points::AbstractMatrix{<:Number}, - rotation::AbstractArray{<:Number, 3}, - translation::AbstractMatrix{<:Number}, - # TODO: for some reason type inference fails if the following - # two arrays are FillArrays... - background::AbstractVector{<:Number}=zeros(T, size(rotation, 3)), - weight::AbstractVector{<:Number}=ones(T, size(rotation, 3)); - prealloc... -) where {N_in, T<:Number} - out_batch_dim = ndims(ds_dout) - batch_axis = axes(ds_dout, out_batch_dim) - @argcheck axes(ds_dout, out_batch_dim) == axes(rotation, 3) == axes(translation, 2) == axes(background, 1) == axes(weight, 1) - args = (;points, rotation, translation, background, weight) - @unpack ds_dpoints, ds_drotation, ds_dtranslation, ds_dbackground, ds_dweight = _pullback_alloc_threaded(args, NamedTuple(prealloc), min(length(batch_axis), Threads.nthreads())) - @assert ndims(ds_dpoints) == 3 - fill!(ds_dpoints, zero(T)) - - Threads.@threads for (idxs, ichunk) in chunks(batch_axis, size(ds_dpoints, 3)) - for i in idxs - args_i = (selectdim(ds_dout, out_batch_dim, i), points, view(rotation, :, :, i), view(translation, :, i), background[i], weight[i]) - result_i = _raster_pullback!(Val(N_in), args_i...; accumulate_prealloc=true, points=view(ds_dpoints, :, :, ichunk)) - ds_drotation[:, :, i] .= result_i.rotation - ds_dtranslation[:, i] = result_i.translation - ds_dbackground[i] = result_i.background - ds_dweight[i] = result_i.weight - end - end - return (; points=dropdims(sum(ds_dpoints; dims=3); dims=3), rotation=ds_drotation, translation=ds_dtranslation, background=ds_dbackground, weight=ds_dweight) -end @testitem "raster_pullback! threaded" begin include("../test/data.jl") @@ -211,7 +232,7 @@ end ds_dpoints = Matrix{Float64}[] for i in 1:D.batch_size - ds_dargs_i = @views raster_pullback!(ds_dout[:, :, :, i], D.more_points, D.rotations[:, :, i], D.translations_3d[:, i], D.backgrounds[i], D.weights[i]) + ds_dargs_i = @views raster_pullback!(ds_dout[:, :, :, i], D.more_points, D.rotations[i], D.translations_3d[i], D.backgrounds[i], D.weights[i]) push!(ds_dpoints, ds_dargs_i.points) @views begin @test ds_dargs_threaded.rotation[:, :, i] ≈ ds_dargs_i.rotation @@ -221,18 +242,15 @@ end end end @test ds_dargs_threaded.points ≈ sum(ds_dpoints) -end -@testitem "raster_project_pullback! threaded" begin - include("../test/data.jl") ds_dout = zeros(D.grid_size_2d..., D.batch_size) - ds_dargs_threaded = DiffPointRasterisation.raster_project_pullback!(ds_dout, D.more_points, D.rotations, D.translations_2d, D.backgrounds, D.weights) + ds_dargs_threaded = DiffPointRasterisation.raster_pullback!(ds_dout, D.more_points, D.projections, D.translations_2d, D.backgrounds, D.weights) ds_dpoints = Matrix{Float64}[] for i in 1:D.batch_size - ds_dargs_i = @views raster_project_pullback!(ds_dout[:, :, i], D.more_points, D.rotations[:, :, i], D.translations_2d[:, i], D.backgrounds[i], D.weights[i]) + ds_dargs_i = @views raster_pullback!(ds_dout[:, :, i], D.more_points, D.projections[i], D.translations_2d[i], D.backgrounds[i], D.weights[i]) push!(ds_dpoints, ds_dargs_i.points) @views begin @test ds_dargs_threaded.rotation[:, :, i] ≈ ds_dargs_i.rotation @@ -242,51 +260,4 @@ end end end @test ds_dargs_threaded.points ≈ sum(ds_dpoints) -end - - -prefix(s::Symbol) = Symbol("ds_d" * string(s)) - - -_pullback_alloc_serial(args, prealloc) = _pullback_alloc_points_serial(args, prealloc) - -function _pullback_alloc_threaded(args, prealloc, n) - points = _pullback_alloc_points_threaded(args, prealloc, n) - other_args = Base.structdiff(args, NamedTuple{(:points,)}) - others = _pullback_alloc_others_threaded(other_args, prealloc) - merge(points, others) -end - -function _pullback_alloc_others_threaded(need_allocation, ::NamedTuple{}) - keys_alloc = prefix.(keys(need_allocation)) - vals = similar.(values(need_allocation)) - NamedTuple{keys_alloc}(vals) -end - -function _pullback_alloc_others_threaded(args, prealloc) - # it's a bit tricky to get this type-stable, but the following does the trick - need_allocation = Base.structdiff(args, prealloc) - keys_alloc = prefix.(keys(need_allocation)) - vals = similar.(values(need_allocation)) - alloc = NamedTuple{keys_alloc}(vals) - keys_prealloc = prefix.(keys(prealloc)) - prefixed_prealloc = NamedTuple{keys_prealloc}(values(prealloc)) - merge(prefixed_prealloc, alloc) -end - -_pullback_alloc_points_serial(args, prealloc) = (;ds_dpoints = get(() -> similar(args.points), prealloc, :points)) - -_pullback_alloc_points_threaded(args, prealloc, n) = (;ds_dpoints = get(() -> similar(args.points, (size(args.points)..., n)), prealloc, :points)) - - -function interpolation_weight(n, N, deltas, shift) - val = @inbounds shift[n] == 1 ? one(eltype(deltas)) : -one(eltype(deltas)) - # loop over other dimensions - @inbounds for other_n in 1:N - if n == other_n - continue - end - val *= deltas[other_n, mod1(shift[other_n], 2)] - end - val end \ No newline at end of file diff --git a/src/util.jl b/src/util.jl index c150799..5482a8e 100644 --- a/src/util.jl +++ b/src/util.jl @@ -22,11 +22,67 @@ For a N-dimensional voxel grid, return a 2^N-tuple of N-tuples, where each element of the outer tuple is a cartesian coordinate shift from the "upper left" voxel. """ -voxel_shifts(::Val{N}, int_type=Int64) where {N} = ntuple(k -> digitstuple(k-1, Val(N), int_type), 2^N) +voxel_shifts(::Val{N}, int_type=Int64) where {N} = ntuple(k -> digitstuple(k-1, Val(N), int_type), Val(2^N)) + +@testitem "voxel_shifts" begin + @inferred DiffPointRasterisation.voxel_shifts(Val(4)) + + @test DiffPointRasterisation.voxel_shifts(Val(1)) == ((0,), (1,)) + + @test DiffPointRasterisation.voxel_shifts(Val(2)) == ((0, 0), (1, 0), (0, 1), (1, 1)) + + @test DiffPointRasterisation.voxel_shifts(Val(3)) == ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1)) +end + +to_sized(arg::StaticArray{<:Any, <:Number}) = arg + +to_sized(arg::AbstractArray{T}) where {T<:Number} = SizedArray{Tuple{size(arg)...}, T}(arg) + +inner_to_sized(arg::AbstractVector{<:Number}) = arg + +inner_to_sized(arg::AbstractVector{<:StaticArray}) = arg + +inner_to_sized(arg::AbstractVector{<:AbstractArray{<:Number}}) = inner_to_sized(arg, Val(size(arg[1]))) + +inner_to_sized(arg::AbstractVector{<:AbstractArray{T}}, ::Val{sz}) where {sz, T<:Number} = SizedArray{Tuple{sz...}, T}.(arg) + +@testitem "inner_to_sized" begin + using StaticArrays + @testset "vector" begin + inp = randn(3) + @inferred DiffPointRasterisation.inner_to_sized(inp) + out = DiffPointRasterisation.inner_to_sized(inp) + @test out === inp + end + + @testset "vec of dynamic vec" begin + inp = [randn(3) for _ in 1:5] + out = DiffPointRasterisation.inner_to_sized(inp) + @test out == inp + @test out isa Vector{<:StaticVector{3}} + end + + @testset "vec of static vec" begin + inp = [@SVector randn(3) for _ in 1:5] + @inferred DiffPointRasterisation.inner_to_sized(inp) + out = DiffPointRasterisation.inner_to_sized(inp) + @test out === inp + @test out isa Vector{<:StaticVector{3}} + end + + @testset "vec of dynamic matrix" begin + inp = [randn(3, 2) for _ in 1:5] + out = DiffPointRasterisation.inner_to_sized(inp) + @test out == inp + @test out isa Vector{<:StaticMatrix{3, 2}} + end +end @inline append_singleton_dim(a) = reshape(a, size(a)..., 1) +@inline append_singleton_dim(a::Number) = [a] + @inline drop_last_dim(a) = dropdims(a; dims=ndims(a)) @testitem "append drop dim" begin diff --git a/test/chainrules.jl b/test/chainrules.jl index 267202c..d41912b 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,32 +1,32 @@ @testitem "ChainRules single" begin - using ChainRulesTestUtils + using ChainRulesTestUtils, ChainRulesCore include("data.jl") - test_rrule(raster, D.grid_size_3d, D.points, D.rotation ⊢ D.rotation_tangent, D.translation_3d, D.background, D.weight) + test_rrule(raster, D.grid_size_3d, D.points_static, D.rotation ⊢ D.rotation_tangent, D.translation_3d, D.background, D.weight) # default arguments - test_rrule(raster, D.grid_size_3d, D.points, D.rotation ⊢ D.rotation_tangent, D.translation_3d) + test_rrule(raster, D.grid_size_3d, D.points_static, D.rotation ⊢ D.rotation_tangent, D.translation_3d) - test_rrule(raster_project, D.grid_size_2d, D.points, D.rotation ⊢ D.rotation_tangent, D.translation_2d, D.background, D.weight) + test_rrule(raster, D.grid_size_2d, D.points_static, D.projection ⊢ D.projection_tangent, D.translation_2d, D.background, D.weight) # default arguments - test_rrule(raster_project, D.grid_size_2d, D.points, D.rotation ⊢ D.rotation_tangent, D.translation_2d) + test_rrule(raster, D.grid_size_2d, D.points_static, D.projection ⊢ D.projection_tangent, D.translation_2d) end @testitem "ChainRules batch" begin using ChainRulesTestUtils include("data.jl") - test_rrule(raster, D.grid_size_3d, D.points, D.rotations ⊢ D.rotation_tangents, D.translations_3d, D.backgrounds, D.weights) + test_rrule(raster, D.grid_size_3d, D.points_static, D.rotations_static ⊢ D.rotation_tangents_static, D.translations_3d_static, D.backgrounds, D.weights) # default arguments - test_rrule(raster, D.grid_size_3d, D.points, D.rotations ⊢ D.rotation_tangents, D.translations_3d) + test_rrule(raster, D.grid_size_3d, D.points_static, D.rotations_static ⊢ D.rotation_tangents_static, D.translations_3d_static) - test_rrule(raster_project, D.grid_size_2d, D.points, D.rotations ⊢ D.rotation_tangents, D.translations_2d, D.backgrounds, D.weights) + test_rrule(raster, D.grid_size_2d, D.points_static, D.projections_static ⊢ D.projection_tangents_static, D.translations_2d_static, D.backgrounds, D.weights) # default arguments - test_rrule(raster_project, D.grid_size_2d, D.points, D.rotations ⊢ D.rotation_tangents, D.translations_2d) + test_rrule(raster, D.grid_size_2d, D.points_static, D.projections_static ⊢ D.projection_tangents_static, D.translations_2d_static) end \ No newline at end of file diff --git a/test/cuda.jl b/test/cuda.jl index f220fad..717ca1f 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -6,11 +6,11 @@ include("util.jl") cuda_available = CUDA.functional() - args = (D.grid_size_3d, D.more_points, D.rotations, D.translations_3d, D.backgrounds, D.weights) + args = (D.grid_size_3d, D.more_points, D.rotations_static, D.translations_3d_static, D.backgrounds, D.weights) @test cuda_cpu_agree(raster, args...) skip=!cuda_available - args = (D.grid_size_2d, D.more_points, D.rotations, D.translations_2d, D.backgrounds, D.weights) - @test cuda_cpu_agree(raster_project, args...) skip=!cuda_available + args = (D.grid_size_2d, D.more_points, D.projections_static, D.translations_2d_static, D.backgrounds, D.weights) + @test cuda_cpu_agree(raster, args...) skip=!cuda_available end @testitem "CUDA backward" begin @@ -21,12 +21,12 @@ end cuda_available = CUDA.functional() ds_dout_3d = randn(D.grid_size_3d..., D.batch_size) - args = (ds_dout_3d, D.more_points, D.rotations, D.translations_3d, D.backgrounds, D.weights) + args = (ds_dout_3d, D.more_points, D.rotations_static, D.translations_3d_static, D.backgrounds, D.weights) @test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available ds_dout_2d = randn(D.grid_size_2d..., D.batch_size) - args = (ds_dout_2d, D.more_points, D.rotations, D.translations_2d, D.backgrounds, D.weights) - @test cuda_cpu_agree(raster_project_pullback!, args...) skip=!cuda_available + args = (ds_dout_2d, D.more_points, D.projections_static, D.translations_2d_static, D.backgrounds, D.weights) + @test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available end # The follwing currently fails. @@ -44,6 +44,6 @@ end # # ds_dout_2d = CUDA.randn(Float64, D.grid_size_2d..., D.batch_size) # args = (D.grid_size_2d, c(D.points), c(D.rotations) ⊢ c(D.rotation_tangents), c(D.translations_2d), c(D.backgrounds), c(D.weights)) -# test_rrule(raster_project, args...; output_tangent=ds_dout_2d) +# test_rrule(raster, args...; output_tangent=ds_dout_2d) # end # end diff --git a/test/data.jl b/test/data.jl index 0cbb201..d0d9a7e 100644 --- a/test/data.jl +++ b/test/data.jl @@ -1,6 +1,6 @@ module D -using Rotations +using Rotations, StaticArrays function batch_size_for_test() local batch_size = Threads.nthreads() + 1 @@ -10,22 +10,58 @@ function batch_size_for_test() batch_size end +const P = @SMatrix Float64[ + 1 0 0 + 0 1 0 +] const grid_size_3d = (8, 8, 8) const grid_size_2d = (8, 8) const batch_size = batch_size_for_test() -const points = 0.4 .* randn(3, 10) -const more_points = 0.4 .* randn(3, 100_000) -const rotation = Array(rand(RotMatrix3)) + +const points = [0.4 * randn(3) for _ in 1:10] +const points_static = SVector{3}.(points) +const points_array = Matrix{Float64}(undef, 3, length(points)) +eachcol(points_array) .= points +const points_reinterp = reinterpret(reshape, SVector{3, Float64}, points_array) +const more_points = [0.4 * @SVector randn(3) for _ in 1:100_000] + +const rotation = rand(RotMatrix3{Float64}) +const rotations_static = rand(RotMatrix3{Float64}, batch_size)::Vector{<:StaticMatrix} +const rotations = (Array.(rotations_static))::Vector{Matrix{Float64}} +const rotations_array = Array{Float64, 3}(undef, 3, 3, batch_size) +eachslice(rotations_array; dims=3) .= rotations +const rotations_reinterp = reinterpret(reshape, SMatrix{3, 3, Float64, 9}, reshape(rotations_array, 9, :)) const rotation_tangent = Array(rand(RotMatrix3)) -const rotations = stack(rand(RotMatrix3, batch_size)) -const rotation_tangents = stack(rand(RotMatrix3, batch_size)) -const translation_3d = 0.1 .* randn(3) -const translation_2d = 0.1 .* randn(2) -const translations_3d = zeros(3, batch_size) -const translations_2d = zeros(2, batch_size) +const rotation_tangents_static = rand(RotMatrix3{Float64}, batch_size)::Vector{<:StaticMatrix} +const rotation_tangents = (Array.(rotation_tangents_static))::Vector{Matrix{Float64}} + +const projection = P * rand(RotMatrix3) +const projections_static = Ref(P) .* rand(RotMatrix3{Float64}, batch_size) +const projections = (Array.(projections_static))::Vector{Matrix{Float64}} +const projections_array = Array{Float64, 3}(undef, 2, 3, batch_size) +eachslice(projections_array; dims=3) .= projections +const projections_reinterp = reinterpret(reshape, SMatrix{2, 3, Float64, 6}, reshape(projections_array, 6, :)) +const projection_tangent = Array(P * rand(RotMatrix3)) +const projection_tangents_static = Ref(P) .* rand(RotMatrix3{Float64}, batch_size) +const projection_tangents = (Array.(projection_tangents_static))::Vector{Matrix{Float64}} + +const translation_3d = 0.1 * @SVector randn(3) +const translation_2d = 0.1 * @SVector randn(2) +const translations_3d_static = [0.1 * @SVector randn(3) for _ in 1:batch_size] +const translations_3d = (Array.(translations_3d_static))::Vector{Vector{Float64}} +const translations_3d_array = Matrix{Float64}(undef, 3, batch_size) +eachcol(translations_3d_array) .= translations_3d +const translations_3d_reinterp = reinterpret(reshape, SVector{3, Float64}, translations_3d_array) +const translations_2d_static = [0.1 * @SVector randn(2) for _ in 1:batch_size] +const translations_2d = (Array.(translations_2d_static))::Vector{Vector{Float64}} +const translations_2d_array = Matrix{Float64}(undef, 2, batch_size) +eachcol(translations_2d_array) .= translations_2d +const translations_2d_reinterp = reinterpret(reshape, SVector{2, Float64}, translations_2d_array) + const background = 0.1 const backgrounds = collect(1:1.0:batch_size) + const weight = rand() const weights = 10 .* rand(batch_size) diff --git a/test/util.jl b/test/util.jl index bf1c14f..0db64a8 100644 --- a/test/util.jl +++ b/test/util.jl @@ -1,10 +1,10 @@ -function run_cuda(f::F, args::Vararg{Any, N}) where {F, N} +function run_cuda(f, args...) cu_args = adapt(CuArray, args) return f(cu_args...) end -function cuda_cpu_agree(f::F, args...) where {F} +function cuda_cpu_agree(f, args...) out_cpu = f(args...) out_cuda = run_cuda(f, args...) is_approx_equal(out_cuda, out_cpu) @@ -18,12 +18,16 @@ end function is_approx_equal(actual::NamedTuple, expected::NamedTuple) actual_cpu = adapt(Array, actual) for prop in propertynames(expected) - # (prop in (:points,)) && continue - actual_elem = getproperty(actual_cpu, prop) - expected_elem = getproperty(expected, prop) - if !(actual_elem ≈ expected_elem) - throw("Element '$(string(prop))' differs:\nActual: $actual_elem \nExpected: $expected_elem") - return false + try + actual_elem = getproperty(actual_cpu, prop) + expected_elem = getproperty(expected, prop) + if !(actual_elem ≈ expected_elem) + throw("Values differ:\nActual: $(string(actual_elem)) \nExpected: $(string(expected_elem))") + return false + end + catch e + println("Error while trying to compare element $(string(prop))") + rethrow() end end true