Skip to content

Commit

Permalink
Merge pull request #7 from trahflow/interface
Browse files Browse the repository at this point in the history
Change Interface from nd-array to vec-of-array
  • Loading branch information
trahflow authored Mar 7, 2024
2 parents e47e646 + ba5c517 commit 2acceca
Show file tree
Hide file tree
Showing 14 changed files with 983 additions and 532 deletions.
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
DiffPointRasterisation = "f984992d-3c45-4382-99a1-cf20f5c47c61"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
18 changes: 12 additions & 6 deletions examples/logo.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
using DiffPointRasterisation
using FFTW
using Images
using LinearAlgebra
using StaticArrays
using Zygote

load_image(path) = load(path) .|> Gray |> channelview

init_points(n) = (rand(Float32, 2, n) .- 0.5f0) .* Float32[2;2;;]
init_points(n) = [2 * rand(Float32, 2) .- 1f0 for _ in 1:n]

target_image = load_image("data/julia.png")

points = init_points(5_000)
rotation = Float32[1;0;;0;1;;]
rotation = I(2)
translation = zeros(Float32, 2)

function model(points, log_bandwidth, log_weight)
# raster points to 2d-image
rough_image = raster(size(target_image), points, rotation, translation, 0f0, exp(log_weight))
# smooth with gaussian kernel
# smooth image with gaussian kernel
kernel = gaussian_kernel(log_bandwidth, size(target_image)...)
image = convolve_image(rough_image, kernel)
image
Expand All @@ -36,18 +39,20 @@ convolve_image(image, kernel) = irfft(rfft(image) .* rfft(kernel), size(image, 1

function loss(points, log_bandwidth, log_weight)
model_image = model(points, log_bandwidth, log_weight)
sum((model_image .- target_image).^2) + sum(points.^2)
# squared error plus regularization term for points
sum((model_image .- target_image).^2) + sum(stack(points).^2)
end

logrange(s, e, n) = round.(Int, exp.(range(log(s), log(e), n)))

function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=true, update_weight=true, eps_after_init=eps; n_init=n, n_logs_init=15)
# Langevin sampling for points and optionally log_bandwidth and log_weight.
logs_init = logrange(1, n_init, n_logs_init)
log_every = false
logstep = 1
for i in 1:n
l, grads = Zygote.withgradient(loss, points, log_bandwidth, log_weight)
points .+= sqrt(eps) .* randn(Float32, size(points)) .- eps .* 0.5f0 .* grads[1]
points .+= sqrt(eps) .* reinterpret(reshape, SVector{2, Float32}, randn(Float32, 2, length(points))) .- eps .* 0.5f0 .* grads[1]
if update_bandwidth
log_bandwidth += sqrt(eps) * randn(Float32) - eps * 0.5f0 * grads[2]
end
Expand All @@ -57,6 +62,7 @@ function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=t

if i == n_init
log_every = true
eps = eps_after_init
end
if log_every || (i in logs_init)
println("iteration $logstep, $i: loss = $l, bandwidth = $(exp(log_bandwidth)), weight = $(exp(log_weight))")
Expand All @@ -68,4 +74,4 @@ function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=t
points, log_bandwidth
end

isinteractive() || langevin!(points, log(0.5f0), 0f0, 5f-6, 6_030, false, true, 5f-4; n_init=6_000)
isinteractive() || langevin!(points, log(0.5f0), 0f0, 5f-6, 6_030, false, true, 4f-5; n_init=6_000)
100 changes: 57 additions & 43 deletions ext/DiffPointRasterisationCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,21 @@ using StaticArrays


function raster_pullback_kernel!(
::Val{N_in},
ds_dout::AbstractArray{T, N_out_p1},
points,
rotations,
translations,
::Type{T},
ds_dout,
points::AbstractVector{<:StaticVector{N_in}},
rotations::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
translations::AbstractVector{<:StaticVector{N_out, TT}},
weights,
shifts,
scale,
projection_idxs,
# outputs:
ds_dpoints,
ds_dprojection_rotation,
ds_drotation,
ds_dtranslation,
ds_dweight,

) where {T, N_in, N_out_p1}
N_out = N_out_p1 - 1 # dimensionality of output, without batch dimension

) where {T, TR, TT, N_in, N_out}
n_voxel = blockDim().z
points_per_workgroup = blockDim().x
batchsize_per_workgroup = blockDim().y
Expand All @@ -46,25 +43,24 @@ function raster_pullback_kernel!(
neighbor_voxel_id = (blockIdx().z - 1) * n_voxel + s
point_idx = (blockIdx().x - 1) * points_per_workgroup + threadIdx().x
batch_idx = (blockIdx().y - 1) * batchsize_per_workgroup + b
in_batch = batch_idx <= size(rotations, 3)
in_batch = batch_idx <= length(rotations)

dimension = (N_out, n_voxel, batchsize_per_workgroup)
ds_dpoint_rot = CuDynamicSharedArray(T, dimension)
ds_dpoint_local = CuDynamicSharedArray(T, (N_in, batchsize_per_workgroup), sizeof(T) * prod(dimension))

rotation = in_batch ? @inbounds(SMatrix{N_in, N_in, T}(@view rotations[:, :, batch_idx])) : @SMatrix zeros(T, N_in, N_in)
point = @inbounds SVector{N_in, T}(@view points[:, point_idx])
rotation = @inbounds in_batch ? rotations[batch_idx] : @SMatrix zeros(TR, N_in, N_in)
point = @inbounds points[point_idx]

if in_batch
translation = @inbounds SVector{N_out, T}(@view translations[:, batch_idx])
translation = @inbounds translations[batch_idx]
weight = @inbounds weights[batch_idx]
shift = @inbounds shifts[neighbor_voxel_id]
origin = (-@SVector ones(T, N_out)) - translation
origin = (-@SVector ones(TT, N_out)) - translation

coord_reference_voxel, deltas = DiffPointRasterisation.reference_coordinate_and_deltas(
point,
rotation,
projection_idxs,
origin,
scale,
)
Expand All @@ -76,12 +72,11 @@ function raster_pullback_kernel!(
@inbounds ds_dweight_local = DiffPointRasterisation.voxel_weight(
deltas,
shift,
projection_idxs,
ds_dout[voxel_idx],
)

factor = ds_dout[voxel_idx] * weight
ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), N_out))
ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), Val(N_out)))
@inbounds ds_dpoint_rot[:, s, b] .= ds_dcoord_part .* scale
else
@inbounds ds_dpoint_rot[:, s, b] .= zero(T)
Expand Down Expand Up @@ -123,7 +118,7 @@ function raster_pullback_kernel!(
j = 1
while j <= N_in
val = coef * point[j]
@inbounds CUDA.@atomic ds_dprojection_rotation[dim, j, batch_idx] += val
@inbounds CUDA.@atomic ds_drotation[dim, j, batch_idx] += val
j += 1
end
end
Expand Down Expand Up @@ -175,37 +170,53 @@ function raster_pullback_kernel!(
nothing
end


function DiffPointRasterisation._raster_pullback!(
::Val{N_in},
ds_dout::CuArray{T, N_out_p1},
points::CuMatrix{<:Number},
rotation::CuArray{<:Number, 3},
translation::CuMatrix{<:Number},
# TODO: for some reason type inference fails if the following
# two arrays are FillArrays...
background::CuVector{<:Number}=CUDA.zeros(T, size(rotation, 3)),
weight::CuVector{<:Number}=CUDA.ones(T, size(rotation, 3)),
) where {T<:Number, N_in, N_out_p1}
N_out = N_out_p1 - 1
out_batch_dim = ndims(ds_dout)
batch_axis = axes(ds_dout, out_batch_dim)
n_points = size(points, 2)
# single image
raster_pullback!(
ds_dout::CuArray{<:Number, N_out},
points::AbstractVector{<:StaticVector{N_in, <:Number}},
rotation::StaticMatrix{N_out, N_in, <:Number},
translation::StaticVector{N_out, <:Number},
background::Number,
weight::Number,
ds_dpoints::AbstractMatrix{<:Number};
kwargs...
) where {N_in, N_out} = error("Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays")

# batch of images
function DiffPointRasterisation.raster_pullback!(
ds_dout::CuArray{<:Number, N_out_p1},
points::CuVector{<:StaticVector{N_in, <:Number}},
rotation::CuVector{<:StaticMatrix{N_out, N_in, <:Number}},
translation::CuVector{<:StaticVector{N_out, <:Number}},
background::CuVector{<:Number},
weight::CuVector{<:Number},
ds_dpoints::CuMatrix{TP},
ds_drotation::CuArray{TR, 3},
ds_dtranslation::CuMatrix{TT},
ds_dbackground::CuVector{<:Number},
ds_dweight::CuVector{TW},
) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, TW<:Number}
T = promote_type(eltype(ds_dout), TP, TR, TT, TW)
batch_axis = axes(ds_dout, N_out_p1)
@argcheck N_out == N_out_p1 - 1
@argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(weight, 1)
@argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dweight, 1)
@argcheck N_out == N_out_p1 - 1

