diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 8b85d9405..be09b6539 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1504,15 +1504,14 @@ def normalize_by_update_norm( >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.normalize_by_update_norm(scale_factor=-1.0) >>> params = jnp.array([1., 2., 3.]) - >>> print('Objective function: ', f(params)) - Objective function: 14.0 + >>> print('Objective function:', f(params)) + Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) - Objective function: 14.0 Objective function: 7.52E+00 Objective function: 3.03E+00 Objective function: 5.50E-01