diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index fe84fc0d7f0a..30c626bec4e3 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params: # and then use `updates` instead of `grad` to actually update the params. # (And we'd include `new_optimizer_state` in the output, naturally.) - new_params = jax.tree_map( + new_params = jax.tree.map( lambda param, g: param - g * LEARNING_RATE, params, grad) return new_params