Skip to content

Commit

Permalink
move part of getindex computation to compilation time
Browse files Browse the repository at this point in the history
Also add `laplacian` and `laplacian!`
  • Loading branch information
johnnychen94 committed Oct 12, 2021
1 parent e72cb9a commit 39058e6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 90 deletions.
2 changes: 2 additions & 0 deletions src/ImageBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ export
fdiff!,
fdiv,
fdiv!,
flaplacian,
flaplacian!,
DiffView,

# basic image statistics, from Images.jl
Expand Down
159 changes: 88 additions & 71 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,92 +3,88 @@ struct Periodic <: BoundaryCondition end
struct ZeroFill <: BoundaryCondition end

"""
DiffView(A::AbstractArray, [rev=Val(false)], [bc::BoundaryCondition=Periodic()]; dims)
DiffView(A::AbstractArray, dims::Val{D}, [bc::BoundaryCondition=Periodic()], [rev=Val(false)])
Lazy version of finite difference [`fdiff`](@ref).
!!! tip
For performance, `rev` should be stable type `Val(false)` or `Val(true)`.
For performance, both `dims` and `rev` require `Val` types.
# Arguments
- `dims::Val{D}`
Specify the dimension D that dinite difference is applied to.
- `rev::Bool`
If `rev==Val(true)`, then it computes the backward difference
`(A[end]-A[1], A[1]-A[2], ..., A[end-1]-A[end])`.
- `boundary::BoundaryCondition`
By default it computes periodically in the boundary, i.e., `Periodic()`.
In some cases, one can fill zero values with `ZeroFill()`.
"""
struct DiffView{T,N,AT<:AbstractArray,BC,REV} <: AbstractArray{T,N}
struct DiffView{T,N,D,BC,REV,AT<:AbstractArray} <: AbstractArray{T,N}
data::AT
dims::Int
end
function DiffView(
data::AbstractArray{T,N},
::Val{D},
bc::BoundaryCondition=Periodic(),
rev::Union{Val, Bool}=Val(false);
dims=_fdiff_default_dims(data)) where {T,N}
isnothing(dims) && throw(UndefKeywordError(:dims))
rev = to_static_bool(rev)
DiffView{maybe_floattype(T),N,typeof(data),typeof(bc),typeof(rev)}(data, dims)
end
function DiffView(
data::AbstractArray,
rev::Union{Val, Bool},
bc::BoundaryCondition = Periodic();
kwargs...)
DiffView(data, bc, rev; kwargs...)
end

to_static_bool(x::Union{Val{true},Val{false}}) = x
function to_static_bool(x::Bool)
@warn "Please use `Val($x)` for performance"
return Val(x)
rev::Val = Val(false)
) where {T,N,D}
DiffView{maybe_floattype(T),N,D,typeof(bc),typeof(rev),typeof(data)}(data)
end

Base.size(A::DiffView) = size(A.data)
Base.axes(A::DiffView) = axes(A.data)
Base.IndexStyle(::DiffView) = IndexCartesian()

Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,Periodic,Val{true}}, I::Vararg{Int, N}) where {T,N,AT}
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,Periodic,Val{true}}, I::Vararg{Int, N}) where {T,N,D}
data = A.data
I_prev = map(ntuple(identity, N), I, axes(data)) do i, p, r
i == A.dims || return p
p == first(r) && return last(r)
p - 1
end
r = axes(data, D)
x = I[D]
x_prev = first(r) == x ? last(r) : x - 1
I_prev = update_tuple(I, x_prev, Val(D))
return convert(T, data[I...]) - convert(T, data[I_prev...])
end
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,Periodic,Val{false}}, I::Vararg{Int, N}) where {T,N,AT}
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,Periodic,Val{false}}, I::Vararg{Int, N}) where {T,N,D}
data = A.data
I_next = map(ntuple(identity, N), I, axes(data)) do i, p, r
i == A.dims || return p
p == last(r) && return first(r)
p + 1
end
r = axes(data, D)
x = I[D]
x_next = last(r) == x ? first(r) : x + 1
I_next = update_tuple(I, x_next, Val(D))
return convert(T, data[I_next...]) - convert(T, data[I...])
end
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,ZeroFill,Val{false}}, I::Vararg{Int, N}) where {T,N,AT}
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,ZeroFill,Val{false}}, I::Vararg{Int, N}) where {T,N,D}
data = A.data
I_next = I .+ ntuple(i->i==A.dims, N)
if checkbounds(Bool, data, I_next...)
vi = convert(T, data[I...]) # it requires the caller to pass @inbounds
@inbounds convert(T, data[I_next...]) - vi
else
x = I[D]
if last(axes(data, D)) == x
zero(T)
else
I_next = update_tuple(I, x+1, Val(D))
convert(T, data[I_next...]) - convert(T, data[I...])
end
end
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,AT,ZeroFill,Val{true}}, I::Vararg{Int, N}) where {T,N,AT}
Base.@propagate_inbounds function Base.getindex(A::DiffView{T,N,D,ZeroFill,Val{true}}, I::Vararg{Int, N}) where {T,N,D}
data = A.data
I_prev = I .- ntuple(i->i==A.dims, N)
if checkbounds(Bool, data, I_prev...)
vi = convert(T, data[I...]) # it requires the caller to pass @inbounds
@inbounds vi - convert(T, data[I_prev...])
else
x = I[D]
if first(axes(data, D)) == x
zero(T)
else
I_prev = update_tuple(I, x-1, Val(D))
convert(T, data[I...]) - convert(T, data[I_prev...])
end
end

