-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Added some Callback functions #1067
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There're several things I think need to be answered/fixed before this PR considered ready for a second round review:
- the usage of global variables is extremely unstable, especially when you're using lowercase names. You need to minimize the number of global variables, my suggestion is to wrap all of them into a dict.
- is it possible to use all of these callbacks at the same time? E.g., will
ctr
be increased twice if you use two callbacks? - Can other users to reuse the data/structure of your
init_cb
to create new callbacks for their own projects without modifying Flux's codebase? - styles need to be corrected according to Julia style guide: indentation levels; TAB/spaces; docstrings.
src/callbacks.jl
Outdated
history(x,y,accuracy) | ||
|
||
Callback that records loss and accuracy into an array. Array can be accessed using `loss_history`. | ||
Arguments: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Julia doesn't document inputs in this style; instead, you can use plain markdown list.
symbols need to be wrapped with `
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change that according to the documentation rules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the keyword arguments be also wrapped with ` ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing it that way requires the forward-pass twice in the iteration the callback is called. I'm not sure how to merge the Zygote pullback stuff with Flux callbacks though.
BTW, I think most uses of train!
will terminate anyways when a NaN loss emerges.
src/callbacks.jl
Outdated
monitor: Quantity to be monitored for the provided (x,y). Can take values 'acc' or 'loss'. Default set to 'loss' | ||
""" | ||
|
||
function lrdecay(x,y;factor = 0.2,loss = loss,accuracy_fn = nothing,patience=5,min_lr = 1e-5,monitor="loss") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage of loss
, accuracy_fun
and monitor
doesn't look right to me. I prefer
function lrdecay(optimizer, metric_fn, x, y; factor=0.2, patience=5, min_lr=1e-5, descend=true)
# by default it watches if metric_fn(x, y) descends.
end
My feel on model_checkpoint
is the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I also think now that, there should be just metric function for that. Thanks for the suggestion. I will make the likewise changes to model_checkpoint as well
src/callbacks.jl
Outdated
best_acc = acc | ||
last_improvement = ctr | ||
elseif (ctr-last_improvement)>patience | ||
if opt.eta>min_lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is opt
defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! It should have been a input parameter. Actually while testing, I took opt
as my Optimizer that's why it didn't throw any error that time. I will add Optimizer as a parameter
src/callbacks.jl
Outdated
accuracy_fn: function to be used to calculate accuracy, when monitor set to 'acc'. Otherwise, optional | ||
|
||
patience: number of epochs that produced the monitored quantity with no improvement, | ||
after which learning rate will be reduced or training stopped. Default value set to 5. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace TAB
with four spaces here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change that according to documentation rules.
global best_acc = 0 | ||
global best_loss = Inf | ||
global last_improvement = 0 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of letting users call init_cb
, you can define a "constant dict", let these callbacks to call and modify the dict at runtime.
const CALLBACK_VALUES = Dict()
function _init_cb()
# BTW, what does ctr stands for?
CALLBACK_VALUES["ctr"] = 0
end
function terminateOnNaN(x, y)
haskey(CALLBACK_VALUES, "ctr") || _init_cb()
...
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually there can be multiple methods that can be used for training, thus there isn't a proper way to keep counter of how many times a callback a called. So defined this ctr. However as you say above that in case of multiple callbacks, this will get affected. I am thinking to use some other way to keep the counter. Is there any way possible to make the callbacks aware of multiple callbacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to give each of the callbacks with no. of epochs requirement a seperate counter like CALLBACK_VALUES["ctr_model_check"]?
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
The way I would do this is to have each callback hold state. For adhoc use I use closures. function make_history_callback(calc_metric)
history = [calc_metric()] # store initial loss
return ()->push!(history, calc_metric())
end which I would use: const test_x, test_y = ...
const train_data = ...
const model = Chain(...)
test_loss() = mse(test_y, model(test_x))
const history_callback = make_history_callback(test_loss)
# may have signature a bit wrong but I trust you get the point
Flux.train!(mse, Iterators.repeat(train_data, 100), model, params(model), ADAM(); callback=history_callback)
using Plots
plot(history_callback.history) To be more explict than using closures would be to use functors. struct HistoryCallback{F,T}
calc_metric::F
history::Vector{T}
end
HistoryCallback(calc_metric) = HistoryCallback(calc_metric, [calc_metric])
(cb::HistoryCallback)()= push!(cb.history, calc_metric()) Which i would use basically as before: const history_callback = HistoryCallback(test_loss) |
Apparently it should be done like this! |
Thanks for the suggestions! I will make the changes accordingly. |
@johnnychen94 I have updated the functions. Please have a look.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice, my comments are more about style.
There're two more concerns (not necessarily about this PR), that:
- as @bhvieira commented, we may need to find a way to reduce unnecessary metric computation across multiple callbacks or even get forward-pass from Zygote. Currently, all metrics are computed multiple times even if they're the same.
- Some callbacks don't need to be called so frequently, we may need a callback scheduler type to dynamically do nothing or call the callbacks, i.e., a callback that calls callbacks, or an improved version of
throttle
For now, I think it's okay to get others involved to see if they find more things that need to be resolved. @dhairyagandhi96 @CarloLucibello
calc_metric::F | ||
end | ||
|
||
function (cb::terminateOnNaN)() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to ignore this comment if you don't like it.
I'm not sure how others feel, instead of creating calc_metri
manually, we could introduce a helper constructor like this:
terminateOnNaN(metric, model, X, Y) = terminateOnNaN(() -> metric(model(X), Y))
so that users can also do:
cb1 = terminateOnNaN(mse, model, valid_X, valid_Y)
The same suggestion for other callbacks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was one of the things that I did previously, but I observed that in general callbacks are called without any arguments and generally the data is provided within the metric definition like
test_loss() = mse(test_y, model(test_x))
by @oxinabox, and even in documentation examples. However still, I guess considering your suggestion, we can provide a proper definition for calc_metric
where we can still put suggestion for something like calc_metric =()-> metric(valid_x, valid_y, model)
where metric can defined in your suggested manner. In that way we can provide a better intuition about calc_metric
terminateOnNaN(calc_metric) | ||
|
||
Callback that terminates training when `NaN` is encountered for `calc_metric` value. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to document what a valid calc_metric
is, e.g., it doesn't accept any argument, it returns a number.
The same suggestion for other callbacks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add that part. I was thinking it would be better, if we provide more detail about this through examples in main documentation.
- `calc_metric`: the metric to be monitored. Evaluated at every call of callback function | ||
- `opt`: the optimizer used while training | ||
- `factor::Float64=0.2`: factor by which learning rate is reduced in every updation. | ||
- `patience::Int=5`: number of epochs that produced the monitored 'calc_metric' with no improvement, after which learning rate will be reduced or training stopped. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally okay with these long descriptions, but if you find it uncomfortable, you can manually split them into multiple lines, I think it would be parsed nicely
- `patience::Int=5`: number of epochs that produced the monitored 'calc_metric' with
no improvement, after which learning rate will be reduced or training stopped.
the four tailing spaces need to be removed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I guess that would be a better idea. I will split it that way, as I think it will be more readable that way, especially for someone who is going through the code. Otherwise I think both are parsed same way in documentation
min_lr::Float64 | ||
end | ||
|
||
lrdecay(calc_metric, opt; factor=0.2, patience=5, descend = true, min_lr=1e-6) = lrdecay(calc_metric, opt, calc_metric(), 0, factor, 0, descend, patience, min_lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speaking of lrdecay callback, do you know that there's already such functionality in Flux?
https://fluxml.ai/Flux.jl/stable/training/optimisers/#Composing-Optimisers-1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! I haven't seen that earlier. So, should I remove this then?
In place, I think we can have something like LRscheduler, where user can provide custom functions to schedule the LR decay, over epochs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this PR I think it's better to remove this, but if you feel the LRScheduler way more friendly to end-users, open another issue for that, although I think these duplicated functionalities are very likely to be rejected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will remove that. Thanks for pointing! :)
end | ||
m = Chain(cb.model_arr...) | ||
if cb.save_best_model_only==1 | ||
BSON.@save joinpath(cb.savepath, cb.filename) m epoch metric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly call bson
might be more clear for this case, i.e.,
bson(joinpath(cb.savepath, cb.filename), model=m, epoch=epoch, metric=metric)
Note that the keys of the contructed dictionary are different.
where m represent a chain of all the sub-models of the function. To access each individual model, use m.layers[i] for ith sub-model. | ||
|
||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to remove this empty line
verbose::I | ||
end | ||
|
||
function ModelCheckpoint(calc_metric, model_arr; descend=true, savepath="./", filename="model.bson", save_best_model_only=true, verbose=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- savepath="./"
+ savepath=pwd()
"./" doesn't work on windows
Perhaps just best_only
?
- `savepath='./'`: the path of the directory in which the model is to be saved. | ||
- `filename='model.bson'`: name of the file to which model is to be saved, must end with a '.bson' extension. | ||
- `save_best_model_only=true`: whether only the best model is to be saved or model at each improvement is to be saved in a seperate file | ||
- `verbose=1`: whether to display 'saving' message on model improvement or not. Can take values `0`=> no messsage or `1`=> display message. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's only a 0/1 difference, I prefer to use true
/false
for now.
|
||
mutable struct ModelCheckpoint{F,T,S,I} | ||
calc_metric::F | ||
model_arr::AbstractArray{T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Limiting it to AbstractArray doesn't look good to me. We still want to pass some manually created Chain
model here, right?
It can be done like this to fit your original idea: if it's an array, construct a Chain
for it, and if it isn't, do nothing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you suggested above, I guess rather doing any of the above things, its better if we simply save on the model as it is into a dict by bson, whether it's single Chain model or multiple models, and they will be retrived in the same manner as passed. So I think here there will be no need for Chain anymore. I will simply remove the Chain and AbstractArray part.
path = cb.savepath | ||
filename = cb.filename | ||
if cb.save_best_model_only ==1 | ||
@warn(" -> Monitored metric improved ! Saving model out to $path$filename") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use @info
instead of @warn
How about "imrpoved from $(cb.best_metric) to $metric"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah! that's a better idea. I will change that
I didn't look into the details of the PR, so just two general comments:
My opinion is that we shouldn't advertise the |
But I don't know exactly if this way of scheduling will be efficient or not. |
For 1), using something like throttle or some more efficient way like callback scheduler, we can minimise calling callbacks.
Yes, I agree that callbacks may require twice the forward pass. But I hope that forward pass may worth its extra cost, if we use it for validating our data, that is needed to be done in many models. |
Although it's not a blocking PR, I suggest holding this PR a while until #1017 gets merged and then decide if we really want this.
I think people won't like to use it unless the redundant computation issue can be solved. The redundant computation issue could be partially addressed if we can add a keyword arg
Speaking of flexibility and efficiency, I totally agree that But unless we decide to deprecate such usage, IMO there are still merits to provide some basic support for it. That said, not being very active in Flux community, I don't want to make such call :P |
closing as outdated |
I have added the following callback functions:
As there aren't any built-in callback functions existing in Flux, I hope these may be helpful in improving the Flux-user experience. Reference taken from keras.