diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 4a6075455c..7eee470630 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -148,6 +148,7 @@ 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx 'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update + 'guides/gemma.ipynb', ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs_nnx/guides/gemma.ipynb b/docs_nnx/guides/gemma.ipynb new file mode 100644 index 0000000000..7230ca1b02 --- /dev/null +++ b/docs_nnx/guides/gemma.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright 2024 The Flax Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", + "\n", + "http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n", + "\n", + "You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install --no-deps -U flax\n", + "! pip install jaxtyping kagglehub penzai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Downloading the checkpoint\n", + "\n", + "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", + "\n", + "1. Visit https://www.kaggle.com/ and create an account.\n", + "2. Go to your account settings, then the 'API' section.\n", + "3. Click 'Create new token' to download your key.\n", + "\n", + "Then run the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'kagglehub'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mkagglehub\u001b[39;00m\n\u001b[1;32m 2\u001b[0m kagglehub\u001b[38;5;241m.\u001b[39mlogin()\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'kagglehub'" + ] + } + ], + "source": [ + "import kagglehub\n", + "kagglehub.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If everything went well, you should see:\n", + "```\n", + "Kaggle credentials set.\n", + "Kaggle credentials successfully validated.\n", + "```\n", + "\n", + "Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", + "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", + "ckpt_path = f'{weights_dir}/{VARIANT}'\n", + "vocab_path = f'{weights_dir}/tokenizer.model'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Python imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import sentencepiece as spm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! git clone https://github.com/google/flax.git flax_examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"./flax_examples/flax/nnx/examples/gemma\")\n", + "import params as params_lib\n", + "import sampler as sampler_lib\n", + "import transformer as transformer_lib\n", + "sys.path.pop();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start Generating with Your Model\n", + "\n", + "Load and prepare your LLM's checkpoint for use with Flax." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "# Load parameters\n", + "params = params_lib.load_and_format_params(ckpt_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "vocab = spm.SentencePieceProcessor()\n", + "vocab.Load(vocab_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transformer = transformer_lib.Transformer.from_params(params)\n", + "nnx.display(transformer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, build a sampler on top of your model and your tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "# Create a sampler with the right param shapes.\n", + "sampler = sampler_lib.Sampler(\n", + " transformer=transformer,\n", + " vocab=vocab,\n", + " params=params['transformer'],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "input_batch = [\n", + " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", + " \"What are the planets of the solar system?\",\n", + " ]\n", + "\n", + "out_data = sampler(\n", + " input_strings=input_batch,\n", + " total_generation_steps=300, # number of steps performed when generating\n", + " )\n", + "\n", + "for input_string, out_string in zip(input_batch, out_data.text):\n", + " print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n", + " print()\n", + " print(10*'#')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should get an implementation of bubble sort and a description of the solar system." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs_nnx/guides/gemma.md b/docs_nnx/guides/gemma.md new file mode 100644 index 0000000000..69ff17acb5 --- /dev/null +++ b/docs_nnx/guides/gemma.md @@ -0,0 +1,150 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +Copyright 2024 The Flax Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +--- + ++++ + +# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide + +You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it. + ++++ + +## Installation + +```{code-cell} ipython3 +! pip install --no-deps -U flax +! pip install jaxtyping kagglehub penzai +``` + +## Downloading the checkpoint + +"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them: + +1. Visit https://www.kaggle.com/ and create an account. +2. Go to your account settings, then the 'API' section. +3. Click 'Create new token' to download your key. + +Then run the cell below. + +```{code-cell} ipython3 +import kagglehub +kagglehub.login() +``` + +If everything went well, you should see: +``` +Kaggle credentials set. +Kaggle credentials successfully validated. +``` + +Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models. + +```{code-cell} ipython3 +VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"} +weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}') +ckpt_path = f'{weights_dir}/{VARIANT}' +vocab_path = f'{weights_dir}/tokenizer.model' +``` + +## Python imports + +```{code-cell} ipython3 +from flax import nnx +import sentencepiece as spm +``` + +Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example. + +```{code-cell} ipython3 +! git clone https://github.com/google/flax.git flax_examples +``` + +```{code-cell} ipython3 +import sys + +sys.path.append("./flax_examples/flax/nnx/examples/gemma") +import params as params_lib +import sampler as sampler_lib +import transformer as transformer_lib +sys.path.pop(); +``` + +## Start Generating with Your Model + +Load and prepare your LLM's checkpoint for use with Flax. + +```{code-cell} ipython3 +:cellView: form + +# Load parameters +params = params_lib.load_and_format_params(ckpt_path) +``` + +Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library. + +```{code-cell} ipython3 +:cellView: form + +vocab = spm.SentencePieceProcessor() +vocab.Load(vocab_path) +``` + +Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release. + +```{code-cell} ipython3 +transformer = transformer_lib.Transformer.from_params(params) +nnx.display(transformer) +``` + +Finally, build a sampler on top of your model and your tokenizer. + +```{code-cell} ipython3 +:cellView: form + +# Create a sampler with the right param shapes. +sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=vocab, + params=params['transformer'], +) +``` + +You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent. + +```{code-cell} ipython3 +:cellView: form + +input_batch = [ + "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", + "What are the planets of the solar system?", + ] + +out_data = sampler( + input_strings=input_batch, + total_generation_steps=300, # number of steps performed when generating + ) + +for input_string, out_string in zip(input_batch, out_data.text): + print(f"Prompt:\n{input_string}\nOutput:\n{out_string}") + print() + print(10*'#') +``` + +You should get an implementation of bubble sort and a description of the solar system. diff --git a/docs_nnx/guides/index.rst b/docs_nnx/guides/index.rst index f8d444bf48..5d917be3e3 100644 --- a/docs_nnx/guides/index.rst +++ b/docs_nnx/guides/index.rst @@ -14,3 +14,4 @@ Guides checkpointing jax_and_nnx_transforms haiku_to_flax + gemma diff --git a/uv.lock b/uv.lock index bbd346cb34..bb7c89e2db 100644 --- a/uv.lock +++ b/uv.lock @@ -767,7 +767,7 @@ wheels = [ [[package]] name = "flax" -version = "0.9.0" +version = "0.10.0" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -2212,7 +2212,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.6.0" +version = "0.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2220,7 +2220,6 @@ dependencies = [ { name = "etils", version = "1.9.2", source = { registry = "https://pypi.org/simple" }, extra = ["epath", "epy"], marker = "python_full_version >= '3.11'" }, { name = "humanize" }, { name = "jax" }, - { name = "jaxlib" }, { name = "msgpack" }, { name = "nest-asyncio" }, { name = "numpy" }, @@ -2230,9 +2229,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/4f/f6b372e70fb3785656d31edd9b99a151dc1b4955486e85a1935e9e0273c5/orbax_checkpoint-0.6.0.tar.gz", hash = "sha256:313586128267e0923d6d2095855da5edcd45acee1f9d2e86d1e8330f69acb110", size = 187560 } +sdist = { url = "https://files.pythonhosted.org/packages/b7/5a/e07d3b2a9dacc6fe882a255080d4af3ac180bc190fd8ce22ab64cf0bfe26/orbax_checkpoint-0.7.0.tar.gz", hash = "sha256:f5a59babbf86fdafacddcfd2fb1c6d45b4fa0685b38a87a4598a5702bb70a657", size = 201557 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/e0/194d62674be60e3bf2cb764f653e8f06db86b02b6c9c9243ea9af0f48bf1/orbax_checkpoint-0.6.0-py3-none-any.whl", hash = "sha256:fce1d61b1a378939f55b03fb4ac9922ad0def0b846822b1f5e70f4a81d24dbc2", size = 253044 }, + { url = "https://files.pythonhosted.org/packages/3a/63/45b63b51b320d104f21cb7f2a5d0ae2b37558e24296c02d33521a291ad87/orbax_checkpoint-0.7.0-py3-none-any.whl", hash = "sha256:0469030dd70729f7416981712a9ea8a82bd02c65ca82c933675c9e3ed4763f9b", size = 279660 }, ] [[package]]