diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index e053e5e88..dbee4ed31 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -276,8 +276,8 @@ " [Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n", " method.\n", "- Computes the `cross_entropy_loss` loss function.\n", - "- Evaluates the loss function and its gradient using\n", - " [jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).\n", + "- Evaluates the gradient of the loss function using\n", + " [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).\n", "- Applies a\n", " [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n", " of gradients to the optimizer to update the model's parameters.\n", @@ -304,8 +304,8 @@ " logits = CNN().apply({'params': params}, batch['image'])\n", " loss = cross_entropy_loss(logits=logits, labels=batch['label'])\n", " return loss, logits\n", - " grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n", - " (_, logits), grads = grad_fn(state.params)\n", + " grad_fn = jax.grad(loss_fn, has_aux=True)\n", + " grads, logits = grad_fn(state.params)\n", " state = state.apply_gradients(grads=grads)\n", " metrics = compute_metrics(logits=logits, labels=batch['label'])\n", " return state, metrics" diff --git a/docs/getting_started.md b/docs/getting_started.md index fccb06fd9..8d0163d1b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -198,8 +198,8 @@ A function that: [Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply) method. - Computes the `cross_entropy_loss` loss function. -- Evaluates the loss function and its gradient using - [jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad). +- Evaluates the gradient of the loss function using + [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad). - Applies a [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions) of gradients to the optimizer to update the model's parameters. @@ -220,8 +220,8 @@ def train_step(state, batch): logits = CNN().apply({'params': params}, batch['image']) loss = cross_entropy_loss(logits=logits, labels=batch['label']) return loss, logits - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, logits), grads = grad_fn(state.params) + grad_fn = jax.grad(loss_fn, has_aux=True) + grads, logits = grad_fn(state.params) state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits=logits, labels=batch['label']) return state, metrics