-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[API Compatiblity] Support paddle.index_add
#76170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -734,3 +734,64 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: | |
| return wrapper | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| def index_add_decorator() -> Callable[ | ||
| [Callable[_InputT, _RetT]], Callable[_InputT, _RetT] | ||
| ]: | ||
| """ | ||
| Usage Example: | ||
| PyTorch: index_add(input, dim, index, source, *, alpha=1) | ||
| torch.index_add(input_tensor, 1, indices, source_tensor) | ||
|
|
||
| Paddle: index_add(x, index, axis, value, alpha=1) | ||
| paddle.index_add(x=input_tensor, index=indices, axis=1, value=source_tensor) | ||
| paddle.index_add(input_tensor, indices, 1, source_tensor) | ||
| """ | ||
|
|
||
| def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: | ||
| @functools.wraps(func) | ||
| def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: | ||
| if "input" in kwargs and "x" not in kwargs: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些判断一般只用假设一个条件,要么是torch的这种用法,要么是paddle的这种用法。其他情况无需判断。 |
||
| kwargs["x"] = kwargs.pop("input") | ||
| if "dim" in kwargs and "axis" not in kwargs: | ||
| kwargs["axis"] = kwargs.pop("dim") | ||
| if "source" in kwargs and "value" not in kwargs: | ||
| kwargs["value"] = kwargs.pop("source") | ||
|
|
||
| if len(args) >= 2 and isinstance(args[1], int): | ||
| if len(args) < 3 and "index" not in kwargs: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些报错检查可以非必要不加 |
||
| raise TypeError( | ||
| "index_add() missing 1 required positional argument: 'index'" | ||
| ) | ||
| if ( | ||
| len(args) < 4 | ||
| and "source" not in kwargs | ||
| and "value" not in kwargs | ||
| ): | ||
| raise TypeError( | ||
| "index_add() missing 1 required positional argument: 'source'" | ||
| ) | ||
| if "x" not in kwargs: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些if not判断是否不需要,此时必然是torch的用法。 |
||
| kwargs["x"] = args[0] | ||
| if "axis" not in kwargs: | ||
| kwargs["axis"] = args[1] | ||
| if len(args) == 3: | ||
| if "index" not in kwargs: | ||
| kwargs["index"] = args[2] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接写一个 不用这么多判断,减少装饰器的非必要损耗。 |
||
| args = args[3:] | ||
| elif len(args) >= 4: | ||
| if "index" not in kwargs: | ||
| kwargs["index"] = args[2] | ||
| if "value" not in kwargs: | ||
| kwargs["value"] = args[3] | ||
| args = args[4:] | ||
| else: | ||
| args = args[2:] | ||
|
|
||
| return func(*args, **kwargs) | ||
|
|
||
| wrapper.__signature__ = inspect.signature(func) | ||
| return wrapper | ||
|
|
||
| return decorator | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参考下gather,使用下多种overload签名吧,这样代码可读性会好一些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gather是把不同的参数分发给了两个ops的kernel,但是index_add是把两套参数都转化为一套参数,传入给底层的ops的kernel中,所以改成那样的方式之后,仍然减少不了大量的判断的逻辑
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fangfangssj 只提升一下代码的可读性,代码逻辑不变。将多套签名在代码里展示出来。参考#76149