Skip to content

Commit

Permalink
Add signal handler and early exit/snapshot to Solver.
Browse files Browse the repository at this point in the history
Add signal handler and early exit/snapshot to Solver.

Add signal handler and early exit/snapshot to Solver.

Also check for exit and snapshot when testing.

Skip running test after early exit.

Fix more lint.

Rebase on master.
  • Loading branch information
jyegerlehner committed May 18, 2015
1 parent 352aef4 commit 17baa33
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 11 deletions.
22 changes: 21 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_

#include <boost/function.hpp>
#include <string>
#include <vector>

#include "caffe/net.hpp"

namespace caffe {

/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverParameter_Action()> ActionCallback;

/**
* @brief An interface for classes that perform optimization on Net%s.
*
Expand All @@ -22,6 +27,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();

// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverParameter_Action GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
Expand All @@ -47,6 +58,8 @@ class Solver {
// written to disk together with the learned net.
void Snapshot();
// The test routine
// stop_was_requested will be set to true iff a request to stop training
// was received whilst testing.
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(SolverState* state) = 0;
Expand All @@ -59,6 +72,13 @@ class Solver {
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;

// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_;

// True iff a request to stop early was received.
bool requested_early_exit_;

DISABLE_COPY_AND_ASSIGN(Solver);
};

Expand Down
24 changes: 24 additions & 0 deletions include/caffe/util/signal_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_

#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"

namespace caffe {

class SignalHandler {
public:
// Contructor. Specify what action to take when a signal is received.
SignalHandler(SolverParameter_Action SIGINT_action,
SolverParameter_Action SIGHUP_action);
ActionCallback GetActionFunction();
private:
SignalHandler(); // Not implemented.
SolverParameter_Action CheckForSignals() const;
SolverParameter_Action SIGINT_action_;
SolverParameter_Action SIGHUP_action_;
};

} // namespace caffe

#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
13 changes: 13 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ message SolverParameter {

// If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [default = true];

// Enumeration of actions that a client of the Solver may request by
// implementing the Solver's action request function, which a
// a client may optionally provide in order to request early termination
// or saving a snapshot without exiting. In the executable caffe, this
// mechanism is used to allow the snapshot to be saved when stopping
// execution with a SIGINT (Ctrl-C).
enum Action {
NONE = 0; // Take no special action.
STOP = 1; // Stop training. snapshot_after_train controls whether a snapshot
// is created.
SNAPSHOT = 2; // Take a snapshot, and keep training.
}
}

// A message that stores the solver snapshots
Expand Down
70 changes: 61 additions & 9 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,31 @@

namespace caffe {

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}

template<typename Dtype>
SolverParameter_Action Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverParameter_Action_NONE;
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param)
: net_() {
: net_(),
requested_early_exit_(false) {
Init(param);
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
: net_() {
: net_(),
requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
Expand Down Expand Up @@ -167,10 +183,14 @@ void Solver<Dtype>::Step(int iters) {
vector<Dtype> losses;
Dtype smoothed_loss = 0;

while (iter_ < stop_iter) {
while (iter_ < stop_iter) {
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}

const bool display = param_.display() && iter_ % param_.display() == 0;
Expand Down Expand Up @@ -214,10 +234,18 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;

SolverParameter_Action request = GetRequestedAction();

// Save a snapshot if needed.
if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
if ((param_.snapshot() && iter_ % param_.snapshot() == 0) ||
(request == SolverParameter_Action_SNAPSHOT)) {
Snapshot();
}
if (SolverParameter_Action_STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

Expand All @@ -226,6 +254,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

// Initialize to false every time we start solving.
requested_early_exit_ = false;

if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
Expand All @@ -240,6 +271,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -257,10 +292,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}


template <typename Dtype>
void Solver<Dtype>::TestAll() {
for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
Expand All @@ -277,6 +313,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverParameter_Action request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
while (request != SolverParameter_Action_NONE) {
if (SolverParameter_Action_SNAPSHOT == request) {
Snapshot();
} else if (SolverParameter_Action_STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}

Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
Expand All @@ -301,6 +352,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
Expand All @@ -321,7 +376,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}


template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param;
Expand Down Expand Up @@ -360,7 +414,6 @@ void Solver<Dtype>::Restore(const char* state_file) {
RestoreSolverState(state);
}


// Return the current learning rate. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
Expand Down Expand Up @@ -826,5 +879,4 @@ INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);

} // namespace caffe
91 changes: 91 additions & 0 deletions src/caffe/util/signal_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <boost/bind.hpp>
#include <boost/thread/mutex.hpp>
#include <boost/thread/thread.hpp>
#include <glog/logging.h>

#include <signal.h>
#include <csignal>

#include "caffe/util/signal_handler.h"

namespace {
static volatile sig_atomic_t got_sigint = false;
static volatile sig_atomic_t got_sighup = false;
static bool already_hooked_up = false;

void handle_signal(int signal) {
switch (signal) {
case SIGHUP:
got_sighup = true;
break;
case SIGINT:
got_sigint = true;
break;
}
}

void HookupHandler() {
if (already_hooked_up) {
LOG(FATAL) << "Tried to hookup signal handlers more than once.";
}
already_hooked_up = true;

struct sigaction sa;
// Setup the sighub handler
sa.sa_handler = &handle_signal;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;
// Block every signal during the handler
sigfillset(&sa.sa_mask);
// Intercept SIGHUP and SIGINT
if (sigaction(SIGHUP, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGHUP handler.";
}
if (sigaction(SIGINT, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGINT handler.";
}
}

// Return true iff a SIGINT has been received since the last time this
// function was called.
bool GotSIGINT() {
bool result = got_sigint;
got_sigint = false;
return result;
}

// Return true iff a SIGHUP has been received since the last time this
// function was called.
bool GotSIGHUP() {
bool result = got_sighup;
got_sighup = false;
return result;
}
} // namespace

namespace caffe {

SignalHandler::SignalHandler(SolverParameter_Action SIGINT_action,
SolverParameter_Action SIGHUP_action):
SIGINT_action_(SIGINT_action),
SIGHUP_action_(SIGHUP_action) {
HookupHandler();
}

SolverParameter_Action SignalHandler::CheckForSignals() const {
if (GotSIGHUP()) {
return SIGHUP_action_;
}
if (GotSIGINT()) {
return SIGINT_action_;
}
return SolverParameter_Action_NONE;
}

// Return the function that the solver can use to find out if a snapshot or
// early exit is being requested.
ActionCallback SignalHandler::GetActionFunction() {
return boost::bind(&SignalHandler::CheckForSignals, this);
}

} // namespace caffe
Loading

0 comments on commit 17baa33

Please sign in to comment.