Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sidebar restructure and minor clean-up #136

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/JAX_Vision_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "8744b685-7ff5-429a-b610-940506455a54",
"metadata": {},
"source": [
"# Vision Transformer with JAX & FLAX\n",
"# Implement Vision Transformer (ViT) model from scratch\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_Vision_transformer.ipynb)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_Vision_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
pavithraes marked this conversation as resolved.
Show resolved Hide resolved
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Vision Transformer with JAX & FLAX
# Implement Vision Transformer (ViT) model from scratch

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_Vision_transformer.ipynb)

Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_basic_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"id": "072f8ce2-7014-4f04-83a6-96953e9c8a79",
"metadata": {},
"source": [
"# 1D Convnet for basic text classification\n",
"# Basic text classification with 1D CNN\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb)\n",
"\n",
"In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convnet to perform sentiment analysis using JAX. This tutorial is originally inspired by [\"Text classification from scratch with Keras\"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).\n",
"In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convolutional Neural Network to perform sentiment analysis using JAX. This tutorial is originally inspired by [\"Text classification from scratch with Keras\"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).\n",
"\n",
"We will use the IMDB movie review dataset to classify the review to \"positive\" and \"negative\" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set."
]
Expand Down
6 changes: 3 additions & 3 deletions docs/source/JAX_basic_text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# 1D Convnet for basic text classification
# Basic text classification with 1D CNN

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb)

In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convnet to perform sentiment analysis using JAX. This tutorial is originally inspired by ["Text classification from scratch with Keras"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).
In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convolutional Neural Network to perform sentiment analysis using JAX. This tutorial is originally inspired by ["Text classification from scratch with Keras"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).

We will use the IMDB movie review dataset to classify the review to "positive" and "negative" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_examples_image_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "343f53e8-f28e-4fed-a0eb-c5c76c73d5a7",
"metadata": {},
"source": [
"# UNETR for image segmentation\n",
"# Image segmentation with UNETR model\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_examples_image_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# UNETR for image segmentation
# Image segmentation with UNETR model

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_for_LLM_pretraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"id": "NIOXoY1xgiww"
},
"source": [
"# Pretraining an LLM using JAX\n",
"# Pre-training an LLM (miniGPT)\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_for_LLM_pretraining.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
---

+++ {"id": "NIOXoY1xgiww"}

# Pretraining an LLM using JAX
# Pre-training an LLM (miniGPT)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_image_captioning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "03487afe-bbca-420c-9b1f-28ea4506c250",
"metadata": {},
"source": [
"# Image Captioning with JAX & FLAX\n",
"# Image Captioning with Vision Transformer (ViT) model\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_image_captioning.ipynb)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_image_captioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Image Captioning with JAX & FLAX
# Image Captioning with Vision Transformer (ViT) model

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_image_captioning.ipynb)

Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_machine_translation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744",
"metadata": {},
"source": [
"# NLP: JAX Machine Translation\n",
"# Machine Translation with encoder-decoder transformer model\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)"
]
Expand All @@ -15,7 +15,7 @@
"id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b",
"metadata": {},
"source": [
"Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_transformer/, which is itself an adaptation from https://www.manning.com/books/deep-learning-with-python-second-edition\n",
"This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n",
"\n",
"We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation."
]
Expand Down
6 changes: 3 additions & 3 deletions docs/source/JAX_machine_translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# NLP: JAX Machine Translation
# Machine Translation with encoder-decoder transformer model

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)

+++

Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_transformer/, which is itself an adaptation from https://www.manning.com/books/deep-learning-with-python-second-edition
This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)

We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_time_series_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Time series classification with JAX\n",
"# Time series classification with CNN\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_time_series_classification.ipynb)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_time_series_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: jax-env
language: python
name: python3
---

# Time series classification with JAX
# Time series classification with CNN

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_time_series_classification.ipynb)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_visualizing_models_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# JAX and Tensorboard / NNX Display\n",
"# Visualize JAX model metrics with TensorBoard\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb)"
]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_visualizing_models_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# JAX and Tensorboard / NNX Display
# Visualize JAX model metrics with TensorBoard

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb)

Expand Down
17 changes: 8 additions & 9 deletions docs/source/contributing.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Developer docs

## Contributing to the JAX AI Stack Documentation
# Contribute to documentation

The documentation in the `jax-ai-stack` repository is meant to build on documentation
of individual packages, and specifically cover topics that touch on multiple packages
Expand Down Expand Up @@ -38,24 +36,25 @@ To add a new notebook to the repository, first move the notebook into the approp
location in the `docs` directory:

