Skip to content

Commit

Permalink
Fix doctest normalize_by_update_norm
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650580171
  • Loading branch information
vroulet authored and OptaxDev committed Jul 9, 2024
1 parent 5b55ccc commit bdb6050
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,15 +1504,14 @@ def normalize_by_update_norm(
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.normalize_by_update_norm(scale_factor=-1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> print('Objective function:', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 14.0
Objective function: 7.52E+00
Objective function: 3.03E+00
Objective function: 5.50E-01
Expand Down

0 comments on commit bdb6050

Please sign in to comment.