-
Notifications
You must be signed in to change notification settings - Fork 274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.1】 为 Paddle 新增 copysign API (RFC update) #793
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -196,6 +196,37 @@ NPY_NO_EXPORT void | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
PyTorch和Numpy实现方式基本一致,都是底层调用cpp的math库实现`copysign`,PyTorch可进行backward。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
paddle的promotion机制暂时正在建设中,故现仅考虑输入的两个元素的类型相同的情况,下面是对竞品面对不同输入类型的行为记录(pytorch和numpy均不支持complex类型): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x | y | np.copysign(x,y)/torch.copysign(x,y) | grad_x(grad_y) | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------------- | ---------------- | ------------------------------------ | -------------- | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.uint8 | np.uint8 | np.float16 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.uint8 | torch.uint8 | torch.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.int8 | np.int8 | np.float16 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.int8 | torch.int8 | torch.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.int16 | np.int16 | np.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.int16 | torch.int16 | torch.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.int32 | np.int32 | np.float64 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.int32 | torch.int32 | torch.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.int64 | np.int64 | np.float64 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.int64 | torch.int64 | torch.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.float16 | np.float16 | np.float16 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.float16 | torch.float16 | torch.float16 | torch.float16 | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.float32 | np.float32 | np.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.float32 | torch.float32 | torch.float32 | torch.float32 | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.float64 | np.float64 | np.float64 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.float64 | torch.float64 | torch.float64 | torch.float64 | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.complex64 | np.complex64 | / | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.complex64 | torch.complex64 | / | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.complex128 | np.complex128 | / | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.complex128 | torch.complex128 | / | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.bool | np.bool | np.float16 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.bool | torch.bool | torch.float32 | / | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.bfloat16 | torch.bfloat16 | torch.bfloat16 | torch.bfloat16 | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ 可以发现,在整型输入时,numpy和pytorch的行为略有不同:pytorch面对整型输入,均保持`float32`作为输出,而numpy在整型输入时,仅当dtype为`int16`时,输出的dtype与pytorch对齐(均为`float32`)。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ 另外,pytorch不支持整型(包括bool)的反向传播(但是paddle目前似乎并未对此作限制)。对于浮点数,输入和输出类型保持一致。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# 五、设计思路与实现方案 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
## 命名与参数设计 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -214,43 +245,51 @@ API的设计为: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
## 底层OP设计 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
参考PyTorch与Numpy中的设计,调用底层cpp实现OP,反向 kernel impl 大致如下: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
参考`elementwise_compute`类型的其他op,支持broadcast。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
```cpp | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
template<typename T> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
struct CopySignGradFunctor { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
CopySignGradFunctor(const T* x_data, const T* y_data, const T* dout, T* dx, int64_t numel) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
: x_data_(x_data), y_data_(y_data), dout_(dout), dx_(dx), numel_(numel) {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// backward 逻辑如下 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
HOSTDEVICE void operator()(int64_t idx) const { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (x_data_[idx] == T(0)) dx_[idx] = T(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else dx_[idx] = T(dout_[idx]) * (T(std::copysign(x_data_[idx], y_data_[idx]) / x_data_[idx])); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
**需要进一步确认的点:** | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const T* x_data_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const T* y_data_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const T* dout_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
T* dx_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t numel_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
template <typename T, typename Context> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void CopySignGradKernel(const Context& dev_ctx, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const DenseTensor& x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const DenseTensor& y, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const DenseTensor& out_grad, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DenseTensor* x_grad) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dev_ctx.template Alloc<T>(x_grad); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto x_data = x.data<T>(), y_data = y.data<T>(), out_grad_data = out_grad.data<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto x_grad_data = x_grad->data<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
phi::funcs::ForRange<Context> for_range(dev_ctx, x.numel()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
phi::CopySignGradFunctor<T> functor(x_data, y_data, out_grad_data, x_grad_data, x.numel()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for_range(functor); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ 竞品均直接采用cpp标准库中的`std::copysign`来实现Functor,然而这个库函数在接收整型输入时,自动会进行promotion为浮点数: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
```cpp | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#ifndef __CORRECT_ISO_CPP11_MATH_H_PROTO_FP | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
constexpr float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copysign(float __x, float __y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ return __builtin_copysignf(__x, __y); } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
constexpr long double | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copysign(long double __x, long double __y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ return __builtin_copysignl(__x, __y); } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#ifndef __CORRECT_ISO_CPP11_MATH_H_PROTO_INT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
template<typename _Tp, typename _Up> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
constexpr typename __gnu_cxx::__promote_2<_Tp, _Up>::__type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copysign(_Tp __x, _Up __y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
typedef typename __gnu_cxx::__promote_2<_Tp, _Up>::__type __type; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return copysign(__type(__x), __type(__y)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
而且numpy和pytorch遇到整型输入,得到输出dtype常常不同,需要确定paddle的实现。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> 答:paddle选择更加符合直觉的行为,例如输入整型,输出同样为整型(而不像竞品会自动提升到浮点数) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ 反向传播的dtype是否一直保持不变?如果是,那么遇到dtype为整型的时(paddle没有严格限制反向传播过程的dtype不能为整型,pytorch有强制限制反向传播过程不能为整型),求梯度会有f(x,y)得到float类型,就变了。paddle目前支持反向传播过程中数据类型发生变化吗? | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> 答:这个问题,仍然需要拆分到grad OP的视角来看,由grad OP控制,比如下面这个例子,可以试着把x,y分别切换成float / int类型看看结果 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
单从功能上来看,`copysign`实现的逻辑比较简单:取第一个变量的绝对值的大小,取第二个变量符号,两者拼接。感觉从功能上来看,可以不拘泥于跟竞品一样调用标准库的`std::copysign`,而是直接在Functor中判断来实现,而且输入和输出dtype保持相同。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否可以仍然调用标准库,只是kernel计算逻辑里额外做下cast;包括这里设计的方案都可以尝试下,看看目前实现起来是否有堵点, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
现在就是在kernel里面加了这个: using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
dev_ctx.template Alloc<U>(out); 就是根据注册的dtype来申请输出的内存空间,如果注册的dtype为整型相关的,那么就申请
如果是浮点数或者float16,bfloat16这两个特殊类型,就保持输入什么dtype,输出就什么dtype,这里也跟pytorch实现的行为一致: >>> ****************** float16 *******************
>>> x = paddle.to_tensor([10], dtype=paddle.float16)
>>> y = paddle.to_tensor([-10], dtype=paddle.float16)
>>> func(x,y)
Tensor(shape=[1], dtype=float16, place=Place(gpu:0), stop_gradient=True,
[-10.])
>>> ****************** bfloat16 *******************
>>> x = paddle.to_tensor([10], dtype=paddle.bfloat16)
>>> y = paddle.to_tensor([-10], dtype=paddle.bfloat16)
>>> func(x,y)
Tensor(shape=[1], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True,
[-10.]) 输入类型和输出类型一样。 然后下面是一些整型: >>> x = paddle.to_tensor([10], dtype=paddle.uint8)
>>> y = paddle.to_tensor([10], dtype=paddle.uint8)
>>> func(x,y)
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[10.])
>>> x = paddle.to_tensor([10], dtype=paddle.int8)
>>> y = paddle.to_tensor([10], dtype=paddle.int8)
>>> func(x,y)
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[10.])
>>> x = paddle.to_tensor([10], dtype=paddle.int16)
>>> y = paddle.to_tensor([10], dtype=paddle.int16)
>>> func(x,y)
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[10.])
>>> x = paddle.to_tensor([10], dtype=paddle.int32)
>>> y = paddle.to_tensor([10], dtype=paddle.int32)
>>> func(x,y)
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[10.])
>>> x = paddle.to_tensor([10], dtype=paddle.int64)
>>> y = paddle.to_tensor([10], dtype=paddle.int64)
>>> func(x,y)
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[10.])
>>> 可以看到输出都是float32。 现在就是目前的实现,仍然调用标准库,而且和pytorch对齐了的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 从前面的讨论来看,大致存在两个做法:
即有符号int仍输出有符号int,无符号int需要额外考虑,不过由于本身也不存在符号,似乎也可以保持原输入dtype 看到前面已经阐述了第一个方案具备可行性。想了解下第二种方案是否具备可行性呢。 目前主要还是想评估一下两个方案的优劣情况,确定这个API的最终行为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
对,这两个方案应该都可以,只是想要对齐pytorch的行为的话,就是目前的版本;如果想要第二种做法的话,在调用 无符号数在调用标准库的实现的时候,它应该已经自动先转为浮点数(而且是正数)了,而不会像 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cocoshe 你好,我们内部讨论了下,决定还是按语义符合直觉的方向去实现更好,,即第二种输入输出dtype一致的方案。辛苦修改下RFC和代码呢~ |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> 从前面的讨论来看,大致存在两个做法: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> 1. 输入整型时,输出提升为浮点数 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> 2. 输出dtype保持与输入dtype一致 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
> 经过讨论,决定还是按语义符合直觉的方向去实现更好,即第二种输入输出dtype一致的方案。 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
## API实现方案 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -265,14 +304,14 @@ void CopySignGradKernel(const Context& dev_ctx, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
测试考虑的case如下: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **编程范式场景**:常规覆盖动态图和静态图的测试场景 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **硬件场景**:常规需覆盖 CPU、GPU 两种测试场景 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **参数组合场景**:常规覆盖 API 的全部入参,需要对全部入参进行参数有效性和边界值测试,同时可选参数也需有相应的测试覆盖 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **计算精度**:需要保证前向计算、反向计算的精度正确性 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ 前向计算:通过 numpy 实现的函数的对比结果 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ 反向计算:通过 numpy 推导,计算反向结果的正确性 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **维度测试**:Paddle API 支持的最低维度为 0 维,单测中应编写相应的 0 维尺寸测试 case | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **边界测试**:y为0、+0、-0时,测试与numpy结果的一致性 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
+ **类型检测**:输入与输出dtype保持相同 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# 七、可行性分析及规划排期 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
梯度输出的类型,理论上还是应该最终由grad OP控制吧,如果从OP的设计上是统一类型的输入输出,那么结果应该是不变的?
这个问题,仍然需要拆分到grad OP的视角来看,由grad OP控制,比如下面这个例子,可以试着把x,y分别切换成float / int类型看看结果
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯我尝试了一下,前两天我去仔细看了下phi算子的注册逻辑,但是就是在这里,有个地方不太理解是怎么实现的。就像之前提到的,在写kernel的时候,我们其实一开始就要给out申请空间,这里就要确定out的类型了:
dev_ctx.template Alloc<T>(out);
比如在你这个
z2 = z1 / 10
的时候,z1
是int32
,但是z2
却是一个float32
。我看到文档里面提到除法就是
divide
包了一层魔法函数,然后去看了下devide
的kernel:它注册的时候:
也并没有去指定他的输出类型,采用的是默认的,也就是说
template<T, Context>
中的T
分别注册了float,double,int8_t,uint8_t,int16_t,int,int64_t,bool,complex64,complex128
这些类型,每次注册的时候,由于都是默认的,所以kernel的参数:const DenseTensor& x,const DenseTensor& y,DenseTensor* out
都应该是当前注册时候的类型,例如注册int32
的时候,输入和输出的tensor dtype都必须全为int32
?我试了一下直接调用
divide
这个api:为什么直接
z1 / 10
能够正常呢?难道是在进kernel之前把输入的两个dtype检测同步了一下吗?因为单从这个kernel来看,应该是仅支持所有输入、输出dtype相同There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
运算符和api在Paddle中有些差异,感兴趣具体可以看
tensor__div__method
的实现,中间有插入额外cast操作。不过我理解上述case主要目的和这个除法无关,主要想表示int tensor在网络中参与反向时的情况。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯谢谢~,fluid部分我还不是很熟悉。
嗯嗯明白,backward确实取决于grad op的操作,但是目前大部分grad op也都是输入和输出dtype必须保持一致吧。
例如上面您说的那个例子:
结果
x.grad
和y.grad
输出是全0,是错误的,类型是int32
可以发现两个梯度是正确的,类型是
float32
这样看来,divide的grad kernel只能接受输入和输出dtype相同的情况。所以在backward计算的时候,似乎基本都不太支持dtype发生改变。
感觉应该是grad kernel中存在forward时候的下一个op(backward时的上一个op)的
dout
,作为其中的一个参数,然后计算当前输入变量的梯度,如果中间类型发生变化,那么对grad kernel来说,又是变成了"输入两个变量类型不同"的情况。