Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added some Callback functions #1067

Closed
wants to merge 6 commits into from

Conversation

AdarshKumar712
Copy link
Contributor

@AdarshKumar712 AdarshKumar712 commented Mar 2, 2020

I have added the following callback functions:

  1. terminateOnNaN
  2. HistoryCallback
  3. lrdecay
  4. ModelCheckpoint

As there aren't any built-in callback functions existing in Flux, I hope these may be helpful in improving the Flux-user experience. Reference taken from keras.

Copy link
Contributor

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There're several things I think need to be answered/fixed before this PR considered ready for a second round review:

  • the usage of global variables is extremely unstable, especially when you're using lowercase names. You need to minimize the number of global variables, my suggestion is to wrap all of them into a dict.
  • is it possible to use all of these callbacks at the same time? E.g., will ctr be increased twice if you use two callbacks?
  • Can other users to reuse the data/structure of your init_cb to create new callbacks for their own projects without modifying Flux's codebase?
  • styles need to be corrected according to Julia style guide: indentation levels; TAB/spaces; docstrings.

src/callbacks.jl Outdated
history(x,y,accuracy)

Callback that records loss and accuracy into an array. Array can be accessed using `loss_history`.
Arguments:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Julia doesn't document inputs in this style; instead, you can use plain markdown list.

symbols need to be wrapped with `

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change that according to the documentation rules

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the keyword arguments be also wrapped with ` ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it that way requires the forward-pass twice in the iteration the callback is called. I'm not sure how to merge the Zygote pullback stuff with Flux callbacks though.
BTW, I think most uses of train! will terminate anyways when a NaN loss emerges.

src/callbacks.jl Outdated
monitor: Quantity to be monitored for the provided (x,y). Can take values 'acc' or 'loss'. Default set to 'loss'
"""

function lrdecay(x,y;factor = 0.2,loss = loss,accuracy_fn = nothing,patience=5,min_lr = 1e-5,monitor="loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of loss, accuracy_fun and monitor doesn't look right to me. I prefer

function lrdecay(optimizer, metric_fn, x, y; factor=0.2, patience=5, min_lr=1e-5, descend=true)
# by default it watches if metric_fn(x, y) descends.
end

My feel on model_checkpoint is the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also think now that, there should be just metric function for that. Thanks for the suggestion. I will make the likewise changes to model_checkpoint as well

src/callbacks.jl Outdated
best_acc = acc
last_improvement = ctr
elseif (ctr-last_improvement)>patience
if opt.eta>min_lr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is opt defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! It should have been a input parameter. Actually while testing, I took opt as my Optimizer that's why it didn't throw any error that time. I will add Optimizer as a parameter

src/callbacks.jl Outdated
accuracy_fn: function to be used to calculate accuracy, when monitor set to 'acc'. Otherwise, optional

patience: number of epochs that produced the monitored quantity with no improvement,
after which learning rate will be reduced or training stopped. Default value set to 5.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace TAB with four spaces here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change that according to documentation rules.

global best_acc = 0
global best_loss = Inf
global last_improvement = 0
end
Copy link
Contributor

@johnnychen94 johnnychen94 Mar 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of letting users call init_cb, you can define a "constant dict", let these callbacks to call and modify the dict at runtime.

const CALLBACK_VALUES = Dict()

function _init_cb()
    # BTW, what does ctr stands for?
    CALLBACK_VALUES["ctr"] = 0
end


function terminateOnNaN(x, y)
    haskey(CALLBACK_VALUES, "ctr") || _init_cb()

	...
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there can be multiple methods that can be used for training, thus there isn't a proper way to keep counter of how many times a callback a called. So defined this ctr. However as you say above that in case of multiple callbacks, this will get affected. I am thinking to use some other way to keep the counter. Is there any way possible to make the callbacks aware of multiple callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to give each of the callbacks with no. of epochs requirement a seperate counter like CALLBACK_VALUES["ctr_model_check"]?

This comment was marked as resolved.

@oxinabox
Copy link
Member

oxinabox commented Mar 3, 2020

The way I would do this is to have each callback hold state.
This lets us avoid using any globals,
which makes it much easier to reason about, more performance (in the non-const case),
and avoids issues e.g. with training two models at once and them both writing to same global.

For adhoc use I use closures.

function make_history_callback(calc_metric)
    history = [calc_metric()]  # store initial loss
    return ()->push!(history, calc_metric())
end

which I would use:

const test_x, test_y = ...
const train_data = ...
const model = Chain(...)
test_loss() = mse(test_y, model(test_x))

const history_callback = make_history_callback(test_loss)

# may have signature a bit wrong but I trust you get the point
Flux.train!(mse, Iterators.repeat(train_data, 100), model, params(model), ADAM(); callback=history_callback)

using Plots
plot(history_callback.history)

To be more explict than using closures would be to use functors.

struct HistoryCallback{F,T}
    calc_metric::F
    history::Vector{T}
end

HistoryCallback(calc_metric) = HistoryCallback(calc_metric, [calc_metric])

(cb::HistoryCallback)()= push!(cb.history, calc_metric())

Which i would use basically as before:

const history_callback = HistoryCallback(test_loss)

@johnnychen94
Copy link
Contributor

struct HistoryCallback{F,T}
   calc_metric::F
   history::Vector{T}
end

Apparently it should be done like this!

@AdarshKumar712
Copy link
Contributor Author

Thanks for the suggestions! I will make the changes accordingly.

@AdarshKumar712 AdarshKumar712 changed the title Added some Callback functions [WIP] Added some Callback functions Mar 4, 2020
@AdarshKumar712
Copy link
Contributor Author

@johnnychen94 I have updated the functions. Please have a look.
However, still some tasks are left.

  1. How to ensure that training should stop, once stop() is called from the functions?
  2. Documentation need improvement. Some lines are too long, need to split them.
    Can you please suggest how these tasks can be addressed?

Copy link
Contributor

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice, my comments are more about style.

There're two more concerns (not necessarily about this PR), that:

  • as @bhvieira commented, we may need to find a way to reduce unnecessary metric computation across multiple callbacks or even get forward-pass from Zygote. Currently, all metrics are computed multiple times even if they're the same.
  • Some callbacks don't need to be called so frequently, we may need a callback scheduler type to dynamically do nothing or call the callbacks, i.e., a callback that calls callbacks, or an improved version of throttle

For now, I think it's okay to get others involved to see if they find more things that need to be resolved. @dhairyagandhi96 @CarloLucibello

calc_metric::F
end

function (cb::terminateOnNaN)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore this comment if you don't like it.

I'm not sure how others feel, instead of creating calc_metri manually, we could introduce a helper constructor like this:

terminateOnNaN(metric, model, X, Y) = terminateOnNaN(() -> metric(model(X), Y))

so that users can also do:

cb1 = terminateOnNaN(mse, model, valid_X, valid_Y)

The same suggestion for other callbacks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was one of the things that I did previously, but I observed that in general callbacks are called without any arguments and generally the data is provided within the metric definition like
test_loss() = mse(test_y, model(test_x)) by @oxinabox, and even in documentation examples. However still, I guess considering your suggestion, we can provide a proper definition for calc_metric where we can still put suggestion for something like calc_metric =()-> metric(valid_x, valid_y, model) where metric can defined in your suggested manner. In that way we can provide a better intuition about calc_metric

terminateOnNaN(calc_metric)

Callback that terminates training when `NaN` is encountered for `calc_metric` value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to document what a valid calc_metric is, e.g., it doesn't accept any argument, it returns a number.

The same suggestion for other callbacks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add that part. I was thinking it would be better, if we provide more detail about this through examples in main documentation.

- `calc_metric`: the metric to be monitored. Evaluated at every call of callback function
- `opt`: the optimizer used while training
- `factor::Float64=0.2`: factor by which learning rate is reduced in every updation.
- `patience::Int=5`: number of epochs that produced the monitored 'calc_metric' with no improvement, after which learning rate will be reduced or training stopped.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally okay with these long descriptions, but if you find it uncomfortable, you can manually split them into multiple lines, I think it would be parsed nicely

- `patience::Int=5`: number of epochs that produced the monitored 'calc_metric' with
   no improvement, after which learning rate will be reduced or training stopped.

the four tailing spaces need to be removed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I guess that would be a better idea. I will split it that way, as I think it will be more readable that way, especially for someone who is going through the code. Otherwise I think both are parsed same way in documentation

min_lr::Float64
end

lrdecay(calc_metric, opt; factor=0.2, patience=5, descend = true, min_lr=1e-6) = lrdecay(calc_metric, opt, calc_metric(), 0, factor, 0, descend, patience, min_lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of lrdecay callback, do you know that there's already such functionality in Flux?

https://fluxml.ai/Flux.jl/stable/training/optimisers/#Composing-Optimisers-1

Copy link
Contributor Author

@AdarshKumar712 AdarshKumar712 Mar 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I haven't seen that earlier. So, should I remove this then?
In place, I think we can have something like LRscheduler, where user can provide custom functions to schedule the LR decay, over epochs.

Copy link
Contributor

@johnnychen94 johnnychen94 Mar 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR I think it's better to remove this, but if you feel the LRScheduler way more friendly to end-users, open another issue for that, although I think these duplicated functionalities are very likely to be rejected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will remove that. Thanks for pointing! :)

end
m = Chain(cb.model_arr...)
if cb.save_best_model_only==1
BSON.@save joinpath(cb.savepath, cb.filename) m epoch metric
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly call bson might be more clear for this case, i.e.,

bson(joinpath(cb.savepath, cb.filename), model=m, epoch=epoch, metric=metric)

Note that the keys of the contructed dictionary are different.

where m represent a chain of all the sub-models of the function. To access each individual model, use m.layers[i] for ith sub-model.

"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to remove this empty line

verbose::I
end

function ModelCheckpoint(calc_metric, model_arr; descend=true, savepath="./", filename="model.bson", save_best_model_only=true, verbose=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- savepath="./"
+ savepath=pwd()

"./" doesn't work on windows

Perhaps just best_only?

- `savepath='./'`: the path of the directory in which the model is to be saved.
- `filename='model.bson'`: name of the file to which model is to be saved, must end with a '.bson' extension.
- `save_best_model_only=true`: whether only the best model is to be saved or model at each improvement is to be saved in a seperate file
- `verbose=1`: whether to display 'saving' message on model improvement or not. Can take values `0`=> no messsage or `1`=> display message.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's only a 0/1 difference, I prefer to use true/false for now.


mutable struct ModelCheckpoint{F,T,S,I}
calc_metric::F
model_arr::AbstractArray{T}
Copy link
Contributor

@johnnychen94 johnnychen94 Mar 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limiting it to AbstractArray doesn't look good to me. We still want to pass some manually created Chain model here, right?

It can be done like this to fit your original idea: if it's an array, construct a Chain for it, and if it isn't, do nothing.

Copy link
Contributor Author

@AdarshKumar712 AdarshKumar712 Mar 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you suggested above, I guess rather doing any of the above things, its better if we simply save on the model as it is into a dict by bson, whether it's single Chain model or multiple models, and they will be retrived in the same manner as passed. So I think here there will be no need for Chain anymore. I will simply remove the Chain and AbstractArray part.

path = cb.savepath
filename = cb.filename
if cb.save_best_model_only ==1
@warn(" -> Monitored metric improved ! Saving model out to $path$filename")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @info instead of @warn

How about "imrpoved from $(cb.best_metric) to $metric"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! that's a better idea. I will change that

@CarloLucibello
Copy link
Member

I didn't look into the details of the PR, so just two general comments:

  1. callbacks are called after each minibatch update (if doing SGD and no full-batch updates)
  2. the callback architecture in some cases implies doing twice the computation of the forward pass, which is highly inefficient

My opinion is that we shouldn't advertise the train! and callback approach at all, in favor of simple for loops. We shouldn't advertise bad coding practices. That said, if @johnnychen94 or other people think this kind of functionality is useful, and since some effort has gone into this PR, we should definitely have it. Just let's be mindful of 1) and 2)

@AdarshKumar712
Copy link
Contributor Author

AdarshKumar712 commented Mar 9, 2020

Looks nice, my comments are more about style.

There're two more concerns (not necessarily about this PR), that:

* as @bhvieira commented, we may need to find a way to reduce unnecessary metric computation across multiple callbacks or even get forward-pass from Zygote. Currently, all metrics are computed multiple times even if they're the same.

* Some callbacks don't need to be called so frequently, we may need a callback scheduler type to dynamically do nothing or call the callbacks, i.e., a callback that calls callbacks.
  • For the first part, I think we can have a wrapper function, please tell me if that would be meaningful or not. I have something like this in my mind.
function callback(metric, cb_list)
    metric_val = metric()
    for each_cb in cb_list
        cblist(metric_val)
    end
end

cb_list = [...]    #list of callbacks
cb = callback(metric, cb_list)
  • For second point, I think we can achieve that using throttle for individual functions. Like
cb_list = [throttle(f1, t1), throttle(f2, t2)....]

But I don't know exactly if this way of scheduling will be efficient or not.

@AdarshKumar712
Copy link
Contributor Author

AdarshKumar712 commented Mar 9, 2020

I didn't look into the details of the PR, so just two general comments:

1. callbacks are called after_each_ minibatch update (if doing SGD and no full-batch updates)

For 1), using something like throttle or some more efficient way like callback scheduler, we can minimise calling callbacks.

2. the callback architecture in some cases implies doing twice the computation of the forward pass, which is highly inefficient 

My opinion is that we shouldn't advertise the train! and callback approach at all, in favor of simple for loops. We shouldn't advertise bad coding practices. That said, if @johnnychen94 or other people think this kind of functionality is useful, and since some effort has gone into this PR, we should definitely have it. Just let's be mindful of 1) and 2)

Yes, I agree that callbacks may require twice the forward pass. But I hope that forward pass may worth its extra cost, if we use it for validating our data, that is needed to be done in many models.
Also, I think these functions aren't just meant for train! function but can also be equally utilised in simple custom for loops while validating the data. I hope they could be used more as a functionality to support easy validation and further provide checks on that validation(like saving best model).

@johnnychen94
Copy link
Contributor

johnnychen94 commented Mar 10, 2020

Although it's not a blocking PR, I suggest holding this PR a while until #1017 gets merged and then decide if we really want this.


Also, I think these functions aren't just meant for train! function but can also be equally utilised in simple custom for loops while validating the data.

I think people won't like to use it unless the redundant computation issue can be solved.

The redundant computation issue could be partially addressed if we can add a keyword arg cache=nothing to callbacks so that if a pre-computed cache is passed, skip the calculation of model(x). And it would be nice if we can make cache passing more automatically. I'm out of time to think of a design, but I would think it as a possible solution.

My opinion is that we shouldn't advertise the train! and callback approach at all, in favor of simple for loops. We shouldn't advertise bad coding practices.

Speaking of flexibility and efficiency, I totally agree that train! isn't prepared for power users, and this PR is likely to not benefit advanced usages at all. After all, how hard would it be to manually write a custom one?

But unless we decide to deprecate such usage, IMO there are still merits to provide some basic support for it. That said, not being very active in Flux community, I don't want to make such call :P

@CarloLucibello
Copy link
Member

closing as outdated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants