Skip to content

Commit

Permalink
add AMSGRAD and PAdam
Browse files Browse the repository at this point in the history
  • Loading branch information
twmht committed Aug 26, 2018
1 parent 99bd997 commit f3e8681
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 19 deletions.
3 changes: 3 additions & 0 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

namespace caffe {

template <typename Dtype>
void caffe_cpu_max(const int n, const Dtype* x, const Dtype *y, Dtype *z);

// Caffe gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
template <typename Dtype>
Expand Down
6 changes: 6 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ message SolverParameter {
// weights parameter separated by ',' (like in a command string) or
// in repeated weights parameters separately.
repeated string weights = 42;

// AMSGRAD
optional bool amsgrad = 43 [default=false];

// Partial ADAM
optional float partial = 44 [default=0.5];
}

// A message that stores the solver snapshots
Expand Down
34 changes: 25 additions & 9 deletions src/caffe/solvers/adam_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ void AdamSolver<Dtype>::AdamPreSolve() {
// Add the extra history entries for Adam after those from
// SGDSolver::PreSolve
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
for (int i = 0; i < net_params.size(); ++i) {
const bool amsgrad = this->param_.amsgrad();
int size = amsgrad? net_params.size(): net_params.size() * 2;
for (int i = 0; i < size; ++i) {
const vector<int>& shape = net_params[i]->shape();
this->history_.push_back(
shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
Expand All @@ -18,8 +20,8 @@ void AdamSolver<Dtype>::AdamPreSolve() {

#ifndef CPU_ONLY
template <typename Dtype>
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate);
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype *max_v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate, float partial);
#endif

template <typename Dtype>
Expand All @@ -35,12 +37,18 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
Blob<Dtype>* val_m = this->history_[param_id].get();
Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
Blob<Dtype>* val_t = this->temp_[param_id].get();
const bool amsgrad = this->param_.amsgrad();
Blob<Dtype>* max_v = NULL;
if (amsgrad) {
max_v = this->history_[param_id + update_history_offset * 2].get();
}

const int t = this->iter_ + 1;
const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
(Dtype(1.) - pow(beta1, t));
const int N = net_params[param_id]->count();
const Dtype eps_hat = this->param_.delta();
const float partial = this->param_.partial();

switch (Caffe::mode()) {
case Caffe::CPU: {
Expand All @@ -59,10 +67,18 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
val_v->mutable_cpu_data());

// set update
caffe_powx(N,
val_v->cpu_data(), Dtype(0.5),
val_t->mutable_cpu_data());
caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
if (amsgrad) {
caffe_cpu_max(N, val_v->cpu_data(), max_v->cpu_data(), max_v->mutable_cpu_data());
caffe_powx(N,
max_v->cpu_data(), Dtype(partial),
val_t->mutable_cpu_data());
caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
} else {
caffe_powx(N,
val_v->cpu_data(), Dtype(partial),
val_t->mutable_cpu_data());
caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
}
caffe_div(N,
val_m->cpu_data(),
val_t->cpu_data(),
Expand All @@ -76,8 +92,8 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
case Caffe::GPU: {
#ifndef CPU_ONLY
adam_update_gpu(N, net_params[param_id]->mutable_gpu_diff(),
val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), beta1, beta2,
eps_hat, local_rate*correction);
val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), max_v->mutable_gpu_data(), beta1, beta2,
eps_hat, local_rate*correction, partial);
#else
NO_GPU;
#endif
Expand Down
25 changes: 15 additions & 10 deletions src/caffe/solvers/adam_solver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,31 @@
namespace caffe {

template <typename Dtype>
__global__ void AdamUpdate(int N, Dtype* g, Dtype* m, Dtype* v,
Dtype beta1, Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) {
__global__ void AdamUpdate(int N, Dtype* g, Dtype* m, Dtype* v, Dtype *max_v,
Dtype beta1, Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate, float partial) {
CUDA_KERNEL_LOOP(i, N) {
float gi = g[i];
float mi = m[i] = m[i]*beta1 + gi*(1-beta1);
float vi = v[i] = v[i]*beta2 + gi*gi*(1-beta2);
g[i] = corrected_local_rate * mi / (sqrt(vi) + eps_hat);
if (max_v) {
float maxvi = v[i] > max_v[i]? v[i]: max_v[i];
g[i] = corrected_local_rate * mi / (powf(maxvi, partial) + eps_hat);
} else {
g[i] = corrected_local_rate * mi / (powf(vi, partial) + eps_hat);
}
}
}
template <typename Dtype>
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) {
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype* max_v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate, float partial) {
AdamUpdate<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
N, g, m, v, beta1, beta2, eps_hat, corrected_local_rate);
N, g, m, v, max_v, beta1, beta2, eps_hat, corrected_local_rate, partial);
CUDA_POST_KERNEL_CHECK;
}
template void adam_update_gpu<float>(int, float*, float*, float*,
float, float, float, float);
template void adam_update_gpu<double>(int, double*, double*, double*,
double, double, double, double);
template void adam_update_gpu<float>(int, float*, float*, float*, float*,
float, float, float, float, float);
template void adam_update_gpu<double>(int, double*, double*, double*, double*,
double, double, double, double, float);

} // namespace caffe
13 changes: 13 additions & 0 deletions src/caffe/util/math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@

namespace caffe {

template <typename Dtype>
void caffe_cpu_max(const int n, const Dtype* x, const Dtype *y, Dtype* z) {
for (int i = 0; i < n; i++) {
z[i] = (x[i] > y[i])? x[i] : y[i];
}
}

template
void caffe_cpu_max<float>(const int n, const float* x, const float *y, float* z);

template
void caffe_cpu_max<double>(const int n, const double* x, const double *y, double* z);

template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
Expand Down

0 comments on commit f3e8681

Please sign in to comment.