From abe99e8748ad7f583c87d1a6132ff2d79e70dd9c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 8 Aug 2015 23:45:08 -0700 Subject: [PATCH] Implement RMSProp Solver Implement RMSProp solver and cleaned up to adjust to new solver interface that uses accumulated gradients and refactored regularization. --- examples/mnist/lenet_solver_rmsprop.prototxt | 27 ++ examples/mnist/train_lenet_rmsprop.sh | 3 + include/caffe/solver.hpp | 25 ++ src/caffe/proto/caffe.proto | 25 +- src/caffe/solver.cpp | 76 ++++++ src/caffe/test/test_gradient_based_solver.cpp | 245 ++++++++++++++---- 6 files changed, 353 insertions(+), 48 deletions(-) create mode 100644 examples/mnist/lenet_solver_rmsprop.prototxt create mode 100755 examples/mnist/train_lenet_rmsprop.sh diff --git a/examples/mnist/lenet_solver_rmsprop.prototxt b/examples/mnist/lenet_solver_rmsprop.prototxt new file mode 100644 index 00000000000..74dadc51069 --- /dev/null +++ b/examples/mnist/lenet_solver_rmsprop.prototxt @@ -0,0 +1,27 @@ +# The train/test net protocol buffer definition +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.01 +momentum: 0.0 +weight_decay: 0.0005 +# The learning rate policy +lr_policy: "inv" +gamma: 0.0001 +power: 0.75 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet_rmsprop" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: RMSPROP +rms_decay: 0.98 diff --git a/examples/mnist/train_lenet_rmsprop.sh b/examples/mnist/train_lenet_rmsprop.sh new file mode 100755 index 00000000000..621cab238bf --- /dev/null +++ b/examples/mnist/train_lenet_rmsprop.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 703434b5fcf..fbade9389ff 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; + +template +class RMSPropSolver : public SGDSolver { + public: + explicit RMSPropSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with RMSProp."; + CHECK_GE(this->param_.rms_decay(), 0) + << "rms_decay should lie between 0 and 1."; + CHECK_LT(this->param_.rms_decay(), 1) + << "rms_decay should lie between 0 and 1."; + } + + DISABLE_COPY_AND_ASSIGN(RMSPropSolver); +}; + template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -146,6 +169,8 @@ Solver* GetSolver(const SolverParameter& param) { return new NesterovSolver(param); case SolverParameter_SolverType_ADAGRAD: return new AdaGradSolver(param); + case SolverParameter_SolverType_RMSPROP: + return new RMSPropSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index a13c0e79d80..89f14595ba6 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 38 (last added: snapshot_format) +// SolverParameter next available ID: 39 (last added: rms_decay) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -153,7 +153,23 @@ message SolverParameter { optional int32 max_iter = 7; // the maximum number of iterations // accumulate gradients over `iter_size` x `batch_size` instances optional int32 iter_size = 36 [default = 1]; - optional string lr_policy = 8; // The learning rate decay policy. + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; optional float gamma = 9; // The parameter to compute the learning rate. optional float power = 10; // The parameter to compute the learning rate. optional float momentum = 11; // The momentum value. @@ -198,11 +214,16 @@ message SolverParameter { SGD = 0; NESTEROV = 1; ADAGRAD = 2; + RMSPROP = 3; } optional SolverType solver_type = 30 [default = SGD]; // numerical stability for AdaGrad optional float delta = 31 [default = 1e-8]; + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38; + // If true, print information about the state of the net that may help with // debugging learning problems. optional bool debug_info = 23 [default = false]; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 32276ac148a..43834c0c569 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -859,9 +859,85 @@ void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector > >& net_params = this->net_->params(); + const vector& net_params_lr = this->net_->params_lr(); + + // get the learning rate + Dtype delta = this->param_.delta(); + Dtype rms_decay = this->param_.rms_decay(); + Dtype local_rate = rate * net_params_lr[param_id]; + + switch (Caffe::mode()) { + case Caffe::CPU: + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_cpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->cpu_data(), + rms_decay, this->history_[param_id]-> mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->gpu_data(), + rms_decay, this->history_[param_id]-> mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); +INSTANTIATE_CLASS(RMSPropSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 7bb0ec18a09..b09189228ba 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -52,13 +52,14 @@ class GradientBasedSolverTest : public MultiDeviceTest { LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); } InitSolver(param); - delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) ? - param.delta() : 0; + delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || + solver_type() == SolverParameter_SolverType_RMSPROP) ? + param.delta() : 0; } string RunLeastSquaresSolver(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, const int num_iters, - const int iter_size = 1, const bool snapshot = false, + const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay, + const int num_iters, const int iter_size = 1, const bool snapshot = false, const char* from_snapshot = NULL) { ostringstream proto; proto << @@ -173,6 +174,9 @@ class GradientBasedSolverTest : public MultiDeviceTest { if (momentum != 0) { proto << "momentum: " << momentum << " "; } + if (rms_decay != 0) { + proto << "rms_decay: " << rms_decay << " "; + } MakeTempDir(&snapshot_prefix_); proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' "; if (snapshot) { @@ -204,7 +208,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // updated_params will store the updated weight and bias results, // using the blobs' diffs to hold the update values themselves. void ComputeLeastSquaresUpdate(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, + const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay, vector > >* updated_params) { const int N = num_; const int D = channels_ * height_ * width_; @@ -287,6 +291,10 @@ class GradientBasedSolverTest : public MultiDeviceTest { case SolverParameter_SolverType_ADAGRAD: update_value /= std::sqrt(history_value + grad * grad) + delta_; break; + case SolverParameter_SolverType_RMSPROP: + update_value /= std::sqrt(rms_decay*history_value + + grad * grad * (1 - rms_decay)) + delta_; + break; default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -352,13 +360,14 @@ class GradientBasedSolverTest : public MultiDeviceTest { } void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay, - const Dtype kMomentum, const int kNumIters, const int kIterSize) { + const Dtype kMomentum, const Dtype rms_decay, const int kNumIters, + const int kIterSize) { const double kPrecision = 1e-2; const double kMinPrecision = 1e-7; constant_data_ = true; // Solve without accumulation and save parameters. this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum, - kNumIters); + rms_decay, kNumIters); // Save parameters for comparison. Net& net = *this->solver_->net(); const vector > >& param_blobs = @@ -370,7 +379,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { } // Solve by equivalent accumulation of gradients over divided batches. this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum, - kNumIters, kIterSize); + rms_decay, kNumIters, kIterSize); Net& net_accum = *this->solver_->net(); const vector > >& accum_params = net_accum.layer_by_name("innerprod")->blobs(); @@ -408,18 +417,19 @@ class GradientBasedSolverTest : public MultiDeviceTest { // matches the solver's (K+1)th update. void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, - const int iter_to_check = 0) { + const Dtype rms_decay = 0.0, const int iter_to_check = 0) { // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, + iter_to_check); // Compute the (K+1)th update using the analytic least squares gradient. vector > > updated_params; ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); + rms_decay, &updated_params); // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, + iter_to_check + 1); // Check that the solver's solution matches ours. CheckLeastSquaresUpdate(updated_params); @@ -427,12 +437,12 @@ class GradientBasedSolverTest : public MultiDeviceTest { void TestSnapshot(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, - const int num_iters = 1) { + const Dtype rms_decay = 0.0, const int num_iters = 1) { // Run the solver for num_iters * 2 iterations. const int total_num_iters = num_iters * 2; bool snapshot = false; const int kIterSize = 1; - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, total_num_iters, kIterSize, snapshot); // Save the resulting param values. @@ -463,12 +473,12 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Run the solver for num_iters iterations and snapshot. snapshot = true; string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, num_iters, kIterSize, snapshot); + momentum, rms_decay, num_iters, kIterSize, snapshot); // Reinitialize the solver and run for num_iters more iterations. snapshot = false; - RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay, + total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); // Check that params now match. const vector*>& params = solver_->net()->learnable_params(); @@ -548,9 +558,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; + const Dtype kRMSDecay = 0; const int kNumIters = 1; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -559,9 +571,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -570,9 +584,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.5; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -581,10 +597,12 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.5; + const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -593,10 +611,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, - kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); } TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -604,11 +623,12 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, - kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); } TYPED_TEST(SGDSolverTest, TestSnapshot) { @@ -616,9 +636,10 @@ TYPED_TEST(SGDSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); } } @@ -627,10 +648,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); } } @@ -672,22 +694,26 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } TYPED_TEST(AdaGradSolverTest, - TestAdaGradLeastSquaresUpdateWithEverythingShare) { + TestAdaGradLeastSquaresUpdateWithEverythingShare) { typedef typename TypeParam::Dtype Dtype; const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -696,10 +722,11 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, - kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); } TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -707,11 +734,12 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, - kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); } TYPED_TEST(AdaGradSolverTest, TestSnapshot) { @@ -719,9 +747,10 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); } } @@ -730,10 +759,11 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); } } @@ -787,9 +817,11 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; + const Dtype kRMSDecay = 0; const int kNumIters = 1; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -798,9 +830,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0; const Dtype kMomentum = 0.5; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -821,10 +855,12 @@ TYPED_TEST(NesterovSolverTest, const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); } } @@ -833,10 +869,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, - kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); } TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { @@ -844,11 +881,12 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; const int kIterSize = 2; this->share_ = true; - this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, - kIterSize); + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); } TYPED_TEST(NesterovSolverTest, TestSnapshot) { @@ -856,9 +894,10 @@ TYPED_TEST(NesterovSolverTest, TestSnapshot) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; const int kNumIters = 4; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); } } @@ -867,10 +906,124 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) { const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.9; + const Dtype kRMSDecay = 0; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + } +} + +template +class RMSPropSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new RMSPropSolver(param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_RMSPROP; + } +}; + +TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices); + +TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.0; + const Dtype kRMSDecay = 0.95; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); + } +} + +TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const Dtype kRMSDecay = 0.95; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); + } +} + +TYPED_TEST(RMSPropSolverTest, + TestRMSPropLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const Dtype kRMSDecay = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, + kRMSDecay, i); + } +} + +TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const Dtype kRMSDecay = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); +} + +TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const Dtype kRMSDecay = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, + kNumIters, kIterSize); +} + +TYPED_TEST(RMSPropSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0.95; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); + } +} + +TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0; + const Dtype kRMSDecay = 0.95; const int kNumIters = 4; this->share_ = true; for (int i = 1; i <= kNumIters; ++i) { - this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i); } }