Skip to content

Commit

Permalink
Merge pull request #2253 from jyegerlehner/snapshot_on_signal
Browse files Browse the repository at this point in the history
Snapshot on signal
  • Loading branch information
ronghanghu committed Aug 22, 2015
2 parents 6232233 + ff19d5f commit 12e1432
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 10 deletions.
37 changes: 36 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_

#include <boost/function.hpp>
#include <string>
#include <vector>

#include "caffe/net.hpp"

namespace caffe {

/**
* @brief Enumeration of actions that a client of the Solver may request by
* implementing the Solver's action request function, which a
* a client may optionally provide in order to request early termination
* or saving a snapshot without exiting. In the executable caffe, this
* mechanism is used to allow the snapshot to be saved when stopping
* execution with a SIGINT (Ctrl-C).
*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, // Stop training. snapshot_after_train controls whether a
// snapshot is created.
SNAPSHOT = 2 // Take a snapshot, and keep training.
};
}

/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverAction::Enum()> ActionCallback;

/**
* @brief An interface for classes that perform optimization on Net%s.
*
Expand All @@ -23,6 +45,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();

// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
Expand Down Expand Up @@ -84,6 +112,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;

// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_;

// True iff a request to stop early was received.
bool requested_early_exit_;

DISABLE_COPY_AND_ASSIGN(Solver);
};

Expand Down
24 changes: 24 additions & 0 deletions include/caffe/util/signal_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_

#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"

namespace caffe {

class SignalHandler {
public:
// Contructor. Specify what action to take when a signal is received.
SignalHandler(SolverAction::Enum SIGINT_action,
SolverAction::Enum SIGHUP_action);
~SignalHandler();
ActionCallback GetActionFunction();
private:
SolverAction::Enum CheckForSignals() const;
SolverAction::Enum SIGINT_action_;
SolverAction::Enum SIGHUP_action_;
};

} // namespace caffe

#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
70 changes: 62 additions & 8 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,31 @@

namespace caffe {

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}

template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverAction::NONE;
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver) {
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
Init(param);
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver) {
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
Expand Down Expand Up @@ -195,6 +211,10 @@ void Solver<Dtype>::Step(int iters) {
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}

for (int i = 0; i < callbacks_.size(); ++i) {
Expand Down Expand Up @@ -250,12 +270,20 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;

SolverAction::Enum request = GetRequestedAction();

// Save a snapshot if needed.
if (param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) {
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

Expand All @@ -265,6 +293,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

// Initialize to false every time we start solving.
requested_early_exit_ = false;

if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
Expand All @@ -279,6 +310,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -296,10 +331,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}


template <typename Dtype>
void Solver<Dtype>::TestAll() {
for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
Expand All @@ -317,6 +353,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
while (request != SolverAction::NONE) {
if (SolverAction::SNAPSHOT == request) {
Snapshot();
} else if (SolverAction::STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}

Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
Expand All @@ -341,6 +392,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
Expand All @@ -361,7 +416,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}


template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
Expand Down
115 changes: 115 additions & 0 deletions src/caffe/util/signal_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include <boost/bind.hpp>
#include <glog/logging.h>

#include <signal.h>
#include <csignal>

#include "caffe/util/signal_handler.h"

namespace {
static volatile sig_atomic_t got_sigint = false;
static volatile sig_atomic_t got_sighup = false;
static bool already_hooked_up = false;

void handle_signal(int signal) {
switch (signal) {
case SIGHUP:
got_sighup = true;
break;
case SIGINT:
got_sigint = true;
break;
}
}

void HookupHandler() {
if (already_hooked_up) {
LOG(FATAL) << "Tried to hookup signal handlers more than once.";
}
already_hooked_up = true;

struct sigaction sa;
// Setup the handler
sa.sa_handler = &handle_signal;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;
// Block every signal during the handler
sigfillset(&sa.sa_mask);
// Intercept SIGHUP and SIGINT
if (sigaction(SIGHUP, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGHUP handler.";
}
if (sigaction(SIGINT, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGINT handler.";
}
}

// Set the signal handlers to the default.
void UnhookHandler() {
if (already_hooked_up) {
struct sigaction sa;
// Setup the sighub handler
sa.sa_handler = SIG_DFL;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;
// Block every signal during the handler
sigfillset(&sa.sa_mask);
// Intercept SIGHUP and SIGINT
if (sigaction(SIGHUP, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot uninstall SIGHUP handler.";
}
if (sigaction(SIGINT, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot uninstall SIGINT handler.";
}

already_hooked_up = false;
}
}

// Return true iff a SIGINT has been received since the last time this
// function was called.
bool GotSIGINT() {
bool result = got_sigint;
got_sigint = false;
return result;
}

// Return true iff a SIGHUP has been received since the last time this
// function was called.
bool GotSIGHUP() {
bool result = got_sighup;
got_sighup = false;
return result;
}
} // namespace

namespace caffe {

SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
SolverAction::Enum SIGHUP_action):
SIGINT_action_(SIGINT_action),
SIGHUP_action_(SIGHUP_action) {
HookupHandler();
}

SignalHandler::~SignalHandler() {
UnhookHandler();
}

SolverAction::Enum SignalHandler::CheckForSignals() const {
if (GotSIGHUP()) {
return SIGHUP_action_;
}
if (GotSIGINT()) {
return SIGINT_action_;
}
return SolverAction::NONE;
}

// Return the function that the solver can use to find out if a snapshot or
// early exit is being requested.
ActionCallback SignalHandler::GetActionFunction() {
return boost::bind(&SignalHandler::CheckForSignals, this);
}

} // namespace caffe
Loading

0 comments on commit 12e1432

Please sign in to comment.