Skip to content

Commit

Permalink
[Typing][B-85] Add type annotations for `python/paddle/nn/utils/clip_…
Browse files Browse the repository at this point in the history
…grad_norm_.py` (#65808)
  • Loading branch information
megemini authored Jul 9, 2024
1 parent 4867094 commit 3998e3a
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions python/paddle/nn/utils/clip_grad_norm_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable

import paddle

if TYPE_CHECKING:
from paddle import Tensor

__all__ = []


@paddle.autograd.no_grad()
def clip_grad_norm_(
parameters,
max_norm,
norm_type=2.0,
error_if_nonfinite=False,
):
parameters: Iterable[Tensor] | Tensor,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
) -> Tensor:
r"""Clips gradient norm of the iteratable parameters.
Norms are calculated together on all gradients, just as they are
Expand Down

0 comments on commit 3998e3a

Please sign in to comment.