From 21de8e84ade8b4492b37570da0454f9afb36f9fe Mon Sep 17 00:00:00 2001 From: zhengshengning Date: Fri, 15 Aug 2025 12:37:23 +0800 Subject: [PATCH] Fix the loss of conductive information in modifier types --- python/paddle/utils/decorator_utils.py | 33 +++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index 54e6654bf2a94d..dee8a38197d7bc 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -18,7 +18,10 @@ from collections.abc import Iterable from typing import Any, Callable, TypeVar, cast -_F = TypeVar("_F", bound=Callable[..., Any]) +from typing_extensions import ParamSpec + +_InputT = ParamSpec("_InputT") +_RetT = TypeVar("_RetT") class DecoratorBase: @@ -31,17 +34,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.args = args self.kwargs = kwargs - def __call__(self, func: _F) -> _F: + def __call__( + self, func: Callable[_InputT, _RetT] + ) -> Callable[_InputT, _RetT]: """As an entry point for decorative applications""" @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: # Pretreatment parameters processed_args, processed_kwargs = self.process(args, kwargs) return func(*processed_args, **processed_kwargs) wrapper.__signature__ = inspect.signature(func) - return cast("_F", wrapper) + return cast("Callable[_InputT, _RetT]", wrapper) def process( self, args: tuple[Any, ...], kwargs: dict[str, Any] @@ -151,9 +156,9 @@ def process( def param_one_alias(alias_list): - def decorator(func): + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: if not kwargs: return func(*args, **kwargs) if (alias_list[0] not in kwargs) and (alias_list[1] in kwargs): @@ -167,9 +172,9 @@ def wrapper(*args, **kwargs): def param_two_alias(alias_list1, alias_list2): - def decorator(func): + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: if not kwargs: return func(*args, **kwargs) if (alias_list1[0] not in kwargs) and (alias_list1[1] in kwargs): @@ -185,9 +190,9 @@ def wrapper(*args, **kwargs): def param_two_alias_one_default(alias_list1, alias_list2, default_param): - def decorator(func): + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: if not kwargs: return func(*args, **kwargs) @@ -253,9 +258,9 @@ def process( def view_decorator(): - def decorator(func): + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: if ("dtype" in kwargs) and ("shape_or_dtype" not in kwargs): kwargs["shape_or_dtype"] = kwargs.pop("dtype") elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs): @@ -282,9 +287,9 @@ def reshape_decorator(): tensor_x.reshape(-1, 1, 3) -> paddle.reshape(tensor_x, -1, 1, 3]) """ - def decorator(func): + def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: if ("input" in kwargs) and ("x" not in kwargs): kwargs["x"] = kwargs.pop("input") elif len(args) >= 2 and type(args[1]) is int: