1818from collections .abc import Iterable
1919from typing import Any , Callable , TypeVar , cast
2020
21- _F = TypeVar ("_F" , bound = Callable [..., Any ])
21+ from typing_extensions import ParamSpec
22+
23+ _InputT = ParamSpec ("_InputT" )
24+ _RetT = TypeVar ("_RetT" )
2225
2326
2427class DecoratorBase :
@@ -31,17 +34,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3134 self .args = args
3235 self .kwargs = kwargs
3336
34- def __call__ (self , func : _F ) -> _F :
37+ def __call__ (
38+ self , func : Callable [_InputT , _RetT ]
39+ ) -> Callable [_InputT , _RetT ]:
3540 """As an entry point for decorative applications"""
3641
3742 @functools .wraps (func )
38- def wrapper (* args , ** kwargs ) :
43+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
3944 # Pretreatment parameters
4045 processed_args , processed_kwargs = self .process (args , kwargs )
4146 return func (* processed_args , ** processed_kwargs )
4247
4348 wrapper .__signature__ = inspect .signature (func )
44- return cast ("_F " , wrapper )
49+ return cast ("Callable[_InputT, _RetT] " , wrapper )
4550
4651 def process (
4752 self , args : tuple [Any , ...], kwargs : dict [str , Any ]
@@ -151,9 +156,9 @@ def process(
151156
152157
153158def param_one_alias (alias_list ):
154- def decorator (func ) :
159+ def decorator (func : Callable [ _InputT , _RetT ]) -> Callable [ _InputT , _RetT ] :
155160 @functools .wraps (func )
156- def wrapper (* args , ** kwargs ) :
161+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
157162 if not kwargs :
158163 return func (* args , ** kwargs )
159164 if (alias_list [0 ] not in kwargs ) and (alias_list [1 ] in kwargs ):
@@ -167,9 +172,9 @@ def wrapper(*args, **kwargs):
167172
168173
169174def param_two_alias (alias_list1 , alias_list2 ):
170- def decorator (func ) :
175+ def decorator (func : Callable [ _InputT , _RetT ]) -> Callable [ _InputT , _RetT ] :
171176 @functools .wraps (func )
172- def wrapper (* args , ** kwargs ) :
177+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
173178 if not kwargs :
174179 return func (* args , ** kwargs )
175180 if (alias_list1 [0 ] not in kwargs ) and (alias_list1 [1 ] in kwargs ):
@@ -185,9 +190,9 @@ def wrapper(*args, **kwargs):
185190
186191
187192def param_two_alias_one_default (alias_list1 , alias_list2 , default_param ):
188- def decorator (func ) :
193+ def decorator (func : Callable [ _InputT , _RetT ]) -> Callable [ _InputT , _RetT ] :
189194 @functools .wraps (func )
190- def wrapper (* args , ** kwargs ) :
195+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
191196 if not kwargs :
192197 return func (* args , ** kwargs )
193198
@@ -253,9 +258,9 @@ def process(
253258
254259
255260def view_decorator ():
256- def decorator (func ) :
261+ def decorator (func : Callable [ _InputT , _RetT ]) -> Callable [ _InputT , _RetT ] :
257262 @functools .wraps (func )
258- def wrapper (* args , ** kwargs ) :
263+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
259264 if ("dtype" in kwargs ) and ("shape_or_dtype" not in kwargs ):
260265 kwargs ["shape_or_dtype" ] = kwargs .pop ("dtype" )
261266 elif ("size" in kwargs ) and ("shape_or_dtype" not in kwargs ):
@@ -310,9 +315,9 @@ def reshape_decorator():
310315 tensor_x.reshape(-1, 1, 3) -> paddle.reshape(tensor_x, -1, 1, 3])
311316 """
312317
313- def decorator (func ) :
318+ def decorator (func : Callable [ _InputT , _RetT ]) -> Callable [ _InputT , _RetT ] :
314319 @functools .wraps (func )
315- def wrapper (* args , ** kwargs ) :
320+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
316321 if ("input" in kwargs ) and ("x" not in kwargs ):
317322 kwargs ["x" ] = kwargs .pop ("input" )
318323 elif len (args ) >= 2 and type (args [1 ]) is int :
0 commit comments