diff --git a/examples/mnist/lenet_adadelta_solver.prototxt b/examples/mnist/lenet_adadelta_solver.prototxt new file mode 100644 index 00000000000..776d1e06139 --- /dev/null +++ b/examples/mnist/lenet_adadelta_solver.prototxt @@ -0,0 +1,24 @@ +# The train/test net protocol buffer definition +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 1.0 +lr_policy: "fixed" +momentum: 0.95 +weight_decay: 0.0005 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet_adadelta" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADADELTA +delta: 1e-6 diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt new file mode 100644 index 00000000000..065647df31b --- /dev/null +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -0,0 +1,19 @@ +net: "examples/mnist/mnist_autoencoder.prototxt" +test_state: { stage: 'test-on-train' } +test_iter: 500 +test_state: { stage: 'test-on-test' } +test_iter: 100 +test_interval: 500 +test_compute_loss: true +base_lr: 1.0 +lr_policy: "fixed" +momentum: 0.95 +delta: 1e-8 +display: 100 +max_iter: 65000 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADADELTA diff --git a/examples/mnist/train_mnist_autoencoder_adadelta.sh b/examples/mnist/train_mnist_autoencoder_adadelta.sh new file mode 100755 index 00000000000..4be0ebddedc --- /dev/null +++ b/examples/mnist/train_mnist_autoencoder_adadelta.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +./build/tools/caffe train \ + --solver=examples/mnist/mnist_autoencoder_solver_adadelta.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index fbade9389ff..5fefd01e549 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -158,6 +158,21 @@ class RMSPropSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(RMSPropSolver); }; +template +class AdaDeltaSolver : public SGDSolver { + public: + explicit AdaDeltaSolver(const SolverParameter& param) + : SGDSolver(param) { AdaDeltaPreSolve(); } + explicit AdaDeltaSolver(const string& param_file) + : SGDSolver(param_file) { AdaDeltaPreSolve(); } + + protected: + void AdaDeltaPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); +}; + template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -171,6 +186,8 @@ Solver* GetSolver(const SolverParameter& param) { return new AdaGradSolver(param); case SolverParameter_SolverType_RMSPROP: return new RMSPropSolver(param); + case SolverParameter_SolverType_ADADELTA: + return new AdaDeltaSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 31cde7ad946..77a0e0070ae 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -1,7 +1,7 @@ """Python net specification. This module provides a way to write nets directly in Python, using a natural, -functional style. See examples/python_nets/caffenet.py for an example. +functional style. See examples/pycaffe/caffenet.py for an example. Currently this works as a thin wrapper around the Python protobuf interface, with layers and parameters automatically generated for the "layers" and diff --git a/scripts/download_model_binary.py b/scripts/download_model_binary.py index 48e9015fd26..03a50f6776a 100755 --- a/scripts/download_model_binary.py +++ b/scripts/download_model_binary.py @@ -32,7 +32,7 @@ def parse_readme_frontmatter(dirname): with open(readme_filename) as f: lines = [line.strip() for line in f.readlines()] top = lines.index('---') - bottom = lines[top + 1:].index('---') + bottom = lines.index('---', top + 1) frontmatter = yaml.load('\n'.join(lines[top + 1:bottom])) assert all(key in frontmatter for key in required_keys) return dirname, frontmatter diff --git a/scripts/download_model_from_gist.sh b/scripts/download_model_from_gist.sh index a1dccf78b5b..89527b7516f 100755 --- a/scripts/download_model_from_gist.sh +++ b/scripts/download_model_from_gist.sh @@ -18,7 +18,7 @@ fi echo "Downloading Caffe model info to $MODEL_DIR ..." mkdir -p $MODEL_DIR -wget https://gist.github.com/$GIST/download -O $MODEL_DIR/gist.tar.gz -tar xzf $MODEL_DIR/gist.tar.gz --directory=$MODEL_DIR --strip-components=1 -rm $MODEL_DIR/gist.tar.gz +wget https://gist.github.com/$GIST/download -O $MODEL_DIR/gist.zip +unzip -j $MODEL_DIR/gist.zip -d $MODEL_DIR +rm $MODEL_DIR/gist.zip echo "Done" diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 89f14595ba6..7cfcaa8bac7 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -215,6 +215,7 @@ message SolverParameter { NESTEROV = 1; ADAGRAD = 2; RMSPROP = 3; + ADADELTA = 4; } optional SolverType solver_type = 30 [default = SGD]; // numerical stability for AdaGrad diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 54e085a63e5..78902ca0ebc 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -934,10 +934,157 @@ void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void AdaDeltaSolver::AdaDeltaPreSolve() { + // Add the extra history entries for AdaDelta after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype delta = this->param_.delta(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + size_t update_history_offset = net_params.size(); + switch (Caffe::mode()) { + case Caffe::CPU: { + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of gradients + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[update_history_offset + param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + + // divide history of updates by history of gradients + caffe_div(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->temp_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_powx(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + // compute the update + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + + // compute square of update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of updates + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_cpu_data()); + + // apply learning rate + caffe_cpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of gradients + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_gpu_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[update_history_offset + param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + + // divide history of updates by history of gradients + caffe_gpu_div(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->temp_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_gpu_powx(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + // compute the update and copy to net_diff + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + + // compute square of update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of updates + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_gpu_data()); + + // apply learning rate + caffe_gpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); INSTANTIATE_CLASS(RMSPropSolver); +INSTANTIATE_CLASS(AdaDeltaSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index eaa7a759b9b..c97d4ede3b3 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -64,7 +64,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { } InitSolver(param); delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || - solver_type() == SolverParameter_SolverType_RMSPROP) ? + solver_type() == SolverParameter_SolverType_RMSPROP || + solver_type() == SolverParameter_SolverType_ADADELTA) ? param.delta() : 0; } @@ -266,7 +267,11 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + if (solver_type() != SolverParameter_SolverType_ADADELTA) { + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + } else { + ASSERT_EQ(4, history.size()); // additional blobs for update history + } Dtype update_value = learning_rate * grad; const Dtype history_value = (i == D) ? history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; @@ -289,6 +294,19 @@ class GradientBasedSolverTest : public MultiDeviceTest { + grad * grad * (1 - rms_decay)) + delta_; } break; + case SolverParameter_SolverType_ADADELTA: + { + const Dtype update_history_value = (i == D) ? + history[3]->cpu_data()[0] : history[2]->cpu_data()[i]; + const Dtype weighted_gradient_average = + momentum * history_value + (1 - momentum) * (grad * grad); + update_value = grad * std::sqrt((update_history_value + delta_) / + (weighted_gradient_average + delta_)); + // not actually needed, just here for illustrative purposes + // const Dtype weighted_update_average = + // momentum * update_history_value + (1 - momentum) * (update_value); + break; + } default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -875,6 +893,139 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) { } } +template +class AdaDeltaSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new AdaDeltaSolver(param)); + } + + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADADELTA; + } +}; + +TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.95; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, + TestAdaDeltaLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + template class RMSPropSolverTest : public GradientBasedSolverTest { typedef typename TypeParam::Dtype Dtype;