diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index ab12ef1b1bd..33297d73efe 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -1,6 +1,6 @@ #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ #define CAFFE_OPTIMIZATION_SOLVER_HPP_ - +#include #include #include @@ -8,6 +8,11 @@ namespace caffe { +/** + * @brief Type of a function that returns a Solver Action enumeration. + */ +typedef boost::function ActionCallback; + /** * @brief An interface for classes that perform optimization on Net%s. * @@ -23,6 +28,12 @@ class Solver { void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets(); + + // Client of the Solver optionally may call this in order to set the function + // that the solver uses to see what action it should take (e.g. snapshot or + // exit training early). + void SetActionFunction(ActionCallback func); + SolverParameter_Action GetRequestedAction(); // The main entry of the solver function. In default, iter will be zero. Pass // in a non-zero iter number to resume training for a pre-trained net. virtual void Solve(const char* resume_file = NULL); @@ -66,6 +77,8 @@ class Solver { string SnapshotToBinaryProto(); string SnapshotToHDF5(); // The test routine + // stop_was_requested will be set to true iff a request to stop training + // was received whilst testing. void TestAll(); void Test(const int test_net_id = 0); virtual void SnapshotSolverState(const string& model_filename) = 0; @@ -84,6 +97,13 @@ class Solver { // in data parallelism const Solver* const root_solver_; + // A function that can be set by a client of the Solver to provide indication + // that it wants a snapshot saved and/or to exit early. + ActionCallback action_request_function_; + + // True iff a request to stop early was received. + bool requested_early_exit_; + DISABLE_COPY_AND_ASSIGN(Solver); }; diff --git a/include/caffe/util/signal_handler.h b/include/caffe/util/signal_handler.h new file mode 100644 index 00000000000..81286aa7700 --- /dev/null +++ b/include/caffe/util/signal_handler.h @@ -0,0 +1,24 @@ +#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ +#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ + +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +namespace caffe { + +class SignalHandler { + public: + // Contructor. Specify what action to take when a signal is received. + SignalHandler(SolverParameter_Action SIGINT_action, + SolverParameter_Action SIGHUP_action); + ActionCallback GetActionFunction(); + private: + SignalHandler(); // Not implemented. + SolverParameter_Action CheckForSignals() const; + SolverParameter_Action SIGINT_action_; + SolverParameter_Action SIGHUP_action_; +}; + +} // namespace caffe + +#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index d4c97d2bd06..75638a3388b 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -234,6 +234,19 @@ message SolverParameter { // If false, don't save a snapshot after training finishes. optional bool snapshot_after_train = 28 [default = true]; + + // Enumeration of actions that a client of the Solver may request by + // implementing the Solver's action request function, which a + // a client may optionally provide in order to request early termination + // or saving a snapshot without exiting. In the executable caffe, this + // mechanism is used to allow the snapshot to be saved when stopping + // execution with a SIGINT (Ctrl-C). + enum Action { + NONE = 0; // Take no special action. + STOP = 1; // Stop training. snapshot_after_train controls whether a snapshot + // is created. + SNAPSHOT = 2; // Take a snapshot, and keep training. + } } // A message that stores the solver snapshots diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 9348e11c249..b7dbdd1fe2d 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -17,15 +17,31 @@ namespace caffe { +template +void Solver::SetActionFunction(ActionCallback func) { + action_request_function_ = func; +} + +template +SolverParameter_Action Solver::GetRequestedAction() { + if (action_request_function_) { + // If the external request function has been set, call it. + return action_request_function_(); + } + return SolverParameter_Action_NONE; +} + template -Solver::Solver(const SolverParameter& param, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver) { +Solver::Solver(const SolverParameter& param) + : net_(), callbacks_(), root_solver_(root_solver), + requested_early_exit_(false) { Init(param); } template -Solver::Solver(const string& param_file, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver) { +Solver::Solver(const string& param_file) + : net_(), callbacks_(), root_solver_(root_solver), + requested_early_exit_(false) { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -195,6 +211,10 @@ void Solver::Step(int iters) { && (iter_ > 0 || param_.test_initialization()) && Caffe::root_solver()) { TestAll(); + if (requested_early_exit_) { + // Break out of the while loop because stop was requested while testing. + break; + } } for (int i = 0; i < callbacks_.size(); ++i) { @@ -250,12 +270,20 @@ void Solver::Step(int iters) { // the number of times the weights have been updated. ++iter_; + SolverParameter_Action request = GetRequestedAction(); + // Save a snapshot if needed. - if (param_.snapshot() - && iter_ % param_.snapshot() == 0 - && Caffe::root_solver()) { + if ((param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) || + (request == SolverParameter_Action_SNAPSHOT)) { Snapshot(); } + if (SolverParameter_Action_STOP == request) { + requested_early_exit_ = true; + // Break out of training loop. + break; + } } } @@ -265,6 +293,9 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); + // Initialize to false every time we start solving. + requested_early_exit_ = false; + if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); @@ -279,6 +310,10 @@ void Solver::Solve(const char* resume_file) { && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { Snapshot(); } + if (requested_early_exit_) { + LOG(INFO) << "Optimization stopped early."; + return; + } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of @@ -296,10 +331,11 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Optimization Done."; } - template void Solver::TestAll() { - for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) { + for (int test_net_id = 0; + test_net_id < test_nets_.size() && !requested_early_exit_; + ++test_net_id) { Test(test_net_id); } } @@ -317,6 +353,21 @@ void Solver::Test(const int test_net_id) { const shared_ptr >& test_net = test_nets_[test_net_id]; Dtype loss = 0; for (int i = 0; i < param_.test_iter(test_net_id); ++i) { + SolverParameter_Action request = GetRequestedAction(); + // Check to see if stoppage of testing/training has been requested. + while (request != SolverParameter_Action_NONE) { + if (SolverParameter_Action_SNAPSHOT == request) { + Snapshot(); + } else if (SolverParameter_Action_STOP == request) { + requested_early_exit_ = true; + } + request = GetRequestedAction(); + } + if (requested_early_exit_) { + // break out of test loop. + break; + } + Dtype iter_loss; const vector*>& result = test_net->Forward(bottom_vec, &iter_loss); @@ -341,6 +392,10 @@ void Solver::Test(const int test_net_id) { } } } + if (requested_early_exit_) { + LOG(INFO) << "Test interrupted."; + return; + } if (param_.test_compute_loss()) { loss /= param_.test_iter(test_net_id); LOG(INFO) << "Test loss: " << loss; @@ -361,7 +416,6 @@ void Solver::Test(const int test_net_id) { } } - template void Solver::Snapshot() { CHECK(Caffe::root_solver()); diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp new file mode 100644 index 00000000000..9b268d8e7b3 --- /dev/null +++ b/src/caffe/util/signal_handler.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include + +#include +#include + +#include "caffe/util/signal_handler.h" + +namespace { + static volatile sig_atomic_t got_sigint = false; + static volatile sig_atomic_t got_sighup = false; + static bool already_hooked_up = false; + + void handle_signal(int signal) { + switch (signal) { + case SIGHUP: + got_sighup = true; + break; + case SIGINT: + got_sigint = true; + break; + } + } + + void HookupHandler() { + if (already_hooked_up) { + LOG(FATAL) << "Tried to hookup signal handlers more than once."; + } + already_hooked_up = true; + + struct sigaction sa; + // Setup the sighub handler + sa.sa_handler = &handle_signal; + // Restart the system call, if at all possible + sa.sa_flags = SA_RESTART; + // Block every signal during the handler + sigfillset(&sa.sa_mask); + // Intercept SIGHUP and SIGINT + if (sigaction(SIGHUP, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot install SIGHUP handler."; + } + if (sigaction(SIGINT, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot install SIGINT handler."; + } + } + + // Return true iff a SIGINT has been received since the last time this + // function was called. + bool GotSIGINT() { + bool result = got_sigint; + got_sigint = false; + return result; + } + + // Return true iff a SIGHUP has been received since the last time this + // function was called. + bool GotSIGHUP() { + bool result = got_sighup; + got_sighup = false; + return result; + } +} // namespace + +namespace caffe { + +SignalHandler::SignalHandler(SolverParameter_Action SIGINT_action, + SolverParameter_Action SIGHUP_action): + SIGINT_action_(SIGINT_action), + SIGHUP_action_(SIGHUP_action) { + HookupHandler(); +} + +SolverParameter_Action SignalHandler::CheckForSignals() const { + if (GotSIGHUP()) { + return SIGHUP_action_; + } + if (GotSIGINT()) { + return SIGINT_action_; + } + return SolverParameter_Action_NONE; +} + +// Return the function that the solver can use to find out if a snapshot or +// early exit is being requested. +ActionCallback SignalHandler::GetActionFunction() { + return boost::bind(&SignalHandler::CheckForSignals, this); +} + +} // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 9f31b37ac2b..0071ea81c15 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -12,6 +12,7 @@ namespace bp = boost::python; #include "boost/algorithm/string.hpp" #include "caffe/caffe.hpp" +#include "caffe/util/signal_handler.h" using caffe::Blob; using caffe::Caffe; @@ -39,6 +40,12 @@ DEFINE_string(weights, "", "separated by ','. Cannot be set simultaneously with snapshot."); DEFINE_int32(iterations, 50, "The number of iterations to run."); +DEFINE_string(sigint_effect, "stop", + "Optional; action to take when a SIGINT signal is received: " + "snapshot, stop or none."); +DEFINE_string(sighup_effect, "snapshot", + "Optional; action to take when a SIGHUP signal is received: " + "snapshot, stop or none."); // A simple registry for caffe commands. typedef int (*BrewFunction)(); @@ -126,6 +133,20 @@ void CopyLayers(caffe::Solver* solver, const std::string& model_list) { } } +caffe::SolverParameter_Action GetRequestedAction( + const std::string& flag_value) { + if (flag_value == "stop") { + return caffe::SolverParameter_Action_STOP; + } + if (flag_value == "snapshot") { + return caffe::SolverParameter_Action_SNAPSHOT; + } + if (flag_value == "none") { + return caffe::SolverParameter_Action_NONE; + } + LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified"; +} + // Train / Finetune a model. int train() { CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; @@ -165,7 +186,14 @@ int train() { Caffe::set_solver_count(gpus.size()); } - shared_ptr > solver(caffe::GetSolver(solver_param)); + caffe::SignalHandler signal_handler( + GetRequestedAction(FLAGS_sigint_effect), + GetRequestedAction(FLAGS_sighup_effect)); + + shared_ptr > + solver(caffe::GetSolver(solver_param)); + + solver->SetActionFunction(signal_handler.GetActionFunction()); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot;