From 8c46b756501e1a7b1626ddc2eec1482423afab51 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Tue, 9 Jul 2024 05:44:29 -0700 Subject: [PATCH] Fix doctest normalize_by_update_norm PiperOrigin-RevId: 650594334 --- optax/_src/transform.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 8b85d9405..be09b6539 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1504,15 +1504,14 @@ def normalize_by_update_norm( >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.normalize_by_update_norm(scale_factor=-1.0) >>> params = jnp.array([1., 2., 3.]) - >>> print('Objective function: ', f(params)) - Objective function: 14.0 + >>> print('Objective function:', f(params)) + Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) - Objective function: 14.0 Objective function: 7.52E+00 Objective function: 3.03E+00 Objective function: 5.50E-01