diff --git a/python/paddle/base/dygraph/base.py b/python/paddle/base/dygraph/base.py index a6443248832446..c9d076b9df2643 100644 --- a/python/paddle/base/dygraph/base.py +++ b/python/paddle/base/dygraph/base.py @@ -17,7 +17,15 @@ import inspect import sys import warnings -from typing import Callable, ContextManager, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Sequence, + TypeVar, + overload, +) import decorator from typing_extensions import ParamSpec @@ -31,6 +39,16 @@ from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from .tracer import Tracer +if TYPE_CHECKING: + from collections import OrderedDict + from types import TracebackType + from typing import Generator + + from typing_extensions import Self + + from paddle import Tensor + from paddle._typing import PlaceLike + __all__ = [] _InputT = ParamSpec("_InputT") @@ -39,7 +57,7 @@ NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable" -def in_to_static_mode(): +def in_to_static_mode() -> bool: """ Return a bool value that indicates whether running code under `@to_static` @@ -82,7 +100,9 @@ def __impl__(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: @signature_safe_contextmanager -def _to_static_mode_guard_(is_to_static=True): +def _to_static_mode_guard_( + is_to_static: bool = True, +) -> Generator[None, None, None]: global global_var original_val = global_var._in_to_static_mode_ global_var._in_to_static_mode_ = is_to_static @@ -93,7 +113,9 @@ def _to_static_mode_guard_(is_to_static=True): @signature_safe_contextmanager -def param_guard(parameters): +def param_guard( + parameters: OrderedDict[str, Tensor] +) -> Generator[None, None, None]: # Note: parameters is a reference of self._parameters or self._buffers if in_to_static_mode() and not paddle.in_dynamic_mode() and parameters: try: @@ -155,7 +177,7 @@ def _convert_into_variable(tensor): return tensor -def enabled(): +def enabled() -> bool: """ This function checks whether the program runs in dynamic graph mode or not. You can enable dynamic graph mode with :ref:`api_paddle_disable_static` api, @@ -184,7 +206,7 @@ def enabled(): return framework.in_dygraph_mode() -def enable_dygraph(place=None): +def enable_dygraph(place: PlaceLike | None = None) -> None: """ .. note:: @@ -227,7 +249,7 @@ def enable_dygraph(place=None): CleanupFuncRegistrar.register(disable_dygraph) -def disable_dygraph(): +def disable_dygraph() -> None: """ .. note:: @@ -261,7 +283,9 @@ def disable_dygraph(): @signature_safe_contextmanager -def _switch_tracer_mode_guard_(is_train=True): +def _switch_tracer_mode_guard_( + is_train: bool = True, +) -> Generator[None, None, None]: tracer = framework._dygraph_tracer() if tracer: has_grad = tracer._has_grad @@ -354,7 +378,9 @@ def __impl__( class _DecoratorContextManager: """Allow a context manager to be used as a decorator""" - def __call__(self, func): + def __call__( + self, func: Callable[_InputT, _RetT] + ) -> Callable[_InputT, _RetT]: @decorator.decorator def _decorate_function(func, *args, **kwargs): with self: @@ -371,18 +397,23 @@ def _decorate_generator(func, *args, **kwargs): else: return _decorate_function(func) - def __enter__(self): + def __enter__(self) -> Any: raise NotImplementedError - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: raise NotImplementedError - def clone(self): + def clone(self) -> Self: # override this method if your children class takes __init__ parameters return self.__class__() -def is_grad_enabled(): +def is_grad_enabled() -> bool: """ Returns whether current dygraph gradient calculation mode is enabled. @@ -410,7 +441,7 @@ def is_grad_enabled(): return tracer._has_grad if tracer else False -def _set_grad_enabled(mode): +def _set_grad_enabled(mode: bool) -> None: tracer = framework._dygraph_tracer() if tracer: tracer._has_grad = mode @@ -448,18 +479,18 @@ class set_grad_enabled(_DecoratorContextManager): True """ - def __init__(self, mode): + def __init__(self, mode) -> None: self.prev = is_grad_enabled() _set_grad_enabled(mode) self.mode = mode - def __enter__(self): + def __enter__(self) -> None: ... - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: _set_grad_enabled(self.prev) - def clone(self): + def clone(self) -> Self: return self.__class__(self.mode) @@ -511,11 +542,11 @@ class no_grad_(_DecoratorContextManager): >>> test_layer() """ - def __enter__(self): + def __enter__(self) -> None: self.prev = is_grad_enabled() _set_grad_enabled(False) - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: _set_grad_enabled(self.prev) @@ -559,16 +590,16 @@ class enable_grad(_DecoratorContextManager): >>> assert(z.stop_gradient == False) """ - def __enter__(self): + def __enter__(self) -> None: self.prev = is_grad_enabled() _set_grad_enabled(True) - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: _set_grad_enabled(self.prev) @signature_safe_contextmanager -def guard(place=None): +def guard(place: PlaceLike | None = None) -> Generator[None, None, None]: """ :api_attr: imperative @@ -617,15 +648,15 @@ def guard(place=None): @framework.non_static_only def grad( - outputs, - inputs, - grad_outputs=None, - retain_graph=None, - create_graph=False, - only_inputs=True, - allow_unused=False, - no_grad_vars=None, -): + outputs: Tensor | Sequence[Tensor], + inputs: Tensor | Sequence[Tensor], + grad_outputs: Tensor | Sequence[Tensor | None] | None = None, + retain_graph: bool | None = None, + create_graph: bool = False, + only_inputs: bool = True, + allow_unused: bool = False, + no_grad_vars: Tensor | Sequence[Tensor] | set[Tensor] | None = None, +) -> list[Tensor]: ''' .. note:: **This API is ONLY available in imperative mode.** @@ -633,12 +664,12 @@ def grad( This API computes the sum of gradients of `outputs` with respect to each `inputs` . Parameters: - outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or + outputs (Tensor|list[Tensor]|tuple[Tensor]): the output Tensor or Tensor list/tuple of the graph to compute gradients. - inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or + inputs (Tensor|list[Tensor]|tuple[Tensor]): the input Tensor or Tensor list/tuple of the graph to compute gradients. The returned values of this API are the gradients of `inputs` . - grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional): + grad_outputs (Tensor|list[Tensor|None]|tuple[Tensor|None], optional): initial gradient values of `outputs` . If `grad_outputs` is None, the initial gradient values of `outputs` would be Tensors filled with 1; if `grad_outputs` is not None, it must have the same length as `outputs` , @@ -646,7 +677,7 @@ def grad( be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs` is None; (2) the i-th element of `grad_outputs` when the i-th element of `grad_outputs` is a Tensor. Default None. - retain_graph (bool, optional): whether to retain the forward graph which + retain_graph (bool|None, optional): whether to retain the forward graph which is used to calculate the gradient. When it is True, the graph would be retained, in which way users can calculate backward twice for the same graph. When it is False, the graph would be freed. Default None, @@ -666,7 +697,7 @@ def grad( `inputs` are unreachable in the graph (i.e., their gradients are None), error would be raised if allow_unused=False, or None would be returned as their gradients if allow_unused=True. Default False. - no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional): + no_grad_vars (Tensor|list[Tensor]|tuple[Tensor]|set[Tensor], optional): the Tensors whose gradients are not needed to compute. Default None. Returns: @@ -687,10 +718,11 @@ def grad( ... ... # Since y = x * x, dx = 2 * x ... dx = paddle.grad( - ... outputs=[y], - ... inputs=[x], - ... create_graph=create_graph, - ... retain_graph=True)[0] + ... outputs=[y], + ... inputs=[x], + ... create_graph=create_graph, + ... retain_graph=True + ... )[0] ... ... z = y + dx ...