@generated function update_tuple(A::NTuple{N, T}, x::T, ::Val{i}) where {T, N, i}
# This is equivalent to `ntuple(j->j==i ? x : A[j], N)` but is optimized by moving
# the if branches to compilation time.
ex = :()
for j in Base.OneTo(N)
new_x = i == j ? :(x) : :(A[$j])
ex = :($ex..., $new_x)
end
return ex
end

# TODO: add keyword `shrink` to give a consistant result on Base
# when this is done, then we can propose this change to upstream Base
"""
Expand Down Expand Up @@ -201,45 +197,66 @@ maybe_floattype(::Type{CT}) where CT<:Color = base_color_type(CT){maybe_floattyp


"""
fdiv(Vs::AbstractArray...; boundary=:periodic)
fdiv(Vs::AbstractArray...)
Discrete divergence operator for vector field (V₁, V₂, ..., Vₙ).
# Example
Laplacian operator of array `A` is the divergence of its gradient vector field (∂₁A, ∂₂A, ..., ∂ₙA):
```jldoctest
julia> using ImageFiltering, ImageBase
julia> X = Float32.(rand(1:9, 7, 7));
julia> laplacian(X) = fdiv(ntuple(i->DiffView(X, dims=i), ndims(X))...)
laplacian (generic function with 1 method)
julia> laplacian(X) == imfilter(X, Kernel.Laplacian(), "circular")
true
```
See also [`fdiv!`](@ref) for the in-place version.
"""
function fdiv(V₁::AbstractArray, Vs...; kwargs...)
fdiv!(similar(V₁, floattype(eltype(V₁))), V₁, Vs...; kwargs...)
end
fdiv(V₁::AbstractArray, Vs...) = fdiv!(similar(V₁, floattype(eltype(V₁))), V₁, Vs...)

"""
fdiv!(dst::AbstractArray, Vs::AbstractArray...)
The in-place version of [`fdiv`](@ref).
"""
function fdiv!(dst::AbstractArray, Vs::AbstractArray...)
= map(ntuple(identity, length(Vs)), Vs) do n, V
DiffView(V, Val(true), dims=n)
end
# negative adjoint of gradient is equivalent to the reversed finite difference
= fnegative_adjoint_gradient(Vs...)
@inbounds for i in CartesianIndices(dst)
dst[i] = sum(x->_inbound_getindex(x, i), ∇)
dst[i] = heterogeneous_getindex_sum(i, ∇...)
end
return dst
end

@inline _inbound_getindex(x, i) = @inbounds x[i]
@generated function heterogeneous_getindex_sum(i, Vs::Vararg{<:AbstractArray, N}) where N
# This method is equivalent to `sum(V->V[i], Vs)` but is optimized for heterogeneous arrays
ex = :(zero(eltype(Vs[1])))
for j in Base.OneTo(N)
ex = :($ex + Vs[$j][i])
end
return ex
end

"""
flaplacian(X::AbstractArray)
The Laplacian operator ∇² is the divergence of the gradient operator.
"""
flaplacian(X::AbstractArray) = flaplacian!(similar(X, maybe_floattype(eltype(X))), X)

"""
flaplacian!(dst::AbstractArray, X::AbstractArray)
The in-place version of the Laplacian operator [`laplacian`](@ref).
"""
flaplacian!(dst::AbstractArray, X::AbstractArray) = fdiv!(dst, fgradient(X)...)

# These two functions pass dimension information `Val(i)` to DiffView so that
# we can move computations to compilation time.
@generated function fgradient(X::AbstractArray{T, N}) where {T, N}
ex = :()
for i in Base.OneTo(N)
new_x = :(DiffView(X, Val($i), Periodic(), Val(false)))
ex = :($ex..., $new_x)
end
return ex
end
@generated function fnegative_adjoint_gradient(Vs::Vararg{<:AbstractArray, N}) where N
ex = :()
for i in Base.OneTo(N)
new_x = :(DiffView(Vs[$i], Val($i), Periodic(), Val(true)))
ex = :($ex..., $new_x)
end
return ex
end
31 changes: 12 additions & 19 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,46 +109,39 @@
end

@testset "DiffView" begin
A = rand(6)
@test DiffView(A) ===
DiffView(A, Val(false), ImageBase.Periodic()) ===
DiffView(A, ImageBase.Periodic(), Val(false)) ===
@test_logs((:warn, "Please use `Val(false)` for performance"), DiffView(A, false))

for T in generate_test_types([N0f8, Float32], [Gray, RGB])
A = rand(T, 6)
Av = DiffView(A)
@test Av == DiffView(A, ImageBase.Periodic(), Val(false))
Av = DiffView(A, Val(1))
@test Av == DiffView(A, Val(1), ImageBase.Periodic(), Val(false))
@test eltype(Av) == floattype(T)
@test axes(Av) == axes(A)
@test Av == fdiff(A)
@test DiffView(A, Val(true)) == fdiff(A; rev=true)
@test DiffView(A, ImageBase.ZeroFill()) == fdiff(A; boundary=:zero)
@test DiffView(A, ImageBase.ZeroFill(), Val(true)) == fdiff(A; boundary=:zero, rev=true)
@test DiffView(A, Val(1), ImageBase.Periodic(), Val(true)) == fdiff(A; rev=true)
@test DiffView(A, Val(1), ImageBase.ZeroFill()) == fdiff(A; boundary=:zero)
@test DiffView(A, Val(1), ImageBase.ZeroFill(), Val(true)) == fdiff(A; boundary=:zero, rev=true)

A = rand(T, 6, 6)
Av = DiffView(A, dims=1)
Av = DiffView(A, Val(1))
@test eltype(Av) == floattype(T)
@test axes(Av) == axes(A)
@test Av == fdiff(A, dims=1)
@test DiffView(A, Val(true), dims=1) == fdiff(A; dims=1, rev=true)
@test DiffView(A, ImageBase.ZeroFill(), dims=1) == fdiff(A; boundary=:zero, dims=1)
@test DiffView(A, ImageBase.ZeroFill(), Val(true), dims=1) == fdiff(A; boundary=:zero, rev=true, dims=1)
@test DiffView(A, Val(1), ImageBase.Periodic(), Val(true)) == fdiff(A; dims=1, rev=true)
@test DiffView(A, Val(1), ImageBase.ZeroFill()) == fdiff(A; boundary=:zero, dims=1)
@test DiffView(A, Val(1), ImageBase.ZeroFill(), Val(true)) == fdiff(A; boundary=:zero, rev=true, dims=1)
end

A = OffsetArray(rand(6, 6), -1, -1)
Av = DiffView(A, dims=1)
Av = DiffView(A, Val(1))
@test axes(Av) == axes(A)
@test Av == fdiff(A, dims=1)
end

@testset "fdiv" begin
laplacian(X) = fdiv(ntuple(i->DiffView(X, dims=i), ndims(X))...)
@testset "fdiv/flaplacian" begin
ref_laplacian(X) = imfilter(X, Kernel.Laplacian(ntuple(x->true, ndims(X))), "circular")
for T in generate_test_types([N0f8, Float32], [Gray, RGB])
for sz in [(7,), (7, 7), (7, 7, 7)]
A = rand(T, sz...)
@test laplacian(A) ref_laplacian(A)
@test flaplacian(A) ref_laplacian(A)
end
end
end

0 comments on commit 39058e6

Please sign in to comment.