Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lppool implementation #447

Merged
merged 12 commits into from
Jan 7, 2023
3 changes: 2 additions & 1 deletion docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ logsoftmax

## Pooling

`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, and `MeanPool` use `NNlib.PoolDims`, `NNlib.maxpool`, and `NNlib.meanpool` as their backend.
`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `LPPool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lppool` as their backend.

```@docs
PoolDims
maxpool
meanpool
lppool
```

## Padding
Expand Down
4 changes: 2 additions & 2 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ include("ctc.jl")
export ctc_loss

include("pooling.jl")
export maxpool, maxpool!, meanpool, meanpool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!
export maxpool, maxpool!, meanpool, meanpool!, lppool, lppool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lppool, ∇lppool!

include("padding.jl")
export pad_constant, pad_repeat, pad_reflect, pad_zeros
Expand Down
38 changes: 32 additions & 6 deletions src/impl/pooling_direct.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing
# the inner loop operation and a few initialization parameters.
for name in (:max, :mean)
for name in (:max, :mean, :lp)
@eval function $((Symbol("$(name)pool_direct!")))(
y::AbstractArray{T, 5}, x::AbstractArray{T, 5},
pdims::PoolDims; alpha::T=T(1), beta::T=T(0)) where T
pdims::PoolDims; alpha::T=T(1), beta::T=T(0), kwargs...) where T
$((Symbol("$(name)pool_direct!")))(
y, x, pdims,
Val(kernel_size(pdims)), Val(channels_out(pdims)),
Val(padding(pdims)), Val(dilation(pdims)), Val(stride(pdims));
alpha, beta)
alpha, beta, kwargs...)
return y
end

Expand All @@ -17,7 +17,7 @@ for name in (:max, :mean)
pdims::PoolDims,
# kernel size, channels out, padding, dilation, stride
::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S};
alpha::T=T(1), beta::T=T(0),
alpha::T=T(1), beta::T=T(0), kwargs...
) where {T, K, C, P, D, S}
@assert beta == T(0) "beta not supported yet"
check_dims(size(x), size(y), pdims)
Expand All @@ -41,10 +41,15 @@ for name in (:max, :mean)
alpha = alpha / prod(K)
end

p = if $(name != :lp) 0 else
!haskey(kwargs, :p) && error("lppool needs keyword argument `p`")
kwargs[:p]
end

# Each loop, we initialize `m` to something, set that here.
m_init = if $(name == :max)
T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T)
elseif $(name == :mean)
elseif $(name == :mean) || $(name == :lp)
T(0)
else
error("Unimplemented codegen path")
Expand Down Expand Up @@ -78,11 +83,17 @@ for name in (:max, :mean)
end
elseif $(name == :mean)
m += x[input_kw, input_kh, input_kd, c, batch_idx]
elseif $(name == :lp)
# y = (∑ x^p)^(1/p), here to calculate (∑ x^p)
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
else
error("Unimplemented codegen path")
end
end

# for lppool, y = (∑ x^p)^(1/p)
m = $(name == :lp) ? m^(T(1) / p) : m

y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
end
end
Expand Down Expand Up @@ -128,12 +139,15 @@ for name in (:max, :mean)
end
elseif $(name == :mean)
m += x[input_kw, input_kh, input_kd, c, batch_idx]
elseif $(name == :lp)
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
else
error("Unimplemented codegen path")
end
end
end
end
$(name == :lp) && (m = m^(T(1) / p))
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
end
end
Expand All @@ -159,7 +173,7 @@ for name in (:max, :mean)
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
y::AbstractArray{T,5}, x::AbstractArray{T,5},
pdims::PoolDims, ::Val{K}; # == kernel_size(pdims)
alpha::T=T(1), beta::T=T(0)) where {T, K}
alpha::T=T(1), beta::T=T(0), kwargs...) where {T, K}
check_dims(size(x), size(dy), pdims)

width, height, depth = input_size(pdims)
Expand All @@ -182,6 +196,11 @@ for name in (:max, :mean)
alpha = alpha / prod(K)
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
end

p = if $(name != :lp) 0 else
!haskey(kwargs, :p) && error("lppool must pass p")
kwargs[:p]
end

