Skip to content

Commit e8ca424

Browse files
authored
[API-Compat] ForbidKeywordsDecorator now warns user (#74725)
* [API-Compat] ForbidKeywordsDecorator now warns user * [API-Compat] Largely cut down the decorator overhead
1 parent 2d73b40 commit e8ca424

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

python/paddle/nn/layer/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,7 @@ class Unfold(Layer):
19141914
illegal_keys={"kernel_size", "dilation", "padding", "stride"},
19151915
func_name="paddle.nn.Unfold",
19161916
correct_name="paddle.compat.Unfold",
1917+
url_suffix="nn/torch.nn.Unfold",
19171918
)
19181919
def __init__(
19191920
self,

python/paddle/tensor/manipulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,7 @@ def row_stack(x: Sequence[Tensor], name: str | None = None) -> Tensor:
27432743
illegal_keys={"tensor", "split_size_or_sections", "dim"},
27442744
func_name="paddle.split",
27452745
correct_name="paddle.compat.split",
2746+
url_suffix="torch/torch.split",
27462747
)
27472748
def split(
27482749
x: Tensor,

python/paddle/tensor/search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,7 @@ def _restrict_nonzero(condition: Tensor, total_true_num: int) -> Tensor:
644644
illegal_keys={'input', 'dim'},
645645
func_name='paddle.sort',
646646
correct_name='paddle.compat.sort',
647+
url_suffix="torch/torch.sort",
647648
)
648649
def sort(
649650
x: Tensor,

python/paddle/utils/decorator_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,32 @@ class ForbidKeywordsDecorator(DecoratorBase):
324324
"""A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""
325325

326326
def __init__(
327-
self, illegal_keys: set[str], func_name: str, correct_name: str
327+
self,
328+
illegal_keys: set[str],
329+
func_name: str,
330+
correct_name: str,
331+
url_suffix: str = "",
328332
) -> None:
333+
"""
334+
Args:
335+
illegal_keys (set[str]): the keywords to reject
336+
func_name (str): the name of the function being decorated (should incorporate module name, like paddle.nn.Unfold)
337+
correct_name (str): the user hint that points to the correct function
338+
url_suffix (str, optional): Only specified in non paddle.compat functions. If specified, the function being decorated
339+
will emit a warning upon the first call, warning the users about the API difference and points to Docs.
340+
Please correctly specifying the `url_suffix`, this should be the suffix of the api-difference doc. For example:
341+
342+
(prefix omitted)/docs/zh/develop/guides/model_convert/convert_from_pytorch/api_difference/**torch/torch.nn.Unfold**.html
343+
344+
In this example, the correct `url_suffix` should be 'torch/torch.nn.Unfold'. Defaults to an empty str.
345+
"""
329346
super().__init__()
330347
self.illegal_keys = illegal_keys
331348
self.func_name = func_name
332349
self.correct_name = correct_name
350+
self.warn_msg = None
351+
if url_suffix:
352+
self.warn_msg = f"\nNon compatible API. Please refer to https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/model_convert/convert_from_pytorch/api_difference/{url_suffix}.html first."
333353

334354
def process(
335355
self, args: tuple[Any, ...], kwargs: dict[str, Any]
@@ -345,6 +365,11 @@ def process(
345365
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
346366
f"\nDid you mean to use {self.correct_name}() instead?"
347367
)
368+
if self.warn_msg is not None:
369+
warnings.warn(
370+
self.warn_msg,
371+
category=Warning,
372+
)
348373
return args, kwargs
349374

350375

0 commit comments

Comments
 (0)