Skip to content

Commit

Permalink
Merge pull request #2867 from ronghanghu/rms-prop
Browse files Browse the repository at this point in the history
RMSProp clean up and rebase
  • Loading branch information
ronghanghu committed Aug 9, 2015
2 parents eb3e114 + abe99e8 commit 698fc76
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 48 deletions.
27 changes: 27 additions & 0 deletions examples/mnist/lenet_solver_rmsprop.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.0
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: RMSPROP
rms_decay: 0.98
3 changes: 3 additions & 0 deletions examples/mnist/train_lenet_rmsprop.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt
25 changes: 25 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};


template <typename Dtype>
class RMSPropSolver : public SGDSolver<Dtype> {
public:
explicit RMSPropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with RMSProp.";
CHECK_GE(this->param_.rms_decay(), 0)
<< "rms_decay should lie between 0 and 1.";
CHECK_LT(this->param_.rms_decay(), 1)
<< "rms_decay should lie between 0 and 1.";
}

DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};

template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
Expand All @@ -146,6 +169,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSPropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
25 changes: 23 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 38 (last added: snapshot_format)
// SolverParameter next available ID: 39 (last added: rms_decay)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -153,7 +153,23 @@ message SolverParameter {
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1];
optional string lr_policy = 8; // The learning rate decay policy.

// The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
Expand Down Expand Up @@ -198,11 +214,16 @@ message SolverParameter {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];

// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38;

// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false];
Expand Down
76 changes: 76 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,9 +859,85 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}

template <typename Dtype>
void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();

// get the learning rate
Dtype delta = this->param_.delta();
Dtype rms_decay = this->param_.rms_decay();
Dtype local_rate = rate * net_params_lr[param_id];

switch (Caffe::mode()) {
case Caffe::CPU:
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history
caffe_cpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
rms_decay, this->history_[param_id]-> mutable_cpu_data());

// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());

caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());

caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSPropSolver);

} // namespace caffe
Loading

0 comments on commit 698fc76

Please sign in to comment.