Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Orthogonal initialization feature. #1496

Merged
merged 39 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c277919
Add Orthogonal initialization feature.
SomTambe Feb 3, 2021
8bca693
Make necessary changes.
SomTambe Feb 3, 2021
95f94d4
Make nice changes.
SomTambe Feb 3, 2021
338c378
Update docstring.
SomTambe Feb 3, 2021
082a971
Make citation better.
SomTambe Feb 3, 2021
44965f0
Replace mapreduce thing.
SomTambe Feb 3, 2021
4a3d12b
Minor docstring change.
SomTambe Feb 3, 2021
1feb19e
Add tests.
SomTambe Feb 3, 2021
b2bc5cc
Minor changes.
SomTambe Feb 3, 2021
7a84f42
Rectified silly mistakes.
SomTambe Feb 3, 2021
28d05df
Modified docstring a bit.
SomTambe Feb 3, 2021
a8b15d1
Minor change.
SomTambe Feb 3, 2021
090fd7e
Update src/utils.jl
SomTambe Feb 3, 2021
8af7659
Removed the unwanted example.
SomTambe Feb 3, 2021
2735e0c
Merge branch 'master' of https://github.com/SomTambe/Flux.jl
SomTambe Feb 3, 2021
64d2e66
dims::Integer to give better error messages.
SomTambe Feb 3, 2021
a0191a5
Update NEWS.md
SomTambe Feb 4, 2021
d542d70
Rectified mistake.
SomTambe Feb 4, 2021
9221196
Changed orthogonal to orthogonal_init
SomTambe Feb 4, 2021
943baf2
Change the docs a bit to see if doctesting works.
SomTambe Feb 4, 2021
644bfef
Minor docstring changes.
SomTambe Feb 4, 2021
da935bb
Trying to make the doctests work
SomTambe Feb 4, 2021
4b04fdd
slight change
SomTambe Feb 5, 2021
b28b8db
Change for dims > 2.
SomTambe Feb 6, 2021
57b9af3
Add tests for dims > 2.
SomTambe Feb 6, 2021
f897c75
Merge branch 'master' into master
SomTambe Feb 6, 2021
418b316
Changed structure. Also changed the documentation a bit.
SomTambe Feb 6, 2021
5e801e2
Merge pull request #1 from FluxML/master
SomTambe Feb 8, 2021
21cdfc8
Make necessary changes.
SomTambe Feb 8, 2021
3e749da
Add `orthogonal` to docs/src/utilities.md
SomTambe Feb 8, 2021
7a2b610
Update src/utils.jl
SomTambe Feb 8, 2021
23a9c5b
Update src/utils.jl
SomTambe Feb 8, 2021
691ca35
Update src/utils.jl
SomTambe Feb 8, 2021
786eb8e
Changed the docs a bit.
SomTambe Feb 8, 2021
54bf710
Update src/utils.jl
SomTambe Feb 9, 2021
a80bea9
Add `rng` which I had forgotten.
SomTambe Feb 9, 2021
c632cf2
Slight change
SomTambe Feb 9, 2021
2c9f4b8
modified tests a bit
SomTambe Feb 9, 2021
8f2e4ed
Rectified mistake.
SomTambe Feb 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,67 @@ end
kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; kwargs...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)

Return an `Array` of size `dims` which is a (semi) orthogonal matrix, as described in *Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013)*.
SomTambe marked this conversation as resolved.
Show resolved Hide resolved

The input tensor must have at least 2 dimensions.
SomTambe marked this conversation as resolved.
Show resolved Hide resolved

# Examples
```jldoctest; setup = :(using Random; Random.seed!(0));
julia> using LinearAlgebra
SomTambe marked this conversation as resolved.
Show resolved Hide resolved

julia> W = Flux.orthogonal(5, 7);

julia> summary(W)
"5×7 Array{Float32,2}"

julia> W * W'
5×5 Array{Float32,2}:
1.0 -2.42898f-8 6.32759f-8 -1.37195f-7 -2.19659f-8
-2.42898f-8 1.0 4.03295f-8 -1.34284f-7 1.06978f-7
6.32759f-8 4.03295f-8 1.0 7.93047f-8 2.6339f-7
-1.37195f-7 -1.34284f-7 7.93047f-8 1.0 6.60169f-8
-2.19659f-8 1.06978f-7 2.6339f-7 6.60169f-8 1.0
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

julia> W * W' ≈ I(5)
true

julia> W2 = Flux.orthogonal(7, 5);

julia> W2 * W2' ≈ I(7)
false

julia> W2' * W2 ≈ I(5)
true
darsnack marked this conversation as resolved.
Show resolved Hide resolved
```

SomTambe marked this conversation as resolved.
Show resolved Hide resolved
# References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120

"""
function orthogonal(rng::AbstractRNG, dims...; gain = 1)
if length(dims) < 2
throw(ArgumentError("Only Arrays with 2 or more dimensions are supported"))
end

rows = dims[1]
cols = div(prod(dims),rows)
mat = rows > cols ? randn(Float32, rows, cols) : randn(Float32, cols, rows)

Q, R = LinearAlgebra.qr(mat)
Q = Array(Q) * sign.(LinearAlgebra.Diagonal(R))
if rows < cols
Q = transpose(Q)
end
Comment on lines +226 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another small one-liner trick and feel free to take any of it, or just ignore it.

Suggested change
if rows < cols
Q = transpose(Q)
end
Q = rows < cols ? transpose(Q) : Q

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should keep my thing, looks more elegant 😄

Copy link
Member

@mcabbott mcabbott Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason this strikes several of us as weird is partly that it's not type-stable to re-use Q, not just for different things, but for different types depending on the values of rows, cols. This isn't performance-critical code but that's where everyone's taste was honed.

Again, I would write

return rows > cols ? gain .* M : gain .* transpose(M)

where M is some name for the thing which isn't Q anymore, and the two branches match the branches which generate the random numbers above. They could both be written out on several lines, mat = if rows > cos; randn(... etc, but however they are written, I think they should put the then/else clauses in the same order.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the mirrored if-else clause are a bit confusing. Should change that if nothing else.


return gain * reshape(Q, dims)
end

orthogonal(dims...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> orthogonal(rng, dims...; kwargs...)
SomTambe marked this conversation as resolved.
Show resolved Hide resolved

"""
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)

Expand Down
10 changes: 9 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, sparse_init, stack, unstack, Zeros
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, sparse_init, stack, unstack, Zeros
using StatsBase: var, std
using Random
using Test
Expand Down Expand Up @@ -96,6 +96,14 @@ end
end
end

@testset "orthogonal" begin
# A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition.
for (rows,cols) in [(5,3),(3,5)]
v = orthogonal(rows, cols)
rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols))
end
end

@testset "sparse_init" begin
# sparse_init should yield an error for non 2-d dimensions
# sparse_init should yield no zero elements if sparsity < 0
Expand Down