n_points = length(points)
batch_size = length(batch_axis)
@argcheck axes(ds_dout, out_batch_dim) == axes(rotation, 3) == axes(translation, 2) == axes(background, 1) == axes(weight, 1)

ds_dbackground = dropdims(sum(ds_dout; dims=1:N_out); dims=ntuple(identity, Val(N_out)))
ds_dbackground = vec(sum!(reshape(ds_dbackground, ntuple(_ -> 1, Val(N_out))..., batch_size), ds_dout))

scale = SVector{N_out, T}(size(ds_dout)[1:end-1]) / T(2)
projection_idxs = SVector{N_out}(ntuple(identity, N_out))
shifts=DiffPointRasterisation.voxel_shifts(Val(N_out))

ds_dpoints = fill!(similar(points), zero(T))
ds_drotation = fill!(similar(rotation), zero(T))
ds_dtranslation = fill!(similar(translation), zero(T))
ds_dweight = fill!(similar(weight), zero(T))
ds_dpoints = fill!(ds_dpoints, zero(TP))
ds_drotation = fill!(ds_drotation, zero(TR))
ds_dtranslation = fill!(ds_dtranslation, zero(TT))
ds_dweight = fill!(ds_dweight, zero(TW))

args = (Val(N_in), ds_dout, points, rotation, translation, weight, shifts, scale, projection_idxs, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight)
args = (T, ds_dout, points, rotation, translation, weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight)

