Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lars op optimiztion with cudaLaunchCooperativeKernel method #35652

Merged

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Sep 10, 2021

PR types

Performance optimization

PR changes

OPs

Describe

  • Conception
    Theoretically, the lars op was consist of 2 steps :

    • L2 norm of grad and param respectively to generate fundamental hyper-param for lars update.
    • Lars update.
  • Drawback:
    Originally, the L2 norm kernel was implemented with combined eigen kernel below. As shown by codes, a combined L2 norm kernel relied on two eigen kernels. Hence, while training with lars optimizer, eigen kernel would be called 4 times in total. Take lars update cuda kernel into account, 5 kernel were needed for only 1 lars op.

     eigen_p.template cast<MPDType>().square().sum().sqrt();
    (eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();

截屏2021-09-06 下午7 13 34

  • Optimization method
  1. Instead of 4 eigen kernels, whole L2 norm calculation can be combined into only 1 cuda kernel,
  2. L2 norm is acquired with two steps :
    • (1) Partial reduction result of pow(grad, 2) and pow(param, 2) in each block and store into temporary global memory.
    • (2) Read partial reduction results from gloabl memory and reduce to get the final l2 norm result.
  3. No more than 1024 * sizeof(DataType) temporary memory was needed for partial reduction in L2 norm by design, in case of too much temporary memory occupied.
  4. Apart from L2 norm, lars update part was optimized with vectorize IO to gain the performance.
  5. Once running on A100 device, whole lars op be combined into just 1 cuda kernel, with the help of cudaLaunchCooperativeKernel. Otherwise, two kernels are needed, namely step (1) was estabulished with one kernel, step (2) and lars update were estabulished with another kernel.
  • Performance
  1. After optimization of lars op without cudaLaunchCooperativeKernel, the training performance increase from 2969 ips to 3040 ips, and total kernel needed shrink from 5 to 2.
    image

  2. After optimization of lars op with cudaLaunchCooperativeKernel, the training performance increase from 2969 ips to 3110.58 ips, , and total kernel needed shrink from 5 to 1.

截屏2021-09-11 下午5 05 24

  • Note
    More than cudaLaunchCooperativeKernel or DynamicParallelism, five other methods acquired from internet had been tried, but none of them worked well in CUDA kernel combination. So, I introduce these 2 ways mentioned to you, my friends. They truly may help.

  • Convergence
    Loss comparison while trainning with or without perf-optimization of lars kernel in one CUDA card are shown in figure below:
    截屏2021-09-24 下午3 11 49

As for resnet50, converged at epoch 35, convergence stats data are shown below:

loss 29.42091, acc1 0.75918, acc5 0.92844, time 17.7330 sec

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@JamesLim-sy JamesLim-sy changed the title A leap of try for cudaLaunchCooperativeKernel Optimization of Lars Op with cudaLaunchCooperativeKernel method Sep 11, 2021
@JamesLim-sy JamesLim-sy changed the title Optimization of Lars Op with cudaLaunchCooperativeKernel method Lars op optimiztion with cudaLaunchCooperativeKernel method Sep 11, 2021
#if CUDA_VERSION >= 11000
#define FUNCTION_FLAG __device__
#else
#define FUNCTION_FLAG __global__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global函数内还能再调用global吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是考虑到代码的复用性,对于CUDA >= 1100即CUDA 11的情况, L2NormKernel 可以合并至MomentumUpdateKernel 内实现,此时是__device__ kernel;其他情况下,与MomentumUpdateKernel 分离,成为单独的__gloabl__ kernel


namespace paddle {
namespace operators {

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

__device__ __forceinline__ float SquareRoot(float x) { return sqrtf(x); }
__device__ __forceinline__ double SquareRoot(double x) { return sqrt(x); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda内的sqrt本身已经重载过float类型了,可以不用再封装一层

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,根据建议修改。


namespace paddle {
namespace operators {

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

__device__ __forceinline__ float SquareRoot(float x) { return sqrtf(x); }
__device__ __forceinline__ double SquareRoot(double x) { return sqrt(x); }
__device__ __forceinline__ float FmaRoot(float x, float y, float z) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改。

return fma(x, y, z);
}

template <typename MT, int VesSize>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VesSize->VecSize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改。

MT local_lr = lr;
const MT p_n = static_cast<MT>(p_norm[0]);
const MT g_n = static_cast<MT>(g_norm[0]);
__device__ inline void VectorizeLarsUpdateMP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VectorizeLarsUpdateMP和VectorizeLarsUpdate不可以写成一个吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

封装后的代码太乱了,所以就拆开分成了两份写。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我也觉得封装成一个函数比较好,代码相似度很高,理论上并不会“太乱”。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,根据建议修改

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证一下ResNet50训练的收敛情况,loss、训练精度变化画成图贴到PR描述中。


#if defined(__NVCC__) && CUDA_VERSION >= 11000
#include <cooperative_groups.h>
#define LARS_FUNCTION_FLAG __device__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加些注释说明下,为什么这样定义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,根据建议添加注释。


namespace paddle {
namespace operators {

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

template <typename MT, int VecSize>
__device__ inline void VectorizeLarsUpdate(
const MT* __restrict__ g, const MT* __restrict__ v, MT* __restrict__ p_out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p、g、v尽量用全称吧,param、grad、velocity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改


#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT grad = g_data.val[j] * rescale_grad;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可直接用g_data[j]访问

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,根据建议修改


namespace paddle {
namespace operators {

template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

template <typename MT, int VecSize>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用T

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

MT local_lr = lr;
const MT p_n = static_cast<MT>(p_norm[0]);
const MT g_n = static_cast<MT>(g_norm[0]);
__device__ inline void VectorizeLarsUpdateMP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我也觉得封装成一个函数比较好,代码相似度很高,理论上并不会“太乱”。

// As for multiple-precision, type T and MT cannot be more than fp16 or fp32,
// Then, the maximum data IO size could be set to 4.
using VecType = paddle::platform::AlignedVector<T, 4>;
using VecMType = paddle::platform::AlignedVector<MT, 4>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么强制写死成了4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于FP16的数据而言,T类型代表fp16, MT 类型代表 fp32,单次向量化IO 4个数据规模并不会超过128 bits的限制,所以这里显式地写明是4了,后续会添加注释

grid_stride, numel);
} else {
if (std::is_same<T, float>::value ||
std::is_same<T, paddle::platform::float16>::value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用判断地址是否对齐吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支还可能是float16类型?

std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
framework::Tensor tmp_buffer_t =
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
{LARS_BLOCK_SIZE << 1}, cuda_ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

协作组方式也需要临时空间吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

申请临时空间用来存储各block上的parital_l2_nrom,为后续计算全局l2_norm做准备,引入协作组的Cooperative_groups的目的是达成 grid_sync操作

Xreki
Xreki previously approved these changes Sep 26, 2021
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 如果精度验证没有问题,这个PR可以先合进去。

return fma(x, y, z);
}

template <typename T, typename MT, int VecSize, bool IsAmp = false>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有必要增加IsAmp这样一个模板变量,可以判断T是不是float16类型。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前修改这里比较疑惑的问题是:调用VectorizeLarsUpdate接口的时候,会不会存在针对fp16类型数据也存在IsAmp=false的情形。所以,当时没有用std::is_same<> 转用了加入一个模板参数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

又看了一次lars_oprimizer 的构造过程,其中包含一条:

find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
attrs = {
"mu": self._momentum,
"lars_coeff": self._lars_coeff,
"lars_weight_decay": _lars_weight_decay,
"multi_precision": find_master,
"rescale_grad": self._rescale_grad
}
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Velocity": velocity_acc,
"LearningRate": lr
}
outputs = {"ParamOut": param_and_grad[0], "VelocityOut": velocity_acc}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)

看逻辑是必须同时满足type == fp16,IsAmp == true 的情况下才能用master_param,单独用float16类型作判断应该是不够,但是不知道这里是不是amp部分设计的逻辑有问题,我去问问amp的同学。


template <typename T, typename MT, int VecSize, bool IsAmp = false>
__device__ inline void VectorizeLarsUpdate(
const T* __restrict__ grad, const MT* __restrict__ param,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param其实是master_param

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于AMP计算来说,传入的是master_param;对于非AMP计算,传入的直接就是param了,这里的形参统一用了param命名

Fma(velocity_data[j], mu,
local_lr * Fma(lars_weight_decay, param_data[j], grad_val));
param_tmp[j] = param_data[j] - velocity_tmp[j];
param_out_tmp[j] = static_cast<T>(param_tmp[j]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个变量的命名不太直观。param_tmp应该是master_param_out_tmp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个commit会根据建议修改

VecMType* master_param_out_vec;
if (IsAmp) {
master_param_out_vec = reinterpret_cast<VecMType*>(master_param_out);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_vec我会以为是AlignedVector类型,结果却是指针。用aligned_vector.h里面封装的LoadStore函数,会不会好一些?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名方式确实存在问题

grid_stride, numel);
} else {
if (std::is_same<T, float>::value ||
std::is_same<T, paddle::platform::float16>::value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支还可能是float16类型?

@JamesLim-sy JamesLim-sy merged commit a112ce4 into PaddlePaddle:develop Sep 27, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
…ddle#35652)

* A leap of try for cudaLaunchCooperativeKernel

* fix bugs

* Totally replace the lar cuda kernel

* Fix bugs

* fix code according to comments

* fix codes according to  review comments

* adding some function overload

* relocate the power operation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants