Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Ternary ops in elmentwise and broadcast #33976

Merged

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Jul 5, 2021

PR types

Function optimization

PR changes

OPs

Describe

  • Change elementwise branch for support ternary ops .
  • Moving GetVectorizedSize function into fast_divmod.h to make it common.
  • Change sub-namespace of fast_divmod.h from "operator" to "platform", and change the subsequent codes that calls the function inside fast_divmod.h .

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jul 5, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;

const InT *__restrict__ in_data[ET];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥不直接叫ins呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为 ins 变量名已经在之前被分配给了 const std::vector<const framework::Tensor *> &ins ,用来代指输入tensor向量

// store
data.store_scalar(out, idx);
}
// load
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的循环还是需要的吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加入循环操作,在我支持三元计算的时候看着挺奇怪的,因为原始的判断条件是:

  int remain = size - VecSize * tid;
  remain = remain > 0 ? remain : 0;
  ... ...
  if (remain >= VecSize) {
     VectorizedKernelImpl(data, func, tid);
  } else {
     ScalarKernelImpl(data, func, tid * VecSize, remain);
  }

这样的话就已经筛选出了仅执行尾段不可向量化数据的线程tid了,以及不可向量化的起始位置 tid * VecSize,但是由于前面计算remain部分的存在,以及tid * VecSize 这步计算的存在,导致真正进入ScalarKernelImpl并实施计算的只有一个线程。这个线程进入之后再执行一个for 循环计算:

for (int i = 0; i < remain; ++i) {
   int idx = start + i;
   data.load_scalar(ins, idx);
   out = func(ins);
   data.store_scalar(out, idx);
}

但是,这里也可以采用多线程完成标量化计算,尽管收益会很小。目前的改法则是尽量利用了多线程:

 if (tid < tail_tid) {    // 直接筛选出 tid号为 0, 1, 2... 的线程
    ScalarKernelImpl<ET, DataWarpper, InT, OutT, Functor>(data, func, tid);
  }

 if (tid < numel) {    // 直接筛选出 tid号为 0, 1, 2... 的线程
    ScalarKernelImpl<ET, DataWarpper, InT, OutT, Functor>(data, func, tid);
  }

  // 同时DataWarpper内辅以预先记录好的标量化计算起始点  scalar_cal_offset,实现下面的计算代替  start + i;
  args[i] = in_data[i][tid + scalar_cal_offset];    

刚刚做了一个x = [31, 129], y = [31, 129]的OP Benchmark 本地case,也通过了精度测试

template <typename InT, typename OutT>
int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) {
int GetVectorizedSizeImpl(const std::vector<const framework::Tensor *> &ins,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xxxImpl一般是xxx的具体实现,逻辑上应该是被调用比较合理

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以改成GetVectorizedSizeForTensors这样的?

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉可以改成 GetVectorizedSizeForIO

using OutVecType = CudaAlignedVector<OutT, VecSize>;
template <ElementwiseType ET, int VecSize, typename DataWarpper, typename InT,
typename OutT, typename Functor>
__device__ inline void VectorizedKernelImpl(DataWarpper data, Functor func,
Copy link
Contributor

@ZzSean ZzSean Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加DataWarpper这个模板参数是为了后续还会有别的DataWarpper来使用这个函数吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

咱们这种类名统一一下吧,都叫ElementwiseArgsWrapperBroadcastArgsWrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入DataWarpper 这个模型参数的目的只是单纯的传递一个数据类型,供给形参使用:

template <typename DataWarpper, ...>
_device__ inline void VectorizedKernelImpl(DataWarpper data,....)

下个Commit 会修改成 ElementwiseArgsWrapper.

@JamesLim-sy JamesLim-sy requested a review from Xreki July 14, 2021 11:00
int tid) {
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
template <ElementwiseType ET, int VecSize, typename ElementwiseWarpper,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warpper->wrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,下个commit 修改过来

__device__ inline void ScalarKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int start, int remain) {
template <ElementwiseType ET, typename ElementwiseWarpper, typename InT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,下个commit 修改过来

@ZzSean
Copy link
Contributor

ZzSean commented Jul 28, 2021

LGTM

using OutVecType = CudaAlignedVector<OutT, VecSize>;
template <ElementwiseType ET, int VecSize, typename ElementwiseWrapper,
typename InT, typename OutT, typename Functor>
__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ElementwiseWrapper没有必要放模板里面?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除掉,下个commit 删除

@Xreki Xreki merged commit 1d7b75d into PaddlePaddle:develop Aug 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants