Skip to content

Commit

Permalink
Fix typos in VAE tutorial (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Oct 18, 2024
1 parent 04144e3 commit 7654af7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
12 changes: 7 additions & 5 deletions docs/digits_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"source": [
"# Debugging in JAX: a Variational autoencoder (VAE) model\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/digits_vae.ipynb)\n",
"\n",
"In [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) we built a simple neural network for classification of handwritten digits, and covered some of the key features of JAX, including its NumPy-style interface in the `jax.numpy`, as well as its transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`.\n",
"\n",
"This tutorial will explore a slightly more involved model: a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the same simple digits data. Along the way, we'll learn a bit more about how JAX's JIT compilation actually works, and what this means for debugging JAX programs."
Expand Down Expand Up @@ -110,7 +112,7 @@
"source": [
"In this network we had one output per class, and the loss function was designed such that once trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.\n",
"\n",
"In this VAE example we use similar building blocks to instead output a small probabilisitic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`:"
"In this VAE example we use similar building blocks to instead output a small probabilistic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`:"
]
},
{
Expand Down Expand Up @@ -180,7 +182,7 @@
},
"source": [
"Now the full VAE model is a single network built from the encoder and decoder.\n",
"It returns both the reconstructed image and then internal latent space model:"
"It returns both the reconstructed image and the internal latent space model:"
]
},
{
Expand Down Expand Up @@ -225,7 +227,7 @@
"1. the `logits` output faithfully reconstruct the input image.\n",
"2. the model represented by `mean` and `std` faithfully represents the \"true\" latent distribution.\n",
"\n",
"VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify theset two goals in a single loss value:"
"VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify these two goals in a single loss value:"
]
},
{
Expand Down Expand Up @@ -498,7 +500,7 @@
"\n",
"optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n",
"\n",
"for epoch in range(301):\n",
"for epoch in range(501):\n",
" loss = train_step(model, optimizer, images_train)\n",
" if epoch % 50 == 0:\n",
" print(f'Epoch {epoch} loss: {loss}')"
Expand All @@ -512,7 +514,7 @@
"source": [
"It looks like our loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.\n",
"\n",
"At this point, we may wish to expect the values within the loss function itself to see where the diverging loss might be coming from.\n",
"At this point, we may wish to inspect the values within the loss function itself to see where the diverging loss might be coming from.\n",
"In typical Python programs you can do this by inserting either a `print` statement or a `breakpoint` in the loss function; it might look something like this:"
]
},
Expand Down
12 changes: 7 additions & 5 deletions docs/digits_vae.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ kernelspec:

# Debugging in JAX: a Variational autoencoder (VAE) model

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/digits_vae.ipynb)

In [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) we built a simple neural network for classification of handwritten digits, and covered some of the key features of JAX, including its NumPy-style interface in the `jax.numpy`, as well as its transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`.

This tutorial will explore a slightly more involved model: a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the same simple digits data. Along the way, we'll learn a bit more about how JAX's JIT compilation actually works, and what this means for debugging JAX programs.
Expand Down Expand Up @@ -78,7 +80,7 @@ class SimpleNN(nnx.Module):

In this network we had one output per class, and the loss function was designed such that once trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.

In this VAE example we use similar building blocks to instead output a small probabilisitic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`:
In this VAE example we use similar building blocks to instead output a small probabilistic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`:

```{code-cell}
:id: Hj7mtR5vmcGr
Expand Down Expand Up @@ -128,7 +130,7 @@ class Decoder(nnx.Module):
+++ {"id": "0QaT-KY6npSc"}

Now the full VAE model is a single network built from the encoder and decoder.
It returns both the reconstructed image and then internal latent space model:
It returns both the reconstructed image and the internal latent space model:

```{code-cell}
:id: Myo2MdxXnzlT
Expand Down Expand Up @@ -163,7 +165,7 @@ Next is the loss function – there are two components to the model that we want
1. the `logits` output faithfully reconstruct the input image.
2. the model represented by `mean` and `std` faithfully represents the "true" latent distribution.

VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify theset two goals in a single loss value:
VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify these two goals in a single loss value:

```{code-cell}
:id: bMpxj8-Wsvui
Expand Down Expand Up @@ -272,7 +274,7 @@ model = VAE(
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for epoch in range(301):
for epoch in range(501):
loss = train_step(model, optimizer, images_train)
if epoch % 50 == 0:
print(f'Epoch {epoch} loss: {loss}')
Expand All @@ -282,7 +284,7 @@ for epoch in range(301):

It looks like our loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.

At this point, we may wish to expect the values within the loss function itself to see where the diverging loss might be coming from.
At this point, we may wish to inspect the values within the loss function itself to see where the diverging loss might be coming from.
In typical Python programs you can do this by inserting either a `print` statement or a `breakpoint` in the loss function; it might look something like this:

```{code-cell}
Expand Down

0 comments on commit 7654af7

Please sign in to comment.