-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cumprod_grad composite #64432
add cumprod_grad composite #64432
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辛苦提交PR,单测和代码可以再看一下
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 加法第二项
zero_mask * ones_tensor
是不是可以化简为zero_mask
,一个张量乘以全1矩阵ones_tensor
等于没乘? - 如果上面结论成立,可以删除
ones_tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@param.parameterized_class( | ||
('primal', 'dtype'), | ||
[ | ||
( | ||
np.random.rand(2, 3, 4), | ||
np.float32, | ||
), | ||
( | ||
np.random.rand(2, 3, 3, 4), | ||
np.float32, | ||
), | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 形状可以添加0D、1D、2D、5D
- 数据类型添加int64
- 数据本身添加包含1个0、2个0,全部都是0的case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充更多单测case,同test/prim/prim/vjp/eager/test_comp_eager_cumprod_grad.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto zero_mask_cumsum1 = cumsum<T>(zero_mask, dim, false, false, reverse); | ||
auto zero_mask_cumsum2 = cumsum<T>(zero_mask, dim, false, true, reverse); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量命名尽量准确符合语义,
zero_mask_cumsum1 --> zero_mask_cumsum_left
zero_mask_cumsum2 --> zero_mask_cumsum_right
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的ones_tensor应该可以删掉吧
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | |
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; | |
auto replace_one = (1 - zero_mask) * x + zero_mask; |
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
逻辑连续的代码块之间可以加上一些注释,比如此处将0的位置填充成1,可以加上: # fill the positions of 0 with 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
@param.parameterized_class( | ||
('primal', 'dtype'), | ||
[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上一个单元素的测试:np.array(np.rand(), dtype="float32")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
return out | ||
|
||
|
||
@param.parameterized_class( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,补充单元素的0-Dcase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处的实现跟details.h里的实现好像有一些区别?可以确认一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto zero_mask_cumsum_exclusive = | ||
cumsum<T>(zero_mask, dim, false, true, reverse); | ||
auto zero_mask_cumsum = | ||
zero_mask_cumsum_inclusive + zero_mask_cumsum_exclusive; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inclusive + exclusive 是否相当于两倍的exclusive + x?两次cumprod应该比一次cumprod + scale + add耗时会更长,是否可以改成后者实现方式?
zero_mask_cumsum_inclusive + zero_mask_cumsum_exclusive --> scale<T>(zero_mask_cumsum_exclusive, 2) + zero_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo: 建议适配optest,当前的这批单侧属于遗留case,未来会被清理
* add cumprod_grad composite * remove cout * update test and fix some bug * update test * add comment and test * remove cout * update * update static test * update * Update details.h
PR Category
Others
PR Types
New features
Description
基于支持了
inclusive
,reverse
参数的cumprod
算子(#64022),实现cumprod_grad
的组合逻辑,避免cumprod_grad
的组合算子因为输入含有 0 而导致输出除 0 出 nan 的问题