Skip to content

Commit

Permalink
Add open in colab links to docs (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Dec 5, 2024
1 parent e1022e1 commit 2801700
Show file tree
Hide file tree
Showing 33 changed files with 59 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using many of the same open source packages that Google developers are using
in their everyday work.

To get started with the JAX AI stack, you can check out [Getting started with JAX](
https://github.com/jax-ml/jax-ai-stack/blob/main/docs/getting_started_with_jax_for_AI.ipynb).
https://github.com/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb).
This is still a work-in-progress, please check back for more documentation and tutorials
in the coming weeks!

Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_Vision_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"source": [
"# Vision Transformer with JAX & FLAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_Vision_transformer.ipynb)\n",
"\n",
"In this tutorial we implement from scratch the Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We load the ImageNet pretrained weights and finetune this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.\n",
"This tutorial is originally inspired by [HuggingFace Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification)."
]
Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_Vision_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ kernelspec:

# Vision Transformer with JAX & FLAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_Vision_transformer.ipynb)

In this tutorial we implement from scratch the Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We load the ImageNet pretrained weights and finetune this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
This tutorial is originally inspired by [HuggingFace Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification).

Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_basic_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"source": [
"# 1D Convnet for basic text classification\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb)\n",
"\n",
"In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convnet to perform sentiment analysis using JAX. This tutorial is originally inspired by [\"Text classification from scratch with Keras\"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).\n",
"\n",
"We will use the IMDB movie review dataset to classify the review to \"positive\" and \"negative\" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set."
Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_basic_text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ kernelspec:

# 1D Convnet for basic text classification

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb)

In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convnet to perform sentiment analysis using JAX. This tutorial is originally inspired by ["Text classification from scratch with Keras"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).

We will use the IMDB movie review dataset to classify the review to "positive" and "negative" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set.
Expand Down
1 change: 1 addition & 0 deletions docs/source/JAX_examples_image_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"source": [
"# UNETR for image segmentation\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)\n",
"\n",
"This tutorial demonstrates how to implement and train a model on image segmentation task. Below, we will be using the [Oxford Pets dataset](https://www.robots.ox.ac.uk/%7Evgg/data/pets/) containing images and masks of cats and dogs. We will implement from scratch the [UNETR](https://arxiv.org/abs/2103.10504) model using Flax NNX. We will train the model on a training set and compute image segmentation metrics on the training and validation sets. We will use [Orbax checkpoint manager](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_manager.html) to store best models during the training."
]
Expand Down
1 change: 1 addition & 0 deletions docs/source/JAX_examples_image_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ kernelspec:

# UNETR for image segmentation

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)

This tutorial demonstrates how to implement and train a model on image segmentation task. Below, we will be using the [Oxford Pets dataset](https://www.robots.ox.ac.uk/%7Evgg/data/pets/) containing images and masks of cats and dogs. We will implement from scratch the [UNETR](https://arxiv.org/abs/2103.10504) model using Flax NNX. We will train the model on a training set and compute image segmentation metrics on the training and validation sets. We will use [Orbax checkpoint manager](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_manager.html) to store best models during the training.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_for_LLM_pretraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"# Pretraining an LLM using JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/JAX_for_LLM_pretraining.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)\n",
"\n",
"This tutorial demonstrates how to use JAX/Flax for LLM pretraining via data and tensor parallelism. It is originally inspired by this [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_for_LLM_pretraining.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ kernelspec:

# Pretraining an LLM using JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/JAX_for_LLM_pretraining.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)

This tutorial demonstrates how to use JAX/Flax for LLM pretraining via data and tensor parallelism. It is originally inspired by this [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).

Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_for_PyTorch_users.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"source": [
"# JAX for PyTorch users\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_PyTorch_users.ipynb)\n",
"\n",
"This is a quick overview of JAX and the JAX AI stack written for those who are famiilar with PyTorch.\n",
"\n",
"First, we cover how to manipulate JAX Arrays following the [well-known PyTorch's tensors tutorial](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html). Next, we explore automatic differentiation with JAX, followed by how to build a model and optimize its parameters.\n",
Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_for_PyTorch_users.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ kernelspec:

# JAX for PyTorch users

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_PyTorch_users.ipynb)

This is a quick overview of JAX and the JAX AI stack written for those who are famiilar with PyTorch.

First, we cover how to manipulate JAX Arrays following the [well-known PyTorch's tensors tutorial](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html). Next, we explore automatic differentiation with JAX, followed by how to build a model and optimize its parameters.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_image_captioning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"source": [
"# Image Captioning with JAX & FLAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_image_captioning.ipynb)\n",
"\n",
"In this tutorial we implement from scratch and train a transformer-based model on the image captioning task. This task consists of generating a caption text for the input image. We train the model on [Flickr8k](http://hockenmaier.cs.illinois.edu/Framing_Image_Description/KCCA.html) dataset and briefly test trained model on few test images. This tutorial is inspired by [\"Image Captioning with Keras\"](https://keras.io/examples/vision/image_captioning/)."
]
},
Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_image_captioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ kernelspec:

# Image Captioning with JAX & FLAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_image_captioning.ipynb)

In this tutorial we implement from scratch and train a transformer-based model on the image captioning task. This task consists of generating a caption text for the input image. We train the model on [Flickr8k](http://hockenmaier.cs.illinois.edu/Framing_Image_Description/KCCA.html) dataset and briefly test trained model on few test images. This tutorial is inspired by ["Image Captioning with Keras"](https://keras.io/examples/vision/image_captioning/).

+++
Expand Down
4 changes: 3 additions & 1 deletion docs/source/JAX_machine_translation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744",
"metadata": {},
"source": [
"# NLP: JAX Machine Translation"
"# NLP: JAX Machine Translation\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions docs/source/JAX_machine_translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ kernelspec:

# NLP: JAX Machine Translation

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)

+++

Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_transformer/, which is itself an adaptation from https://www.manning.com/books/deep-learning-with-python-second-edition
Expand Down
1 change: 1 addition & 0 deletions docs/source/JAX_porting_PyTorch_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"source": [
"# Porting a PyTorch model to JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)\n",
"\n",
"In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`.\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions docs/source/JAX_porting_PyTorch_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ kernelspec:

# Porting a PyTorch model to JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)

In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`.

Expand Down
6 changes: 4 additions & 2 deletions docs/source/JAX_time_series_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"source": [
"# Time series classification with JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_time_series_classification.ipynb)\n",
"\n",
"In this tutorial, we're going to perform time series classification with a Convolutional Neural Network.\n",
"We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/),\n",
"which contains measurements of engine noise captured by a motor sensor.\n",
Expand Down Expand Up @@ -154,7 +156,7 @@
"\n",
"Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources\n",
"might come in different forms, but they all need to implement the `grain.RandomAccessDataSource`\n",
"interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)\n",
"interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/source/data_sources.md)\n",
"for more details.\n",
"\n",
"Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:"
Expand Down Expand Up @@ -193,7 +195,7 @@
"metadata": {},
"source": [
"Samplers determine the order in which records are processed, and we'll use the\n",
"[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler)\n",
"[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/source/data_loader/samplers.md#index-sampler)\n",
"recommended by Grain.\n",
"\n",
"Finally, we'll create `DataLoader`s that handle orchestration of loading.\n",
Expand Down
6 changes: 4 additions & 2 deletions docs/source/JAX_time_series_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ kernelspec:

# Time series classification with JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_time_series_classification.ipynb)

In this tutorial, we're going to perform time series classification with a Convolutional Neural Network.
We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/),
which contains measurements of engine noise captured by a motor sensor.
Expand Down Expand Up @@ -108,7 +110,7 @@ Flax models.

Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources
might come in different forms, but they all need to implement the `grain.RandomAccessDataSource`
interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)
interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/source/data_sources.md)
for more details.

Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:
Expand All @@ -132,7 +134,7 @@ test_source = DataSource(x_test, y_test)
```

Samplers determine the order in which records are processed, and we'll use the
[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler)
[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/source/data_loader/samplers.md#index-sampler)
recommended by Grain.

Finally, we'll create `DataLoader`s that handle orchestration of loading.
Expand Down
4 changes: 3 additions & 1 deletion docs/source/JAX_transformer_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"source": [
"# Text classification with transformer model\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_transformer_text_classification.ipynb)\n",
"\n",
"In this short tutorial, we will perform sentiment analysis on movie reviews. We would like\n",
"to know if a review is overall positive or negative, so we're facing a classification task.\n",
"\n",
Expand Down Expand Up @@ -162,7 +164,7 @@
"For handling input data we're going to use Grain, a pure Python package developed for JAX and\n",
"Flax models. Grain supports custom setups where data sources might come in different forms, but\n",
"they all need to implement the `grain.RandomAccessDataSource` interface. See\n",
"[PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)\n",
"[PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/source/data_sources.md)\n",
"for more details.\n",
"\n",
"Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:"
Expand Down
4 changes: 3 additions & 1 deletion docs/source/JAX_transformer_text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ kernelspec:

# Text classification with transformer model

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_transformer_text_classification.ipynb)

In this short tutorial, we will perform sentiment analysis on movie reviews. We would like
to know if a review is overall positive or negative, so we're facing a classification task.

Expand Down Expand Up @@ -126,7 +128,7 @@ x_test = pad_sequences(x_test, max_len=maxlen)
For handling input data we're going to use Grain, a pure Python package developed for JAX and
Flax models. Grain supports custom setups where data sources might come in different forms, but
they all need to implement the `grain.RandomAccessDataSource` interface. See
[PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)
[PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/source/data_sources.md)
for more details.

Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:
Expand Down
6 changes: 4 additions & 2 deletions docs/source/JAX_visualizing_models_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# JAX and Tensorboard / NNX Display"
"# JAX and Tensorboard / NNX Display\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To keep things straightforward and familiar, we reuse the model and data from '[Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html)' - if you haven't read that yet and want the primer, start there before returning.\n",
"To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning.\n",
"\n",
"All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them."
]
Expand Down
4 changes: 3 additions & 1 deletion docs/source/JAX_visualizing_models_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ kernelspec:

# JAX and Tensorboard / NNX Display

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb)

+++

To keep things straightforward and familiar, we reuse the model and data from '[Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html)' - if you haven't read that yet and want the primer, start there before returning.
To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning.

All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them.

Expand Down
Loading

0 comments on commit 2801700

Please sign in to comment.