Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kaiming initialization and relevant docstrings #1243

Merged
merged 2 commits into from
Jun 30, 2020
Merged

add kaiming initialization and relevant docstrings #1243

merged 2 commits into from
Jun 30, 2020

Conversation

johnnychen94
Copy link
Contributor

@johnnychen94 johnnychen94 commented Jun 20, 2020

This is an updated version of #425, Distributions is not used because it always generates Array{Float64, N} instead of Array{Float32, N}.

  • also updates the docstring of glorot initialization
  • add a method for nfan(::Tuple) for robustness consideration, otherwise nfan((100, 400)) would return (1, (100, 400)), which isn't correct.

These methods are not exported because glorot_* aren't, either.

If this get merged, we could switch from glorot_uniform to kaiming_uniform for Conv since people nowadays use relu mostly, but that belongs to another PR.

closes #425 closes #424

Co-authored-by: Aniket Das aniketd@iitk.ac.in

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @MikeInnes or @dhairyagandhi96 (for API changes).

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@johnnychen94 johnnychen94 changed the title add kaiming initialization add kaiming initialization and relevant docstrings Jun 22, 2020
src/utils.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 23, 2020

For future reference:

  • Keras' default is glorot uniform in [-limit, limit], where limit = sqrt(6 / (fan_in + fan_out) https://keras.io/api/layers/initializers/ . This implies std = sqrt(2 / (fan_in + fan_out), which should be the correct value for identity activations when fan_in=fan_out (by a variance propagation argument).
  • Pytorch's default instead is uniform in [-limit, limit], where limit = sqrt(1 / fan_in)
    nn.linear module weight initialization fix pytorch/pytorch#19526 (comment).
    I don't have a mathematical explanation for Pytorch's default, I think is there for backward compatibility, maybe an initialization suggested by LeCunn paper in the '90s, I don't remember which one.

So neither libraries use by default the ReLU rescaling factor sqrt(2). Let's keep this in mind if we are going to change the default init

@johnnychen94
Copy link
Contributor Author

bump

@DhairyaLGandhi
Copy link
Member

Lgtm, although the references to the other functions are a bit verbose. We should also add a section for initialisation in the docs (in a future PR)

Thanks @johnnychen94

@DhairyaLGandhi
Copy link
Member

We should add a bit in the docs about how to use the initializations in regular layers

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jun 28, 2020

This might be a future PR too, but could we consider taking T as user input, since although we default to f32, there are times when higher (or lower) precision may be preferred.

johnnychen94 and others added 2 commits June 29, 2020 07:51
* updates the docstring of glorot initialization
* add a method for nfan(::Tuple) for robustness consideration, otherwise nfan((100, 400)) would return 1, (100, 400), which isn't correct.

Co-authored-by: Aniket Das <aniketd@iitk.ac.in>
Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
@johnnychen94
Copy link
Contributor Author

johnnychen94 commented Jun 28, 2020

Rebased commits with no content changes.

@CarloLucibello
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Jun 30, 2020

Build succeeded:

@bors bors bot merged commit 469ceca into FluxML:master Jun 30, 2020
@johnnychen94 johnnychen94 deleted the jc/kaiming branch June 30, 2020 21:50
@johnnychen94
Copy link
Contributor Author

johnnychen94 commented Jul 28, 2020

Pytorch's default instead is uniform in [-limit, limit], where limit = sqrt(1 / fan_in)
pytorch/pytorch#19526 (comment).

1.7.0-dev: https://github.com/pytorch/pytorch/blob/4f723825b48e555512813fcffd56e89e5b16eeaf/torch/nn/modules/conv.py#L85-L90

Pytorch's default is kaiming_uniform with gain=sqrt(5) for weights and uniform for bias.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for Kaiming Initialization
3 participants