Skip to content

Commit e4e9446

Browse files
[API compatibility][Typing] Fix the loss of conductive information in modifier types (#74629)
1 parent e2c10ca commit e4e9446

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

python/paddle/utils/decorator_utils.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from collections.abc import Iterable
1919
from typing import Any, Callable, TypeVar, cast
2020

21-
_F = TypeVar("_F", bound=Callable[..., Any])
21+
from typing_extensions import ParamSpec
22+
23+
_InputT = ParamSpec("_InputT")
24+
_RetT = TypeVar("_RetT")
2225

2326

2427
class DecoratorBase:
@@ -31,17 +34,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3134
self.args = args
3235
self.kwargs = kwargs
3336

34-
def __call__(self, func: _F) -> _F:
37+
def __call__(
38+
self, func: Callable[_InputT, _RetT]
39+
) -> Callable[_InputT, _RetT]:
3540
"""As an entry point for decorative applications"""
3641

3742
@functools.wraps(func)
38-
def wrapper(*args, **kwargs):
43+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
3944
# Pretreatment parameters
4045
processed_args, processed_kwargs = self.process(args, kwargs)
4146
return func(*processed_args, **processed_kwargs)
4247

4348
wrapper.__signature__ = inspect.signature(func)
44-
return cast("_F", wrapper)
49+
return cast("Callable[_InputT, _RetT]", wrapper)
4550

4651
def process(
4752
self, args: tuple[Any, ...], kwargs: dict[str, Any]
@@ -151,9 +156,9 @@ def process(
151156

152157

153158
def param_one_alias(alias_list):
154-
def decorator(func):
159+
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
155160
@functools.wraps(func)
156-
def wrapper(*args, **kwargs):
161+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
157162
if not kwargs:
158163
return func(*args, **kwargs)
159164
if (alias_list[0] not in kwargs) and (alias_list[1] in kwargs):
@@ -167,9 +172,9 @@ def wrapper(*args, **kwargs):
167172

168173

169174
def param_two_alias(alias_list1, alias_list2):
170-
def decorator(func):
175+
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
171176
@functools.wraps(func)
172-
def wrapper(*args, **kwargs):
177+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
173178
if not kwargs:
174179
return func(*args, **kwargs)
175180
if (alias_list1[0] not in kwargs) and (alias_list1[1] in kwargs):
@@ -185,9 +190,9 @@ def wrapper(*args, **kwargs):
185190

186191

187192
def param_two_alias_one_default(alias_list1, alias_list2, default_param):
188-
def decorator(func):
193+
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
189194
@functools.wraps(func)
190-
def wrapper(*args, **kwargs):
195+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
191196
if not kwargs:
192197
return func(*args, **kwargs)
193198

@@ -253,9 +258,9 @@ def process(
253258

254259

255260
def view_decorator():
256-
def decorator(func):
261+
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
257262
@functools.wraps(func)
258-
def wrapper(*args, **kwargs):
263+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
259264
if ("dtype" in kwargs) and ("shape_or_dtype" not in kwargs):
260265
kwargs["shape_or_dtype"] = kwargs.pop("dtype")
261266
elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs):
@@ -310,9 +315,9 @@ def reshape_decorator():
310315
tensor_x.reshape(-1, 1, 3) -> paddle.reshape(tensor_x, -1, 1, 3])
311316
"""
312317

313-
def decorator(func):
318+
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
314319
@functools.wraps(func)
315-
def wrapper(*args, **kwargs):
320+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
316321
if ("input" in kwargs) and ("x" not in kwargs):
317322
kwargs["x"] = kwargs.pop("input")
318323
elif len(args) >= 2 and type(args[1]) is int:

0 commit comments

Comments
 (0)