Skip to content

Commit

Permalink
freezingdocs
Browse files Browse the repository at this point in the history
  • Loading branch information
isentropic committed Mar 13, 2024
1 parent 5f84b68 commit d3b800b
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ makedocs(
=#
# Not really sure where this belongs... some in Fluxperimental, aim to delete?
"Custom Layers" => "models/advanced.md", # TODO move freezing to Training
"Advanced tweaking of models" => "tutorials/misc-model-tweaking.md",
],
],
format = Documenter.HTML(
Expand Down
41 changes: 0 additions & 41 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,47 +75,6 @@ Flux.@layer Affine trainable=(W,)

There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument.


## Freezing Layer Parameters

When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.

!!! compat "Flux ≤ 0.14"
The mechanism described here is for Flux's old "implicit" training style.
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.

Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
this using the slicing features `Chain` provides:

```julia
m = Chain(
Dense(784 => 64, relu),
Dense(64 => 64, relu),
Dense(32 => 10)
);

ps = Flux.params(m[3:end])
```

The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.

During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.

`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:

```julia
Flux.params(m[1], m[3:end])
```

Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
by simply deleting it from `ps`:

```julia
ps = Flux.params(m)
delete!(ps, m[2].bias)
```

## Custom multiple input or output layer

Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf).
Expand Down
2 changes: 1 addition & 1 deletion docs/src/training/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Flux.Optimise.Optimiser

## Scheduling Optimisers

In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/dev/README.html). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser.
In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/dev). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser.

First, we import ParameterSchedulers.jl and initialize a cosine annealing schedule to vary the learning rate between `1e-4` and `1e-2` every 10 steps. We also create a new [`Momentum`](@ref) optimiser.
```julia
Expand Down
131 changes: 131 additions & 0 deletions docs/src/tutorials/misc-model-tweaking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Choosing differentiable/gpu parts of the model
!!! note
This tutorial features somewhat disconnected topics about customizing your
models even further. It is advised to be familiar with
[`Flux.@layer`](@ref), [`Flux.@functor`](@ref), [`freeze!`](@ref
Flux.freeze!) and other basics of Flux.

Flux provides several ways of freezing, excluding from backprop entirely and
marking custom struct fields not to be moved to the GPU
([Functors.@functor](@ref)) hence excluded from being trained. The following
subsections should make it clear which one suits your needs the best.

## On-the-fly freezing per model instance
Perhaps you'd like to freeze some of the weights of the model (even at
mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`.

```julia
m = Chain(
Dense(784 => 64, relu), # freeze this one
Dense(64 => 64, relu),
Dense(32 => 10)
)
opt_state = Flux.setup(Momentum(), m);

# Freeze some layers right away
Flux.freeze!(opt_state.layers[1])

for data in train_set
input, label = data

# Some params could be frozen during the training:
Flux.freeze!(opt_state.layers[2])

grads = Flux.gradient(m) do m
result = m(input)
loss(result, label)
end
Flux.update!(opt_state, m, grads[1])

# Optionally unfreeze the params later
Flux.thaw!(opt_state.layers[1])
end
```

## Static freezing per model definition
Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params
still need to reside on the GPU (these params are still needed in the forward
and/or backward pass).
```julia
struct MaskedLayer{T}
chain::Chain
mask::T
end
Flux.@layer MyLayer trainable=(chain,)
# mask field will not be updated in the training loop

function (m::MaskedLayer)(x)
# mask field will still move to to gpu for efficient operations:
return m.chain(x) + x + m.mask
end

model = MaskedLayer(...) # this model will not have the `mask` field trained
```
Note how this method permanently sets some model fields to be excluded from
training without on-the-fly changing.

## Excluding from model definition
Sometimes some parameters aren't just "not trainable" but they shouldn't even
transfer to the GPU (or be part of the functor). All scalar fields are like this
by default, so things like learning rate multipliers are not trainable nor
transferred to the GPU by default.
```julia
struct CustomLayer{T, F}
chain::Chain
activation_results::Vector{F}
lr_multiplier::Float32
end
Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results`

function (m::CustomLayer)(x)
result = m.chain(x) + x

# `activation_results` are not part of the GPU loop, hence we could do
# things like `push!`
push!(m.activation_results, mean(result))
return result
end
```
See more about this in [`Flux.@functor`](@ref)


## Freezing Layer Parameters (deprecated)

When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.

!!! compat "Flux ≤ 0.14"
The mechanism described here is for Flux's old "implicit" training style.
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.

Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
this using the slicing features `Chain` provides:

```julia
m = Chain(
Dense(784 => 64, relu),
Dense(64 => 64, relu),
Dense(32 => 10)
);

ps = Flux.params(m[3:end])
```

The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.

During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.

`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:

```julia
Flux.params(m[1], m[3:end])
```

Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
by simply deleting it from `ps`:

```julia
ps = Flux.params(m)
delete!(ps, m[2].bias)
```

0 comments on commit d3b800b

Please sign in to comment.