Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-44] Add type annotations for dygraph base #66006

Merged
merged 1 commit into from
Jul 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 73 additions & 41 deletions python/paddle/base/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
import inspect
import sys
import warnings
from typing import Callable, ContextManager, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Sequence,
TypeVar,
overload,
)

import decorator
from typing_extensions import ParamSpec
Expand All @@ -31,6 +39,16 @@
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
from .tracer import Tracer

if TYPE_CHECKING:
from collections import OrderedDict
from types import TracebackType
from typing import Generator

from typing_extensions import Self

from paddle import Tensor
from paddle._typing import PlaceLike

__all__ = []

_InputT = ParamSpec("_InputT")
Expand All @@ -39,7 +57,7 @@
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"


def in_to_static_mode():
def in_to_static_mode() -> bool:
"""
Return a bool value that indicates whether running code under `@to_static`

Expand Down Expand Up @@ -82,7 +100,9 @@ def __impl__(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:


@signature_safe_contextmanager
def _to_static_mode_guard_(is_to_static=True):
def _to_static_mode_guard_(
is_to_static: bool = True,
) -> Generator[None, None, None]:
global global_var
original_val = global_var._in_to_static_mode_
global_var._in_to_static_mode_ = is_to_static
Expand All @@ -93,7 +113,9 @@ def _to_static_mode_guard_(is_to_static=True):


@signature_safe_contextmanager
def param_guard(parameters):
def param_guard(
parameters: OrderedDict[str, Tensor]
) -> Generator[None, None, None]:
# Note: parameters is a reference of self._parameters or self._buffers
if in_to_static_mode() and not paddle.in_dynamic_mode() and parameters:
try:
Expand Down Expand Up @@ -155,7 +177,7 @@ def _convert_into_variable(tensor):
return tensor


def enabled():
def enabled() -> bool:
"""
This function checks whether the program runs in dynamic graph mode or not.
You can enable dynamic graph mode with :ref:`api_paddle_disable_static` api,
Expand Down Expand Up @@ -184,7 +206,7 @@ def enabled():
return framework.in_dygraph_mode()


def enable_dygraph(place=None):
def enable_dygraph(place: PlaceLike | None = None) -> None:
"""

.. note::
Expand Down Expand Up @@ -227,7 +249,7 @@ def enable_dygraph(place=None):
CleanupFuncRegistrar.register(disable_dygraph)


def disable_dygraph():
def disable_dygraph() -> None:
"""

.. note::
Expand Down Expand Up @@ -261,7 +283,9 @@ def disable_dygraph():


@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
def _switch_tracer_mode_guard_(
is_train: bool = True,
) -> Generator[None, None, None]:
tracer = framework._dygraph_tracer()
if tracer:
has_grad = tracer._has_grad
Expand Down Expand Up @@ -354,7 +378,9 @@ def __impl__(
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""

def __call__(self, func):
def __call__(
self, func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
@decorator.decorator
def _decorate_function(func, *args, **kwargs):
with self:
Expand All @@ -371,18 +397,23 @@ def _decorate_generator(func, *args, **kwargs):
else:
return _decorate_function(func)

def __enter__(self):
def __enter__(self) -> Any:
raise NotImplementedError

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
raise NotImplementedError

def clone(self):
def clone(self) -> Self:
# override this method if your children class takes __init__ parameters
return self.__class__()


def is_grad_enabled():
def is_grad_enabled() -> bool:
"""
Returns whether current dygraph gradient calculation mode is enabled.

Expand Down Expand Up @@ -410,7 +441,7 @@ def is_grad_enabled():
return tracer._has_grad if tracer else False


def _set_grad_enabled(mode):
def _set_grad_enabled(mode: bool) -> None:
tracer = framework._dygraph_tracer()
if tracer:
tracer._has_grad = mode
Expand Down Expand Up @@ -448,18 +479,18 @@ class set_grad_enabled(_DecoratorContextManager):
True
"""

def __init__(self, mode):
def __init__(self, mode) -> None:
self.prev = is_grad_enabled()
_set_grad_enabled(mode)
self.mode = mode

def __enter__(self):
def __enter__(self) -> None:
...

def __exit__(self, *args):
def __exit__(self, *args: object) -> None:
_set_grad_enabled(self.prev)

def clone(self):
def clone(self) -> Self:
return self.__class__(self.mode)


Expand Down Expand Up @@ -511,11 +542,11 @@ class no_grad_(_DecoratorContextManager):
>>> test_layer()
"""

def __enter__(self):
def __enter__(self) -> None:
self.prev = is_grad_enabled()
_set_grad_enabled(False)

def __exit__(self, *args):
def __exit__(self, *args: object) -> None:
_set_grad_enabled(self.prev)


Expand Down Expand Up @@ -559,16 +590,16 @@ class enable_grad(_DecoratorContextManager):
>>> assert(z.stop_gradient == False)
"""

def __enter__(self):
def __enter__(self) -> None:
self.prev = is_grad_enabled()
_set_grad_enabled(True)

def __exit__(self, *args):
def __exit__(self, *args: object) -> None:
_set_grad_enabled(self.prev)


@signature_safe_contextmanager
def guard(place=None):
def guard(place: PlaceLike | None = None) -> Generator[None, None, None]:
"""
:api_attr: imperative

Expand Down Expand Up @@ -617,36 +648,36 @@ def guard(place=None):

@framework.non_static_only
def grad(
outputs,
inputs,
grad_outputs=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False,
no_grad_vars=None,
):
outputs: Tensor | Sequence[Tensor],
inputs: Tensor | Sequence[Tensor],
grad_outputs: Tensor | Sequence[Tensor | None] | None = None,
retain_graph: bool | None = None,
create_graph: bool = False,
only_inputs: bool = True,
allow_unused: bool = False,
no_grad_vars: Tensor | Sequence[Tensor] | set[Tensor] | None = None,
) -> list[Tensor]:
'''
.. note::
**This API is ONLY available in imperative mode.**

This API computes the sum of gradients of `outputs` with respect to each `inputs` .

Parameters:
outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or
outputs (Tensor|list[Tensor]|tuple[Tensor]): the output Tensor or
Tensor list/tuple of the graph to compute gradients.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
inputs (Tensor|list[Tensor]|tuple[Tensor]): the input Tensor or
Tensor list/tuple of the graph to compute gradients. The returned
values of this API are the gradients of `inputs` .
grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional):
grad_outputs (Tensor|list[Tensor|None]|tuple[Tensor|None], optional):
initial gradient values of `outputs` . If `grad_outputs` is None,
the initial gradient values of `outputs` would be Tensors filled with 1;
if `grad_outputs` is not None, it must have the same length as `outputs` ,
and in this case, the initial gradient value of the i-th `outputs` would
be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs`
is None; (2) the i-th element of `grad_outputs` when the i-th element of
`grad_outputs` is a Tensor. Default None.
retain_graph (bool, optional): whether to retain the forward graph which
retain_graph (bool|None, optional): whether to retain the forward graph which
is used to calculate the gradient. When it is True, the graph would
be retained, in which way users can calculate backward twice for the
same graph. When it is False, the graph would be freed. Default None,
Expand All @@ -666,7 +697,7 @@ def grad(
`inputs` are unreachable in the graph (i.e., their gradients are None),
error would be raised if allow_unused=False, or None would be returned as
their gradients if allow_unused=True. Default False.
no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional):
no_grad_vars (Tensor|list[Tensor]|tuple[Tensor]|set[Tensor], optional):
the Tensors whose gradients are not needed to compute. Default None.

Returns:
Expand All @@ -687,10 +718,11 @@ def grad(
...
... # Since y = x * x, dx = 2 * x
... dx = paddle.grad(
... outputs=[y],
... inputs=[x],
... create_graph=create_graph,
... retain_graph=True)[0]
... outputs=[y],
... inputs=[x],
... create_graph=create_graph,
... retain_graph=True
... )[0]
...
... z = y + dx
...
Expand Down