Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/nn/layer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,7 @@ class Unfold(Layer):
illegal_keys={"kernel_size", "dilation", "padding", "stride"},
func_name="paddle.nn.Unfold",
correct_name="paddle.compat.Unfold",
url_suffix="nn/torch.nn.Unfold",
)
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,6 +2742,7 @@ def row_stack(x: Sequence[Tensor], name: str | None = None) -> Tensor:
illegal_keys={"tensor", "split_size_or_sections", "dim"},
func_name="paddle.split",
correct_name="paddle.compat.split",
url_suffix="torch/torch.split",
)
def split(
x: Tensor,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def _restrict_nonzero(condition: Tensor, total_true_num: int) -> Tensor:
illegal_keys={'input', 'dim'},
func_name='paddle.sort',
correct_name='paddle.compat.sort',
url_suffix="torch/torch.sort",
)
def sort(
x: Tensor,
Expand Down
27 changes: 26 additions & 1 deletion python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,32 @@ class ForbidKeywordsDecorator(DecoratorBase):
"""A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""

def __init__(
self, illegal_keys: set[str], func_name: str, correct_name: str
self,
illegal_keys: set[str],
func_name: str,
correct_name: str,
url_suffix: str = "",
) -> None:
"""
Args:
illegal_keys (set[str]): the keywords to reject
func_name (str): the name of the function being decorated (should incorporate module name, like paddle.nn.Unfold)
correct_name (str): the user hint that points to the correct function
url_suffix (str, optional): Only specified in non paddle.compat functions. If specified, the function being decorated
will emit a warning upon the first call, warning the users about the API difference and points to Docs.
Please correctly specifying the `url_suffix`, this should be the suffix of the api-difference doc. For example:

(prefix omitted)/docs/zh/develop/guides/model_convert/convert_from_pytorch/api_difference/**torch/torch.nn.Unfold**.html

In this example, the correct `url_suffix` should be 'torch/torch.nn.Unfold'. Defaults to an empty str.
"""
super().__init__()
self.illegal_keys = illegal_keys
self.func_name = func_name
self.correct_name = correct_name
self.warn_msg = None
if url_suffix:
self.warn_msg = f"\nNon compatible API. Please refer to https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/model_convert/convert_from_pytorch/api_difference/{url_suffix}.html first."

def process(
self, args: tuple[Any, ...], kwargs: dict[str, Any]
Expand All @@ -345,6 +365,11 @@ def process(
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
f"\nDid you mean to use {self.correct_name}() instead?"
)
if self.warn_msg is not None:
warnings.warn(
self.warn_msg,
category=Warning,
)
return args, kwargs


Expand Down