Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU-Adam fix for scalar mode #735

Merged
merged 5 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 46 additions & 41 deletions csrc/adam/cpu_adam.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void Adam_Optimizer::Step(float* _params,
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }

#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
Expand Down Expand Up @@ -101,47 +103,50 @@ void Adam_Optimizer::Step(float* _params,
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}

#endif

if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t k = rounded_size; k < _param_size; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;

variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;

grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;

_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + rounded_size,
(_param_size - rounded_size),
Context::Instance().GetCurrentStream());
for (size_t k = t; k < offset; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;

variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;

grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;

_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
}
}
Expand Down Expand Up @@ -189,6 +194,7 @@ void Adam_Optimizer::Step_4(float* _params,
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
Expand Down Expand Up @@ -295,10 +301,8 @@ void Adam_Optimizer::Step_4(float* _params,
}

if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
Expand Down Expand Up @@ -400,6 +404,7 @@ void Adam_Optimizer::Step_8(float* _params,
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
Expand Down Expand Up @@ -582,10 +587,8 @@ void Adam_Optimizer::Step_8(float* _params,
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
Expand Down Expand Up @@ -628,6 +631,7 @@ int ds_adam_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));

opt->SynchronizeStreams();
return 0;
}

Expand Down Expand Up @@ -664,6 +668,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);

opt->SynchronizeStreams();
return 0;
}

Expand Down
2 changes: 2 additions & 0 deletions csrc/includes/context.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Context {
return stream;
}

cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); }

cublasHandle_t GetCublasHandle() { return _cublasHandle; }

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
Expand Down
10 changes: 9 additions & 1 deletion csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class Adam_Optimizer {
{
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
}
~Adam_Optimizer()
{
Expand All @@ -89,7 +92,10 @@ class Adam_Optimizer {
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);

inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
inline void IncrementStep(size_t step, float beta1, float beta2)
{
if (beta1 != _betta1 || beta2 != _betta2) {
Expand Down Expand Up @@ -152,4 +158,6 @@ class Adam_Optimizer {
float* _doubled_buffer[2];
bool _buf_index;
bool _adamw_mode;

cudaStream_t _streams[2];
};