-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[API Compatiblity] Support paddle.index_add
#76170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
|
/re-run all-failed |
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (78.26%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #76170 +/- ##
==========================================
Coverage ? 78.26%
==========================================
Files ? 2
Lines ? 46
Branches ? 0
==========================================
Hits ? 36
Misses ? 10
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| kwargs["value"] = kwargs.pop("source") | ||
|
|
||
| if len(args) >= 2 and isinstance(args[1], int): | ||
| if len(args) < 3 and "index" not in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些报错检查可以非必要不加
| def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: | ||
| @functools.wraps(func) | ||
| def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: | ||
| if "input" in kwargs and "x" not in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些判断一般只用假设一个条件,要么是torch的这种用法,要么是paddle的这种用法。其他情况无需判断。
| raise TypeError( | ||
| "index_add() missing 1 required positional argument: 'source'" | ||
| ) | ||
| if "x" not in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些if not判断是否不需要,此时必然是torch的用法。
| kwargs["axis"] = args[1] | ||
| if len(args) == 3: | ||
| if "index" not in kwargs: | ||
| kwargs["index"] = args[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接写一个
args_list = ['x', 'axis', 'index', 'value']
for ele in args:
# 顺序读取args
不用这么多判断,减少装饰器的非必要损耗。
| ) | ||
|
|
||
|
|
||
| @index_add_decorator() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参考下gather,使用下多种overload签名吧,这样代码可读性会好一些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gather是把不同的参数分发给了两个ops的kernel,但是index_add是把两套参数都转化为一套参数,传入给底层的ops的kernel中,所以改成那样的方式之后,仍然减少不了大量的判断的逻辑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gather是把不同的参数分发给了两个ops的kernel,但是index_add是把两套参数都转化为一套参数,传入给底层的ops的kernel中,所以改成那样的方式之后,仍然减少不了大量的判断的逻辑
@fangfangssj 只提升一下代码的可读性,代码逻辑不变。将多套签名在代码里展示出来。参考#76149
PR Category
User Experience
PR Types
New features
Description
Support compatibility for
paddle.index_add,paddle.index_add_andpaddle.Tensor.index_add为index_add增加torch格式的调用,增加alpha和out参数
参考PR #74873