# Start with the central region
w_region, h_region, d_region = central_region
@inbounds for batch_idx in 1:size(x, 5), c in 1:out_c
Expand Down Expand Up @@ -226,6 +245,10 @@ for name in (:max, :mean)
elseif $(name == :mean)
# Either does meanpool :(
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha
elseif $(name == :lp)
# y = (∑ x^p)^(1/p), ∂y/∂x = x^(p-1) × y^(1-p)
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
else
error("Unimplemented codegen path")
end
Expand Down Expand Up @@ -286,6 +309,9 @@ for name in (:max, :mean)
end
elseif $(name == :mean)
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...]
elseif $(name == :lp)
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
else
error("Unimplemented codegen path")
end
Expand Down
74 changes: 69 additions & 5 deletions src/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
# - maxpool!(y, x, pdims)
# - meanpool(x, pdims)
# - meanpool!(y, x, pdims)
# - lppool(x, pdims)
# - lppool!(y, x, pdims)
# - Pooling input backprop
# - ∇maxpool(dy, y, x, pdims)
# - ∇maxpool!(dx, dy, y, x, pdims)
# - ∇meanpool(dy, y, x, pdims)
# - ∇meanpool!(dx, dy, y, x pdims)
# - ∇lppool(dy, y, x, pdims)
# - ∇lppool!(dx, dy, y, x pdims)
#
# All methods require a `PoolDims` object to define the dimensions and optional
# elements of the convolution (stride, dilation, etc...), which is easily constructable
Expand All @@ -26,6 +30,7 @@ for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:maxpool => :direct,
:meanpool => :direct,
:lppool => :direct,
)

