diff --git a/docs/notebooks/annotated_mnist.ipynb b/docs/notebooks/annotated_mnist.ipynb index 80122d8474..32645bc7c1 100644 --- a/docs/notebooks/annotated_mnist.ipynb +++ b/docs/notebooks/annotated_mnist.ipynb @@ -71,11 +71,11 @@ "## 2. Define network\n", "\n", "Create a convolutional neural network with the Linen API by subclassing\n", - "[`Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).\n", + "[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).\n", "Because the architecture in this example is relatively simple—you're just\n", "stacking layers—you can define the inlined submodules directly within the\n", "`__call__` method and wrap it with the\n", - "[`@compact`](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)\n", + "[@compact](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)\n", "decorator." ] }, @@ -115,10 +115,10 @@ "## 3. Define loss\n", "\n", "Define a cross-entropy loss function using just\n", - "[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html)\n", + "[jax.numpy](https://jax.readthedocs.io/en/latest/jax.numpy.html)\n", "that takes the model's logits and label vectors and returns a scalar loss. The\n", "labels can be one-hot encoded with\n", - "[`jax.nn.one_hot`](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html),\n", + "[jax.nn.one_hot](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html),\n", "as demonstrated below.\n", "\n", "Note that for demonstration purposes, we return `nn.log_softmax()` from\n", @@ -249,17 +249,17 @@ "\n", "- Evaluates the neural network given the parameters and a batch of input images\n", " with the\n", - " [`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n", + " [Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)\n", " method.\n", "- Computes the `cross_entropy_loss` loss function.\n", "- Evaluates the loss function and its gradient using\n", - " [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).\n", + " [jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).\n", "- Applies a\n", " [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n", " of gradients to the optimizer to update the model's parameters.\n", "- Computes the metrics using `compute_metrics` (defined earlier).\n", "\n", - "Use JAX's [`@jit`](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n", + "Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n", "decorator to trace the entire `train_step` function and just-in-time compile\n", "it with [XLA](https://www.tensorflow.org/xla) into fused device operations\n", "that run faster and more efficiently on hardware accelerators." @@ -296,7 +296,7 @@ "## 8. Evaluation step\n", "\n", "Create a function that evaluates your model on the test set with\n", - "[`Module.apply`](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)" + "[Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)" ] }, { @@ -324,7 +324,7 @@ "Define a training function that:\n", "\n", "- Shuffles the training data before each epoch using\n", - " [`jax.random.permutation`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html)\n", + " [jax.random.permutation](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html)\n", " that takes a PRNGKey as a parameter (check the\n", " [JAX - the sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG)).\n", "- Runs an optimization step for each batch.\n",