Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test with Yota, too #105

Merged
merged 8 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Functors = "0.3, 0.4"
Yota = "0.8.2"
Zygote = "0.6.40"
julia = "1.6"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Test", "StaticArrays", "Yota", "Zygote"]
76 changes: 61 additions & 15 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ to adjust the model:

```julia

using Flux, Metalhead, Optimisers
using Flux, Metalhead, Zygote, Optimisers

model = Metalhead.ResNet(18) |> gpu # define a model to train
image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
Expand All @@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
end;

state, model = Optimisers.update(state, model, ∇model);
@show sum(model(image));
@show sum(model(image)); # reduced

```

Expand All @@ -62,8 +62,14 @@ tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.
(The method of `apply!` above is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.)
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:

```julia
Base.summarysize(model) / 1024^2 # about 45MB
Base.summarysize(state) / 1024^2 # about 90MB
```

Optimisers.jl does not depend on any one automatic differentiation package,
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
Expand All @@ -72,14 +78,34 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.


## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)

Yota is another modern automatic differentiation package, an alternative to Zygote.

Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
but also returns a gradient component for the loss function.
To extract what Optimisers.jl needs, you can write (for the Flux model above):

```julia
using Yota

loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x)
end;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased this and tests pass!

This example does not, it fails with a seemingly simple error:

julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
          sum(m(x))
        end;


loss, (_, ∇model) = Yota.grad(m -> sum(m(image)), model)ERROR: No derivative rule found for op %454 = ntuple(%452, 4)::NTuple{4, Int64} , try defining it using 

	ChainRulesCore.rrule(::typeof(ntuple), ::Flux.var"#336#337"{4, Array{Float32, 4}}, ::Int64) = ...

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/packages/Yota/KJQ6n/src/grad.jl:219

That was on tagged Yota; on latest everything instead it seems to take forever, and interrupts here:

julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
          sum(m(x))
        end;

^CERROR: InterruptException:
Stacktrace:
   [1] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
     @ Base ./array.jl:792
   [2] todo_list(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
     @ Yota ~/.julia/packages/Yota/5CVY7/src/grad.jl:113
   [3] #68
     @ ./none:0 [inlined]
   [4] iterate
     @ ./generator.jl:47 [inlined]
   [5] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
     @ Base ./array.jl:787
   [6] todo_list(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
     @ Yota ~/.julia/packages/Yota/5CVY7/src/grad.jl:113
   [7] #68
     @ ./array.jl:0 [inlined]
   [8] iterate
     @ ./generator.jl:47 [inlined]
   [9] collect_to!(dest::Vector{Vector{Umlaut.Variable}}, itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}}, offs::Int64, st::Int64)
     @ Base ./array.jl:845
  [10] collect_to_with_first!(dest::Vector{Vector{Umlaut.Variable}}, v1::Vector{Umlaut.Variable}, itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}}, st::Int64)
     @ Base ./array.jl:823
  [11] collect(itr::Base.Generator{Vector{Umlaut.Variable}, Yota.var"#68#72"{Umlaut.Tape{Yota.GradCtx}}})
     @ Base ./array.jl:797
--- the last 10 lines are repeated 2 more times ---

(jl_aZPcXz) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_aZPcXz/Project.toml`
  [dbeba491] Metalhead v0.8.0-DEV `https://github.com/FluxML/Metalhead.jl.git#master`
  [3bd65402] Optimisers v0.2.10 `~/.julia/dev/Optimisers`
  [09ab397b] StructArrays v0.6.13 `https://github.com/JuliaArrays/StructArrays.jl.git#master`
  [cd998857] Yota v0.8.1 `https://github.com/dfdx/Yota.jl.git#main`

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I was indeed investigating incredibly long processing time, but profiler blamed type inference/abstract interpreter, so I started a long search for a better way to trace functions (e.g. see my recent post on Discourse). However, your stacktrace implies the problem may actually appear after the tracing. I will try to investigate this option too closer to the end of the week.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I opened an issue to track this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The ResNet(18) example now compiles and runs in 61 second (compared to 47 seconds with Zygote). Subsequent calls take ~0.4 seconds on my CPU.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I see something similar locally, on 0.8.2


