Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added early stopping to the solver #76

Closed
wants to merge 5 commits into from

Conversation

tdomhan
Copy link
Contributor

@tdomhan tdomhan commented Feb 6, 2014

I added early stopping to the solver. You can now choose whether caffe should stop after a maximum number of iterations or alternatively if the test accuracy doesn't improve for a given number of tries.

@@ -0,0 +1,68 @@
// Copyright 2013 Yangqing Jia
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have your copyright.

@shelhamer
Copy link
Member

Looks good to me. Builds and passes tests. Any comments @Yangqing ?

@sguada
Copy link
Contributor

sguada commented Feb 6, 2014

It looks good to me. I'm thinking adding something similar to reduce the
learning rate when the val accuracy don't change, rather than on fixed
steps.

Sergio
On Feb 6, 2014 10:43 AM, "Evan Shelhamer" notifications@github.com wrote:

Looks good to me. Builds and passes tests. Any comments @Yangqinghttps://github.com/Yangqing?

Reply to this email directly or view it on GitHubhttps://github.com//pull/76#issuecomment-34355660
.

@kloudkl
Copy link
Contributor

kloudkl commented Feb 7, 2014

How to extend TerminationCriterion to support more criteria? The current design requires at least 4 steps as follows.

  1. Add more methods like NotifyTestAccuracy and NotifyIteration to TerminationCriterion and all of its subclasses, and add new subclasses for the new criteria;
  2. Add new enum TerminationCriterion value and config for the new criteria in caffe.proto;
  3. Add if conditions in the constructor of Solver to handle the initialization of the new type of TerminationCriterion;
  4. Add the appropriate Notify* at the right place in Solver::Solve;

There are a few improvements worth considering.

  1. Condense the Notify* methods;
template <typename Dtype>
 class TerminationCriterion {
 public:
   TerminationCriterion(Dtype threshold): 
      criterion_met_(false), src/caffe/solver.cpp
      threshold_(threshold) {}

   virtual bool IsCriterionMet() {return criterion_met_;}

   virtual void Notify(Dtype value) = 0;
 protected:
   bool criterion_met_;
   const Dtype threshold_;
 };

template <typename Dtype>
class MaxIterTerminationCriterion : public TerminationCriterion<Dtype> {
public:
    MaxIterTerminationCriterion(Dtype threshold): 
        TerminationCriterion<Dtype>(threshold) {}

    virtual void Notify(Dtype value) {
        criterion_met_ = value > threshold_;
    }
};

template <typename Dtype>
class TestAccuracyTerminationCriterion : public TerminationCriterion<Dtype> {
public:
    TestAccuracyTerminationCriterion(Dtype threshold): 
        TerminationCriterion<Dtype>(threshold) {}

    virtual void Notify(Dtype value) {
        if (best_accuracy_ < value) {
            best_accuracy_ = value;
            count_down_ = threshold_;
        } else {
            --count_down_;
            criterion_met_ = count_down_ <= 0;
        }
    }
protected:
    Dtype best_accuracy_;
    uint32_t count_down_;
};
  1. Rename enum TerminationCriterion to be enum TerminationCriterionType;
  2. Use factory pattern to produce the desired TerminationCriterion class;
