Skip to content

Commit

Permalink
Update getting_started.md
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp authored Oct 28, 2022
1 parent 8e7e190 commit ed53891
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ A function that:
[Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)
method.
- Computes the `cross_entropy_loss` loss function.
- Evaluates the loss function and its gradient using
[jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).
- Evaluates the gradient of the loss function using
[jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).
- Applies a
[pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)
of gradients to the optimizer to update the model's parameters.
Expand All @@ -220,8 +220,8 @@ def train_step(state, batch):
logits = CNN().apply({'params': params}, batch['image'])
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=batch['label'])
return state, metrics
Expand Down

0 comments on commit ed53891

Please sign in to comment.