# We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling
Expand All @@ -42,6 +47,7 @@ end
for (front_name, backend) in (
:∇maxpool => :direct,
:∇meanpool => :direct,
:∇lppool => :direct,
)
@eval begin
function $(Symbol("$(front_name)!"))(
Expand All @@ -57,7 +63,7 @@ end
# Our strategy for pooling is to reshape to an array with three spatial dimensions, which
# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,
# since we can specialize on sizes.
for front_name in (:maxpool, :meanpool)
for front_name in (:maxpool, :meanpool, :lppool)
for backend in (Symbol(), :_direct)
for N in (3, 4)
@eval begin
Expand Down Expand Up @@ -103,7 +109,7 @@ end
# Finally, let's generate auto-allocating versions of all our functions, for all backends:
for backend in (Symbol(), :_direct, :_nnpack)
# First make auto-allocating versions of the basic pooling calls:
for name in (:maxpool, :meanpool)
for name in (:maxpool, :meanpool, :lppool)
@eval begin
function $(Symbol("$(name)$(backend)"))(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that this similar means that lower-precision input gives the correct output without promotion, and that integer input fails here (unlike other pooling):

julia> NNlib.lpnormpool(ones(Float32, 4,1,1), 2.001, (2,), stride=1)
3×1×1 Array{Float32, 3}:
[:, :, 1] =
 1.4139687
 1.4139687
 1.4139687

julia> NNlib.lpnormpool(ones(Int, 4,1,1), 2.001, (2,), stride=1)
ERROR: InexactError: Int64(1.4139686415190424)

julia> NNlib.maxpool(ones(Int, 4,1,1), (2,), stride=1)  # is OK with integers
3×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 1
 1

julia> NNlib.meanpool(ones(Int, 4,1,1), (2,), stride=1)  # is OK with integers
3×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 1
 1

x::AbstractArray{xT,N},
Expand Down Expand Up @@ -141,11 +147,19 @@ expand(N, i::Integer) = ntuple(_ -> i, N)


"""
maxpool(x, k::NTuple; pad=0, stride=k)
maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k)

Perform max pool operation with window size `k` on input tensor `x`.

* `x` and `k`: Usually, ndim(x) ∈ [3, 5], length(k) ∈ [1, 3], s.t. ndim(x) == length(k) + 2
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
* `pad`: See [`pad_zeros`](@ref) for details.
* `stride`: Stride for each spatial axis. `k` as default if not present.
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
"""
function maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N
ndims(x) == length(k) + 2 || error("maxpool expects ndims(x) == length(k)+2,
dimension of x is $(ndims(x)),
length of k need $(ndims(x) - 2),
but now it's $(length(k))")
pad = expand(Val(N), pad)
stride = expand(Val(N), stride)
pdims = PoolDims(x, k; padding=pad, stride=stride)
Expand All @@ -154,19 +168,69 @@ end


"""
meanpool(x, k::NTuple; pad=0, stride=k)
meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k)

Perform mean pool operation with window size `k` on input tensor `x`.

* `x` and `k`: Usually, ndim(x) ∈ [3, 5], length(k) ∈ [1, 3], s.t. ndim(x) == length(k) + 2
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
* `pad`: See [`pad_zeros`](@ref) for details.
* `stride`: Stride for each spatial axis. `k` as default if not present.
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
"""
function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N
ndims(x) == length(k) + 2 || error("meanpool expects ndims(x) == length(k)+2,
dimension of x is $(ndims(x)),
length of k need $(ndims(x) - 2),
but now it's $(length(k))")
pad = expand(Val(N), pad)
stride = expand(Val(N), stride)
pdims = PoolDims(x, k; padding=pad, stride=stride)
return meanpool(x, pdims)
end


for pool in [:maxpool, :meanpool]
"""
lppool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k)

Perform Lp pool operation with value of the Lp norm `p` and `window size `k` on input tensor `x`.

* `x` and `k`: Usually, ndim(x) ∈ [3, 5], length(k) ∈ [1, 3], s.t. ndim(x) == length(k) + 2
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
* `pad`: See [`pad_zeros`](@ref) for details.
* `stride`: Stride for each spatial axis. `k` as default if not present.
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved

For each element `x` in (k × k) window, lppool computes `(∑ x^p)^(1 / p)` as output.

* When p = 1, lppool(x, p, k) ./ prod(k) ≈ meanpool(x, k)
* When p = 2, lppool(x, p, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)

!!! warning

Theoretically, when `p -> ∞`, lppool(x, p, k) ≈ maxpool(x, k).
But it's not correct in julia. Given an arbitrary Number `n`,
```jldoctest
julia> n = 10
10

julia> ans^Inf
Inf

julia> ans^(1/Inf)
1.0
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
```
Please use `meanpool` and `maxpool` directly when needed.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain a bit more?

  • It should say that it requires ndims(x) == length(k)+2.

  • I think it should also say lppool(x, 1, k) ./ prod(k) ≈ meanpool(x, k), and maybe lppool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k).

  • What range of p is allowed? E.g. someone might expect lppool(reshape(1:10.0,:,1,1), Inf, (3,)) to be maxpool, but it isn't... maybe anything which works for norm but not here should be an error?

  • Can it briefly say what types are allowed for stride, pad, or refer to fuller docs elsewhere?

Copy link
Member

@ToucheSir ToucheSir Dec 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PyTorch layer docstring would be a good inspiration for this.

Copy link
Contributor Author

@skyleaworlder skyleaworlder Dec 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain a bit more?

  • It should say that it requires ndims(x) == length(k)+2.
  • I think it should also say lppool(x, 1, k) ./ prod(k) ≈ meanpool(x, k), and maybe lppool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k).
  • What range of p is allowed? E.g. someone might expect lppool(reshape(1:10.0,:,1,1), Inf, (3,)) to be maxpool, but it isn't... maybe anything which works for norm but not here should be an error?
  • Can it briefly say what types are allowed for stride, pad, or refer to fuller docs elsewhere?

I add docs for these aspects in new commit.

The PyTorch layer docstring would be a good inspiration for this.

I wonder if it might be suitable to write super detailed docs here. Flux.jl would wrap pooling layer after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our docstrings for meanpool and maxpool are quite short. Given that LP-pooling is far less well known however, it warrants a bit more explanation.

function lppool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k) where N
ndims(x) == length(k) + 2 || error("lppool expects ndims(x) == length(k)+2,
dimension of x is $(ndims(x)),
length of k need $(ndims(x) - 2),
but now it's $(length(k))")
pad = expand(Val(N), pad)
stride = expand(Val(N), stride)
pdims = PoolDims(x, k; padding=pad, stride=stride)
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
return lppool(x, pdims; p=p)
end


for pool in [:maxpool, :meanpool, :lppool]
∇pool = Symbol(:∇, pool)
pullback = Symbol(pool, :_pullback)
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Expand Down
1 change: 1 addition & 0 deletions test/perf/perf_report.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ for rank in (2,),
for (pool, ∇pool, name) in (
(NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"),
(NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
(NNlib.lppool!, NNlib.∇lppool!, "lppool"),
)

t_fwd = @benchmark $(pool)( $y, $x, $pdims)
Expand Down
Loading