From 8007bb8fa7999adc4705c62820629ca30622f7ca Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Fri, 20 Dec 2024 19:47:24 +0100 Subject: [PATCH 1/6] :truck: Split into tutorials and example applications Signed-off-by: Pavithra Eswaramoorthy --- docs/source/data_loaders.md | 10 ++++++++++ docs/source/examples.md | 16 ++++++++++++++++ docs/source/index.rst | 1 + docs/source/tutorials.md | 25 +++++++------------------ 4 files changed, 34 insertions(+), 18 deletions(-) create mode 100644 docs/source/data_loaders.md create mode 100644 docs/source/examples.md diff --git a/docs/source/data_loaders.md b/docs/source/data_loaders.md new file mode 100644 index 0000000..832d1cd --- /dev/null +++ b/docs/source/data_loaders.md @@ -0,0 +1,10 @@ +# Introduction to Data Loaders + +Learn about data loading strategies on different hardware systems: + +```{toctree} +:maxdepth: 1 + +data_loaders_on_cpu_with_jax +data_loaders_on_gpu_with_jax +``` diff --git a/docs/source/examples.md b/docs/source/examples.md new file mode 100644 index 0000000..00a18cc --- /dev/null +++ b/docs/source/examples.md @@ -0,0 +1,16 @@ +# Example applications + +The following pages provide examples of common applications of the JAX AI stack: + +```{toctree} +:maxdepth: 1 + +JAX_for_LLM_pretraining +JAX_basic_text_classification +JAX_transformer_text_classification +JAX_machine_translation +JAX_examples_image_segmentation +JAX_image_captioning +JAX_Vision_transformer +JAX_time_series_classification +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f09cf7..2ab31e5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,4 +10,5 @@ Jax AI Stack install tutorials + examples contributing diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md index 23cdd58..45106bc 100644 --- a/docs/source/tutorials.md +++ b/docs/source/tutorials.md @@ -1,32 +1,21 @@ # Tutorials -```{note} -This is a work in progress; visit again soon for updated content! -``` - -The following tutorials are meant as an intro to the full stack: +The following tutorials are meant as an introduction to the full stack: ```{toctree} -:maxdepth: 2 +:maxdepth: 1 getting_started_with_jax_for_AI digits_vae -JAX_for_PyTorch_users -JAX_porting_PyTorch_model digits_diffusion_model -JAX_for_LLM_pretraining -JAX_basic_text_classification -JAX_examples_image_segmentation -JAX_Vision_transformer -JAX_machine_translation JAX_visualizing_models_metrics -JAX_image_captioning -JAX_time_series_classification -JAX_transformer_text_classification -data_loaders_on_cpu_with_jax -data_loaders_on_gpu_with_jax +data_loaders +JAX_for_PyTorch_users +JAX_porting_PyTorch_model ``` +## Further references + Once you've gone through this content, you can refer to package-specific documentation for resources that go into more depth on various topics: From 6360732706da76a257119780b93c408d8283aebd Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Fri, 20 Dec 2024 19:59:04 +0100 Subject: [PATCH 2/6] :memo: Update contributing guide to point to docs/source/* Signed-off-by: Pavithra Eswaramoorthy --- docs/source/contributing.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/source/contributing.md b/docs/source/contributing.md index b7affa5..0a30d7d 100644 --- a/docs/source/contributing.md +++ b/docs/source/contributing.md @@ -1,6 +1,4 @@ -# Developer docs - -## Contributing to the JAX AI Stack Documentation +# Contribute to documentation The documentation in the `jax-ai-stack` repository is meant to build on documentation of individual packages, and specifically cover topics that touch on multiple packages @@ -38,24 +36,25 @@ To add a new notebook to the repository, first move the notebook into the approp location in the `docs` directory: ```bash -mv ~/new-tutorial.ipynb docs/new_tutorial.ipynb +mv ~/new-tutorial.ipynb docs/source/new_tutorial.ipynb ``` Next, we use `jupytext` to mark the notebook for syncing with Markdown: ```bash -jupytext --set-formats ipynb,md:myst docs/new_tutorial.ipynb +jupytext --set-formats ipynb,md:myst docs/source/new_tutorial.ipynb ``` Finally, we can sync the notebook and markdown source: ```bash -jupytext --sync docs/new_tutorial.ipynb +jupytext --sync docs/source/new_tutorial.ipynb ``` To ensure that the new notebook is rendered as part of the site, be sure to add references to a `toctree` declaration somewhere in the source tree, for example -in `docs/tutorials.md`. You will also need to add references in `docs/conf.py` +in `docs/source/tutorials.md` or `docs/source/examples.md`. +You will also need to add references in `docs/conf.py` to specify whether the notebook should be executed, and to specify which file sphinx should use when generating the site. @@ -70,9 +69,9 @@ you can do the following: ```bash pip install pre-commit -git add docs/new_tutorial.* # stage the new changes +git add docs/source/new_tutorial.* # stage the new changes pre-commit run # run pre-commit checks on added files -git add docs/new_tutorial.* # stage the files updated by pre-commit +git add docs/source/new_tutorial.* # stage the files updated by pre-commit git commit -m "update new tutorial" # commit to the branch ``` From c8fa1c48917bf29bcde7d2d719f0fc250076c438 Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Fri, 20 Dec 2024 19:59:39 +0100 Subject: [PATCH 3/6] :memo: Update titles (and some text) for consitenncy & flow Signed-off-by: Pavithra Eswaramoorthy --- docs/source/JAX_Vision_transformer.ipynb | 2 +- docs/source/JAX_Vision_transformer.md | 4 ++-- docs/source/JAX_basic_text_classification.ipynb | 4 ++-- docs/source/JAX_basic_text_classification.md | 6 +++--- docs/source/JAX_examples_image_segmentation.ipynb | 2 +- docs/source/JAX_examples_image_segmentation.md | 4 ++-- docs/source/JAX_for_LLM_pretraining.ipynb | 2 +- docs/source/JAX_for_LLM_pretraining.md | 4 ++-- docs/source/JAX_image_captioning.ipynb | 2 +- docs/source/JAX_image_captioning.md | 4 ++-- docs/source/JAX_machine_translation.ipynb | 4 ++-- docs/source/JAX_machine_translation.md | 6 +++--- docs/source/JAX_time_series_classification.ipynb | 2 +- docs/source/JAX_time_series_classification.md | 4 ++-- docs/source/JAX_visualizing_models_metrics.ipynb | 2 +- docs/source/JAX_visualizing_models_metrics.md | 4 ++-- docs/source/digits_diffusion_model.ipynb | 6 +++--- docs/source/digits_diffusion_model.md | 8 ++++---- 18 files changed, 35 insertions(+), 35 deletions(-) diff --git a/docs/source/JAX_Vision_transformer.ipynb b/docs/source/JAX_Vision_transformer.ipynb index dcd766f..2fc4e9f 100644 --- a/docs/source/JAX_Vision_transformer.ipynb +++ b/docs/source/JAX_Vision_transformer.ipynb @@ -5,7 +5,7 @@ "id": "8744b685-7ff5-429a-b610-940506455a54", "metadata": {}, "source": [ - "# Vision Transformer with JAX & FLAX\n", + "# Implement Vision Transformer (ViT) model from scratch\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_Vision_transformer.ipynb)\n", "\n", diff --git a/docs/source/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md index 6a507a9..22dd4c7 100644 --- a/docs/source/JAX_Vision_transformer.md +++ b/docs/source/JAX_Vision_transformer.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# Vision Transformer with JAX & FLAX +# Implement Vision Transformer (ViT) model from scratch [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_Vision_transformer.ipynb) diff --git a/docs/source/JAX_basic_text_classification.ipynb b/docs/source/JAX_basic_text_classification.ipynb index 1763b9b..72f2384 100644 --- a/docs/source/JAX_basic_text_classification.ipynb +++ b/docs/source/JAX_basic_text_classification.ipynb @@ -5,11 +5,11 @@ "id": "072f8ce2-7014-4f04-83a6-96953e9c8a79", "metadata": {}, "source": [ - "# 1D Convnet for basic text classification\n", + "# Basic text classification with 1D CNN\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb)\n", "\n", - "In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convnet to perform sentiment analysis using JAX. This tutorial is originally inspired by [\"Text classification from scratch with Keras\"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).\n", + "In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convolutional Neural Network to perform sentiment analysis using JAX. This tutorial is originally inspired by [\"Text classification from scratch with Keras\"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).\n", "\n", "We will use the IMDB movie review dataset to classify the review to \"positive\" and \"negative\" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set." ] diff --git a/docs/source/JAX_basic_text_classification.md b/docs/source/JAX_basic_text_classification.md index 5cddc18..07d3e3a 100644 --- a/docs/source/JAX_basic_text_classification.md +++ b/docs/source/JAX_basic_text_classification.md @@ -5,18 +5,18 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# 1D Convnet for basic text classification +# Basic text classification with 1D CNN [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb) -In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convnet to perform sentiment analysis using JAX. This tutorial is originally inspired by ["Text classification from scratch with Keras"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model). +In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convolutional Neural Network to perform sentiment analysis using JAX. This tutorial is originally inspired by ["Text classification from scratch with Keras"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model). We will use the IMDB movie review dataset to classify the review to "positive" and "negative" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set. diff --git a/docs/source/JAX_examples_image_segmentation.ipynb b/docs/source/JAX_examples_image_segmentation.ipynb index b26242b..c1a34db 100644 --- a/docs/source/JAX_examples_image_segmentation.ipynb +++ b/docs/source/JAX_examples_image_segmentation.ipynb @@ -5,7 +5,7 @@ "id": "343f53e8-f28e-4fed-a0eb-c5c76c73d5a7", "metadata": {}, "source": [ - "# UNETR for image segmentation\n", + "# Image segmentation with UNETR model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)\n", "\n", diff --git a/docs/source/JAX_examples_image_segmentation.md b/docs/source/JAX_examples_image_segmentation.md index 7640b59..21a9c17 100644 --- a/docs/source/JAX_examples_image_segmentation.md +++ b/docs/source/JAX_examples_image_segmentation.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# UNETR for image segmentation +# Image segmentation with UNETR model [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb) diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb index b8228cb..aa33340 100644 --- a/docs/source/JAX_for_LLM_pretraining.ipynb +++ b/docs/source/JAX_for_LLM_pretraining.ipynb @@ -6,7 +6,7 @@ "id": "NIOXoY1xgiww" }, "source": [ - "# Pretraining an LLM using JAX\n", + "# Pre-training an LLM (miniGPT)\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb)\n", "\n", diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index 4f514ef..c674b22 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "NIOXoY1xgiww"} -# Pretraining an LLM using JAX +# Pre-training an LLM (miniGPT) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb) diff --git a/docs/source/JAX_image_captioning.ipynb b/docs/source/JAX_image_captioning.ipynb index e37b65e..1f85d2f 100644 --- a/docs/source/JAX_image_captioning.ipynb +++ b/docs/source/JAX_image_captioning.ipynb @@ -5,7 +5,7 @@ "id": "03487afe-bbca-420c-9b1f-28ea4506c250", "metadata": {}, "source": [ - "# Image Captioning with JAX & FLAX\n", + "# Image Captioning with Vision Transformer (ViT) model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_image_captioning.ipynb)\n", "\n", diff --git a/docs/source/JAX_image_captioning.md b/docs/source/JAX_image_captioning.md index 566f5d8..2083525 100644 --- a/docs/source/JAX_image_captioning.md +++ b/docs/source/JAX_image_captioning.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# Image Captioning with JAX & FLAX +# Image Captioning with Vision Transformer (ViT) model [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_image_captioning.ipynb) diff --git a/docs/source/JAX_machine_translation.ipynb b/docs/source/JAX_machine_translation.ipynb index f588445..99394c6 100644 --- a/docs/source/JAX_machine_translation.ipynb +++ b/docs/source/JAX_machine_translation.ipynb @@ -5,7 +5,7 @@ "id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744", "metadata": {}, "source": [ - "# NLP: JAX Machine Translation\n", + "# Machine Translation with encoder-decoder transformer model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)" ] @@ -15,7 +15,7 @@ "id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b", "metadata": {}, "source": [ - "Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_transformer/, which is itself an adaptation from https://www.manning.com/books/deep-learning-with-python-second-edition\n", + "This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n", "\n", "We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation." ] diff --git a/docs/source/JAX_machine_translation.md b/docs/source/JAX_machine_translation.md index 1989bec..f5ec231 100644 --- a/docs/source/JAX_machine_translation.md +++ b/docs/source/JAX_machine_translation.md @@ -5,20 +5,20 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# NLP: JAX Machine Translation +# Machine Translation with encoder-decoder transformer model [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb) +++ -Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_transformer/, which is itself an adaptation from https://www.manning.com/books/deep-learning-with-python-second-edition +This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition) We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation. diff --git a/docs/source/JAX_time_series_classification.ipynb b/docs/source/JAX_time_series_classification.ipynb index d9b32eb..be8a69c 100644 --- a/docs/source/JAX_time_series_classification.ipynb +++ b/docs/source/JAX_time_series_classification.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Time series classification with JAX\n", + "# Time series classification with CNN\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_time_series_classification.ipynb)\n", "\n", diff --git a/docs/source/JAX_time_series_classification.md b/docs/source/JAX_time_series_classification.md index 82a542c..069b8ce 100644 --- a/docs/source/JAX_time_series_classification.md +++ b/docs/source/JAX_time_series_classification.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: jax-env language: python name: python3 --- -# Time series classification with JAX +# Time series classification with CNN [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_time_series_classification.ipynb) diff --git a/docs/source/JAX_visualizing_models_metrics.ipynb b/docs/source/JAX_visualizing_models_metrics.ipynb index 4975b9c..6ece7e0 100644 --- a/docs/source/JAX_visualizing_models_metrics.ipynb +++ b/docs/source/JAX_visualizing_models_metrics.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# JAX and Tensorboard / NNX Display\n", + "# Visualize JAX model metrics with TensorBoard\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb)" ] diff --git a/docs/source/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md index 1910192..2abd111 100644 --- a/docs/source/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# JAX and Tensorboard / NNX Display +# Visualize JAX model metrics with TensorBoard [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb) diff --git a/docs/source/digits_diffusion_model.ipynb b/docs/source/digits_diffusion_model.ipynb index 09a5203..6472765 100644 --- a/docs/source/digits_diffusion_model.ipynb +++ b/docs/source/digits_diffusion_model.ipynb @@ -6,12 +6,12 @@ "id": "Kzqlx7fpXRnJ" }, "source": [ - "# Image generation in JAX: a simple diffusion model\n", + "# Simple diffusion model for image generation in JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n", "\n", - "In [Debugging in JAX: a Variational autoencoder (VAE) model](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes:\n", - "- preparing the dataset\n", + "In [Variational autoencoder (VAE) and debugging in JAX](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial, you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes:\n", + "- Preparing the dataset\n", "- Developing the custom diffusion model\n", "- Creating the loss and training functions\n", "- Perform the model training using Colab TPU v2 as a hardware accelerator\n", diff --git a/docs/source/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md index 212c6df..3fb1c91 100644 --- a/docs/source/digits_diffusion_model.md +++ b/docs/source/digits_diffusion_model.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -13,12 +13,12 @@ kernelspec: +++ {"id": "Kzqlx7fpXRnJ"} -# Image generation in JAX: a simple diffusion model +# Simple diffusion model for image generation in JAX [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb) -In [Debugging in JAX: a Variational autoencoder (VAE) model](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes: -- preparing the dataset +In [Variational autoencoder (VAE) and debugging in JAX](digits_vae.ipynb) you explored a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the simple digits data. In this tutorial, you will find the steps to develop, train and perform inferences with a simple diffusion model developed with JAX, Flax, NNX and Optax. It includes: +- Preparing the dataset - Developing the custom diffusion model - Creating the loss and training functions - Perform the model training using Colab TPU v2 as a hardware accelerator From d87bd1525459c241034bb291eedd91b33d976b50 Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Fri, 20 Dec 2024 23:15:53 +0100 Subject: [PATCH 4/6] :rewind: Revert accidental jupytext version change Signed-off-by: Pavithra Eswaramoorthy --- docs/source/JAX_Vision_transformer.md | 2 +- docs/source/JAX_basic_text_classification.md | 2 +- docs/source/JAX_examples_image_segmentation.md | 2 +- docs/source/JAX_for_LLM_pretraining.md | 2 +- docs/source/JAX_image_captioning.md | 2 +- docs/source/JAX_machine_translation.md | 2 +- docs/source/JAX_time_series_classification.md | 2 +- docs/source/JAX_visualizing_models_metrics.md | 2 +- docs/source/digits_diffusion_model.md | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md index 22dd4c7..eae6a68 100644 --- a/docs/source/JAX_Vision_transformer.md +++ b/docs/source/JAX_Vision_transformer.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_basic_text_classification.md b/docs/source/JAX_basic_text_classification.md index 07d3e3a..9fa39a3 100644 --- a/docs/source/JAX_basic_text_classification.md +++ b/docs/source/JAX_basic_text_classification.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_examples_image_segmentation.md b/docs/source/JAX_examples_image_segmentation.md index 21a9c17..c5268db 100644 --- a/docs/source/JAX_examples_image_segmentation.md +++ b/docs/source/JAX_examples_image_segmentation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index c674b22..75ad425 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/source/JAX_image_captioning.md b/docs/source/JAX_image_captioning.md index 2083525..62b1a32 100644 --- a/docs/source/JAX_image_captioning.md +++ b/docs/source/JAX_image_captioning.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_machine_translation.md b/docs/source/JAX_machine_translation.md index f5ec231..16e8777 100644 --- a/docs/source/JAX_machine_translation.md +++ b/docs/source/JAX_machine_translation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_time_series_classification.md b/docs/source/JAX_time_series_classification.md index 069b8ce..ec7f417 100644 --- a/docs/source/JAX_time_series_classification.md +++ b/docs/source/JAX_time_series_classification.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: jax-env language: python diff --git a/docs/source/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md index 2abd111..2e27bc3 100644 --- a/docs/source/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md index 3fb1c91..9020702 100644 --- a/docs/source/digits_diffusion_model.md +++ b/docs/source/digits_diffusion_model.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 name: python3 From 8d24dd9891c7c05e2aae21aef91440f051032407 Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Fri, 20 Dec 2024 23:56:46 +0100 Subject: [PATCH 5/6] :truck: Create section for PyTorch Signed-off-by: Pavithra Eswaramoorthy --- docs/source/pytorch_users.md | 10 ++++++++++ docs/source/tutorials.md | 3 +-- 2 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 docs/source/pytorch_users.md diff --git a/docs/source/pytorch_users.md b/docs/source/pytorch_users.md new file mode 100644 index 0000000..edf469e --- /dev/null +++ b/docs/source/pytorch_users.md @@ -0,0 +1,10 @@ +# From PyTorch to JAX + +The following tutorials provide an onboarding path to JAX, for users who are familiar with PyTorch: + +```{toctree} +:maxdepth: 1 + +JAX_for_PyTorch_users +JAX_porting_PyTorch_model +``` diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md index 45106bc..3e984c9 100644 --- a/docs/source/tutorials.md +++ b/docs/source/tutorials.md @@ -10,8 +10,7 @@ digits_vae digits_diffusion_model JAX_visualizing_models_metrics data_loaders -JAX_for_PyTorch_users -JAX_porting_PyTorch_model +pytorch_users ``` ## Further references From a70c08b6d57ff3b72c55a8c2709c75426cd0fb00 Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Fri, 20 Dec 2024 23:57:07 +0100 Subject: [PATCH 6/6] :memo: Fix plain links Signed-off-by: Pavithra Eswaramoorthy --- docs/source/JAX_for_PyTorch_users.ipynb | 2 +- docs/source/JAX_for_PyTorch_users.md | 2 +- docs/source/JAX_porting_PyTorch_model.ipynb | 4 ++-- docs/source/JAX_porting_PyTorch_model.md | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/JAX_for_PyTorch_users.ipynb b/docs/source/JAX_for_PyTorch_users.ipynb index a6585df..adfb32b 100644 --- a/docs/source/JAX_for_PyTorch_users.ipynb +++ b/docs/source/JAX_for_PyTorch_users.ipynb @@ -1554,7 +1554,7 @@ "source": [ "### Further reading\n", "\n", - "- https://jax.readthedocs.io/en/latest/jit-compilation.html" + "- [JAX documentation on Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html)" ] } ], diff --git a/docs/source/JAX_for_PyTorch_users.md b/docs/source/JAX_for_PyTorch_users.md index f0d4964..7e2cef7 100644 --- a/docs/source/JAX_for_PyTorch_users.md +++ b/docs/source/JAX_for_PyTorch_users.md @@ -853,4 +853,4 @@ jit_matmul_relu(x, y) ### Further reading -- https://jax.readthedocs.io/en/latest/jit-compilation.html +- [JAX documentation on Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) diff --git a/docs/source/JAX_porting_PyTorch_model.ipynb b/docs/source/JAX_porting_PyTorch_model.ipynb index dd02604..e1b1456 100644 --- a/docs/source/JAX_porting_PyTorch_model.ipynb +++ b/docs/source/JAX_porting_PyTorch_model.ipynb @@ -2215,8 +2215,8 @@ "source": [ "## Further reading\n", "\n", - "- https://flax.readthedocs.io/en/latest/examples/core_examples.html\n", - "- https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html" + "- [Flax documentation: Core Exampels](https://flax.readthedocs.io/en/latest/examples/core_examples.html)\n", + "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html)" ] } ], diff --git a/docs/source/JAX_porting_PyTorch_model.md b/docs/source/JAX_porting_PyTorch_model.md index b940635..f043aac 100644 --- a/docs/source/JAX_porting_PyTorch_model.md +++ b/docs/source/JAX_porting_PyTorch_model.md @@ -1615,5 +1615,5 @@ cosine_dist ## Further reading -- https://flax.readthedocs.io/en/latest/examples/core_examples.html -- https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html +- [Flax documentation: Core Exampels](https://flax.readthedocs.io/en/latest/examples/core_examples.html) +- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html)