```bash
mv ~/new-tutorial.ipynb docs/new_tutorial.ipynb
mv ~/new-tutorial.ipynb docs/source/new_tutorial.ipynb
```

Next, we use `jupytext` to mark the notebook for syncing with Markdown:

```bash
jupytext --set-formats ipynb,md:myst docs/new_tutorial.ipynb
jupytext --set-formats ipynb,md:myst docs/source/new_tutorial.ipynb
```

Finally, we can sync the notebook and markdown source:

```bash
jupytext --sync docs/new_tutorial.ipynb
jupytext --sync docs/source/new_tutorial.ipynb
```

To ensure that the new notebook is rendered as part of the site, be sure to add
references to a `toctree` declaration somewhere in the source tree, for example
in `docs/tutorials.md`. You will also need to add references in `docs/conf.py`
in `docs/source/tutorials.md` or `docs/source/examples.md`.
You will also need to add references in `docs/conf.py`
to specify whether the notebook should be executed, and to specify which file
sphinx should use when generating the site.

Expand All @@ -70,9 +69,9 @@ you can do the following:

```bash
pip install pre-commit
git add docs/new_tutorial.* # stage the new changes
git add docs/source/new_tutorial.* # stage the new changes
pre-commit run # run pre-commit checks on added files
git add docs/new_tutorial.* # stage the files updated by pre-commit
git add docs/source/new_tutorial.* # stage the files updated by pre-commit
git commit -m "update new tutorial" # commit to the branch
```

Expand Down
10 changes: 10 additions & 0 deletions docs/source/data_loaders.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Introduction to Data Loaders

Learn about data loading strategies on different hardware systems:

```{toctree}
:maxdepth: 1

data_loaders_on_cpu_with_jax
data_loaders_on_gpu_with_jax
```
6 changes: 3 additions & 3 deletions docs/source/digits_diffusion_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
"id": "Kzqlx7fpXRnJ"
},
"source": [
"# Image generation in JAX: a simple diffusion model\n",
"# Simple diffusion model for image generation in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n",
"\n",
"In [Debugging in JAX: a Variational autoencoder (VAE) model](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes:\n",
"- preparing the dataset\n",
"In [Variational autoencoder (VAE) and debugging in JAX](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial, you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes:\n",
"- Preparing the dataset\n",
"- Developing the custom diffusion model\n",
"- Creating the loss and training functions\n",
"- Perform the model training using Colab TPU v2 as a hardware accelerator\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
---

+++ {"id": "Kzqlx7fpXRnJ"}

# Image generation in JAX: a simple diffusion model
# Simple diffusion model for image generation in JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)

In [Debugging in JAX: a Variational autoencoder (VAE) model](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes:
- preparing the dataset
In [Variational autoencoder (VAE) and debugging in JAX](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial, you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes:
- Preparing the dataset
- Developing the custom diffusion model
- Creating the loss and training functions
- Perform the model training using Colab TPU v2 as a hardware accelerator
Expand Down
16 changes: 16 additions & 0 deletions docs/source/examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Example applications

The following pages provide examples of common applications of the JAX AI stack:

```{toctree}
:maxdepth: 1

JAX_for_LLM_pretraining
JAX_basic_text_classification
JAX_transformer_text_classification
JAX_machine_translation
JAX_examples_image_segmentation
JAX_image_captioning
JAX_Vision_transformer
JAX_time_series_classification
```
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ Jax AI Stack

install
tutorials
examples
contributing
25 changes: 7 additions & 18 deletions docs/source/tutorials.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
# Tutorials

```{note}
This is a work in progress; visit again soon for updated content!
```

The following tutorials are meant as an intro to the full stack:
The following tutorials are meant as an introduction to the full stack:

```{toctree}
:maxdepth: 2
:maxdepth: 1

getting_started_with_jax_for_AI
digits_vae
JAX_for_PyTorch_users
JAX_porting_PyTorch_model
digits_diffusion_model
JAX_for_LLM_pretraining
JAX_basic_text_classification
JAX_examples_image_segmentation
JAX_Vision_transformer
JAX_machine_translation
JAX_visualizing_models_metrics
JAX_image_captioning
JAX_time_series_classification
JAX_transformer_text_classification
data_loaders_on_cpu_with_jax
data_loaders_on_gpu_with_jax
data_loaders
JAX_for_PyTorch_users
pavithraes marked this conversation as resolved.
Show resolved Hide resolved
JAX_porting_PyTorch_model
```

## Further references

Once you've gone through this content, you can refer to package-specific
documentation for resources that go into more depth on various topics:

Expand Down
Loading