Skip to content

Commit

Permalink
break out Step from Solver
Browse files Browse the repository at this point in the history
  • Loading branch information
longjon committed Dec 31, 2014
1 parent b8715c6 commit 291a2c2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
5 changes: 4 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Solver {
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
Expand All @@ -36,7 +37,7 @@ class Solver {
protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
virtual void PreSolve();
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
Expand All @@ -57,9 +58,11 @@ class Solver {

SolverParameter param_;
int iter_;
int start_iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
bool initialized_;

DISABLE_COPY_AND_ASSIGN(Solver);
};
Expand Down
61 changes: 36 additions & 25 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ Solver<Dtype>::Solver(const string& param_file)

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
initialized_ = false;
LOG(INFO) << "Initializing solver from parameters: " << std::endl
<< param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss, 1) << "average_loss should be non-negative.";
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
Expand Down Expand Up @@ -155,35 +157,20 @@ void Solver<Dtype>::InitTestNets() {
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();

iter_ = 0;
current_step_ = 0;
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
void Solver<Dtype>::Step(int iters) {
if (!initialized_) {
PreSolve();
}
// Remember the initial iter_ value; will be non-zero if we loaded from a
// resume_file above.
const int start_iter = iter_;

vector<Blob<Dtype>*> bottom_vec;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();

CHECK_GE(average_loss, 1) << "average_loss should be non-negative.";

vector<Dtype> losses;
Dtype smoothed_loss = 0;

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
for (; iter_ < param_.max_iter(); ++iter_) {
for (; iter_ < stop_iter; ++iter_) {
// Save a snapshot if needed.
if (param_.snapshot() && iter_ > start_iter &&
if (param_.snapshot() && iter_ > start_iter_ &&
iter_ % param_.snapshot() == 0) {
Snapshot();
}
Expand Down Expand Up @@ -227,10 +214,33 @@ void Solver<Dtype>::Solve(const char* resume_file) {
}
}
}

ComputeUpdateValue();
net_->Update();
}
}

template <typename Dtype>
void Solver<Dtype>::PreSolve() {
initialized_ = true;
start_iter_ = iter_ = 0;
current_state_ = 0;
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

PreSolve();
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
Step(param_.max_iter() - iter_);
// Always save a snapshot after optimization, unless overridden by setting
// snapshot_after_train := false.
if (param_.snapshot_after_train()) { Snapshot(); }
Expand All @@ -242,7 +252,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == 0) {
Dtype loss;
net_->Forward(bottom_vec, &loss);
net_->ForwardPrefilled(&loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
Expand Down Expand Up @@ -352,7 +362,7 @@ void Solver<Dtype>::Restore(const char* state_file) {
ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
net_->CopyTrainedLayersFrom(net_param);
}
iter_ = state.iter();
start_iter_ = iter_ = state.iter();
current_step_ = state.current_step();
RestoreSolverState(state);
}
Expand Down Expand Up @@ -414,6 +424,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {

template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
Solver<Dtype>::PreSolve();
// Initialize the history
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
history_.clear();
Expand Down

0 comments on commit 291a2c2

Please sign in to comment.