Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Improve TorchTune Extensibility and Build Interop with Ecosystem #442

Merged
merged 20 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/recipe_integration_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ jobs:
python3 -m pip install awscli==1.32.6
mkdir -p /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/llama2-7b/tokenizer.model /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/llama2-7b-01242024 /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/llama2-7b-torchtune.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-hf-03082024.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-tune-03082024.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-meta-03082024.pt /tmp/test-artifacts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to get pretty annoying to hardcode this tempfile directory, as users on the same box can overwrite each other's stuff / not have access to this directory. We should at least add some sort of unique id to this.

- name: Install dependencies
run: |
python -m pip install -r requirements.txt
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/recipe_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ jobs:
mkdir -p /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/llama2-7b/tokenizer.model /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-01242024 /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-hf-03082024.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-tune-03082024.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-meta-03082024.pt /tmp/test-artifacts
- name: Install dependencies
run: |
python -m pip install -r requirements.txt
Expand Down
28 changes: 20 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,24 @@ Follow the instructions on the official [`meta-llama`](https://huggingface.co/me
You can find your token at https://huggingface.co/settings/tokens

```
tune download --repo-id meta-llama/Llama-2-7b --hf-token <HF_TOKEN> --output-dir /tmp/llama2
tune download --repo-id meta-llama/Llama-2-7b \
--hf-token <HF_TOKEN> \
--output-dir /tmp/llama2
```

&nbsp;

#### Converting the checkpoint into PyTorch-native
#### Converting the checkpoint into PyTorch-native for LoRA

Now that you have the Llama2 model weights, convert them into a PyTorch-native format supported by TorchTune.
Now that you have the Llama2 model weights, convert them into a PyTorch-native format supported by TorchTune. This is only
needed if you're running LoRA. For full fine-tuning, you should be able to use the downloaded checkpoints without any
converstion. See the [Running Recipes](#running-recipes) section for more details.

```
tune convert_checkpoint --checkpoint-path /tmp/llama2/consolidated.00.pth --output-path /tmp/llama2_native --model llama2
tune convert_checkpoint --checkpoint-path /tmp/llama2/consolidated.00.pth \
--output-path /tmp/llama2/llama2_native.pt \
--model llama2 \
--train-type full
```

&nbsp;
Expand All @@ -110,18 +117,23 @@ tune convert_checkpoint --checkpoint-path /tmp/llama2/consolidated.00.pth --outp

TorchTune contains recipes for [full finetuning](https://github.com/pytorch-labs/torchtune/blob/e802c057d17773f65cf80721807086724e4fa7db/recipes/full_finetune.py), [LoRA finetuning](https://github.com/pytorch-labs/torchtune/blob/e802c057d17773f65cf80721807086724e4fa7db/recipes/lora_finetune.py), and [generation](https://github.com/pytorch-labs/torchtune/blob/e802c057d17773f65cf80721807086724e4fa7db/recipes/alpaca_generate.py).

To run a full finetune on two devices on the Alpaca dataset using FSDP:
Full-finetuning runs without the need for any model conversion. To run a full finetune on two devices on the Alpaca dataset using FSDP:

```
tune --nnodes 1 --nproc_per_node 2 full_finetune --config alpaca_llama2_full_finetune
tune --nnodes 1 --nproc_per_node 2 \
full_finetune \
--config alpaca_llama2_full_finetune
```

The argument passed to `--nproc_per_node` can be varied depending on how many GPUs you have. A full finetune can be memory-intensive, so make sure you are running on enough devices. See [this table](https://github.com/pytorch-labs/torchtune/blob/main/README.md#finetuning-resource-requirements) for resource requirements on common hardware setups.

Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via
Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via the following. Remember to convert your
model with `train_type` set to `lora`

```
tune --nnodes 1 --nproc_per_node 2 lora_finetune --config alpaca_llama2_lora_finetune
tune --nnodes 1 --nproc_per_node 2 \
lora_finetune \
--config alpaca_llama2_lora_finetune
```

Again, the argument to `--nproc_per_node` can be varied subject to memory constraints of your device(s).
Expand Down
11 changes: 9 additions & 2 deletions recipes/configs/alpaca_llama2_full_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ shuffle: True
# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_7b
model_checkpoint: /tmp/llama2_native

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2
checkpoint_files: [consolidated.00.pth]
model_type: LLAMA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where can I, as a user, find the supported "model_type"'s?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add more info on this after the LoRA change. For now its a copy paste for the config


resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
Expand All @@ -32,7 +39,7 @@ max_steps_per_epoch: null
gradient_accumulation_steps: 1
log_every_n_steps: null
run_generation: null
resume_from_checkpoint: False


# Distributed
device: cuda
Expand Down
57 changes: 27 additions & 30 deletions recipes/full_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys

from functools import partial
Expand All @@ -17,6 +16,7 @@
from torch import nn
from torch.cuda.amp import GradScaler
from torch.distributed import init_process_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler

Expand Down Expand Up @@ -93,13 +93,20 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.total_training_steps = 0

def load_checkpoint(self, ckpt_path: str):
def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate.
Extract and load state dict from checkpoint file.
"""
ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint)
return ckpt_dict
self._checkpointer = config.instantiate(
cfg,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
"""
Expand All @@ -108,13 +115,7 @@ def setup(self, cfg: DictConfig) -> None:
"""
self._metric_logger = config.instantiate(cfg.metric_logger)

ckpt_dict = self.load_checkpoint(ckpt_path=cfg.model_checkpoint)

# If we're resuming from checkpoint, the recipe's state should be updated before
# initializing the training components. This ensures that the seed is correctly
# propagated to the relevant components
if self._resume_from_checkpoint:
self._update_recipe_state(ckpt_dict)
ckpt_dict = self.load_checkpoint(cfg.checkpointer)

# ``_setup_model`` handles initialization and loading the state dict. This method
# should be called before ``_setup_optimizer`` since transforming the optimizer
Expand Down Expand Up @@ -239,6 +240,7 @@ def _setup_optimizer(
for FSDP.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
opt_state_dict = utils.transform_opt_state_dict(
opt_state_dict, self._model, optimizer
Expand Down Expand Up @@ -290,35 +292,30 @@ def _setup_data(

def save_checkpoint(self, epoch: int) -> None:
"""
Checkpoint the relevant state of a recipe.

This makes use of the `save_checkpoint` utility which is responsible for
writing the checkpoint dictionary to file. The contents of the dict are dictated
by whether training is complete or not.

If training is ongoing, optimizer state, seed and epochs_run are saved along with the
model weights.
Save state dict to file.
"""
os.makedirs(self._output_dir, exist_ok=True)
output_loc = f"{self._output_dir}/model_{epoch}.ckpt"
ckpt_dict = {MODEL_KEY: self._model}

ckpt_dict = {MODEL_KEY: self._model.state_dict()}
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
optimizer_state_dict = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we have a helper function for this? Any reason we prefer it directly in recipe for now/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was just a 2 line function specific to this recipe, so just pulled it in here

FSDP.optim_state_dict(self._model, self._optimizer)
if utils.contains_fsdp(self._model)
else self._optimizer.state_dict()
kartikayk marked this conversation as resolved.
Show resolved Hide resolved
)
ckpt_dict.update(
{
OPT_KEY: self._optimizer,
OPT_KEY: optimizer_state_dict,
SEED_KEY: self.seed,
EPOCHS_KEY: self.epochs_run,
TOTAL_EPOCHS_KEY: self.total_epochs,
MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
utils.save_checkpoint(ckpt_dict, output_loc)

if self._is_rank_zero:
log.info(
f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)

def _should_update_weights(self, current_iteration: int) -> bool:
Expand Down
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# Install PyTorch
torch>=2.2.0

# HuggingFace Integration Reqs
datasets==2.15.0
huggingface_hub==0.19.4
Expand Down
8 changes: 7 additions & 1 deletion tests/recipes/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
# Longer Test using llama 7b checkpoint: ./run_test.sh --large-scale

LOCAL_DIR="/tmp/test-artifacts"
S3_URLS=("s3://pytorch-multimodal/llama2-7b/tokenizer.model" "s3://pytorch-multimodal/small-ckpt-01242024")
S3_URLS=(
"s3://pytorch-multimodal/llama2-7b/tokenizer.model"
"s3://pytorch-multimodal/small-ckpt-01242024"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one still around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for LoRA

"s3://pytorch-multimodal/small-ckpt-tune-03082024.pt"
"s3://pytorch-multimodal/small-ckpt-meta-03082024.pt"
"s3://pytorch-multimodal/small-ckpt-hf-03082024.pt"
)
PYTEST_COMMAND="pytest tests/recipes -s"

if [[ $# -gt 0 ]]; then
Expand Down
Loading
Loading