# Or else, this may save computing ∇image:
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
```

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux is that the tree of parameters is separate from
The main design difference of Lux from Flux is that the tree of parameters is separate from
the layer structure. It is these parameters which `setup` and `update` need to know about.

Lux describes this separation of parameter storage from model description as "explicit" parameters.
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly
identical trees of nested `NamedTuple`s.)

```julia
Expand All @@ -88,27 +114,47 @@ using Lux, Boltz, Zygote, Optimisers

lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model
@show sum(y) # initial dummy loss
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
@show sum(y); # initial dummy loss

rule = Optimisers.Adam()
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters

∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree
y, _ = Lux.apply(lux_model, x, p, lux_state)
sum(y)
(loss, lux_state), back = Zygote.pullback(params, images) do p, x
y, st = Lux.apply(lux_model, x, p, lux_state)
sum(y), st # return both the loss, and the updated lux_state
end;
∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree
loss == sum(y) # not yet changed

opt_state, params = Optimisers.update!(opt_state, params, ∇params);

y, _ = Lux.apply(lux_model, images, params, lux_state);
@show sum(y)
y, lux_state = Lux.apply(lux_model, images, params, lux_state);
@show sum(y); # now reduced

```

Besides the parameters stored in `params` and gradually optimised, any other model state
is stored in `lux_state`. For simplicity this example does not show how to propagate the
updated `lux_state` to the next iteration, see Lux's documentation.
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.

```julia
Base.summarysize(lux_model) / 1024 # just 2KB
Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model
Base.summarysize(lux_state) / 1024 # 40KB
Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam
```

If you are certain there is no model state, then the gradient calculation can
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:

```julia
∇params, _ = gradient(params, images) do p, x
y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state
sum(y)
end;
```


## Non-`trainable` Parameters

Expand Down
55 changes: 55 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ end
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
# Diffractor error in perform_optic_transform
end

VERSION < v"1.9-" && @testset "using Yota" begin
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])

g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1]
@test g5.a[1].x == [0,0,1]
@test g5.a[2] === nothing

g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1]
@test g6.a == [0,0,0]
@test g6.a isa Vector{Float64}
@test g6.b == [0+im]

g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
@test g8[1].x == [2,4,6]
@test g8[2].b.x == [8]
@test g8[3] == [[10.0]]

g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
@test g9.c === nothing
end
end

@testset "gradient of rebuild" begin
Expand Down Expand Up @@ -149,6 +174,36 @@ end
# Not fixed by this:
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end

VERSION < v"1.9-" && @testset "using Yota" begin
re1 = destructure(m1)[2]
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
re2 = destructure(m2)[2]
@test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
re3 = destructure(m3)[2]
@test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
@test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]

re4 = destructure(m4)[2]
@test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
@test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
@test Yota_gradient(rand(6)) do x
m = re4(x)
m.x[1] + 2*m.y[2] + 3*m.z[3]
end[1] == [1,2,0, 0,0,3]

re7 = destructure(m7)[2]
@test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1]
@test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
@test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]

v8, re8 = destructure(m8)
@test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
@test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

re9 = destructure(m9)[2]
@test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
end
end

@testset "Flux issue 1826" begin
Expand Down
15 changes: 15 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,18 @@ end
@test static_loss(static_model) < 1.9
end
end

VERSION < v"1.9-" && @testset "using Yota" begin
@testset "$(name(o))" for o in RULES
w′ = (abc = (α = rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps))
w = (abc = (α = 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d = (δ = rand(3), ε = eps))
st = Optimisers.setup(o, w)
loss(x, y) = mean((x.abc.α .* x.abc.β .- y.abc.α .* y.abc.β) .^ 2) # does not use γ, δ, ε
@test loss(w, w′) > 0.5
for i = 1:10^4
_, (_, g, _) = Yota.grad(loss, w, w′)
st, w = Optimisers.update(st, w, g)
end
@test loss(w, w′) < 0.001
end
end
9 changes: 8 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy

Expand Down Expand Up @@ -37,6 +37,13 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
return state, dx
end

# Make Yota's output look like Zygote's:

Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
y2z(x) = x

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin

Expand Down