Skip to content

Commit

Permalink
Address remaining feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasplappert committed Aug 10, 2015
1 parent e4eb50b commit 2bf52ef
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 69 deletions.
7 changes: 3 additions & 4 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,12 @@ template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { PreSolve(); }
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { PreSolve(); }
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }

protected:
void PreSolve();
virtual void Regularize(int param_id);
void AdaDeltaPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
Expand Down
68 changes: 3 additions & 65 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,82 +860,20 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::PreSolve() {
void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
// Add the extra history entries for AdaDelta after those from
// SGDSolver::PreSolve
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
this->history_.push_back(
shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::Regularize(int param_id) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
switch (Caffe::mode()) {
case Caffe::CPU: {
if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype delta = this->param_.delta();
Dtype momentum = this->param_.momentum();
Expand Down
47 changes: 47 additions & 0 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,18 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
}
}

TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverythingShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.95;
const int kNumIters = 4;
this->share_ = true;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
Expand All @@ -984,6 +996,41 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
kIterSize);
}

TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.95;
const int kNumIters = 4;
const int kIterSize = 2;
this->share_ = true;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.95;
const int kNumIters = 4;
for (int i = 1; i <= kNumIters; ++i) {
this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.95;
const int kNumIters = 4;
this->share_ = true;
for (int i = 1; i <= kNumIters; ++i) {
this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
}
}

template <typename TypeParam>
class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
Expand Down

0 comments on commit 2bf52ef

Please sign in to comment.