Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce trilinear upsampling #315

Merged
merged 3 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 153 additions & 1 deletion src/upsample.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export upsample_nearest, ∇upsample_nearest,
upsample_bilinear, ∇upsample_bilinear,
upsample_trilinear, ∇upsample_trilinear,
pixel_shuffle

"""
Expand All @@ -9,7 +10,7 @@ export upsample_nearest, ∇upsample_nearest,
Upsamples the array `x` by integer multiples along the first `S` dimensions.
Subsequent dimensions of `x` are not altered.

Either the `scale` factors or the final output `size` can be specified.
Either the `scale` factors or the final output `size` can be specified.

See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array.

Expand Down Expand Up @@ -257,6 +258,157 @@ function rrule(::typeof(upsample_bilinear), x; size)
return Ω, upsample_bilinear_pullback
end

###########
# trilinear
###########
"""
upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real})
upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer})

Upsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`,
using trilinear interpolation. As an alternative to using `scale`, the resulting image `size`
can be directly specified with a keyword argument.

The size of the output is equal to
`(scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)`, where `S1, S2, S3, S4, S5 = size(x)`.

# Examples

```julia
upsample_trilinear(x, (2, 3, 4))
upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead
upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed
```
"""
function upsample_trilinear(x::AbstractArray{<:Any,5}, scale::NTuple{3,Real})
outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 3)
return upsample_trilinear(x; size=outsize)
end

upsample_trilinear(x, scale::Real) = upsample_trilinear(x, (scale,scale,scale))

function upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
w,h,d,c,n = Base.size(x)
if (w,h,d) == size
return x
end
y = similar(x, T, size..., c, n)
return upsample_trilinear_whdcn!(y, x)
end

function upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) where T<:Integer
y = float.(x)
res = upsample_trilinear(y; size=size)
return round.(T, res)
end

function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, in_d, channels, batches = size(input)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, out_d, _, _ = size(output)
output_slice_size = out_h * out_w * out_d

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))
depth_scale = T((in_d - 1) // (out_d - 1))

@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for od in 0:out_d-1
id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda(depth_scale, od, in_d, out_d)
for oh in 0:out_h-1
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1
output[output_offset] =
d0lambda * h0lambda * w0lambda * input[idx(c, id0, ih0, iw0)] + # d0 * h0 * w0 * i000
d0lambda * h0lambda * w1lambda * input[idx(c, id0, ih0, iw1)] + # d0 * h0 * w1 * i001
d0lambda * h1lambda * w0lambda * input[idx(c, id0, ih1, iw0)] + # d0 * h1 * w0 * i010
d0lambda * h1lambda * w1lambda * input[idx(c, id0, ih1, iw1)] + # d0 * h1 * w1 * i011
d1lambda * h0lambda * w0lambda * input[idx(c, id1, ih0, iw0)] + # d1 * h0 * w0 * i100
d1lambda * h0lambda * w1lambda * input[idx(c, id1, ih0, iw1)] + # d1 * h0 * w1 * i101
d1lambda * h1lambda * w0lambda * input[idx(c, id1, ih1, iw0)] + # d1 * h1 * w0 * i110
d1lambda * h1lambda * w1lambda * input[idx(c, id1, ih1, iw1)] # d1 * h1 * w1 * i111
end
end
end
end
return output
end

"""
∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T

# Arguments
- `Δ`: Incoming gradient array, backpropagated from downstream layers
- `size`: Lateral size & depth (W,H,D) of the image upsampled in the first place

# Outputs
- `dx`: Downsampled version of `Δ`
"""
function ∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
w, h, d, c, n = Base.size(Δ)
out_w, out_h, out_d = size
if (w,h,d) == (out_w, out_h, out_d)
return Δ
end
dx = zero(similar(Δ, T, size..., c, n))
return ∇upsample_trilinear_whdcn!(dx, Δ)
end

function ∇upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))")
in_w, in_h, in_d, channels, batches = size(dx)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, out_d, _, _ = size(Δ)
output_slice_size = out_h * out_w * out_d

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))
depth_scale = T((in_d - 1) // (out_d - 1))

@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for od in 0:out_d-1
id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda(depth_scale, od, in_d, out_d)
for oh in 0:out_h-1
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1
Δ_value = Δ[output_offset]
dx[idx(c, id0, ih0, iw0)] += d0lambda * h0lambda * w0lambda * Δ_value # /* i000 */
dx[idx(c, id0, ih0, iw1)] += d0lambda * h0lambda * w1lambda * Δ_value # /* i001 */
dx[idx(c, id0, ih1, iw0)] += d0lambda * h1lambda * w0lambda * Δ_value # /* i010 */
dx[idx(c, id0, ih1, iw1)] += d0lambda * h1lambda * w1lambda * Δ_value # /* i011 */
dx[idx(c, id1, ih0, iw0)] += d1lambda * h0lambda * w0lambda * Δ_value # /* i100 */
dx[idx(c, id1, ih0, iw1)] += d1lambda * h0lambda * w1lambda * Δ_value # /* i101 */
dx[idx(c, id1, ih1, iw0)] += d1lambda * h1lambda * w0lambda * Δ_value # /* i110 */
dx[idx(c, id1, ih1, iw1)] += d1lambda * h1lambda * w1lambda * Δ_value # /* i111 */
end
end
end
end
return dx
end

function rrule(::typeof(upsample_trilinear), x; size)
Ω = upsample_trilinear(x; size=size)
function upsample_trilinear_pullback(Δ)
(NO_FIELDS, ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
end
return Ω, upsample_trilinear_pullback
end


"""
pixel_shuffle(x, r::Integer)

Expand Down
33 changes: 32 additions & 1 deletion test/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
y = upsample_nearest(x, (2,3))
@test size(y) == (4,6,1,1)
∇upsample_nearest(y, (2,3)) == [6 12; 18 24]

gradtest(x -> upsample_nearest(x, (2,3)), rand(2,2,1,1))

y2 = upsample_nearest(x, size=(4,6))
Expand Down Expand Up @@ -65,6 +65,37 @@ end
@test y == y_true_int
end

@testset "Trilinear upsampling" begin
# Layout: WHDCN, where D is depth
# we generate data which is constant along W & H and differs in D
# then we upsample along all dimensions
x = ones(Float32, 3,3,3,1,1)
x[:,:,1,:,:] .= 1.
x[:,:,2,:,:] .= 2.
x[:,:,3,:,:] .= 3.

y_true = ones(Float32, 5,5,5,1,1)
y_true[:,:,1,:,:] .= 1.
y_true[:,:,2,:,:] .= 1.5
y_true[:,:,3,:,:] .= 2.
y_true[:,:,4,:,:] .= 2.5
y_true[:,:,5,:,:] .= 3.

y = upsample_trilinear(x; size=(5,5,5))

@test size(y) == size(y_true)
@test eltype(y) == Float32
@test collect(y) ≈ collect(y_true)

# this test only works when align_corners=false (not present for CPU yet)
# o = ones(Float32,8,8,8,1,1)
# grad_true = 8*ones(Float32,4,4,4,1,1)
# @test ∇upsample_trilinear(o; size=(4,4,4)) ≈ grad_true

x = Float64.(x)
gradtest(x -> upsample_trilinear(x, (2,2,2)), x)
end

@testset "pixel_shuffle" begin
x = reshape(1:16, (2, 2, 4, 1))
# [:, :, 1, 1] =
Expand Down