@@ -324,12 +324,32 @@ class ForbidKeywordsDecorator(DecoratorBase):
324324 """A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""
325325
326326 def __init__ (
327- self , illegal_keys : set [str ], func_name : str , correct_name : str
327+ self ,
328+ illegal_keys : set [str ],
329+ func_name : str ,
330+ correct_name : str ,
331+ url_suffix : str = "" ,
328332 ) -> None :
333+ """
334+ Args:
335+ illegal_keys (set[str]): the keywords to reject
336+ func_name (str): the name of the function being decorated (should incorporate module name, like paddle.nn.Unfold)
337+ correct_name (str): the user hint that points to the correct function
338+ url_suffix (str, optional): Only specified in non paddle.compat functions. If specified, the function being decorated
339+ will emit a warning upon the first call, warning the users about the API difference and points to Docs.
340+ Please correctly specifying the `url_suffix`, this should be the suffix of the api-difference doc. For example:
341+
342+ (prefix omitted)/docs/zh/develop/guides/model_convert/convert_from_pytorch/api_difference/**torch/torch.nn.Unfold**.html
343+
344+ In this example, the correct `url_suffix` should be 'torch/torch.nn.Unfold'. Defaults to an empty str.
345+ """
329346 super ().__init__ ()
330347 self .illegal_keys = illegal_keys
331348 self .func_name = func_name
332349 self .correct_name = correct_name
350+ self .warn_msg = None
351+ if url_suffix :
352+ self .warn_msg = f"\n Non compatible API. Please refer to https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/model_convert/convert_from_pytorch/api_difference/{ url_suffix } .html first."
333353
334354 def process (
335355 self , args : tuple [Any , ...], kwargs : dict [str , Any ]
@@ -345,6 +365,11 @@ def process(
345365 f"{ self .func_name } () received unexpected keyword argument{ plural } { keys_str } . "
346366 f"\n Did you mean to use { self .correct_name } () instead?"
347367 )
368+ if self .warn_msg is not None :
369+ warnings .warn (
370+ self .warn_msg ,
371+ category = Warning ,
372+ )
348373 return args , kwargs
349374
350375
0 commit comments