From e1f1273f8ab74a12b46ef0343182746bf2ab2aa0 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Tue, 19 Aug 2025 07:47:57 +0000 Subject: [PATCH 1/2] [API-Compat] ForbidKeywordsDecorator now warns user --- python/paddle/nn/layer/common.py | 1 + python/paddle/tensor/manipulation.py | 1 + python/paddle/tensor/search.py | 1 + python/paddle/utils/decorator_utils.py | 31 +++++++++++++++++++++++++- 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index eed4eaca760f52..6ba4ef9f76290a 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1914,6 +1914,7 @@ class Unfold(Layer): illegal_keys={"kernel_size", "dilation", "padding", "stride"}, func_name="paddle.nn.Unfold", correct_name="paddle.compat.Unfold", + url_suffix="nn/torch.nn.Unfold", ) def __init__( self, diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 403f48d17c2334..8c421a8ad8f81a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2742,6 +2742,7 @@ def row_stack(x: Sequence[Tensor], name: str | None = None) -> Tensor: illegal_keys={"tensor", "split_size_or_sections", "dim"}, func_name="paddle.split", correct_name="paddle.compat.split", + url_suffix="torch/torch.split", ) def split( x: Tensor, diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 5a40997626ba7b..5ab629e5fc5bc2 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -631,6 +631,7 @@ def _restrict_nonzero(condition: Tensor, total_true_num: int) -> Tensor: illegal_keys={'input', 'dim'}, func_name='paddle.sort', correct_name='paddle.compat.sort', + url_suffix="torch/torch.sort", ) def sort( x: Tensor, diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index 8f0c55e38caf5c..0260a6ed361f92 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -323,13 +323,36 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: class ForbidKeywordsDecorator(DecoratorBase): """A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected""" + _site_format = ( + "https://www.paddlepaddle.org.cn/documentation/docs/en/develop/" + "guides/model_convert/convert_from_pytorch/api_difference/{url_suffix}.html" + ) + def __init__( - self, illegal_keys: set[str], func_name: str, correct_name: str + self, + illegal_keys: set[str], + func_name: str, + correct_name: str, + url_suffix: str = "", ) -> None: + """ + Args: + illegal_keys (set[str]): the keywords to reject + func_name (str): the name of the function being decorated (should incorporate module name, like paddle.nn.Unfold) + correct_name (str): the user hint that points to the correct function + url_suffix (str, optional): Only specified in non paddle.compat functions. If specified, the function being decorated + will emit a warning upon the first call, warning the users about the API difference and points to Docs. + Please correctly specifying the `url_suffix`, this should be the suffix of the api-difference doc. For example: + + (prefix omitted)/docs/zh/develop/guides/model_convert/convert_from_pytorch/api_difference/**torch/torch.nn.Unfold**.html + + In this example, the correct `url_suffix` should be 'torch/torch.nn.Unfold'. Defaults to an empty str. + """ super().__init__() self.illegal_keys = illegal_keys self.func_name = func_name self.correct_name = correct_name + self.url_suffix = url_suffix def process( self, args: tuple[Any, ...], kwargs: dict[str, Any] @@ -345,6 +368,12 @@ def process( f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. " f"\nDid you mean to use {self.correct_name}() instead?" ) + if self.url_suffix: + warnings.warn( + f"\nThis is a non compatible API. Please refer to {self._site_format.format(url_suffix=self.url_suffix)} first." + f"\nA compatible version of this API: `{self.correct_name}` can be also used, make sure the correct API is called.", + category=Warning, + ) return args, kwargs From 7b50c1474e5752f8b480493d9789729115e784bb Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Wed, 20 Aug 2025 05:13:40 +0000 Subject: [PATCH 2/2] [API-Compat] Largely cut down the decorator overhead --- python/paddle/utils/decorator_utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index 0260a6ed361f92..77123a62d4e27d 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -323,11 +323,6 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: class ForbidKeywordsDecorator(DecoratorBase): """A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected""" - _site_format = ( - "https://www.paddlepaddle.org.cn/documentation/docs/en/develop/" - "guides/model_convert/convert_from_pytorch/api_difference/{url_suffix}.html" - ) - def __init__( self, illegal_keys: set[str], @@ -352,7 +347,9 @@ def __init__( self.illegal_keys = illegal_keys self.func_name = func_name self.correct_name = correct_name - self.url_suffix = url_suffix + self.warn_msg = None + if url_suffix: + self.warn_msg = f"\nNon compatible API. Please refer to https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/model_convert/convert_from_pytorch/api_difference/{url_suffix}.html first." def process( self, args: tuple[Any, ...], kwargs: dict[str, Any] @@ -368,10 +365,9 @@ def process( f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. " f"\nDid you mean to use {self.correct_name}() instead?" ) - if self.url_suffix: + if self.warn_msg is not None: warnings.warn( - f"\nThis is a non compatible API. Please refer to {self._site_format.format(url_suffix=self.url_suffix)} first." - f"\nA compatible version of this API: `{self.correct_name}` can be also used, make sure the correct API is called.", + self.warn_msg, category=Warning, ) return args, kwargs