diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index a971bfa7d6efe..4faadb9e91bf7 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -28,119 +28,119 @@ limitations under the License. */ #define LARS_BLOCK_SIZE 512 #endif +#if CUDA_VERSION >= 11000 +#define L2_NORM_FUNCTION_FLAG __device__ +#else +#define L2_NORM_FUNCTION_FLAG __global__ +#endif + namespace paddle { namespace operators { template using MultiPrecisionType = typename details::MPTypeTrait::Type; -__device__ __forceinline__ float square_root(float x) { return sqrtf(x); } -__device__ __forceinline__ double square_root(double x) { return sqrt(x); } -__device__ __forceinline__ float fma_root(float x, float y, float z) { +__device__ __forceinline__ float SquareRoot(float x) { return sqrtf(x); } +__device__ __forceinline__ double SquareRoot(double x) { return sqrt(x); } +__device__ __forceinline__ float FmaRoot(float x, float y, float z) { return fmaf(x, y, z); } -__device__ __forceinline__ double fma_root(double x, double y, double z) { +__device__ __forceinline__ double FmaRoot(double x, double y, double z) { return fma(x, y, z); } -#if CUDA_VERSION >= 11000 -template -__device__ inline MT L2NormCalculation( - const cooperative_groups::grid_group& cg, const T* __restrict__ data, - MT* tmp_buffer, int tid, const int repeat_times, const int grid_stride, - const int64_t numel, const MT rescale_grad = static_cast(1)) { - MT rescale_grad_pow = rescale_grad * rescale_grad; - __shared__ MT s_buffer; - s_buffer = static_cast(0); - - MT tmp_val = static_cast(0); - if (repeat_times == 1) { - if (tid < numel) { - tmp_val = static_cast(data[tid]); - } - s_buffer += math::blockReduceSum(tmp_val * tmp_val, FINAL_MASK); - } else { - for (int i = 0; i < repeat_times - 1; ++i) { - if (tid < numel) { - tmp_val = static_cast(data[tid]); - } - tid += grid_stride; - s_buffer += math::blockReduceSum(tmp_val * tmp_val, FINAL_MASK); - __syncthreads(); - } - MT val = static_cast(0); - if (tid < numel) { - val = static_cast(data[tid]); +template +__device__ inline void VectorizeLarsUpdate( + const MT* __restrict__ g, const MT* __restrict__ v, MT* __restrict__ p_out, + MT* __restrict__ v_out, const MT* __restrict__ p, const MT mu, MT local_lr, + const MT lars_weight_decay, const MT rescale_grad, const int tid, + const int grid_stride, const int numel) { + using VecMType = paddle::platform::AlignedVector; + int main = numel >> (VesSize >> 1); + int tail_offset = main * VesSize; + + const VecMType* __restrict__ g_vec = reinterpret_cast(g); + const VecMType* __restrict__ v_vec = reinterpret_cast(v); + const VecMType* __restrict__ p_vec = reinterpret_cast(p); + VecMType* p_out_vec = reinterpret_cast(p_out); + VecMType* v_out_vec = reinterpret_cast(v_out); + + for (int i = tid; i < main; i += grid_stride) { + VecMType v_new, p_new; + VecMType g_data = g_vec[i]; + VecMType v_data = v_vec[i]; + VecMType p_data = p_vec[i]; + +#pragma unroll + for (int j = 0; j < VesSize; ++j) { + MT grad = g_data.val[j] * rescale_grad; + v_new.val[j] = + FmaRoot(v_data.val[j], mu, + local_lr * FmaRoot(lars_weight_decay, p_data.val[j], grad)); + p_new.val[j] = p_data.val[j] - v_new.val[j]; } - s_buffer += math::blockReduceSum(val * val, FINAL_MASK); + v_out_vec[i] = v_new; + p_out_vec[i] = p_new; } - __syncthreads(); - if (threadIdx.x == 0) { - tmp_buffer[blockIdx.x] = s_buffer; + for (int i = tid + tail_offset; i < numel; i += grid_stride) { + MT grad = g[i] * rescale_grad; + MT param = p[i]; + MT v_new = + FmaRoot(v[i], mu, local_lr * FmaRoot(lars_weight_decay, param, grad)); + v_out[i] = v_new; + p_out[i] = param - v_new; } - // Grid sync for completely writring partial result back to gloabl memory - cg.sync(); - MT val = threadIdx.x < gridDim.x ? tmp_buffer[threadIdx.x] : 0; - MT result = math::blockReduceSum(val, FINAL_MASK); - return square_root(rescale_grad_pow * result); } -#endif template -__device__ inline void LarsUpdateMP(const T* __restrict__ g, - const MT* __restrict__ v, T* p_out, - MT* v_out, const MT* __restrict__ master_p, - MT* __restrict__ master_p_out, const MT mu, - MT local_lr, const MT lars_weight_decay, - const MT rescale_grad, const int tid, - const int grid_stride, const int numel) { +__device__ inline void VectorizeLarsUpdateMP( + const T* __restrict__ g, const MT* __restrict__ v, T* __restrict__ p_out, + MT* __restrict__ v_out, const MT* __restrict__ master_p, + MT* __restrict__ master_p_out, const MT mu, MT local_lr, + const MT lars_weight_decay, const MT rescale_grad, const int tid, + const int grid_stride, const int numel) { // As for multiple-precision, type T and MT cannot be more than fp16 or fp32, // Then, the maximum data IO size could be set to 4. - using VecType = platform::AlignedVector; - using VecMType = platform::AlignedVector; + using VecType = paddle::platform::AlignedVector; + using VecMType = paddle::platform::AlignedVector; int main = numel >> 2; int tail_offset = main << 2; - const VecType* g_4 = reinterpret_cast(g); - const VecMType* v_4 = reinterpret_cast(v); - const VecMType* master_p_4 = reinterpret_cast(master_p); - VecType* p_out_4 = reinterpret_cast(p_out); - VecMType* v_out_4 = reinterpret_cast(v_out); - VecMType* master_p_out_4 = reinterpret_cast(master_p_out); - - for (int i = tid; i < main; i += LARS_BLOCK_SIZE * gridDim.x) { - VecType p_out_tmp; - VecMType v_new_tmp, master_p_new_tmp; - - VecType g_tmp = g_4[i]; - VecMType v_tmp = v_4[i]; - VecMType p_tmp = master_p_4[i]; - T* g_data = reinterpret_cast(&g_tmp); - T* p_out_data = reinterpret_cast(&p_out_tmp); - MT* v_data = reinterpret_cast(&v_tmp); - MT* p_data = reinterpret_cast(&p_tmp); - MT* v_new = reinterpret_cast(&v_new_tmp); - MT* p_new = reinterpret_cast(&master_p_new_tmp); + const VecType* __restrict__ g_vec = reinterpret_cast(g); + const VecMType* __restrict__ v_vec = reinterpret_cast(v); + const VecMType* __restrict__ master_p_vec = + reinterpret_cast(master_p); + VecType* p_out_vec = reinterpret_cast(p_out); + VecMType* v_out_vec = reinterpret_cast(v_out); + VecMType* master_p_out_vec = reinterpret_cast(master_p_out); + + for (int i = tid; i < main; i += grid_stride) { + VecType p_out; + VecMType v_new, p_new; + VecType g_data = g_vec[i]; + VecMType v_data = v_vec[i]; + VecMType p_data = master_p_vec[i]; #pragma unroll for (int j = 0; j < 4; ++j) { - MT grad = static_cast(g_data[j]) * rescale_grad; - v_new[j] = fma_root(v_data[j], mu, local_lr * fma_root(lars_weight_decay, - p_data[j], grad)); - p_new[j] = p_data[j] - v_new[j]; - p_out_data[j] = static_cast(p_new[j]); + MT grad = static_cast(g_data.val[j]) * rescale_grad; + v_new.val[j] = + FmaRoot(v_data.val[j], mu, + local_lr * FmaRoot(lars_weight_decay, p_data.val[j], grad)); + p_new.val[j] = p_data.val[j] - v_new.val[j]; + p_out.val[j] = static_cast(p_new.val[j]); } - v_out_4[i] = v_new_tmp; - p_out_4[i] = p_out_tmp; - master_p_out_4[i] = master_p_new_tmp; + v_out_vec[i] = v_new; + p_out_vec[i] = p_out; + master_p_out_vec[i] = p_new; } for (int i = tid + tail_offset; i < numel; i += grid_stride) { MT grad = static_cast(g[i]) * rescale_grad; MT param = master_p[i]; MT v_new = - fma_root(v[i], mu, local_lr * fma_root(lars_weight_decay, param, grad)); + FmaRoot(v[i], mu, local_lr * FmaRoot(lars_weight_decay, param, grad)); MT p_new = param - v_new; v_out[i] = v_new; p_out[i] = static_cast(p_new); @@ -148,48 +148,117 @@ __device__ inline void LarsUpdateMP(const T* __restrict__ g, } } +template +L2_NORM_FUNCTION_FLAG inline void L2NormKernel( + const T* __restrict__ p_data, const T* __restrict__ g_data, + MT* __restrict__ p_tmp_buffer, MT* __restrict__ g_tmp_buffer, + const int repeat_times, const int64_t numel, const MT rescale_grad, + MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int grid_stride = LARS_BLOCK_SIZE * gridDim.x; + const MT rescale_grad_pow = rescale_grad * rescale_grad; + __shared__ MT s_buffer[2]; + s_buffer[0] = static_cast(0); + s_buffer[1] = static_cast(0); + MT p_tmp_val = static_cast(0); + MT g_tmp_val = static_cast(0); + + if (repeat_times == 0) { + if (tid < numel) { + p_tmp_val = static_cast(p_data[tid]); + g_tmp_val = static_cast(g_data[tid]); + } + s_buffer[0] += math::blockReduceSum(p_tmp_val * p_tmp_val, FINAL_MASK); + s_buffer[1] += math::blockReduceSum(g_tmp_val * g_tmp_val, FINAL_MASK); + } else { + /* To avoid occupy too much temp buffer. Hence, slice the whole data into 2 + parts, the front of them whose quantity is excatly multiple of grid-thread + number, and this part of data is delt in for loop, the rest of data is delt + with another step to avoid visiting data address beyond bound. */ + for (int i = 0; i < repeat_times; ++i) { + p_tmp_val = static_cast(p_data[tid]); + g_tmp_val = static_cast(g_data[tid]); + tid += grid_stride; + s_buffer[0] += + math::blockReduceSum(p_tmp_val * p_tmp_val, FINAL_MASK); + s_buffer[1] += + math::blockReduceSum(g_tmp_val * g_tmp_val, FINAL_MASK); + __syncthreads(); + } + MT p_val = 0; + MT g_val = 0; + if (tid < numel) { + p_val = static_cast(p_data[tid]); + g_val = static_cast(g_data[tid]); + } + s_buffer[0] += math::blockReduceSum(p_val * p_val, FINAL_MASK); + s_buffer[1] += math::blockReduceSum(g_val * g_val, FINAL_MASK); + } + __syncthreads(); + + if (threadIdx.x == 0) { + p_tmp_buffer[blockIdx.x] = s_buffer[0]; + g_tmp_buffer[blockIdx.x] = rescale_grad_pow * s_buffer[1]; + } + +#if CUDA_VERSION >= 11000 + // Grid sync for completely writring partial result back to gloabl memory + const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); + cg.sync(); + MT p_partial_sum = threadIdx.x < gridDim.x ? p_tmp_buffer[threadIdx.x] : 0; + MT g_partial_sum = threadIdx.x < gridDim.x ? g_tmp_buffer[threadIdx.x] : 0; + *p_n = SquareRoot(math::blockReduceSum(p_partial_sum, FINAL_MASK)); + *g_n = SquareRoot(math::blockReduceSum(g_partial_sum, FINAL_MASK)); +#endif +} + template __global__ void MomentumLarsKernel( const T* __restrict__ p, const T* __restrict__ g, const MT* __restrict__ v, T* p_out, MT* v_out, const MT* __restrict__ master_p, MT* __restrict__ master_p_out, const MT* __restrict__ learning_rate, - MT* __restrict__ tmp_buffer, MT* __restrict__ tmp_buffer_2, const MT mu, + MT* __restrict__ p_tmp_buffer, MT* __restrict__ g_tmp_buffer, const MT mu, const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, - const MT rescale_grad, const int64_t numel) { + const MT rescale_grad, const int repeat_times, const int thresh, + const int64_t numel) { int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000 - const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); - const int repeat_times = (numel + grid_stride - 1) / grid_stride; - MT p_n = L2NormCalculation(cg, p, tmp_buffer, tid, repeat_times, - grid_stride, numel); - MT g_n = L2NormCalculation(cg, g, tmp_buffer, tid, repeat_times, - grid_stride, numel, rescale_grad); + MT p_n = static_cast(0); + MT g_n = static_cast(0); + L2NormKernel(p, g, p_tmp_buffer, g_tmp_buffer, repeat_times, numel, + rescale_grad, &p_n, &g_n); #else - const MT p_n = tmp_buffer[0]; - const MT g_n = tmp_buffer_2[0]; + MT p_val = threadIdx.x < thresh ? p_tmp_buffer[threadIdx.x] : 0; + MT g_val = threadIdx.x < thresh ? g_tmp_buffer[threadIdx.x] : 0; + __syncthreads(); + MT p_n = SquareRoot(math::blockReduceSum(p_val, FINAL_MASK)); + MT g_n = SquareRoot(math::blockReduceSum(g_val, FINAL_MASK)); #endif + const MT lr = learning_rate[0]; MT local_lr = lr; - if (lars_weight_decay > static_cast(0) && p_n > static_cast(0) && - g_n > static_cast(0)) { + if (lars_weight_decay > static_cast(0)) { local_lr = lr * lars_coeff * p_n / - (fma_root(lars_weight_decay, p_n, g_n) + epsilon); + (FmaRoot(lars_weight_decay, p_n, g_n) + epsilon); } if (master_p) { - LarsUpdateMP(g, v, p_out, v_out, master_p, master_p_out, mu, - local_lr, lars_weight_decay, rescale_grad, tid, - grid_stride, numel); + VectorizeLarsUpdateMP(g, v, p_out, v_out, master_p, master_p_out, mu, + local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); } else { - for (int i = tid; i < numel; i += grid_stride) { - MT grad = static_cast(g[i]) * rescale_grad; - MT param = static_cast(p[i]); - MT v_new = fma_root(v[i], mu, - local_lr * fma_root(lars_weight_decay, param, grad)); - MT p_new = param - v_new; - v_out[i] = v_new; - p_out[i] = static_cast(p_new); + if (std::is_same::value || + std::is_same::value) { + VectorizeLarsUpdate( + reinterpret_cast(g), v, reinterpret_cast(p_out), + v_out, reinterpret_cast(p), mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel); + } else { + VectorizeLarsUpdate( + reinterpret_cast(g), v, reinterpret_cast(p_out), + v_out, reinterpret_cast(p), mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel); } } } @@ -201,12 +270,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const bool multi_precision = ctx.Attr("multi_precision"); - InnerCompute(ctx, multi_precision); - } - - private: - void InnerCompute(const framework::ExecutionContext& ctx, - const bool multi_precision) const { auto param_out = ctx.Output("ParamOut"); auto velocity_out = ctx.Output("VelocityOut"); auto param = ctx.Input("Param"); @@ -214,6 +277,8 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { auto grad = ctx.Input("Grad"); auto learning_rate = ctx.Input("LearningRate"); + int64_t numel = param->numel(); + int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; const framework::Tensor* master_param = nullptr; framework::Tensor* master_param_out = nullptr; const MT* master_p = nullptr; @@ -246,9 +311,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { auto* g = grad->data(); auto* v = velocity->data(); auto* lr = learning_rate->data(); - int64_t numel = param->numel(); - int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; - auto& dev_ctx = ctx.template device_context(); + auto& cuda_ctx = ctx.template device_context(); #if CUDA_VERSION >= 11000 /* @@ -266,7 +329,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { lanuch, essential basis is to control all grid-threads while running. Apart from normal lanuch form, cuda9.0 provides `cudaLaunchCooperativeKernel` api : - - The thread quantity shall equal to pyhsical SM limited threads + - The thread quantity shall less than pyhsical SM limited threads - Launches a device function where thread blocks can cooperate and synchronize as they execute. */ @@ -275,13 +338,17 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, MomentumLarsKernel, LARS_BLOCK_SIZE, sizeof(MT)); - int sm_num = dev_ctx.GetSMCount(); - int grid_real = std::min(sm_num * num_blocks_per_sm, grid); + int sm_num = cuda_ctx.GetSMCount(); + int grid_real = + std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); framework::Tensor tmp_buffer_t = - ctx.AllocateTmpTensor({grid_real}, - dev_ctx); - auto* tmp_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); - MT* tmp_buffer_2 = nullptr; + ctx.AllocateTmpTensor( + {LARS_BLOCK_SIZE << 1}, cuda_ctx); + auto* p_tmp_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); + auto* g_tmp_buffer = p_tmp_buffer + LARS_BLOCK_SIZE; + int grid_stride = LARS_BLOCK_SIZE * grid; + int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; + int thresh = 0; // Uniform kernel parameter for cudaLaunchCooperativeKernel void* cuda_param[] = { @@ -293,41 +360,46 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { reinterpret_cast(&master_p), reinterpret_cast(&master_p_out), reinterpret_cast(&lr), - reinterpret_cast(&tmp_buffer), - reinterpret_cast(&tmp_buffer_2), // Just a placeholder + reinterpret_cast(&p_tmp_buffer), + reinterpret_cast(&g_tmp_buffer), // Just a placeholder reinterpret_cast(&mu), reinterpret_cast(&lars_coeff), reinterpret_cast(&lars_weight_decay), reinterpret_cast(&epsilon), reinterpret_cast(&rescale_grad), + reinterpret_cast(&repeat_times), + reinterpret_cast(&thresh), reinterpret_cast(&numel), }; // Lanuch all sm theads. cudaLaunchCooperativeKernel( reinterpret_cast(MomentumLarsKernel), grid_real, - LARS_BLOCK_SIZE, cuda_param, 0, dev_ctx.stream()); + LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream()); #else - auto eigen_p = framework::EigenVector::Flatten(*param); - auto eigen_g = framework::EigenVector::Flatten(*grad); - // calculate norms using eigein and launch the kernel. - framework::Tensor p_norm_t = - ctx.AllocateTmpTensor({1}, dev_ctx); - framework::Tensor g_norm_t = - ctx.AllocateTmpTensor({1}, dev_ctx); - auto* p_norm_data = p_norm_t.mutable_data(ctx.GetPlace()); - auto* g_norm_data = g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); - auto* place = dev_ctx.eigen_device(); - // eigen unsupport fp16 l2-norm - ep_norm.device(*place) = eigen_p.template cast().square().sum().sqrt(); - eg_norm.device(*place) = - (eigen_g.template cast() * rescale_grad).square().sum().sqrt(); - - MomentumLarsKernel<<>>( - p, g, v, p_out, v_out, master_p, master_p_out, lr, p_norm_data, - g_norm_data, mu, lars_coeff, lars_weight_decay, epsilon, rescale_grad, - numel); + // Determine to read 4 fp16 or float data once, but 2 double data once. + int grid_lars = + sizeof(T) < 64 + ? (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2) + : (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1); + + grid = std::min(grid, LARS_BLOCK_SIZE); + framework::Tensor p_buffer_t = + ctx.AllocateTmpTensor( + {LARS_BLOCK_SIZE << 1}, cuda_ctx); + auto* p_tmp_buffer = p_buffer_t.mutable_data(ctx.GetPlace()); + auto* g_tmp_buffer = p_tmp_buffer + LARS_BLOCK_SIZE; + + const int grid_stride = LARS_BLOCK_SIZE * grid; + const int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; + + L2NormKernel<<>>( + p, g, p_tmp_buffer, g_tmp_buffer, repeat_times, numel, rescale_grad); + + MomentumLarsKernel< + T, MT><<>>( + p, g, v, p_out, v_out, master_p, master_p_out, lr, p_tmp_buffer, + g_tmp_buffer, mu, lars_coeff, lars_weight_decay, epsilon, rescale_grad, + grid, numel); #endif } };