-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pten] Refactor the implementation of custom operator #37122
[Pten] Refactor the implementation of custom operator #37122
Conversation
… pten/refactor_custom_op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测整理被识别为移除单测
paddle/pten/common/data_type.h
Outdated
PD_THROW("Data type ", | ||
static_cast<int>(data_type), | ||
" is not supported by tensor."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里和下面176行的报错提示写法有点不太一致,是两种写法都可以吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,这里也加了个引号
* @brief Return the shape (dimensions) of Tensor. | ||
* The compatible method of `Tensor::dims()`. | ||
* This is a deprecated method and may be removed in the future! | ||
* | ||
* @return std::vector<int64_t> | ||
*/ | ||
std::vector<int64_t> shape() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果将来删除shape函数,会不会出现C++的Tensor接口与Python不一致的问题?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要做的不是移除这个接口,而是做不兼容升级,后续在接口中给出具体的提示
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape
is better than dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前需要保留shape接口也是为了兼容
paddle/pten/api/include/tensor.h
Outdated
bool is_cpu() const; | ||
|
||
/** | ||
* @brief Determine whether the tensor device is CPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的CPU应该是GPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
*/ | ||
bool is_cpu() const { return paddle::platform::is_cpu_place(place()); } | ||
bool is_cuda() const { return paddle::platform::is_gpu_place(place()); } | ||
paddle::platform::Place inner_place() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果上面的place()接口将来移除的话,这里的inner_place()可以更新为place()吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以,这里也需要做不兼容升级
paddle/pten/api/include/tensor.h
Outdated
* This is a deprecated method and may be removed in the future! | ||
* | ||
* @tparam T | ||
* @param target_place of target place, of which the tensor will copy to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的语法感觉有点怪
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
paddle/pten/api/include/tensor.h
Outdated
* @return None | ||
* @brief Transfer the current Tensor to the specified device and return. | ||
* | ||
* @param place of target place, of which the tensor will copy to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
|
||
/** | ||
* @brief Determine whether Tensor is initialized. | ||
* This is a deprecated method and may be removed in the future! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些将来准备删除的接口是不是可以在注释里给出替代当前接口的方法?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续会在接口调用时直接给出warning提示,注释对用户来讲没什么意义,后续会再补充一下,目前统一使用这个写法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for overall
paddle/pten/api/lib/tensor.cc
Outdated
|
||
template <typename T> | ||
T *Tensor::mutable_data() { | ||
if (impl_->type_info().name() == "DenseTensor") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not Important,这里对是否为DenseTensor,后续是否可以单独一个函数?我看多处调用了,后续实现上有变动,接口层是不用改动的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thx,暂时封装了一个inline函数做判断,后面再想想更好的方式
… pten/refactor_custom_op
DataType dtype = DataType::UNDEFINED, | ||
Backend backend = Backend::UNDEFINED, | ||
DataLayout layout = DataLayout::UNDEFINED); | ||
PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PD_DLL_DECL含义是?还有使用PD_DLL_DECL的原因是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
windows符号需要手动导出,不然运行时会找不到符号,这点和unix不一样
int numel = x.size(); | ||
int block = 512; | ||
int grid = (numel + block - 1) / block; | ||
PD_DISPATCH_FLOATING_AND_HALF_TYPES( | ||
x.type(), "relu_cuda_forward_kernel", ([&] { | ||
auto cpu_input = x.copy_to<data_t>(paddle::PlaceType::kCPU); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不测下copy_to吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy_to暂时disable了,还有reshape,cast,slice,会在下个PR加回来,也会再补充对应的单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以看下PR描述中的TODO事项解释
paddle/pten/api/include/registry.h
Outdated
PD_DLL_DECL int RegisterSymbolsFor##name() { return 0; } | ||
|
||
#define PT_DECLARE_API(name) \ | ||
extern int RegisterSymbolsFor##name(); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也加个 PD_DLL_DECL 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
|
||
#define PT_DECLARE_API(name) \ | ||
extern int RegisterSymbolsFor##name(); \ | ||
UNUSED static int use_pten_api_##name = RegisterSymbolsFor##name() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UNUSED作用是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对外暴露的普通函数和类,声明时都需要加上PD_DLL_DECL
|
||
namespace detail { | ||
|
||
inline bool IsDenseTensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[TODO] 后续此处会改为 no-rtti 的形式
pten::DenseTensor::classof(derived_a.get());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
@@ -20,4 +25,5 @@ endif() | |||
if(WITH_XPU) | |||
set(PTEN_DEPS ${PTEN_DEPS} manipulation_xpu) | |||
endif() | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove additional blank line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不重要,仅是和上面的做个跨行区分
platform::errors::InvalidArgument( | ||
"TensorImpl with nullptr is not supported")); | ||
} | ||
explicit Tensor(const PlaceType& place); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a good way to compatible with original custom tensor, these two constructor here is not safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个构造都不安全,但是现在需要兼容,后续会增加deprecated warning,需要一段时间,比如在2.4版本废弃掉
*/ | ||
paddle::experimental::DataType type() const { return impl_->data_type(); } | ||
DataType type() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type is not a good name... how about var_type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这个type接口也是为了兼容自定义算子;2. var_type并不是一个好名字,因为目前的体系里没有Variable
|
||
} // namespace detail | ||
|
||
paddle::platform::DeviceContext* GetDeviceContextByBackend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* | ||
* @return int64_t | ||
*/ | ||
int64_t size() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i prefer to using size()
instead of numel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, 我下个PR将numel移除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see #37237
* @brief Return the shape (dimensions) of Tensor. | ||
* The compatible method of `Tensor::dims()`. | ||
* This is a deprecated method and may be removed in the future! | ||
* | ||
* @return std::vector<int64_t> | ||
*/ | ||
std::vector<int64_t> shape() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape
is better than dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimization
PR changes
OPs
Describe
[Pten] Refactor the implementation of custom operator
Paddle Tensor计算库(以下简称pten)需要为自定义算子提供更多的C++运算类API,从而使自定义算子的开发成本进一步降低,因此需要将自定义算子与pten实现打通,本PR改动如下:
本次自定义算子与pten整合后,自定义算子开发便利性会得到有效增强,复杂Kernel开发代码量显著减少,以开发一个linear算子的前向Kernel为例:
1. 原先实现方式(手写基础实现逻辑,较复杂)
说明: 此处仅为paddle内部matmul前向kernel的主体实现代码,不包含linear的add运算,add运算内部实现还有上千行代码(#37034),代码量太多,这里不贴了。此处示例代码未编译验证,仅用于对比代码量变化。并且,在外部自定义算子中写出高效的matmul和elementwise_add是很困难的,原因包括但不限于:
2. 本PR实现方式(一行写完,并且支持多设备,已在单测中验证)
说明: 这里使用的C++ API与Python端对应API在以下各方面均保持一致:
因此,用户可直接参考Python API文档使用C++ API,无额外理解成本,并且性能也是经过paddle内部同学优化的。
当然,目前以上两种实现方式都是支持的,并不是PR合入后原先的写法就不支持了。
TODO事项
reshape
,copy_to
,slice
,cast
方法均是之前专门为自定义算子编写的实现,本次整合后,这几个API将直接复用pten的C++ API和kernel实现功能,而目前pten相关的kernel还在迁移中,因此这几个API在后续PR(11.20之前)完成,本PR暂时禁用了这几个方法paddle::experimental
作为命名空间前缀,目前的API还在实验阶段,尚有不确定性,之前自定义算子开放API过于草率,导致之前设计考虑不够充分的一些数据结构和API直接暴露给用户,现在为了确保兼容性又不能废弃,例如之前的PlaceType,及Tensor的不完全构造函数、Tensor::stream()等,后续这些设计不太好的接口也会添加Deprecated Warning