Skip to content

Commit

Permalink
Merge pull request #1228 from longjon/solver-step
Browse files Browse the repository at this point in the history
Refactor Solver to allow interactive stepping
  • Loading branch information
jeffdonahue committed Jan 7, 2015
2 parents b8715c6 + 10c2364 commit c6a88bf
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 42 deletions.
10 changes: 4 additions & 6 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Solver {
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
Expand All @@ -34,9 +35,6 @@ class Solver {
int iter() { return iter_; }

protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
Expand Down Expand Up @@ -73,14 +71,14 @@ template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) {}
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) {}
: Solver<Dtype>(param_file) { PreSolve(); }

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
virtual void PreSolve();
void PreSolve();
Dtype GetLearningRate();
virtual void ComputeUpdateValue();
virtual void SnapshotSolverState(SolverState * state);
Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("test_nets", &PySGDSolver::test_nets)
.add_property("iter", &PySGDSolver::iter)
.def("solve", &PySGDSolver::Solve)
.def("solve", &PySGDSolver::SolveResume);
.def("solve", &PySGDSolver::SolveResume)
.def("step", &PySGDSolver::Step);

bp::class_<vector<shared_ptr<PyNet> > >("NetVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<PyNet> >, true>());
Expand Down
1 change: 1 addition & 0 deletions python/caffe/_caffe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class PySGDSolver {
vector<shared_ptr<PyNet> > test_nets() { return test_nets_; }
int iter() { return solver_->iter(); }
void Solve() { return solver_->Solve(); }
void Step(int iters) { solver_->Step(iters); }
void SolveResume(const string& resume_file);

protected:
Expand Down
73 changes: 38 additions & 35 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
LOG(INFO) << "Initializing solver from parameters: " << std::endl
<< param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
// Scaffolding code
InitTrainNet();
InitTestNets();
LOG(INFO) << "Solver scaffolding done.";
iter_ = 0;
current_step_ = 0;
}

template <typename Dtype>
Expand Down Expand Up @@ -155,39 +158,15 @@ void Solver<Dtype>::InitTestNets() {
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();

iter_ = 0;
current_step_ = 0;
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
// Remember the initial iter_ value; will be non-zero if we loaded from a
// resume_file above.
void Solver<Dtype>::Step(int iters) {
vector<Blob<Dtype>*> bottom_vec;
const int start_iter = iter_;

const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();

CHECK_GE(average_loss, 1) << "average_loss should be non-negative.";

vector<Dtype> losses;
Dtype smoothed_loss = 0;

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
for (; iter_ < param_.max_iter(); ++iter_) {
// Save a snapshot if needed.
if (param_.snapshot() && iter_ > start_iter &&
iter_ % param_.snapshot() == 0) {
Snapshot();
}

for (; iter_ < stop_iter; ++iter_) {
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {
TestAll();
Expand Down Expand Up @@ -227,13 +206,36 @@ void Solver<Dtype>::Solve(const char* resume_file) {
}
}
}

ComputeUpdateValue();
net_->Update();

// Save a snapshot if needed.
if (param_.snapshot() && (iter_ + 1) % param_.snapshot() == 0) {
Snapshot();
}
}
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
// Always save a snapshot after optimization, unless overridden by setting
// snapshot_after_train := false.
if (param_.snapshot_after_train()) { Snapshot(); }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -242,7 +244,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == 0) {
Dtype loss;
net_->Forward(bottom_vec, &loss);
net_->ForwardPrefilled(&loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
Expand Down Expand Up @@ -328,14 +330,15 @@ void Solver<Dtype>::Snapshot() {
string model_filename, snapshot_filename;
const int kBufferSize = 20;
char iter_str_buffer[kBufferSize];
snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_);
// Add one to iter_ to get the number of iterations that have completed.
snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1);
filename += iter_str_buffer;
model_filename = filename + ".caffemodel";
LOG(INFO) << "Snapshotting to " << model_filename;
WriteProtoToBinaryFile(net_param, model_filename.c_str());
SolverState state;
SnapshotSolverState(&state);
state.set_iter(iter_);
state.set_iter(iter_ + 1);
state.set_learned_net(model_filename);
state.set_current_step(current_step_);
snapshot_filename = filename + ".solverstate";
Expand Down

0 comments on commit c6a88bf

Please sign in to comment.