Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lppool implementation #447

Merged
merged 12 commits into from
Jan 7, 2023
4 changes: 2 additions & 2 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ include("ctc.jl")
export ctc_loss

include("pooling.jl")
export maxpool, maxpool!, meanpool, meanpool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!
export maxpool, maxpool!, meanpool, meanpool!, lppool, lppool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lppool, ∇lppool!

include("padding.jl")
export pad_constant, pad_repeat, pad_reflect, pad_zeros
Expand Down
40 changes: 33 additions & 7 deletions src/impl/pooling_direct.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing
# the inner loop operation and a few initialization parameters.
for name in (:max, :mean)
for name in (:max, :mean, :lp)
@eval function $((Symbol("$(name)pool_direct!")))(
y::AbstractArray{T, 5}, x::AbstractArray{T, 5},
pdims::PoolDims; alpha::T=T(1), beta::T=T(0)) where T
pdims::PoolDims; alpha::T=T(1), beta::T=T(0), kwargs...) where T
$((Symbol("$(name)pool_direct!")))(
y, x, pdims,
Val(kernel_size(pdims)), Val(channels_out(pdims)),
Val(padding(pdims)), Val(dilation(pdims)), Val(stride(pdims));
alpha, beta)
alpha, beta, kwargs...)
return y
end

Expand All @@ -17,7 +17,7 @@ for name in (:max, :mean)
pdims::PoolDims,
# kernel size, channels out, padding, dilation, stride
::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S};
alpha::T=T(1), beta::T=T(0),
alpha::T=T(1), beta::T=T(0), kwargs...
) where {T, K, C, P, D, S}
@assert beta == T(0) "beta not supported yet"
check_dims(size(x), size(y), pdims)
Expand All @@ -41,10 +41,15 @@ for name in (:max, :mean)
alpha = alpha / prod(K)
end

p = if $(name != :lp) 0 else
!haskey(kwargs, :p) && error("lppool must pass p")
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
kwargs[:p]
end

# Each loop, we initialize `m` to something, set that here.
m_init = if $(name == :max)
T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T)
elseif $(name == :mean)
elseif $(name == :mean) || $(name == :lp)
T(0)
else
error("Unimplemented codegen path")
Expand Down Expand Up @@ -78,11 +83,17 @@ for name in (:max, :mean)
end
elseif $(name == :mean)
m += x[input_kw, input_kh, input_kd, c, batch_idx]
elseif $(name == :lp)
# y = (∑ x^p)^(1/p), here to calculate (∑ x^p)
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
else
error("Unimplemented codegen path")
end
end

# for lppool, y = (∑ x^p)^(1/p)
m = $(name == :lp) ? m^(1 / p) : m
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved

y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
end
end
Expand Down Expand Up @@ -128,12 +139,15 @@ for name in (:max, :mean)
end
elseif $(name == :mean)
m += x[input_kw, input_kh, input_kd, c, batch_idx]
elseif $(name == :lp)
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
else
error("Unimplemented codegen path")
end
end
end
end
$(name == :lp) && (m = m^(1 / p))
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
end
end
Expand All @@ -159,7 +173,7 @@ for name in (:max, :mean)
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
y::AbstractArray{T,5}, x::AbstractArray{T,5},
pdims::PoolDims, ::Val{K}; # == kernel_size(pdims)
alpha::T=T(1), beta::T=T(0)) where {T, K}
alpha::T=T(1), beta::T=T(0), kwargs...) where {T, K}
check_dims(size(x), size(dy), pdims)

width, height, depth = input_size(pdims)
Expand All @@ -178,10 +192,15 @@ for name in (:max, :mean)

# If we're doing mean pooling, we represent division by kernel size by rolling
# it into the `alpha` multiplier.
if $(name == :mean)
if $(name == :mean) || $(name == :lp)
alpha = alpha / prod(K)
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
end

p = if $(name != :lp) 0 else
!haskey(kwargs, :p) && error("lppool must pass p")
kwargs[:p]
end

# Start with the central region
w_region, h_region, d_region = central_region
@inbounds for batch_idx in 1:size(x, 5), c in 1:out_c
Expand Down Expand Up @@ -226,6 +245,10 @@ for name in (:max, :mean)
elseif $(name == :mean)
# Either does meanpool :(
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha
elseif $(name == :lp)
# y = (∑ x^p)^(1/p), ∂y/∂x = x^(p-1) × y^(1-p)
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
else
error("Unimplemented codegen path")
end
Expand Down Expand Up @@ -286,6 +309,9 @@ for name in (:max, :mean)
end
elseif $(name == :mean)
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...]
elseif $(name == :lp)
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
else
error("Unimplemented codegen path")
end
Expand Down
25 changes: 22 additions & 3 deletions src/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
# - maxpool!(y, x, pdims)
# - meanpool(x, pdims)
# - meanpool!(y, x, pdims)
# - lppool(x, pdims)
# - lppool!(y, x, pdims)
# - Pooling input backprop
# - ∇maxpool(dy, y, x, pdims)
# - ∇maxpool!(dx, dy, y, x, pdims)
# - ∇meanpool(dy, y, x, pdims)
# - ∇meanpool!(dx, dy, y, x pdims)
# - ∇lppool(dy, y, x, pdims)
# - ∇lppool!(dx, dy, y, x pdims)
#
# All methods require a `PoolDims` object to define the dimensions and optional
# elements of the convolution (stride, dilation, etc...), which is easily constructable
Expand All @@ -26,6 +30,7 @@ for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:maxpool => :direct,
:meanpool => :direct,
:lppool => :direct,
)

