Skip to content

Commit

Permalink
Llama3 changes (#793)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Apr 18, 2024
1 parent 83785f9 commit 20747cd
Show file tree
Hide file tree
Showing 49 changed files with 4,209 additions and 83 deletions.
53 changes: 46 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
![Recipe Integration Test](https://github.com/pytorch/torchtune/actions/workflows/recipe_test.yaml/badge.svg)
[![](https://dcbadge.vercel.app/api/server/4Xsdn8Rr9Q?style=flat)](https://discord.gg/4Xsdn8Rr9Q)

 
 

**Note: torchtune now supports Llama3! Currently we support the Llama3 8B Model with LoRA, QLoRA and Full fine-tune. Find more details in the [Llama3](#llama3) section!**


# torchtune

Expand Down Expand Up @@ -40,6 +45,7 @@ torchtune currently supports the following models.

| Model | Sizes |
|-----------------------------------------------|-----------|
| [Llama3](https://llama.meta.com/llama3) | 8B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] |
| [Llama2](https://llama.meta.com/llama2/) | 7B, 13B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] |
| [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] |
| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] |
Expand All @@ -54,8 +60,8 @@ torchtune provides the following fine-tuning recipes.

| Training | Fine-tuning Method |
|------------------------------------|------------------------------------|
| Distributed Training [1 to 8 GPUs] | Full [[code](recipes/full_finetune_distributed.py), [example](recipes/configs/llama2/7B_full.yaml)], LoRA [[code](recipes/lora_finetune_distributed.py), [example](recipes/configs/llama2/7B_lora.yaml)] |
| Single Device / Low Memory [1 GPU] | Full [[code](recipes/full_finetune_single_device.py), [example](recipes/configs/llama2/7B_full_low_memory.yaml)], LoRA + QLoRA [[code](recipes/lora_finetune_single_device.py), [example](recipes/configs/llama2/7B_qlora_single_device.yaml)] |
| Distributed Training [1 to 8 GPUs] | Full [[code](recipes/full_finetune_distributed.py), [example](recipes/configs/llama3/8B_full.yaml)], LoRA [[code](recipes/lora_finetune_distributed.py), [example](recipes/configs/llama3/8B_lora.yaml)] |
| Single Device / Low Memory [1 GPU] | Full [[code](recipes/full_finetune_single_device.py), [example](recipes/configs/llama3/8B_full_single_device.yaml)], LoRA + QLoRA [[code](recipes/lora_finetune_single_device.py), [example](recipes/configs/llama3/8B_lora_single_device.yaml)] |
| Single Device [1 GPU] | DPO [[code](recipes/full_finetune_distributed.py), [example](recipes/configs/llama2/7B_lora_dpo_single_device.yaml)]

 
Expand All @@ -69,14 +75,47 @@ This table captures the minimum memory requirements for our different recipes us

| Example HW Resources | Finetuning Method | Config | Model | Peak Memory per GPU
|--------------|-------------------|---------|------------|---------------------|
| 1 x RTX 4090 | QLoRA | [qlora_finetune_single_device](recipes/configs/llama2/7B_qlora_single_device.yaml) | Llama-7B | 9.29 GB |
| 2 x RTX 4090 | LoRA | [lora_finetune_distributed](recipes/configs/llama2/7B_lora.yaml) | Llama-7B | 20.95 GB |
| 1 x RTX 4090 | LoRA | [lora_finetune_single_device](recipes/configs/llama2/7B_lora_single_device.yaml) | Llama-7B | 17.18 GB |
| 1 x RTX 4090 | Full finetune | [full_finetune_single_device](recipes/configs/llama2/7B_full_low_memory.yaml) | Llama-7B | 14.97 GB |
| 4 x RTX 4090 | Full finetune | [full_finetune_distributed](recipes/configs/llama2/7B_full.yaml) | Llama-7B | 22.9 GB |
| 1 x RTX 4090 | QLoRA | [qlora_finetune_single_device](recipes/configs/llama2/7B_qlora_single_device.yaml) | Llama2-7B | 8.57 GB |
| 2 x RTX 4090 | LoRA | [lora_finetune_distributed](recipes/configs/llama2/7B_lora.yaml) | Llama2-7B | 20.95 GB |
| 1 x RTX 4090 | LoRA | [lora_finetune_single_device](recipes/configs/llama2/7B_lora_single_device.yaml) | Llama2-7B | 17.18 GB |
| 1 x RTX 4090 | Full finetune | [full_finetune_single_device](recipes/configs/llama2/7B_full_low_memory.yaml) | Llama2-7B | 14.97 GB |
| 4 x RTX 4090 | Full finetune | [full_finetune_distributed](recipes/configs/llama2/7B_full.yaml) | Llama2-7B | 22.9 GB |

* these are averaged over multiple runs, but there might be some variance based on the setup. We'll update this table regularly.

 

## Llama3

torchtune supports fine-tuning for the Llama3 8B models with support for 70B on its way. We currently support LoRA, QLoRA and Full-finetune on a single GPU as well as LoRA and Full fine-tune on multiple devices. For all the details, take a look at our [tutorial](https://pytorch.org/torchtune/main/tutorials/llama3.html).


In our initial experiments, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.

- LoRA on a single GPU.

```bash
tune run lora_finetune_single_device --config llama3/8B_lora_single_device
```

- QLoRA on a single GPU

```bash
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
```

- LoRA on 2 GPUs

```bash
tune run --nproc_per_node 4 lora_finetune_distributed --config llama3/8B_lora
```

- Full fine-tune on 2 GPUs

```bash
tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_full
```


 

Expand Down
20 changes: 20 additions & 0 deletions docs/source/api_ref_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@ torchtune.models

.. currentmodule:: torchtune.models

llama3
------

All models from the `Llama3 family <https://llama.meta.com/llama3/>`_.

.. code-block:: bash
tune download meta-llama/Meta-Llama-3-8B --hf-token <ACCESS_TOKEN>
.. autosummary::
:toctree: generated/
:nosignatures:

llama3.llama3_8b
llama3.lora_llama3_8b
llama3.qlora_llama3_8b


llama2
------

Expand All @@ -26,6 +45,7 @@ Pre-trained models can be downloaded from the Hugging Face Hub with the followin
llama2.lora_llama2_13b
llama2.qlora_llama2_13b


mistral
-------

Expand Down
10 changes: 9 additions & 1 deletion docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@ Modeling Components and Building Blocks
get_cosine_schedule_with_warmup
RotaryPositionalEmbeddings
RMSNorm
Tokenizer
TransformerDecoderLayer
TransformerDecoder

Tokenizers
------------------------

.. autosummary::
:toctree: generated/
:nosignatures:

tokenizers.SentencePieceTokenizer
tokenizers.TikTokenTokenizer

PEFT Components
---------------
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ torchtune tutorials.

.. customcardstart::

.. customcarditem::
:header: Llama3 in torchtune
:card_description:
:image: _static/img/generic-pytorch-logo.png
:link: tutorials/lora_finetune.html
:tags: finetuning,llama3

.. customcarditem::
:header: Finetuning with LoRA in torchtune
:card_description: Parameter-efficient finetuning of Llama2 using LoRA
Expand Down Expand Up @@ -88,6 +95,7 @@ torchtune tutorials.
:caption: Tutorials
:hidden:

tutorials/llama3
tutorials/lora_finetune
tutorials/qlora_finetune
tutorials/e2e_flow
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tutorials/first_finetune_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ a single device. For a more in-depth discussion on LoRA in torchtune, you can se

|
.. _tune_cp_label:

Modifying a config
------------------
YAML configs hold most of the important information needed for running your recipe.
Expand Down
Loading

0 comments on commit 20747cd

Please sign in to comment.