Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes RST link formatting. #1621

Merged
merged 3 commits into from
Jan 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions docs/notebooks/annotated_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@
"## 2. Define network\n",
"\n",
"Create a convolutional neural network with the Linen API by subclassing\n",
"[`Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).\n",
"[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).\n",
"Because the architecture in this example is relatively simple—you're just\n",
"stacking layers—you can define the inlined submodules directly within the\n",
"`__call__` method and wrap it with the\n",
"[`@compact`](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)\n",
"[@compact](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)\n",
"decorator."
]
},
Expand Down Expand Up @@ -115,10 +115,10 @@
"## 3. Define loss\n",
"\n",
"Define a cross-entropy loss function using just\n",
"[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html)\n",
"[jax.numpy](https://jax.readthedocs.io/en/latest/jax.numpy.html)\n",
"that takes the model's logits and label vectors and returns a scalar loss. The\n",
"labels can be one-hot encoded with\n",
"[`jax.nn.one_hot`](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html),\n",
"[jax.nn.one_hot](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html),\n",
"as demonstrated below.\n",
"\n",
"Note that for demonstration purposes, we return `nn.log_softmax()` from\n",
Expand Down Expand Up @@ -249,17 +249,17 @@
"\n",
"- Evaluates the neural network given the parameters and a batch of input images\n",
" with the\n",
" [`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n",
" [Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n",
" method.\n",
"- Computes the `cross_entropy_loss` loss function.\n",
"- Evaluates the loss function and its gradient using\n",
" [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).\n",
" [jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).\n",
"- Applies a\n",
" [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n",
" of gradients to the optimizer to update the model's parameters.\n",
"- Computes the metrics using `compute_metrics` (defined earlier).\n",
"\n",
"Use JAX's [`@jit`](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n",
"Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n",
"decorator to trace the entire `train_step` function and just-in-time compile\n",
"it with [XLA](https://www.tensorflow.org/xla) into fused device operations\n",
"that run faster and more efficiently on hardware accelerators."
Expand Down Expand Up @@ -296,7 +296,7 @@
"## 8. Evaluation step\n",
"\n",
"Create a function that evaluates your model on the test set with\n",
"[`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)"
"[Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)"
]
},
{
Expand Down Expand Up @@ -324,7 +324,7 @@
"Define a training function that:\n",
"\n",
"- Shuffles the training data before each epoch using\n",
" [`jax.random.permutation`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html)\n",
" [jax.random.permutation](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html)\n",
" that takes a PRNGKey as a parameter (check the\n",
" [JAX - the sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG)).\n",
"- Runs an optimization step for each batch.\n",
Expand Down