ndrange = (n_points, batch_size, 2^N_out)

Expand All @@ -232,4 +243,7 @@ function DiffPointRasterisation._raster_pullback!(
)
end


DiffPointRasterisation.default_ds_dpoints_batched(points::CuVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points)))

end # module
88 changes: 70 additions & 18 deletions ext/DiffPointRasterisationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,92 @@
module DiffPointRasterisationChainRulesCoreExt

using DiffPointRasterisation, ChainRulesCore
using DiffPointRasterisation, ChainRulesCore, StaticArrays

# single image
function ChainRulesCore.rrule(
::typeof(raster),
::typeof(DiffPointRasterisation.raster),
grid_size,
args...,
)
out = raster(grid_size, args...)
points::AbstractVector{<:StaticVector{N_in, T}},
rotation::AbstractMatrix{<:Number},
translation::AbstractVector{<:Number},
optional_args...
) where {N_in, T<:Number}
out = raster(grid_size, points, rotation, translation, optional_args...)

raster_pullback(ds_dout) = NoTangent(), NoTangent(), values(
raster_pullback!(
function raster_pullback(ds_dout)
out_pb = raster_pullback!(
unthunk(ds_dout),
args...,
points,
rotation,
translation,
optional_args...,
)
)[1:length(args)]...
ds_dpoints = reinterpret(reshape, SVector{N_in, T}, out_pb.points)
return NoTangent(), NoTangent(), ds_dpoints, values(out_pb)[2:3+length(optional_args)]...
end

return out, raster_pullback
end

ChainRulesCore.rrule(
f::typeof(DiffPointRasterisation.raster),
grid_size,
points::AbstractVector{<:AbstractVector{<:Number}},
rotation::AbstractMatrix{<:Number},
translation::AbstractVector{<:Number},
optional_args...
) = ChainRulesCore.rrule(
f,
grid_size,
DiffPointRasterisation.inner_to_sized(points),
rotation,
translation,
optional_args...
)

# batch of images
function ChainRulesCore.rrule(
::typeof(raster_project),
::typeof(DiffPointRasterisation.raster),
grid_size,
args...,
)
out = raster_project(grid_size, args...)
points::AbstractVector{<:StaticVector{N_in, TP}},
rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
translation::AbstractVector{<:StaticVector{N_out, TT}},
optional_args...
) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number}
out = raster(grid_size, points, rotation, translation, optional_args...)

raster_project_pullback(ds_dout) = NoTangent(), NoTangent(), values(
raster_project_pullback!(
function raster_pullback(ds_dout)
out_pb = raster_pullback!(
unthunk(ds_dout),
args...,
points,
rotation,
translation,
optional_args...,
)
)[1:length(args)]...
ds_dpoints = reinterpret(reshape, SVector{N_in, TP}, out_pb.points)
L = N_out * N_in
ds_drotation = reinterpret(reshape, SMatrix{N_out, N_in, TR, L}, reshape(out_pb.rotation, L, :))
ds_dtranslation = reinterpret(reshape, SVector{N_out, TT}, out_pb.translation)
return NoTangent(), NoTangent(), ds_dpoints, ds_drotation, ds_dtranslation, values(out_pb)[4:3+length(optional_args)]...
end

return out, raster_project_pullback
return out, raster_pullback
end

ChainRulesCore.rrule(
f::typeof(DiffPointRasterisation.raster),
grid_size,
points::AbstractVector{<:AbstractVector{<:Number}},
rotation::AbstractVector{<:AbstractMatrix{<:Number}},
translation::AbstractVector{<:AbstractVector{<:Number}},
optional_args...
) = ChainRulesCore.rrule(
f,
grid_size,
DiffPointRasterisation.inner_to_sized(points),
DiffPointRasterisation.inner_to_sized(rotation),
DiffPointRasterisation.inner_to_sized(translation),
optional_args...
)

end # module DiffPointRasterisationChainRulesCoreExt
Binary file modified logo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion src/DiffPointRasterisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ using TestItems
include("util.jl")
include("raster.jl")
include("raster_pullback.jl")
include("interface.jl")

export raster, raster!, raster_project, raster_project!, raster_pullback!, raster_project_pullback!
export raster, raster!, raster_pullback!

end
Loading

0 comments on commit 2acceca

Please sign in to comment.