Skip to content

Commit

Permalink
fix(clip): fix clip_grad_norm for empty tensor trees (#118)
Browse files Browse the repository at this point in the history
* chore: update README / cmake / conda / make files

* fix(clip): fix `clip_grad_norm` for empty tensor trees
  • Loading branch information
XuehaiPan committed Nov 22, 2022
1 parent 3e7c857 commit 927de85
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions torchopt/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# ==============================================================================
"""Utilities for gradient clipping."""

from typing import Union

import torch

from torchopt import pytree
Expand All @@ -30,12 +32,18 @@


def clip_grad_norm(
max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
) -> GradientTransformation:
"""Clips gradient norm of an iterable of parameters.
Args:
max_delta: The maximum absolute value for each element in the update.
max_norm (float or int): The maximum absolute value for each element in the update.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if :data:`True`, an error is thrown if the total norm of the
gradients from :attr:`updates` is ``nan``, ``inf``, or ``-inf``.
Returns:
An ``(init_fn, update_fn)`` tuple.
Expand All @@ -45,12 +53,9 @@ def init_fn(params): # pylint: disable=unused-argument
return ClipState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
available_updates = []
for g in updates:
if g is not None:
available_updates.append(g)
available_updates = pytree.tree_leaves(updates)
if len(available_updates) == 0:
return torch.tensor(0.0)
return updates, state
device = available_updates[0].device
with torch.no_grad():
if norm_type == torch.inf:
Expand Down

0 comments on commit 927de85

Please sign in to comment.