// include/caffe/solver.hpp
template<typename Dtype>
class TerminationCriterionFactory {
public:
    static 
shared<TerminationCriterion<Dtype> > GetTerminationCriterion(TerminationCriterionType type);
}
// src/caffe/solver.cpp
template<typename Dtype>
shared<TerminationCriterion<Dtype> >
TerminationCriterionFactory::GetTerminationCriterion(const TerminationCriterionType type, const SolverParameter& param) {
    shared<TerminationCriterion<Dtype> > termination_criterion;
    switch (type) {
    case SolverParameter::MAX_ITER: {
        termination_criterion.reset(new MaxIterTerminationCriterion<Dtype >(param.max_iter()));
        break;
    } case SolverParameter::TEST_ACCURACY: {
        CHECK(param.has_test_net()) << "Test network needed for TestAccuracyTerminationCriterion.";
        termination_criterion.reset(new TestAccuracyTerminationCriterion<Dtype >(param.test_accuracy_stop_countdown()));
         break;
    } default: {
        LOG(ERROR) << "Unknown TerminationCriterionType " << type;
        break;
    }
    return termination_criterion;
}
  1. Subclass to add new criterion to Solver;
template<typaname Dtype>
class Solver {
public:
    ...
    void Solve(...) {
        do {
            ++iter;
            ...
            PostIteration();
        } while(!GetShouldStop());
    }
    void PostIteration();   
    bool GetShouldStop();  
    void SetShouldStop(const bool stop); 
    Dtype GetTestAccuracy();
protected:
    bool shouldStop_;
}

template<typaname Dtype>
class TestAccuracyEarlyStoppingSolver: public Solver<Dtype> {
public:
    public:
        TestAccuracyEarlyStoppingSolver(): 
            criterion_(TerminationCriterionFactory::GetTerminationCriterion(SolverParameter::TEST_ACCURACY)) {}
        void PostIteration() {
            ...
            criterion_->Notify(GetTestAccuracy());
            SetShouldStop(criterion_->IsCriterionMet());
        }
private:
    shared_ptr<TerminationCriterion<Type> > criterion_;
}

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 7, 2014

I agree that the notify* will be a bit a bottle neck here and if other termination criteria will be added in the future, you need to add new notifications and add new functions to each criterion.

I don't think subclassing won’t make things easier though. You will not only need a factory for the TerminationCriterion, but also a Factory for the different Solvers. And also: Why subclass according to TerminationCriterion, but not for example according to learning rate policy?
So I think this will make things overly complicated, especially if the solver will be subclasses for other reasons in the future.

Another solution I could see is:

  • remove the notify function
  • each Termination Criterion knows the Solver instances
  • get the information needed through e.g. solver.GetCurrentIteration(), solver.GetLastTestError()
    The problem here is of course that the test error doesn’t chance in each iteration and you need to take this into account.

So why not the following:

template <typename Dtype>
 class TerminationCriterion {
 public:
   TerminationCriterion(Dtype threshold): 
      criterion_met_(false), src/caffe/solver.cpp
      threshold_(threshold) {}

   virtual bool IsCriterionMet() {return criterion_met_;}

   virtual void Notify(const char* type, Dtype value) = 0;
 protected:
   bool criterion_met_;
   const Dtype threshold_;
 };

This way it's still clear when calling Notify, what we are actually notifying about, e.g. termination_criterion_.Notify("test_accuracy", 0.5)

This way:

  • things are kept simple
  • adding a new TerminationCriterion will not result in adding any unnecessary functions to the existing ones
  • each TerminationCriterion can pick what information it is interested in and ignore the rest

@kloudkl
Copy link
Contributor

kloudkl commented Feb 7, 2014

If we do not subclass Solver, maybe Solver should just hold both MaxIterTerminationCriterion and TestAccuracyTerminationCriterion whose default thresholds never end the training. Before new criterion is really added, keep the design as simple as possible.

class Solver {
    Solve() {
          max_iter_termination_criterion_.Notify(iter_);
          test_accuracy_termination_criterion_.Notify(test_accuracy);
    }

    MaxIterTerminationCriterion max_iter_termination_criterion_;
    TestAccuracyTerminationCriterion test_accuracy_termination_criterion_;
}

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 7, 2014

I'm not sure why that should improve anything. Your would need a switch
statement everywhere you use the termination criterion.
On Feb 7, 2014 2:02 PM, "kloudkl" notifications@github.com wrote:

If we do not subclass Solver, maybe Solver should just hold both
MaxIterTerminationCriterion and TestAccuracyTerminationCriterion. Before
new criterion is really added, keep the design as simple as possible.

class Solver {
Solve() {
max_iter_termination_criterion_.Notify(iter_);
test_accuracy_termination_criterion_.Notify(test_accuracy);
}

MaxIterTerminationCriterion max_iter_termination_criterion_;
TestAccuracyTerminationCriterion test_accuracy_termination_criterion_;}

Reply to this email directly or view it on GitHubhttps://github.com//pull/76#issuecomment-34434516
.

@kloudkl
Copy link
Contributor

kloudkl commented Feb 8, 2014

Sorry, the above code snippet is too simple and a bit confusing. I meant to use both the criteria in the solver so that whichever is met earlier the training stops. The details can be seen in the newly added line notes. Basically, there is no switch and only a few changes shortening the Notify methods to simplify extension of TerminationCriterion.

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 10, 2014

If you want more than one TerminationCriterion, I'd add AndCombinedTerminationCriterion and OrCombinedTerminationCriterion. Both take a vector of TerminationCriterions and compute the logical and/or of their members. In the protobuf file the TerminationCriterion would become repeated and another field will need to be added, to determine how to combine them.

@kloudkl
Copy link
Contributor

kloudkl commented Feb 10, 2014

It smells that the requirement is ballooning. Bad signal. I would like this issue not to stop making progress. You might want the simpler version to be merged first and let it be improved in a future PR. Thanks for your commit!

@Yangqing
Copy link
Member

Thanks all for your comments.

I am leaning towards a much simpler solution, instead of heavily factoring
a mechanism that in the end will implement only a few criteria. The code is
research code anyway, and we do not want a new PhD student to come and have
to dig up Design Pattern from his/her undergrad courses before being able
to track down things. Thus, I think a simpler way is to just extend the
Solver, change the while statement

https://github.com/BVLC/caffe/blob/master/src/caffe/solver.cpp#L59

to something like

while (ShouldExit())

and add a function ShouldExit directly in solver, rather have a separate
TerminationCriterion class. My argument for this is that there are - in the
forseeable future - maybe 5 criteria tops? Something like (a) wall-time (b)
iteration (c) accuracy. We could keep things simple first and then, when we
hit 5 or 10 criteria, start thinking about refactoring.

As a concise example, please take a look at the LayerFactory that is
currently used in caffe:

https://github.com/BVLC/caffe/blob/master/src/caffe/layer_factory.cpp#L20

I am not proud of it - honestly, with 20 if statements we are probably
hitting a too-ugly-need-refactoring point. But it is just my feeling that
things should start simple, e.g. in this case.

Yangqing

On Mon, Feb 10, 2014 at 7:14 AM, kloudkl notifications@github.com wrote:

It smells that the requirement is ballooning. Bad signal. I would like
this issue not to stop making progress. You might want the simpler version
to be merged first and let it be improved in a future PR. Thanks for your
commit!

Reply to this email directly or view it on GitHubhttps://github.com//pull/76#issuecomment-34642357
.

@shelhamer
Copy link
Member

This has been a good discussion to keep in mind as we outfit Caffe with more features and generality.

As Yangqing outlined, for research code simplest-first is a helpful motto. We can always refine further along the way and put to rest speculation in the meantime.

@tdomhan, I am for the inclusion of early stopping, and would be happy to merge this PR rewritten along the lines of @Yangqing's simple proposal or your refactoring suggested in #76 (comment). Thanks for your work on this and thanks everyone for the discussion.

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 13, 2014

I also like to keep things simple, but at the same time I think it's nice to include the termination criterion in the unit tests, without building up a whole solver instance.
@shelhamer I refactored according the the above comment.

@tdomhan
Copy link
Contributor Author

tdomhan commented Feb 25, 2014

@kloudkl well one enum is for selecting different criteria and the other one specifies what information is passed to the object. I first just wanted to pass a string as I said above, but then I thought it might be nicer to do it with an enum.

@shelhamer
Copy link
Member

@tdomhan thanks again for this suggestion, but please implement according to @Yangqing's suggestion: code it directly into the solver as ifs/elses. Plus, please rebase on the latest dev since this is not currently a clean merge.

@tdomhan
Copy link
Contributor Author

tdomhan commented Mar 15, 2014

@shelhamer alright, I'll implement early stopping as suggested as soon as I find some time. Should I maybe directly base this off the branch sguada:new_lr_policies (aka #190 )?

@kloudkl
Copy link
Contributor

kloudkl commented Mar 15, 2014

According to the new development guideline, new PR should generally base on the dev branch. Because there is no way to change the target of an existing PR, you may want to open a new one to replace this one.

@tdomhan
Copy link
Contributor Author

tdomhan commented Mar 15, 2014

Thanks, that makes sense.
On Mar 15, 2014 12:30 PM, "kloudkl" notifications@github.com wrote:

According to the new development guideline, new PR should generally base
on the dev branch. Because there is no way to change the target of an
existing PR, you may want to open a new one to replace this one.

Reply to this email directly or view it on GitHubhttps://github.com//pull/76#issuecomment-37723511
.

@sguada
Copy link
Contributor

sguada commented Mar 16, 2014

@tdomhan for clarity and maintainability is better to make different PR for new functionalities. However you can get inspired in my PR #190. Probably I will make the changes keeping in mind that you plan to this PR. Since for instance the list of Test accuracies is needed for early stop and early step.

@jeffdonahue
Copy link
Contributor

Closing as this seems to have been abandoned after the refactor suggestion -- @tdomhan, if you resume work please comment and I can reopen. (Or anyone else who wants to take over may open a new PR.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants