-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify the elementwise op according to the kernel primitive API #34456
Modify the elementwise op according to the kernel primitive API #34456
Conversation
Thanks for your contribution! |
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Outdated
Show resolved
Hide resolved
paddle/fluid/operators/elementwise/elementwise_op_broadcast_api.cu.h
Outdated
Show resolved
Hide resolved
6c0f358
to
bbc68cf
Compare
bbc68cf
to
5065b23
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除不需要的模板参数
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Outdated
Show resolved
Hide resolved
LGTM |
2 similar comments
LGTM |
LGTM |
namespace paddle { | ||
namespace operators { | ||
|
||
#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
定义在ElementwiseType
里面,可以最后定义一个kMaxArity = 4
。
LoadVectorizedDataByDivmod(args[j], tid, j); | ||
} | ||
} | ||
template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议ShapeSize
-> Rank
} | ||
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize, | ||
int VecSize, typename Functor, bool IsBoundary = false> | ||
__device__ void DealSegment( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数名不恰当。
broadcast_wrapper.LoadVectorizedData(args, tid); | ||
template <typename InT, typename OutT, ElementwiseType ET, int VecSize, | ||
int Size, typename Functor> | ||
void LaunchKernel(const platform::CUDADeviceContext &ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LaunchBroadcastKernel
吧,函数名区分一下。
framework::Tensor *out, Functor func, | ||
DimensionsTransform merge_dims) { | ||
int numel = out->numel(); | ||
const int threads = 256; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
线程数原来是通过GetThreadsConfig
控制的,对一些小case能够有效调整线程配置。
OutT *out_data = out->data<OutT>(); | ||
|
||
framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM> | ||
configlists; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
configlists -> config_list
InT args[ET][VecSize]; | ||
broadcast_wrapper.LoadVectorizedData(args, tid); | ||
template <typename InT, typename OutT, ElementwiseType ET, int VecSize, | ||
int Size, typename Functor> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的Size
也是Rank
吧?
inline __device__ void LoadScalarizedData(InT args[], int tid) { | ||
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, | ||
typename Functor, bool IsBoundary> | ||
__device__ void DealSegment( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数名同样需要改下,并且这个函数same dims版本和broadcast版本差不多,可以考虑合并一下。
PR types
Function optimization
PR changes
APIs
Describe
Modify the elementwise op according to the kernel primitive API
1.将elementwise_op_impl.cu.h 中的ElementwiseVectorKernel 根据ET类型拆分成3个cuda kernel 以适配primivetive_api
2.将elementwise_op_broadcast.cu.h 中的ElementwiseBroadcastKernel 根据ET类型拆分成3个cuda kernel 以适配primivetive_api;
3.重构elementwise_op_broadcast.cu.h 中函数调用结构,定义BroadcastConfig结构体,简化broadcastConfig配置方式;
性能: 替换前后 性能与替换之前性能打平,部分case超越原始性能