Skip to content

Commit

Permalink
[Typing][B-40] Add type annotations for `python/paddle/autograd/backw…
Browse files Browse the repository at this point in the history
…ard_mode.py` (PaddlePaddle#66277)
  • Loading branch information
tlxd authored and inaomIIsfarell committed Jul 31, 2024
1 parent e81b518 commit bb965fe
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions python/paddle/autograd/backward_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

import paddle
from paddle.base import core, framework
from paddle.base.backward import gradients_with_optimizer # noqa: F401

if TYPE_CHECKING:
from paddle import Tensor


__all__ = []


@framework.dygraph_only
def backward(tensors, grad_tensors=None, retain_graph=False):
def backward(
tensors: list[Tensor],
grad_tensors: list[Tensor | None] | None = None,
retain_graph: bool = False,
) -> None:
"""
Compute the backward gradients of given tensors.
Expand Down Expand Up @@ -80,19 +92,21 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
"""

def check_tensors(in_out_list, name):
def check_tensors(
in_out_list: Sequence[Tensor] | Tensor, name: str
) -> Sequence[Tensor]:
assert in_out_list is not None, f"{name} should not be None"

if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, f"{name} cannot be empty"
for each_var in in_out_list:
assert isinstance(
each_var, (paddle.Tensor, core.eager.Tensor)
each_var, paddle.Tensor
), f"Elements of {name} must be paddle.Tensor"
return in_out_list
else:
assert isinstance(
in_out_list, (paddle.Tensor, core.eager.Tensor)
in_out_list, paddle.Tensor
), f"{name} must be Tensor or list of Tensor"
return [in_out_list]

Expand All @@ -109,7 +123,7 @@ def check_tensors(in_out_list, name):
for each_tensor in grad_tensors:
if each_tensor is not None:
assert isinstance(
each_tensor, (paddle.Tensor, core.eager.Tensor)
each_tensor, paddle.Tensor
), "The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'."
else:
grad_tensors = []
Expand Down

0 comments on commit bb965fe

Please sign in to comment.