Skip to content

Commit

Permalink
Merge pull request #481 from brentyi:fix_global_norm_sig
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 507720585
  • Loading branch information
OptaxDev committed Feb 7, 2023
2 parents a77d69c + 3d0c422 commit 4919451
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from optax._src import numerics


def global_norm(updates: base.Updates) -> base.Updates:
def global_norm(updates: base.PyTree) -> chex.Array:
"""Compute the global norm across a nested structure of tensors."""
return jnp.sqrt(sum(
jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))
Expand Down

0 comments on commit 4919451

Please sign in to comment.