# We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling
Expand All @@ -42,6 +47,7 @@ end
for (front_name, backend) in (
:∇maxpool => :direct,
:∇meanpool => :direct,
:∇lppool => :direct,
)
@eval begin
function $(Symbol("$(front_name)!"))(
Expand All @@ -57,7 +63,7 @@ end
# Our strategy for pooling is to reshape to an array with three spatial dimensions, which
# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,
# since we can specialize on sizes.
for front_name in (:maxpool, :meanpool)
for front_name in (:maxpool, :meanpool, :lppool)
for backend in (Symbol(), :_direct)
for N in (3, 4)
@eval begin
Expand Down Expand Up @@ -103,7 +109,7 @@ end
# Finally, let's generate auto-allocating versions of all our functions, for all backends:
for backend in (Symbol(), :_direct, :_nnpack)
# First make auto-allocating versions of the basic pooling calls:
for name in (:maxpool, :meanpool)
for name in (:maxpool, :meanpool, :lppool)
@eval begin
function $(Symbol("$(name)$(backend)"))(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that this similar means that lower-precision input gives the correct output without promotion, and that integer input fails here (unlike other pooling):

julia> NNlib.lpnormpool(ones(Float32, 4,1,1), 2.001, (2,), stride=1)
3×1×1 Array{Float32, 3}:
[:, :, 1] =
 1.4139687
 1.4139687
 1.4139687

julia> NNlib.lpnormpool(ones(Int, 4,1,1), 2.001, (2,), stride=1)
ERROR: InexactError: Int64(1.4139686415190424)

julia> NNlib.maxpool(ones(Int, 4,1,1), (2,), stride=1)  # is OK with integers
3×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 1
 1

julia> NNlib.meanpool(ones(Int, 4,1,1), (2,), stride=1)  # is OK with integers
3×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 1
 1

x::AbstractArray{xT,N},
Expand Down Expand Up @@ -166,7 +172,20 @@ function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N
end


for pool in [:maxpool, :meanpool]
"""
lppool(x, p::Number, k::NTuple; pad=0, stride=k)

Perform Lp pool operation with `p`-norm and `window size `k` on input tensor `x`
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain a bit more?

  • It should say that it requires ndims(x) == length(k)+2.

  • I think it should also say lppool(x, 1, k) ./ prod(k) ≈ meanpool(x, k), and maybe lppool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k).

  • What range of p is allowed? E.g. someone might expect lppool(reshape(1:10.0,:,1,1), Inf, (3,)) to be maxpool, but it isn't... maybe anything which works for norm but not here should be an error?

  • Can it briefly say what types are allowed for stride, pad, or refer to fuller docs elsewhere?

Copy link
Member

@ToucheSir ToucheSir Dec 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PyTorch layer docstring would be a good inspiration for this.

Copy link
Contributor Author

@skyleaworlder skyleaworlder Dec 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain a bit more?

  • It should say that it requires ndims(x) == length(k)+2.
  • I think it should also say lppool(x, 1, k) ./ prod(k) ≈ meanpool(x, k), and maybe lppool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k).
  • What range of p is allowed? E.g. someone might expect lppool(reshape(1:10.0,:,1,1), Inf, (3,)) to be maxpool, but it isn't... maybe anything which works for norm but not here should be an error?
  • Can it briefly say what types are allowed for stride, pad, or refer to fuller docs elsewhere?

I add docs for these aspects in new commit.

The PyTorch layer docstring would be a good inspiration for this.

I wonder if it might be suitable to write super detailed docs here. Flux.jl would wrap pooling layer after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our docstrings for meanpool and maxpool are quite short. Given that LP-pooling is far less well known however, it warrants a bit more explanation.

function lppool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k) where N
pad = expand(Val(N), pad)
stride = expand(Val(N), stride)
pdims = PoolDims(x, k; padding=pad, stride=stride)
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
return lppool(x, pdims; p=p)
end


for pool in [:maxpool, :meanpool, :lppool]
∇pool = Symbol(:∇, pool)
pullback = Symbol(pool, :_pullback)
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Expand Down
1 change: 1 addition & 0 deletions test/perf/perf_report.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ for rank in (2,),
for (pool, ∇pool, name) in (
(NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"),
(NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
(NNlib.lppool!, NNlib.∇lppool!, "lppool"),
)

t_fwd = @benchmark $(pool)( $y, $x, $pdims)
Expand Down
106 changes: 106 additions & 0 deletions test/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,70 @@ meanpool_answer_dict = Dict(
)
)

lppool_answer_dict = Dict(
skyleaworlder marked this conversation as resolved.
Show resolved Hide resolved
1 => Dict(
"y" => [2.019312856150994, 4.221163518110637],
"y_nostride" => [
2.080083823051904, 3.2710663101885897,
4.497941445275415, 5.738793548317167
],
"y_pad" => [1.0, 3.605551275463989, 6.4031242374328485],
"dx" => [
0.17258020254042603, 1.9525221042381296,
1.2774501198988355, 3.496467771732918, 0.0
],
"dx_nostride" => [
0.48074985676913606, 3.1458422620080637,
4.752311710531486, 6.345225258061685, 4.356316321455918
],
"dx_pad" => [1.0, 2.0, 3.0, 4.0, 5.0],
"p" => 4.5,
"p_nostride" => 3.0,
"p_pad" => 2.0
),
2 => Dict(
"y" => [
8.71909 24.9703;
11.7336 28.3804
],
"y_nostride" => [
11.1128 23.134 35.5704;
13.4219 25.6082 38.0707;
15.8033 28.0907 40.5735;
18.2249 30.5795 43.0782
],
"y_pad" => [
1.0 11.3616 16.0;
3.19158 15.9662 21.3545;
5.56869 18.7771 23.7903
],
"dx" => [
0.33866 4.97727 7.30092 12.8076;
0.957876 6.27208 8.31879 14.0269;
1.51693 6.6057 8.79844 14.3351;
2.33547 7.8822 9.83293 15.5461;
0.0 0.0 0.0 0.0
],
"dx_nostride" => [
3.33359 19.9471 35.7329 23.8564;
9.89551 44.627 76.2257 50.0307;
13.231 50.9101 82.5686 53.2022;
16.4888 57.223 88.9133 56.3742;
9.54591 30.9869 46.8371 29.3524
],
"dx_pad" => [
1.0 2.30261 10.4791 16.0;
0.992125 2.0321 7.81903 12.075;
2.73398 2.83743 9.5512 13.9299;
2.43512 2.98652 9.0132 13.5608;
4.25398 3.8865 10.7099 15.4161
],
"p" => 2.5,
"p_nostride" => 1.5,
"p_pad" => 3.5
)
)

for rank in (1, 2, 3)
@testset "pool$(rank)d" begin
for (pool, ∇pool, answer_dict) in (
Expand Down Expand Up @@ -297,6 +361,48 @@ for rank in (1, 2, 3)
end
end

for rank in (1, 2)
for (pool, ∇pool, answer_dict) in (
(lppool, ∇lppool, lppool_answer_dict),
(NNlib.lppool_direct, NNlib.∇lppool_direct, lppool_answer_dict),)
@testset "$(pool)$(rank)d" begin
y = answer_dict[rank]["y"]
y_nostride = answer_dict[rank]["y_nostride"]
y_pad = answer_dict[rank]["y_pad"]
dx = answer_dict[rank]["dx"]
dx_nostride = answer_dict[rank]["dx_nostride"]
dx_pad = answer_dict[rank]["dx_pad"]
p = answer_dict[rank]["p"]
p_nostride = answer_dict[rank]["p_nostride"]
p_pad = answer_dict[rank]["p_pad"]

x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1)

ddims(x) = dropdims(x, dims=(rank + 1, rank + 2))

@test pool(x, PoolDims(x, 1); p=p) ≈ x atol = 1e-3

# Test vanilla pooling
pdims = PoolDims(x, 2)
y_hat = pool(x, pdims; p=p)
@test ddims(y_hat) ≈ y atol = 1e-3
@test ddims(∇pool(y_hat, y_hat, x, pdims; p=p)) ≈ dx atol = 1e-3

# Strided pooling
pdims = PoolDims(x, 2; stride=1)
y_hat = pool(x, pdims; p=p_nostride)
@test ddims(y_hat) ≈ y_nostride atol = 1e-3
@test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_nostride)) ≈ dx_nostride atol = 1e-3

# Padded pooling
pdims = PoolDims(x, 2; padding=1)
y_hat = pool(x, pdims; p=p_pad)
@test ddims(y_hat) ≈ y_pad atol = 1e-3
@test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_pad)) ≈ dx_pad atol = 1e-3
end
end
end

@testset "Pooling - Check Sizes" begin
x = rand(10, 10, 3, 10)
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
Expand Down