diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index e549120a933..6495b3a4c08 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -12,6 +12,9 @@ namespace caffe { +template +void caffe_cpu_max(const int n, const Dtype* x, const Dtype *y, Dtype *z); + // Caffe gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. template diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 3dcad697f6d..88b8f705f46 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -255,6 +255,12 @@ message SolverParameter { // weights parameter separated by ',' (like in a command string) or // in repeated weights parameters separately. repeated string weights = 42; + + // AMSGRAD + optional bool amsgrad = 43 [default=false]; + + // Partial ADAM + optional float partial = 44 [default=0.5]; } // A message that stores the solver snapshots diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp index 4a91f00bd49..c14b49b3dea 100644 --- a/src/caffe/solvers/adam_solver.cpp +++ b/src/caffe/solvers/adam_solver.cpp @@ -9,17 +9,21 @@ void AdamSolver::AdamPreSolve() { // Add the extra history entries for Adam after those from // SGDSolver::PreSolve const vector*>& net_params = this->net_->learnable_params(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - this->history_.push_back( - shared_ptr >(new Blob(shape))); + const bool amsgrad = this->param_.amsgrad(); + int loop = amsgrad? 2 : 1; + for (int k = 0; k < loop; k++) { + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } } } #ifndef CPU_ONLY template -void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1, - Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate); +void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Blob *max_v, Dtype beta1, + Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate, float partial); #endif template @@ -35,12 +39,18 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { Blob* val_m = this->history_[param_id].get(); Blob* val_v = this->history_[param_id + update_history_offset].get(); Blob* val_t = this->temp_[param_id].get(); + const bool amsgrad = this->param_.amsgrad(); + Blob* max_v = NULL; + if (amsgrad) { + max_v = this->history_[param_id + update_history_offset * 2].get(); + } const int t = this->iter_ + 1; const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / (Dtype(1.) - pow(beta1, t)); const int N = net_params[param_id]->count(); const Dtype eps_hat = this->param_.delta(); + const float partial = this->param_.partial(); switch (Caffe::mode()) { case Caffe::CPU: { @@ -59,10 +69,18 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { val_v->mutable_cpu_data()); // set update - caffe_powx(N, - val_v->cpu_data(), Dtype(0.5), - val_t->mutable_cpu_data()); - caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + if (amsgrad) { + caffe_cpu_max(N, val_v->cpu_data(), max_v->cpu_data(), max_v->mutable_cpu_data()); + caffe_powx(N, + max_v->cpu_data(), Dtype(partial), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + } else { + caffe_powx(N, + val_v->cpu_data(), Dtype(partial), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + } caffe_div(N, val_m->cpu_data(), val_t->cpu_data(), @@ -76,8 +94,8 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { case Caffe::GPU: { #ifndef CPU_ONLY adam_update_gpu(N, net_params[param_id]->mutable_gpu_diff(), - val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), beta1, beta2, - eps_hat, local_rate*correction); + val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), max_v, beta1, beta2, + eps_hat, local_rate*correction, partial); #else NO_GPU; #endif diff --git a/src/caffe/solvers/adam_solver.cu b/src/caffe/solvers/adam_solver.cu index 917ae100246..65500a83ee5 100644 --- a/src/caffe/solvers/adam_solver.cu +++ b/src/caffe/solvers/adam_solver.cu @@ -1,29 +1,39 @@ +#include "caffe/blob.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { template -__global__ void AdamUpdate(int N, Dtype* g, Dtype* m, Dtype* v, - Dtype beta1, Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) { +__global__ void AdamUpdate(int N, Dtype* g, Dtype* m, Dtype* v, Dtype *max_v, + Dtype beta1, Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate, float partial) { CUDA_KERNEL_LOOP(i, N) { float gi = g[i]; float mi = m[i] = m[i]*beta1 + gi*(1-beta1); float vi = v[i] = v[i]*beta2 + gi*gi*(1-beta2); - g[i] = corrected_local_rate * mi / (sqrt(vi) + eps_hat); + if (max_v) { + float maxvi = v[i] > max_v[i]? v[i]: max_v[i]; + g[i] = corrected_local_rate * mi / (powf(maxvi, partial) + eps_hat); + } else { + g[i] = corrected_local_rate * mi / (powf(vi, partial) + eps_hat); + } } } template -void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1, - Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) { +void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Blob* max_v, Dtype beta1, + Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate, float partial) { + Dtype* max_v_data = NULL; + if (max_v) { + max_v_data = max_v->mutable_gpu_data(); + } AdamUpdate // NOLINT_NEXT_LINE(whitespace/operators) <<>>( - N, g, m, v, beta1, beta2, eps_hat, corrected_local_rate); + N, g, m, v, max_v_data, beta1, beta2, eps_hat, corrected_local_rate, partial); CUDA_POST_KERNEL_CHECK; } -template void adam_update_gpu(int, float*, float*, float*, - float, float, float, float); -template void adam_update_gpu(int, double*, double*, double*, - double, double, double, double); +template void adam_update_gpu(int, float*, float*, float*, Blob*, + float, float, float, float, float); +template void adam_update_gpu(int, double*, double*, double*, Blob*, + double, double, double, double, float); } // namespace caffe diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 59625bc05ce..f845cd14cb6 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -9,6 +9,19 @@ namespace caffe { +template +void caffe_cpu_max(const int n, const Dtype* x, const Dtype *y, Dtype* z) { + for (int i = 0; i < n; i++) { + z[i] = (x[i] > y[i])? x[i] : y[i]; + } +} + +template +void caffe_cpu_max(const int n, const float* x, const float *y, float* z); + +template +void caffe_cpu_max(const int n, const double* x, const double *y, double* z); + template<> void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,