Skip to content

Commit

Permalink
change interface backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
trahflow committed Mar 6, 2024
1 parent 2fe18ed commit e6279d4
Show file tree
Hide file tree
Showing 12 changed files with 450 additions and 250 deletions.
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
DiffPointRasterisation = "f984992d-3c45-4382-99a1-cf20f5c47c61"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
18 changes: 12 additions & 6 deletions examples/logo.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
using DiffPointRasterisation
using FFTW
using Images
using LinearAlgebra
using StaticArrays
using Zygote

load_image(path) = load(path) .|> Gray |> channelview

init_points(n) = (rand(Float32, 2, n) .- 0.5f0) .* Float32[2;2;;]
init_points(n) = [2 * rand(Float32, 2) .- 1f0 for _ in 1:n]

target_image = load_image("data/julia.png")

points = init_points(5_000)
rotation = Float32[1;0;;0;1;;]
rotation = I(2)
translation = zeros(Float32, 2)

function model(points, log_bandwidth, log_weight)
# raster points to 2d-image
rough_image = raster(size(target_image), points, rotation, translation, 0f0, exp(log_weight))
# smooth with gaussian kernel
# smooth image with gaussian kernel
kernel = gaussian_kernel(log_bandwidth, size(target_image)...)
image = convolve_image(rough_image, kernel)
image
Expand All @@ -36,18 +39,20 @@ convolve_image(image, kernel) = irfft(rfft(image) .* rfft(kernel), size(image, 1

function loss(points, log_bandwidth, log_weight)
model_image = model(points, log_bandwidth, log_weight)
sum((model_image .- target_image).^2) + sum(points.^2)
# squared error plus regularization term for points
sum((model_image .- target_image).^2) + sum(stack(points).^2)
end

logrange(s, e, n) = round.(Int, exp.(range(log(s), log(e), n)))

function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=true, update_weight=true, eps_after_init=eps; n_init=n, n_logs_init=15)
# Langevin sampling for points and optionally log_bandwidth and log_weight.
logs_init = logrange(1, n_init, n_logs_init)
log_every = false
logstep = 1
for i in 1:n
l, grads = Zygote.withgradient(loss, points, log_bandwidth, log_weight)
points .+= sqrt(eps) .* randn(Float32, size(points)) .- eps .* 0.5f0 .* grads[1]
points .+= sqrt(eps) .* reinterpret(reshape, SVector{2, Float32}, randn(Float32, 2, length(points))) .- eps .* 0.5f0 .* grads[1]
if update_bandwidth
log_bandwidth += sqrt(eps) * randn(Float32) - eps * 0.5f0 * grads[2]
end
Expand All @@ -57,6 +62,7 @@ function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=t

if i == n_init
log_every = true
eps = eps_after_init
end
if log_every || (i in logs_init)
println("iteration $logstep, $i: loss = $l, bandwidth = $(exp(log_bandwidth)), weight = $(exp(log_weight))")
Expand All @@ -68,4 +74,4 @@ function langevin!(points, log_bandwidth, log_weight, eps, n, update_bandwidth=t
points, log_bandwidth
end

isinteractive() || langevin!(points, log(0.5f0), 0f0, 5f-6, 6_030, false, true, 5f-4; n_init=6_000)
isinteractive() || langevin!(points, log(0.5f0), 0f0, 5f-6, 6_030, false, true, 4f-5; n_init=6_000)
2 changes: 1 addition & 1 deletion ext/DiffPointRasterisationCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ function raster_pullback_kernel!(
end


function DiffPointRasterisation._raster_pullback!(
function DiffPointRasterisation.raster_pullback!(
::Val{N_in},
ds_dout::CuArray{T, N_out_p1},
points::CuMatrix{<:Number},
Expand Down
88 changes: 70 additions & 18 deletions ext/DiffPointRasterisationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,92 @@
module DiffPointRasterisationChainRulesCoreExt

using DiffPointRasterisation, ChainRulesCore
using DiffPointRasterisation, ChainRulesCore, StaticArrays

# single image
function ChainRulesCore.rrule(
::typeof(raster),
::typeof(DiffPointRasterisation.raster),
grid_size,
args...,
)
out = raster(grid_size, args...)
points::AbstractVector{<:StaticVector{N_in, T}},
rotation::AbstractMatrix{<:Number},
translation::AbstractVector{<:Number},
optional_args...
) where {N_in, T<:Number}
out = raster(grid_size, points, rotation, translation, optional_args...)

raster_pullback(ds_dout) = NoTangent(), NoTangent(), values(
raster_pullback!(
function raster_pullback(ds_dout)
out_pb = raster_pullback!(
unthunk(ds_dout),
args...,
points,
rotation,
translation,
optional_args...,
)
)[1:length(args)]...
ds_dpoints = reinterpret(reshape, SVector{N_in, T}, out_pb.points)
return NoTangent(), NoTangent(), ds_dpoints, values(out_pb)[2:3+length(optional_args)]...
end

return out, raster_pullback
end

ChainRulesCore.rrule(
f::typeof(DiffPointRasterisation.raster),
grid_size,
points::AbstractVector{<:AbstractVector{<:Number}},
rotation::AbstractMatrix{<:Number},
translation::AbstractVector{<:Number},
optional_args...
) = ChainRulesCore.rrule(
f,
grid_size,
DiffPointRasterisation.inner_to_sized(points),
rotation,
translation,
optional_args...
)

# batch of images
function ChainRulesCore.rrule(
::typeof(raster_project),
::typeof(DiffPointRasterisation.raster),
grid_size,
args...,
)
out = raster_project(grid_size, args...)
points::AbstractVector{<:StaticVector{N_in, TP}},
rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
translation::AbstractVector{<:StaticVector{N_out, TT}},
optional_args...
) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number}
out = raster(grid_size, points, rotation, translation, optional_args...)

raster_project_pullback(ds_dout) = NoTangent(), NoTangent(), values(
raster_project_pullback!(
function raster_pullback(ds_dout)
out_pb = raster_pullback!(
unthunk(ds_dout),
args...,
points,
rotation,
translation,
optional_args...,
)
)[1:length(args)]...
ds_dpoints = reinterpret(reshape, SVector{N_in, TP}, out_pb.points)
L = N_out * N_in
ds_drotation = reinterpret(reshape, SMatrix{N_out, N_in, TR, L}, reshape(out_pb.rotation, L, :))
ds_dtranslation = reinterpret(reshape, SVector{N_out, TT}, out_pb.translation)
return NoTangent(), NoTangent(), ds_dpoints, ds_drotation, ds_dtranslation, values(out_pb)[4:3+length(optional_args)]...
end

return out, raster_project_pullback
return out, raster_pullback
end

ChainRulesCore.rrule(
f::typeof(DiffPointRasterisation.raster),
grid_size,
points::AbstractVector{<:AbstractVector{<:Number}},
rotation::AbstractVector{<:AbstractMatrix{<:Number}},
translation::AbstractVector{<:AbstractVector{<:Number}},
optional_args...
) = ChainRulesCore.rrule(
f,
grid_size,
DiffPointRasterisation.inner_to_sized(points),
DiffPointRasterisation.inner_to_sized(rotation),
DiffPointRasterisation.inner_to_sized(translation),
optional_args...
)

end # module DiffPointRasterisationChainRulesCoreExt
Binary file modified logo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion src/DiffPointRasterisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ include("raster.jl")
include("raster_pullback.jl")
include("interface.jl")

export raster, raster!, raster_project, raster_project!, raster_pullback!, raster_project_pullback!
export raster, raster!, raster_pullback!

end
176 changes: 174 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ deep_eltype(t::Type{<:AbstractArray}) = deep_eltype(eltype(t))

raster!(
out::AbstractArray{<:Number},
points::AbstractVecOrMat,
points::AbstractVector{<:AbstractVector{<:Number}},
rotation::AbstractMatrix{<:Number},
translation::AbstractVector{<:Number},
background::Number,
Expand All @@ -92,7 +92,7 @@ raster!(
# i.e. vectors of statically sized arrays
###############################################

raster!(out::AbstractArray{<:Number}, args::Vararg{AbstractVector, 5}) = raster!(out, canonical_arg.(args)...)
raster!(out::AbstractArray{<:Number}, args::Vararg{AbstractVector, 5}) = raster!(out, inner_to_sized.(args)...)

###############################################
# Step 5: Error on inconsistent dimensions
Expand Down Expand Up @@ -120,6 +120,178 @@ function raster!(
error("Dispatch error. Should not arrive here. Please file a bug.")
end

# now similar for pullback
"""
raster_pullback!(
ds_dout, args...;
[points, rotation, translation, background, weight]
)
Pullback for `raster(grid_size, args...)`/`raster!(out, args...)`.
Take as input `ds_dout` the sensitivity of some quantity (`s` for "scalar")
to the *output* `out` of the function `out = raster(grid_size, args...)`
(or `out = raster!(out, args...)`), as well as
the exact same arguments `args` that were passed to `raster`/`raster!`, and
return the sensitivities of `s` to the *inputs* `args` of the function
`raster`/`raster!`.
Optionally, pre-allocated output arrays for each input sensitivity can be
specified as keyword arguments with the name of the original argument to
`raster` as key, and a nd-array as value, where the n-th dimension is the
batch dimension.
For example to provide a pre-allocated array for the sensitivity of `s` to
the `translation` argument of `raster`, do:
`sensitivities = raster_pullback!(ds_dout, args...; translation = [zeros(2) for _ in 1:8])`
for 2-dimensional points and a batch size of 8.
"""
function raster_pullback! end


###############################################
# Step 1: Fill default arguments if necessary
###############################################

@inline raster_pullback!(ds_out::AbstractArray{<:Number}, args::Vararg{Any, 3}; kwargs...) = raster_pullback!(ds_out, args..., default_background(args[2]); kwargs...)
@inline raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{Any, 4}; kwargs...) = raster_pullback!(ds_dout, args..., default_weight(args[2]); kwargs...)


###############################################
# Step 2: Convert arguments to canonical form,
# i.e. vectors of statically sized arrays
###############################################

# single image
raster_pullback!(
ds_dout::AbstractArray{<:Number},
points::AbstractVector{<:AbstractVector{<:Number}},
rotation::AbstractMatrix{<:Number},
translation::AbstractVector{<:Number},
background::Number,
weight::Number;
kwargs...
) = raster_pullback!(
ds_dout,
inner_to_sized(points),
to_sized(rotation),
to_sized(translation),
background,
weight;
kwargs...
)

# batch of images
raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{AbstractVector, 5}; kwargs...) = raster_pullback!(ds_dout, inner_to_sized.(args)...; kwargs...)


###############################################
# Step 3: Allocate output
###############################################

# single image
raster_pullback!(
ds_dout::AbstractArray{<:Number, N_out},
inp_points::AbstractVector{<:StaticVector{N_in, T}},
inp_rotation::StaticMatrix{N_out, N_in, <:Number},
inp_translation::StaticVector{N_out, <:Number},
inp_background::Number,
inp_weight::Number;
accumulate_ds_dpoints=false,
points::AbstractMatrix{T} = similar(inp_points, T, (N_in, length(inp_points)))
) where {N_in, N_out, T<:Number} =raster_pullback!(
ds_dout,
inp_points,
inp_rotation,
inp_translation,
inp_background,
inp_weight,
points;
accumulate_ds_dpoints,
)

# batch of images
raster_pullback!(
ds_dout::AbstractArray{<:Number},
inp_points::AbstractVector{<:StaticVector{N_in, TP}},
inp_rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
inp_translation::AbstractVector{<:StaticVector{N_out, TT}},
inp_background::AbstractVector{TB},
inp_weight::AbstractVector{TW};
points::AbstractArray{TP, 3} = similar(inp_points, TP, (N_in, length(inp_points), min(length(inp_rotation), Threads.nthreads()))),
rotation::AbstractArray{TR, 3} = similar(inp_rotation, TR, (N_out, N_in, length(inp_rotation))),
translation::AbstractMatrix{TT} = similar(inp_translation, TT, (N_out, length(inp_translation))),
background::AbstractVector{TB} = similar(inp_background),
weight::AbstractVector{TW} = similar(inp_weight),
) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number, TB<:Number, TW<:Number} = raster_pullback!(
ds_dout,
inp_points,
inp_rotation,
inp_translation,
inp_background,
inp_weight,
points,
rotation,
translation,
background,
weight,
)


###############################################
# Step 4: Error on inconsistent dimensions
###############################################

# single image
raster_pullback!(
::AbstractArray{<:Number, N_out},
::AbstractVector{<:StaticVector{N_in, <:Number}},
::StaticMatrix{N_out_rot, N_in_rot, <:Number},
::StaticVector{N_out_trans, <:Number},
::Number,
::Number,
::AbstractMatrix{<:Number};
kwargs...
) where {N_in, N_out, N_in_rot, N_out_rot, N_out_trans} = error_dimensions(
N_in,
N_out,
N_in_rot,
N_out_rot,
N_out_trans
)

# batch of images
raster_pullback!(
::AbstractArray{<:Number, N_out_p1},
::AbstractVector{<:StaticVector{N_in, <:Number}},
::AbstractVector{<:StaticMatrix{N_out_rot, N_in_rot, <:Number}},
::AbstractVector{<:StaticVector{N_out_trans, <:Number}},
::AbstractVector{<:Number},
::AbstractVector{<:Number},
::AbstractArray{<:Number, 3},
::AbstractArray{<:Number, 3},
::AbstractMatrix{<:Number},
::AbstractVector{<:Number},
::AbstractVector{<:Number},
) where {N_in, N_out_p1, N_in_rot, N_out_rot, N_out_trans} = error_dimensions(
N_in,
N_out_p1 - 1,
N_in_rot,
N_out_rot,
N_out_trans
)

function error_dimensions(N_in, N_out, N_in_rot, N_out_rot, N_out_trans)
if N_out_trans != N_out
error("Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!")
end
if N_out_rot != N_out
error("Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!")
end
if N_in_rot != N_in
error("Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!")
end
error("Dispatch error. Should not arrive here. Please file a bug.")
end

default_background(rotation::AbstractMatrix, T=eltype(rotation)) = zero(T)

Expand Down
Loading

0 comments on commit e6279d4

Please sign in to comment.