Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lppool implementation #447

Merged
merged 12 commits into from
Jan 7, 2023
Merged

Add lppool implementation #447

merged 12 commits into from
Jan 7, 2023

Conversation

skyleaworlder
Copy link
Contributor

@skyleaworlder skyleaworlder commented Dec 24, 2022

About FluxML/Flux.jl#1431: LPPool1d and LPPool2d.

PR Checklist

Related tests and documentation will come soon.

  • Tests are added
  • Documentation, if applicable

@skyleaworlder skyleaworlder marked this pull request as ready for review December 26, 2022 18:07
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! I have not checked the implementation closely but have a few comments in the meantime...

lppool(x, p::Number, k::NTuple; pad=0, stride=k)

Perform Lp pool operation with `p`-norm and `window size `k` on input tensor `x`
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain a bit more?

  • It should say that it requires ndims(x) == length(k)+2.

  • I think it should also say lppool(x, 1, k) ./ prod(k) ≈ meanpool(x, k), and maybe lppool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k).

  • What range of p is allowed? E.g. someone might expect lppool(reshape(1:10.0,:,1,1), Inf, (3,)) to be maxpool, but it isn't... maybe anything which works for norm but not here should be an error?

  • Can it briefly say what types are allowed for stride, pad, or refer to fuller docs elsewhere?

Copy link
Member

@ToucheSir ToucheSir Dec 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PyTorch layer docstring would be a good inspiration for this.

Copy link
Contributor Author

@skyleaworlder skyleaworlder Dec 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain a bit more?

  • It should say that it requires ndims(x) == length(k)+2.
  • I think it should also say lppool(x, 1, k) ./ prod(k) ≈ meanpool(x, k), and maybe lppool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k).
  • What range of p is allowed? E.g. someone might expect lppool(reshape(1:10.0,:,1,1), Inf, (3,)) to be maxpool, but it isn't... maybe anything which works for norm but not here should be an error?
  • Can it briefly say what types are allowed for stride, pad, or refer to fuller docs elsewhere?

I add docs for these aspects in new commit.

The PyTorch layer docstring would be a good inspiration for this.

I wonder if it might be suitable to write super detailed docs here. Flux.jl would wrap pooling layer after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our docstrings for meanpool and maxpool are quite short. Given that LP-pooling is far less well known however, it warrants a bit more explanation.

src/pooling.jl Show resolved Hide resolved
src/impl/pooling_direct.jl Outdated Show resolved Hide resolved
src/impl/pooling_direct.jl Outdated Show resolved Hide resolved
src/impl/pooling_direct.jl Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
test/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member

mcabbott commented Dec 29, 2022

Re the name, we have meanpool and maxpool, corresponding to mapping mean and maximum (which is reduce(max, ...) over moving windows.

This one maps norm over a moving window. Shouldn't it be normpool?

I realise this breaks with pytorch, but we aren't tied to copying them. Also, lppool is a weird name to read, how many ppp in a row, etc. Obviously we can say "also known as..." so that searching for variations of this name finds you the right function.

@ToucheSir
Copy link
Member

The PyTorch implementation uses "power-average pooling", so that's another option.

One concerning thing I did find while looking around for name ideas is that nobody seems to be using this? GitHub search turns up only a couple instances of LPPool2d that weren't just copies of the PyTorch repo, and I couldn't identify the source paper/presentation/document this operation was first introduced in.

@skyleaworlder
Copy link
Contributor Author

The PyTorch implementation uses "power-average pooling", so that's another option.

One concerning thing I did find while looking around for name ideas is that nobody seems to be using this? GitHub search turns up only a couple instances of LPPool2d that weren't just copies of the PyTorch repo, and I couldn't identify the source paper/presentation/document this operation was first introduced in.

The paper proposes LPPool might be Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks. In this paper, they claim it's better to consider $L_p$ as a hyperparameter instead of a fixed value. And about another problem, I must say I also haven't used this kind of pooling method, maxpool for me mostly. Actually, I just notice Flux.jl#1431 and feel like taking this to learn about Flux.

@ToucheSir
Copy link
Member

Nice find! Do you mind adding it as the primary reference in the docstring (see how this is done for Flux layers)? Also, is lpnormpool too much of a mouthful, and if so which parts of that name should stay since the paper uses both $L_p$ and "norm"?

@skyleaworlder
Copy link
Contributor Author

Nice find! Do you mind adding it as the primary reference in the docstring (see how this is done for Flux layers)? Also, is lpnormpool too much of a mouthful, and if so which parts of that name should stay since the paper uses both $L_p$ and "norm"?

OK, I've changed the name in recent commit and add paper reference.

@mcabbott
Copy link
Member

mcabbott commented Jan 4, 2023

The title does have "Norm Pooling" as they felt squeezing L_p in there too much?

Also LinearAlgebra's function could be called lpnorm, but it's not, they settled on just norm. That seems like the obvious thing to echo.

@skyleaworlder
Copy link
Contributor Author

I think lppool is ok or "lp" is needed because pytorch uses this. If normpool is used as name, users may spend time figuring out that normpool equals lppool in pytorch.

@CarloLucibello
Copy link
Member

having the more articulate name lpnormpool for a rarely used operator seems fine

src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
src/pooling.jl Outdated Show resolved Hide resolved
@skyleaworlder skyleaworlder requested a review from mcabbott January 6, 2023 15:20
@skyleaworlder
Copy link
Contributor Author

@mcabbott Thanks for your patient review! I resolved all points metioned above and re-requested a new review just now :)

@mcabbott mcabbott merged commit 16b7486 into FluxML:master Jan 7, 2023
@skyleaworlder
Copy link
Contributor Author

@mcabbott @ToucheSir @CarloLucibello Thanks for your patience and kindness again! I really learn alot from this PR.

@@ -103,7 +109,7 @@ end
# Finally, let's generate auto-allocating versions of all our functions, for all backends:
for backend in (Symbol(), :_direct, :_nnpack)
# First make auto-allocating versions of the basic pooling calls:
for name in (:maxpool, :meanpool)
for name in (:maxpool, :meanpool, :lpnormpool)
@eval begin
function $(Symbol("$(name)$(backend)"))(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that this similar means that lower-precision input gives the correct output without promotion, and that integer input fails here (unlike other pooling):

julia> NNlib.lpnormpool(ones(Float32, 4,1,1), 2.001, (2,), stride=1)
3×1×1 Array{Float32, 3}:
[:, :, 1] =
 1.4139687
 1.4139687
 1.4139687

julia> NNlib.lpnormpool(ones(Int, 4,1,1), 2.001, (2,), stride=1)
ERROR: InexactError: Int64(1.4139686415190424)

julia> NNlib.maxpool(ones(Int, 4,1,1), (2,), stride=1)  # is OK with integers
3×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 1
 1

julia> NNlib.meanpool(ones(Int, 4,1,1), (2,), stride=1)  # is OK with integers
3×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 1
 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants