-
Notifications
You must be signed in to change notification settings - Fork 18.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added early stopping to the solver #76
Conversation
@@ -0,0 +1,68 @@ | |||
// Copyright 2013 Yangqing Jia |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should have your copyright.
Looks good to me. Builds and passes tests. Any comments @Yangqing ? |
It looks good to me. I'm thinking adding something similar to reduce the Sergio
|
How to extend TerminationCriterion to support more criteria? The current design requires at least 4 steps as follows.
There are a few improvements worth considering.
template <typename Dtype>
class TerminationCriterion {
public:
TerminationCriterion(Dtype threshold):
criterion_met_(false), src/caffe/solver.cpp
threshold_(threshold) {}
virtual bool IsCriterionMet() {return criterion_met_;}
virtual void Notify(Dtype value) = 0;
protected:
bool criterion_met_;
const Dtype threshold_;
};
template <typename Dtype>
class MaxIterTerminationCriterion : public TerminationCriterion<Dtype> {
public:
MaxIterTerminationCriterion(Dtype threshold):
TerminationCriterion<Dtype>(threshold) {}
virtual void Notify(Dtype value) {
criterion_met_ = value > threshold_;
}
};
template <typename Dtype>
class TestAccuracyTerminationCriterion : public TerminationCriterion<Dtype> {
public:
TestAccuracyTerminationCriterion(Dtype threshold):
TerminationCriterion<Dtype>(threshold) {}
virtual void Notify(Dtype value) {
if (best_accuracy_ < value) {
best_accuracy_ = value;
count_down_ = threshold_;
} else {
--count_down_;
criterion_met_ = count_down_ <= 0;
}
}
protected:
Dtype best_accuracy_;
uint32_t count_down_;
};
// include/caffe/solver.hpp
template<typename Dtype>
class TerminationCriterionFactory {
public:
static
shared<TerminationCriterion<Dtype> > GetTerminationCriterion(TerminationCriterionType type);
}
// src/caffe/solver.cpp
template<typename Dtype>
shared<TerminationCriterion<Dtype> >
TerminationCriterionFactory::GetTerminationCriterion(const TerminationCriterionType type, const SolverParameter& param) {
shared<TerminationCriterion<Dtype> > termination_criterion;
switch (type) {
case SolverParameter::MAX_ITER: {
termination_criterion.reset(new MaxIterTerminationCriterion<Dtype >(param.max_iter()));
break;
} case SolverParameter::TEST_ACCURACY: {
CHECK(param.has_test_net()) << "Test network needed for TestAccuracyTerminationCriterion.";
termination_criterion.reset(new TestAccuracyTerminationCriterion<Dtype >(param.test_accuracy_stop_countdown()));
break;
} default: {
LOG(ERROR) << "Unknown TerminationCriterionType " << type;
break;
}
return termination_criterion;
}
template<typaname Dtype>
class Solver {
public:
...
void Solve(...) {
do {
++iter;
...
PostIteration();
} while(!GetShouldStop());
}
void PostIteration();
bool GetShouldStop();
void SetShouldStop(const bool stop);
Dtype GetTestAccuracy();
protected:
bool shouldStop_;
}
template<typaname Dtype>
class TestAccuracyEarlyStoppingSolver: public Solver<Dtype> {
public:
public:
TestAccuracyEarlyStoppingSolver():
criterion_(TerminationCriterionFactory::GetTerminationCriterion(SolverParameter::TEST_ACCURACY)) {}
void PostIteration() {
...
criterion_->Notify(GetTestAccuracy());
SetShouldStop(criterion_->IsCriterionMet());
}
private:
shared_ptr<TerminationCriterion<Type> > criterion_;
} |
I agree that the notify* will be a bit a bottle neck here and if other termination criteria will be added in the future, you need to add new notifications and add new functions to each criterion. I don't think subclassing won’t make things easier though. You will not only need a factory for the TerminationCriterion, but also a Factory for the different Solvers. And also: Why subclass according to TerminationCriterion, but not for example according to learning rate policy? Another solution I could see is:
So why not the following: template <typename Dtype>
class TerminationCriterion {
public:
TerminationCriterion(Dtype threshold):
criterion_met_(false), src/caffe/solver.cpp
threshold_(threshold) {}
virtual bool IsCriterionMet() {return criterion_met_;}
virtual void Notify(const char* type, Dtype value) = 0;
protected:
bool criterion_met_;
const Dtype threshold_;
}; This way it's still clear when calling This way:
|
If we do not subclass Solver, maybe Solver should just hold both MaxIterTerminationCriterion and TestAccuracyTerminationCriterion whose default thresholds never end the training. Before new criterion is really added, keep the design as simple as possible. class Solver {
Solve() {
max_iter_termination_criterion_.Notify(iter_);
test_accuracy_termination_criterion_.Notify(test_accuracy);
}
MaxIterTerminationCriterion max_iter_termination_criterion_;
TestAccuracyTerminationCriterion test_accuracy_termination_criterion_;
} |
I'm not sure why that should improve anything. Your would need a switch
|
Sorry, the above code snippet is too simple and a bit confusing. I meant to use both the criteria in the solver so that whichever is met earlier the training stops. The details can be seen in the newly added line notes. Basically, there is no switch and only a few changes shortening the Notify methods to simplify extension of TerminationCriterion. |
If you want more than one |
It smells that the requirement is ballooning. Bad signal. I would like this issue not to stop making progress. You might want the simpler version to be merged first and let it be improved in a future PR. Thanks for your commit! |
Thanks all for your comments. I am leaning towards a much simpler solution, instead of heavily factoring https://github.com/BVLC/caffe/blob/master/src/caffe/solver.cpp#L59 to something like
and add a function ShouldExit directly in solver, rather have a separate As a concise example, please take a look at the LayerFactory that is https://github.com/BVLC/caffe/blob/master/src/caffe/layer_factory.cpp#L20 I am not proud of it - honestly, with 20 if statements we are probably Yangqing On Mon, Feb 10, 2014 at 7:14 AM, kloudkl notifications@github.com wrote:
|
This has been a good discussion to keep in mind as we outfit Caffe with more features and generality. As Yangqing outlined, for research code simplest-first is a helpful motto. We can always refine further along the way and put to rest speculation in the meantime. @tdomhan, I am for the inclusion of early stopping, and would be happy to merge this PR rewritten along the lines of @Yangqing's simple proposal or your refactoring suggested in #76 (comment). Thanks for your work on this and thanks everyone for the discussion. |
I also like to keep things simple, but at the same time I think it's nice to include the termination criterion in the unit tests, without building up a whole solver instance. |
@kloudkl well one enum is for selecting different criteria and the other one specifies what information is passed to the object. I first just wanted to pass a string as I said above, but then I thought it might be nicer to do it with an enum. |
@shelhamer alright, I'll implement early stopping as suggested as soon as I find some time. Should I maybe directly base this off the branch |
According to the new development guideline, new PR should generally base on the |
Thanks, that makes sense.
|
@tdomhan for clarity and maintainability is better to make different PR for new functionalities. However you can get inspired in my PR #190. Probably I will make the changes keeping in mind that you plan to this PR. Since for instance the list of Test accuracies is needed for early stop and early step. |
Closing as this seems to have been abandoned after the refactor suggestion -- @tdomhan, if you resume work please comment and I can reopen. (Or anyone else who wants to take over may open a new PR.) |
I added early stopping to the solver. You can now choose whether caffe should stop after a maximum number of iterations or alternatively if the test accuracy doesn't improve for a given number of tries.