Skip to content

Commit

Permalink
Clean up RMSprop to be compatible with new solver interface
Browse files Browse the repository at this point in the history
Clean up the RMS prop solver to adjust to new solver interface that uses
accumulated gradients and refactored regularization.
  • Loading branch information
ronghanghu committed Aug 7, 2015
1 parent b6d2d2d commit 430db2b
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 302 deletions.
6 changes: 2 additions & 4 deletions examples/mnist/lenet_solver_rmsprop.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,5 @@ snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type:RMSPROP
rms_decay:0.98


solver_type: RMSPROP
rms_decay: 0.98
28 changes: 18 additions & 10 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,25 @@ class AdaGradSolver : public SGDSolver<Dtype> {


template <typename Dtype>
class RMSpropSolver : public SGDSolver<Dtype> {
public:
explicit RMSpropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { }
explicit RMSpropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { }
class RMSPropSolver : public SGDSolver<Dtype> {
public:
explicit RMSPropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }

protected:
virtual void ComputeUpdateValue();
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with RMSProp.";
CHECK_GE(this->param_.rms_decay(), 0)
<< "rms_decay should lie between 0 and 1.";
CHECK_LT(this->param_.rms_decay(), 1)
<< "rms_decay should lie between 0 and 1.";
}

DISABLE_COPY_AND_ASSIGN(RMSpropSolver);
DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};

template <typename Dtype>
Expand All @@ -155,7 +163,7 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSpropSolver<Dtype>(param);
return new RMSPropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
17 changes: 0 additions & 17 deletions python/caffe/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,7 @@ class Classifier(caffe.Net):
def __init__(self, model_file, pretrained_file, image_dims=None,
mean=None, input_scale=None, raw_scale=None,
channel_swap=None):
<<<<<<< HEAD
caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
=======
"""
Take
image_dims: dimensions to scale input for cropping/sampling.
Default is to scale to net input size for whole-image crop.
gpu, mean, input_scale, raw_scale, channel_swap: params for
preprocessing options.
"""
caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
caffe.set_phase_test()

if gpu:
caffe.set_mode_gpu()
else:
caffe.set_mode_cpu()
>>>>>>> Implement RMSprop

# configure pre-processing
in_ = self.inputs[0]
Expand Down
31 changes: 22 additions & 9 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
<<<<<<< HEAD
// SolverParameter next available ID: 37 (last added: iter_size)
=======
// SolverParameter next available ID: 37 (last added: rms_decay)
>>>>>>> Implement RMSprop
// SolverParameter next available ID: 38 (last added: rms_decay)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -155,7 +151,23 @@ message SolverParameter {
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1];
optional string lr_policy = 8; // The learning rate decay policy.

// The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
Expand Down Expand Up @@ -200,9 +212,10 @@ message SolverParameter {
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];

//RMSprop decay value
optional float rms_decay = 36;

// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 37;

// If true, print information about the state of the net that may help with
// debugging learning problems.
Expand Down
200 changes: 44 additions & 156 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,6 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
switch (Caffe::mode()) {
<<<<<<< HEAD
case Caffe::CPU: {
// save history momentum for stepping back
caffe_copy(net_params[param_id]->count(),
Expand All @@ -665,53 +664,6 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
caffe_copy(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
=======
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// save history momentum for stepping back
caffe_copy(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

// update history
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
this->history_[param_id]->mutable_cpu_data());

// compute uppate: step back then over step
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
this->history_[param_id]->cpu_data(), -momentum,
this->update_[param_id]->mutable_cpu_data());

// copy
caffe_copy(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
}
>>>>>>> Implement RMSprop
break;
}
case Caffe::GPU: {
Expand Down Expand Up @@ -824,135 +776,71 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}

template <typename Dtype>
void RMSpropSolver<Dtype>::ComputeUpdateValue() {
void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();

// get the learning rate
Dtype rate = this->GetLearningRate();
Dtype delta = this->param_.delta();
Dtype rms_decay = this->param_.rms_decay();
Dtype local_rate = rate * net_params_lr[param_id];

if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

//Compute RMSstep
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history
caffe_cpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
rms_decay, this->history_[param_id]-> mutable_cpu_data());
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());
// update history
caffe_cpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
rms_decay, this->history_[param_id]-> mutable_cpu_data());

// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());

caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());
caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());

caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(),
this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());
caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
}
// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
for (int param_id = 0; param_id < net_params.size(); ++param_id) {

Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

//Compute RMSstep
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());
// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());
caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());
caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
}
caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
Expand All @@ -966,6 +854,6 @@ INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSpropSolver);
INSTANTIATE_CLASS(RMSPropSolver);

} // namespace caffe
Loading

0 comments on commit 430db2b

Please sign in to comment.