Skip to content

Commit

Permalink
Test with Yota, too (#105)
Browse files Browse the repository at this point in the history
* test with Yota too, and document this

* also test destructure

* actually try out the doc examples

* tidy, add summarysize

* add again changes made on website which got lost in a local rebase without checking first because I forgot about this for ages

* Yota 0.8.2, etc

* skip Yota tests on 1.9 & later

* skip more tests
  • Loading branch information
mcabbott authored Dec 8, 2022
1 parent acc9b16 commit 79269be
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 17 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Functors = "0.3, 0.4"
Yota = "0.8.2"
Zygote = "0.6.40"
julia = "1.6"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Test", "StaticArrays", "Yota", "Zygote"]
76 changes: 61 additions & 15 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ to adjust the model:

```julia

using Flux, Metalhead, Optimisers
using Flux, Metalhead, Zygote, Optimisers

model = Metalhead.ResNet(18) |> gpu # define a model to train
image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
Expand All @@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
end;

state, model = Optimisers.update(state, model, ∇model);
@show sum(model(image));
@show sum(model(image)); # reduced

```

Expand All @@ -62,8 +62,14 @@ tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.
(The method of `apply!` above is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.)
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:

```julia
Base.summarysize(model) / 1024^2 # about 45MB
Base.summarysize(state) / 1024^2 # about 90MB
```

Optimisers.jl does not depend on any one automatic differentiation package,
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
Expand All @@ -72,14 +78,34 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.


## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)

Yota is another modern automatic differentiation package, an alternative to Zygote.

Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
but also returns a gradient component for the loss function.
To extract what Optimisers.jl needs, you can write (for the Flux model above):

```julia
using Yota

loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x)
end;

# Or else, this may save computing ∇image:
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
```
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
The main design difference of Lux is that the tree of parameters is separate from
The main design difference of Lux from Flux is that the tree of parameters is separate from
the layer structure. It is these parameters which `setup` and `update` need to know about.
Lux describes this separation of parameter storage from model description as "explicit" parameters.
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly
identical trees of nested `NamedTuple`s.)
```julia
Expand All @@ -88,27 +114,47 @@ using Lux, Boltz, Zygote, Optimisers

lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model
@show sum(y) # initial dummy loss
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
@show sum(y); # initial dummy loss

rule = Optimisers.Adam()
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters

∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree
y, _ = Lux.apply(lux_model, x, p, lux_state)
sum(y)
(loss, lux_state), back = Zygote.pullback(params, images) do p, x
y, st = Lux.apply(lux_model, x, p, lux_state)
sum(y), st # return both the loss, and the updated lux_state
end;
∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree
loss == sum(y) # not yet changed

opt_state, params = Optimisers.update!(opt_state, params, ∇params);

y, _ = Lux.apply(lux_model, images, params, lux_state);
@show sum(y)
y, lux_state = Lux.apply(lux_model, images, params, lux_state);
@show sum(y); # now reduced

```
Besides the parameters stored in `params` and gradually optimised, any other model state
is stored in `lux_state`. For simplicity this example does not show how to propagate the
updated `lux_state` to the next iteration, see Lux's documentation.
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
```julia
Base.summarysize(lux_model) / 1024 # just 2KB
Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model
Base.summarysize(lux_state) / 1024 # 40KB
Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam
```
If you are certain there is no model state, then the gradient calculation can
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
```julia
∇params, _ = gradient(params, images) do p, x
y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state
sum(y)
end;
```
## Non-`trainable` Parameters
Expand Down
55 changes: 55 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ end
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
# Diffractor error in perform_optic_transform
end

VERSION < v"1.9-" && @testset "using Yota" begin
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])

g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1]
@test g5.a[1].x == [0,0,1]
@test g5.a[2] === nothing

g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1]
@test g6.a == [0,0,0]
@test g6.a isa Vector{Float64}
@test g6.b == [0+im]

g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
@test g8[1].x == [2,4,6]
@test g8[2].b.x == [8]
@test g8[3] == [[10.0]]

g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
@test g9.c === nothing
end
end

@testset "gradient of rebuild" begin
Expand Down Expand Up @@ -149,6 +174,36 @@ end
# Not fixed by this:
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end

VERSION < v"1.9-" && @testset "using Yota" begin
re1 = destructure(m1)[2]
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
re2 = destructure(m2)[2]
@test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
re3 = destructure(m3)[2]
@test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
@test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]

re4 = destructure(m4)[2]
@test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
@test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
@test Yota_gradient(rand(6)) do x
m = re4(x)
m.x[1] + 2*m.y[2] + 3*m.z[3]
end[1] == [1,2,0, 0,0,3]

re7 = destructure(m7)[2]
@test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1]
@test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
@test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]

v8, re8 = destructure(m8)
@test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
@test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

re9 = destructure(m9)[2]
@test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
end
end

@testset "Flux issue 1826" begin
Expand Down
15 changes: 15 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,18 @@ end
@test static_loss(static_model) < 1.9
end
end

VERSION < v"1.9-" && @testset "using Yota" begin
@testset "$(name(o))" for o in RULES
w′ = (abc == rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
w = (abc == 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
st = Optimisers.setup(o, w)
loss(x, y) = mean((x.abc.α .* x.abc.β .- y.abc.α .* y.abc.β) .^ 2) # does not use γ, δ, ε
@test loss(w, w′) > 0.5
for i = 1:10^4
_, (_, g, _) = Yota.grad(loss, w, w′)
st, w = Optimisers.update(st, w, g)
end
@test loss(w, w′) < 0.001
end
end
9 changes: 8 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy

Expand Down Expand Up @@ -37,6 +37,13 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
return state, dx
end

# Make Yota's output look like Zygote's:

Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
y2z(x) = x

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin

Expand Down

0 comments on commit 79269be

Please sign in to comment.