From feaf248184d07aacee54870c623cf16619c57385 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 8 Feb 2021 18:08:46 +0000 Subject: [PATCH 1/3] continue to copy in tiled manner when in scalar mode --- csrc/adam/cpu_adam.cpp | 64 +++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 29 deletions(-) mode change 100755 => 100644 csrc/adam/cpu_adam.cpp diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp old mode 100755 new mode 100644 index e817322630b8..6694a947da43 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -112,36 +112,42 @@ void Adam_Optimizer::Step(float* _params, #endif if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; #pragma omp parallel for - for (size_t k = rounded_size; k < _param_size; k++) { - float grad = grads[k]; - float param = _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } - momentum = momentum * _betta1; - momentum = grad * betta1_minus1 + momentum; - - variance = variance * _betta2; - grad = grad * grad; - variance = grad * betta2_minus1 + variance; - - grad = sqrt(variance); - grad = grad * _bias_correction2 + _eps; - grad = momentum / grad; - if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } - param = grad * step_size + param; - if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param; - - _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; - } - if (dev_params) { - launch_param_update(_doubled_buffer[_buf_index], - dev_params + rounded_size, - (_param_size - rounded_size), - Context::Instance().GetCurrentStream()); + for (size_t k = t; k < offset; k++) { + float grad = grads[k]; + float param = _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; + + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + if (dev_params) { + launch_param_update(_doubled_buffer[_buf_index], + dev_params + t, + (copy_size), + Context::Instance().GetCurrentStream()); + _buf_index = !_buf_index; + } } } } From d3dbec7e0f1b5201c231a5788686a5c95b378f0d Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 8 Feb 2021 20:36:38 +0000 Subject: [PATCH 2/3] guarding the tiled-copy overlapping --- csrc/adam/cpu_adam.cpp | 31 +++++++++++++++---------------- csrc/includes/context.h | 2 ++ csrc/includes/cpu_adam.h | 5 +++++ 3 files changed, 22 insertions(+), 16 deletions(-) mode change 100755 => 100644 csrc/includes/context.h diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 6694a947da43..471155243ac9 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -62,6 +62,8 @@ void Adam_Optimizer::Step(float* _params, size_t copy_size = TILE; if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } + #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { AVX_Data grad_4; @@ -101,10 +103,8 @@ void Adam_Optimizer::Step(float* _params, SIMD_STORE(_exp_avg_sq + i, variance_4.data); } if (dev_params) { - launch_param_update(_doubled_buffer[_buf_index], - dev_params + t, - copy_size, - Context::Instance().GetCurrentStream()); + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); _buf_index = !_buf_index; } } @@ -116,6 +116,7 @@ void Adam_Optimizer::Step(float* _params, size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } #pragma omp parallel for for (size_t k = t; k < offset; k++) { float grad = grads[k]; @@ -142,10 +143,8 @@ void Adam_Optimizer::Step(float* _params, _exp_avg_sq[k] = variance; } if (dev_params) { - launch_param_update(_doubled_buffer[_buf_index], - dev_params + t, - (copy_size), - Context::Instance().GetCurrentStream()); + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); _buf_index = !_buf_index; } } @@ -195,6 +194,7 @@ void Adam_Optimizer::Step_4(float* _params, size_t copy_size = TILE; if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } #pragma omp parallel for for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { AVX_Data grad_4[4]; @@ -301,10 +301,8 @@ void Adam_Optimizer::Step_4(float* _params, } if (dev_params) { - launch_param_update(_doubled_buffer[_buf_index], - dev_params + t, - copy_size, - Context::Instance().GetCurrentStream()); + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); _buf_index = !_buf_index; } } @@ -406,6 +404,7 @@ void Adam_Optimizer::Step_8(float* _params, size_t copy_size = TILE; if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } #pragma omp parallel for for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { AVX_Data grad_4[8]; @@ -588,10 +587,8 @@ void Adam_Optimizer::Step_8(float* _params, SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data); } if (dev_params) { - launch_param_update(_doubled_buffer[_buf_index], - dev_params + t, - copy_size, - Context::Instance().GetCurrentStream()); + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); _buf_index = !_buf_index; } } @@ -634,6 +631,7 @@ int ds_adam_step(int optimizer_id, opt->update_state(lr, epsilon, weight_decay, bias_correction); opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); + for (int i = 0; i < 2; i++) cudaStreamSynchronize(opt->_streams[i]); return 0; } @@ -670,6 +668,7 @@ int ds_adam_step_plus_copy(int optimizer_id, opt->Step_8( params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); + for (int i = 0; i < 2; i++) cudaStreamSynchronize(opt->_streams[i]); return 0; } diff --git a/csrc/includes/context.h b/csrc/includes/context.h old mode 100755 new mode 100644 index c2e26cdfa708..5f0424116546 --- a/csrc/includes/context.h +++ b/csrc/includes/context.h @@ -81,6 +81,8 @@ class Context { return stream; } + cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); } + cublasHandle_t GetCublasHandle() { return _cublasHandle; } std::pair IncrementOffset(uint64_t offset_inc) diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 0f45409186c1..b647b740a91f 100755 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -65,6 +65,9 @@ class Adam_Optimizer { { cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + + _streams[0] = Context::Instance().GetCurrentStream(); + _streams[1] = Context::Instance().GetNewStream(); } ~Adam_Optimizer() { @@ -152,4 +155,6 @@ class Adam_Optimizer { float* _doubled_buffer[2]; bool _buf_index; bool _adamw_mode; + + cudaStream_t _streams[2]; }; From e347b6487c9131ca5d5a2152ba7e506df9849907 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 8 Feb 2021 20:45:59 +0000 Subject: [PATCH 3/3] move stream synchronize in class method --- csrc/adam/cpu_adam.cpp | 4 ++-- csrc/includes/cpu_adam.h | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 471155243ac9..d425dc3169ef 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -631,7 +631,7 @@ int ds_adam_step(int optimizer_id, opt->update_state(lr, epsilon, weight_decay, bias_correction); opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); - for (int i = 0; i < 2; i++) cudaStreamSynchronize(opt->_streams[i]); + opt->SynchronizeStreams(); return 0; } @@ -668,7 +668,7 @@ int ds_adam_step_plus_copy(int optimizer_id, opt->Step_8( params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); - for (int i = 0; i < 2; i++) cudaStreamSynchronize(opt->_streams[i]); + opt->SynchronizeStreams(); return 0; } diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index b647b740a91f..5fae35261f55 100755 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -92,7 +92,10 @@ class Adam_Optimizer { float* _exp_avg_sq, size_t _param_size, __half* dev_params = nullptr); - + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); + } inline void IncrementStep(size_t step, float beta1, float beta2) { if (beta1 != _betta1 || beta2 != _betta2) {