Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable QLoRA + FSDP2 #909

Merged
merged 100 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
e5826a1
enable LoRA + FSDP2
weifengpy Apr 24, 2024
64fc870
reset params for lora weights and rope
weifengpy Apr 24, 2024
0cd21c6
support lora weights checkpoint and checkpoint utils
weifengpy Apr 24, 2024
589191e
fix lora meta device bug
weifengpy Apr 24, 2024
c801f26
save optim state dict
weifengpy Apr 25, 2024
19a2d70
mark TODO
weifengpy Apr 25, 2024
441da10
optimizer foreach=True for DTensor
weifengpy Apr 25, 2024
750b9e5
clip grad norm
weifengpy Apr 25, 2024
3d632d5
switch to ptd state dict api
weifengpy Apr 26, 2024
cb3abb3
add profiler
weifengpy May 1, 2024
dfcdde3
qlora 7b config
weifengpy May 1, 2024
e68804a
use torchao copy_
weifengpy May 1, 2024
b6fad93
Merge pull request #1 from weifengpy/fsdp2
weifengpy May 1, 2024
d6af9a2
enable saving checkpoint
weifengpy May 1, 2024
7bbe522
Merge pull request #2 from weifengpy/fsdp2
weifengpy May 1, 2024
b616394
optimizer state dict: load on rank0 and broadcast
weifengpy May 1, 2024
a400497
import Optimizer
weifengpy May 1, 2024
e9de63c
resume training
weifengpy May 3, 2024
05d3895
prepare for full test
weifengpy May 3, 2024
7a5bb80
prepare for full test
weifengpy May 3, 2024
64bf49c
remove profiler
weifengpy May 3, 2024
cb1bba4
passed integration test
weifengpy May 4, 2024
ac516e9
remove uncesssary change
weifengpy May 4, 2024
bfde704
Merge branch 'main' into fsdp2
weifengpy May 4, 2024
102db31
bring back state dict validation
weifengpy May 4, 2024
0b66651
align indent on comment
weifengpy May 4, 2024
672aabb
remove unused import
weifengpy May 4, 2024
6af2723
switch to ptd state dict and keep self implemented in record
weifengpy May 8, 2024
42ad99c
clean unused code
weifengpy May 8, 2024
74f6175
remove cuda value error
weifengpy May 8, 2024
f1b8a5e
comment on to_empty
weifengpy May 8, 2024
36e6829
fix memory issues by switching model state dict api
weifengpy May 8, 2024
08cd1fd
clean for review
weifengpy May 8, 2024
559bc4d
Merge branch 'main' into fsdp2
weifengpy May 8, 2024
2333134
fix linter
weifengpy May 9, 2024
49a0364
fix checkpoint loading
weifengpy May 9, 2024
dc2ce02
expecttest CI depedency
weifengpy May 9, 2024
0a604aa
ci depdencecy
weifengpy May 9, 2024
fa83140
fix CI issue
weifengpy May 10, 2024
6203a1f
Merge branch 'main' into qlora
weifengpy May 10, 2024
4b5a895
Merge branch 'pytorch:main' into fsdp2
weifengpy May 10, 2024
1080e2c
Merge branch 'fsdp2' into qlora
weifengpy May 10, 2024
1a70498
rebase qlora
weifengpy May 10, 2024
cb862e9
rebase qlora
weifengpy May 10, 2024
21f5458
sync lora changes
weifengpy May 14, 2024
33773bd
push qlora for perf measurement
weifengpy May 14, 2024
483028b
fix meta init + cpu offloading
weifengpy May 15, 2024
cf42618
init RotaryPositionalEmbeddings in both fresh training and resume
weifengpy May 15, 2024
b519d50
import cpu offloading when needed
weifengpy May 17, 2024
8600ced
FSDP(CheckpointWrapper(Model))
weifengpy May 22, 2024
b2fd531
bring back cpu offloading
weifengpy May 22, 2024
bb8a8bc
remove model.to
weifengpy May 29, 2024
db71c5c
apply nf4 when loading model state dict
weifengpy May 30, 2024
16bf2de
move lora to cpu when cpu offloading
weifengpy May 30, 2024
df6e535
Update documentation tab to point to main instead of stable (#960)
kartikayk May 11, 2024
5f621e1
Update tokens_per_sec to tokens_per_sec_per_gpu (#956)
kartikayk May 11, 2024
7d92b1c
Delete init_weights_with_constant test util (#974)
ebsmothers May 13, 2024
588871e
Sample packing for map datasets with correct RoPE encoding and no cro…
RdoubleA May 15, 2024
1a5bf1a
Utilize compile on model rather than via torch API (#953)
joecummings May 15, 2024
ae7de20
Add better formatting for Eleuther eval results (#986)
joecummings May 16, 2024
23cea56
updating help docs for hf-token arg in download.py (#991)
SalmanMohammadi May 16, 2024
be06efa
Fix position embeddings for Phi3 when packing + nits (#992)
RdoubleA May 17, 2024
79ef995
Llama3-8b memory efficient full finetune (#990)
rohan-varma May 17, 2024
5f55c16
Fix Gemma 2B model forward call (#998)
joecummings May 17, 2024
b88fa2d
fix: lora dropout applied to all models (#995)
Optimox May 17, 2024
b47ee93
fix: different rope base between phi3 and lora_phi3 (#997)
Optimox May 17, 2024
d86b454
Add support for free generation tasks in evals (#975)
joecummings May 19, 2024
2b109f4
Filter out special tokens and placeholder tokens for Phi-3 (#983)
joecummings May 20, 2024
9bd07a6
TorchTune --> torchtune (#1007)
joecummings May 20, 2024
f5cb12e
Support for unstructured text corpus datasets for CPT (#868)
RdoubleA May 21, 2024
a2066f9
Save adapter config and remapped adapter weights for loading into PEF…
ebsmothers May 21, 2024
29d1761
Datasets tutorial improvements (#994)
RdoubleA May 21, 2024
3a01d7f
Fix TypeError: tuple indices must be integers or slices, not str issu…
tambulkar May 22, 2024
1d6b4a2
Add recipe test for llama3 (#929)
SLR722 May 23, 2024
c74c9a9
Fix the Gemma generation (#1016)
solitude-alive May 24, 2024
00f96ff
Update chat tutorial so that it works as is (#1004)
christobill May 28, 2024
62192df
[fix] llama3 70B_lora update commented instructions (#1030)
pbontrager May 30, 2024
ecd5e7e
Move nf4 op registration from utils to modules (#1035)
ebsmothers May 31, 2024
99c549b
feat: add gemma7b support (#971)
Optimox May 31, 2024
7d11a89
Llama3-70b: Full Finetune w/CPU offload + fused optimizer (#993)
rohan-varma Jun 1, 2024
0080795
enable LoRA + FSDP2 (#855)
weifengpy Jun 3, 2024
00360f7
Merge branch 'weifengpy-qlora' into qlora
weifengpy Jun 4, 2024
d8664a3
Merge branch 'main' into qlora
weifengpy Jun 4, 2024
f58f9b2
rebase
weifengpy Jun 4, 2024
b9bfd41
revert lora_finetune_distributed.py
weifengpy Jun 4, 2024
7a3d9a1
rebase and register recipe
weifengpy Jun 4, 2024
2835d2a
del logits to save memory
weifengpy Jun 4, 2024
559b81d
fix linter
weifengpy Jun 4, 2024
85f978b
gate NF4.copy_ on TorchAO==0.2.0
weifengpy Jun 4, 2024
dbae23c
improve torchao gating comment
weifengpy Jun 4, 2024
f4a8dfa
upgrade torchao to 0.2
weifengpy Jun 4, 2024
10e304d
gate torchao 0.2
weifengpy Jun 4, 2024
e117a21
replace with lora_finetune_fsdp2
weifengpy Jun 4, 2024
4bb5e0f
add llama2-70B
weifengpy Jun 4, 2024
174d916
replace with qlora and lora_finetune_fsdp2 in yaml
weifengpy Jun 4, 2024
5fdcefb
rename yaml to _fsdp2.yaml
weifengpy Jun 5, 2024
b878018
add unit test for nf4 state dict
weifengpy Jun 5, 2024
a8f1a9a
python 3.8 style dict union
weifengpy Jun 5, 2024
ae49684
validate lora sd missing
weifengpy Jun 5, 2024
cbb3da8
skip test if <2 gpu
weifengpy Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Config for multi-device QLoRA in lora_finetune_fsdp2.py
# using a Llama2 70B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token <HF_TOKEN>
#
# This config needs 8 GPUs to run
# # tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/70B_qlora
#

# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_70b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32

tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-70b-hf/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files: [
pytorch_model-00001-of-00015.bin,
pytorch_model-00002-of-00015.bin,
pytorch_model-00003-of-00015.bin,
pytorch_model-00004-of-00015.bin,
pytorch_model-00005-of-00015.bin,
pytorch_model-00006-of-00015.bin,
pytorch_model-00007-of-00015.bin,
pytorch_model-00008-of-00015.bin,
pytorch_model-00009-of-00015.bin,
pytorch_model-00010-of-00015.bin,
pytorch_model-00011-of-00015.bin,
pytorch_model-00012-of-00015.bin,
pytorch_model-00013-of-00015.bin,
pytorch_model-00014-of-00015.bin,
pytorch_model-00015-of-00015.bin,
]
recipe_checkpoint: null
output_dir: /tmp/Llama-2-70b-hf
model_type: LLAMA2
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False

# Logging
output_dir: /tmp/qlora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
84 changes: 84 additions & 0 deletions recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Config for single device QLoRA with lora_finetune_fsdp2.py
# using a Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_fsdp2 --config llama2/7B_qlora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run lora_finetune_fsdp2 --config llama2/7B_qlora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_7b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin
]
adapter_checkpoint: null
recipe_checkpoint: null
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 16
compile: False

# Logging
output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
66 changes: 31 additions & 35 deletions recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
from warnings import warn

import torch
Expand All @@ -32,7 +32,7 @@
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -214,6 +214,7 @@ def setup(self, cfg: DictConfig) -> None:
if self._resume_from_checkpoint
else None
),
cfg_fsdp=cfg.fsdp if hasattr(cfg, "fsdp") else None,
)
self._tokenizer = config.instantiate(cfg.tokenizer)

Expand Down Expand Up @@ -265,6 +266,7 @@ def _setup_model(
enable_activation_checkpointing: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
cfg_fsdp: Optional[Union[DictConfig, None]] = None,
) -> nn.Module:
"""
Model initialization has some important considerations:
Expand All @@ -290,25 +292,6 @@ def _setup_model(
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

if self._is_rank_zero:
# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
Copy link
Contributor Author

@weifengpy weifengpy Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_state_dict_for_lora needs some refactor since FSDP2 use clean FQNs.
FSDP2 also check param on meta and throw Runtime error if any https://fburl.com/wrhxykn3

"FSDP parameters should be materialized from meta device before training, "
f"but the following were still on meta device: {param_names_on_meta}\n"
"For example, call module.to_empty(device) to materialize to device and "
"call module.reset_parameters() on each module to initialize values."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this is what I was alluding to in the previous PR. You can try validate_missing_and_unexpected_for_lora as in the single-device recipe now that we have the correct FQN. If it doesn't work out of the box feel free to leave this as-is, we can come back to refactor after.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_missing_and_unexpected_for_lora works out of the box. added it in this PR. finger acrossed!

lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

Expand All @@ -317,47 +300,58 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

fsdp_kwargs = {}
if cfg_fsdp and cfg_fsdp.cpu_offload:
from torch.distributed._composable.fsdp import CPUOffloadPolicy

fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# iterating from lowerer modules to higher
# eg grouping lora adapters before transformer block
for m in reversed(list(model.modules())):
if isinstance(m, nn.Linear) and m.weight.requires_grad:
fully_shard(m)
fully_shard(m, **fsdp_kwargs)
# TransformerDecoderLayer is wrapped by CheckpointWrapper
# when enable_activation_checkpointing
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
fully_shard(m)
fully_shard(m, **fsdp_kwargs)
else:
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)

fully_shard(model)
fully_shard(m, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

if lora_weights_state_dict:
utils.load_from_full_model_state_dict(
lora_missing, lora_unexpected = utils.load_from_full_model_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)
else:
lora_missing, lora_unexpected = None, None

with utils.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if cfg_fsdp and cfg_fsdp.cpu_offload else self._device
for m in model.modules():
if isinstance(m, LoRALinear) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=self._device)
m.lora_b.to_empty(device=self._device)
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()

utils.load_from_full_model_state_dict(
base_missing, base_unexpected = utils.load_from_full_model_state_dict(
model, base_model_state_dict, self._device, self._is_rank_zero
)

# LoRA hyper-params needed for merging weights while saving checkpoints
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha

validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

Expand Down Expand Up @@ -590,6 +584,8 @@ def train(self) -> None:
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
weifengpy marked this conversation as resolved.
Show resolved Hide resolved

loss = loss / self._gradient_accumulation_steps
running_loss += loss
Expand Down
45 changes: 43 additions & 2 deletions tests/torchtune/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from torch.distributed import launcher

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torchao.dtypes.nf4tensor import NF4Tensor
from torchtune import modules, utils
from torchtune.models.llama2._component_builders import llama2, lora_llama2
from torchtune.models.llama3._component_builders import llama3
Expand Down Expand Up @@ -237,7 +239,7 @@ def test_lora_meta_device_init_fsdp(self):
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
lora_rank=4,
lora_alpha=1.0,
lora_alpha=8,
)
utils.prepare_model_for_fsdp_with_meta_device(lora)
for m in lora.modules():
Expand All @@ -257,7 +259,7 @@ def world_size(self) -> int:
return 2

@gpu_test(gpu_count=2)
def test_state_dict(self):
def test_lora_state_dict(self):
rank = self.rank
is_rank_zero = rank == 0
mlp_dim = 4
Expand Down Expand Up @@ -393,6 +395,45 @@ def test_state_dict(self):
for key, value in sharded_model_sd.items():
self.assertEqual(value, expected_sharded_model_sd[key])

@gpu_test(gpu_count=2)
def test_qlora_state_dict(self):
is_rank_zero = self.rank == 0
torch.manual_seed(42)
kwargs = {
"lora_attn_modules": ["q_proj", "v_proj", "k_proj", "output_proj"],
"apply_lora_to_mlp": True,
"apply_lora_to_output": False,
"vocab_size": 1024,
"num_layers": 3,
"num_heads": 4,
"num_kv_heads": 2,
"embed_dim": 1024,
"max_seq_len": 64,
"lora_rank": 4,
"lora_alpha": 1.0,
}
with torch.device("cuda"):
lora_kwargs = dict({"quantize_base": False}, **kwargs)
model_lora = lora_llama2(**lora_kwargs)
full_sd = model_lora.cpu().state_dict()
with torch.device("meta"):
qlora_kwargs = dict({"quantize_base": True}, **kwargs)
model_qlora = lora_llama2(**qlora_kwargs)
set_trainable_params(model_qlora, get_adapter_params(model_qlora))
for m in model_qlora.modules():
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)
fully_shard(model_qlora)
utils.load_from_full_model_state_dict(
model_qlora, full_sd, "cuda", is_rank_zero
)
# LoRALinear base weights should be DTensor(NF4)
for name, module in model_qlora.named_modules():
if isinstance(module, LoRALinear):
self.assertTrue(isinstance(module.weight, DTensor))
self.assertTrue(isinstance(module.weight._local_tensor, NF4Tensor))
self.assertEqual(module.weight.device.type, "cuda")

def _broadcast_full_state_dict(self, full_sd):
result = []
if torch.distributed.get_rank() == 0:
Expand Down
Loading
Loading