Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable LoRA + FSDP2 #855

Merged
merged 58 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e5826a1
enable LoRA + FSDP2
weifengpy Apr 24, 2024
64fc870
reset params for lora weights and rope
weifengpy Apr 24, 2024
0cd21c6
support lora weights checkpoint and checkpoint utils
weifengpy Apr 24, 2024
589191e
fix lora meta device bug
weifengpy Apr 24, 2024
c801f26
save optim state dict
weifengpy Apr 25, 2024
19a2d70
mark TODO
weifengpy Apr 25, 2024
441da10
optimizer foreach=True for DTensor
weifengpy Apr 25, 2024
750b9e5
clip grad norm
weifengpy Apr 25, 2024
3d632d5
switch to ptd state dict api
weifengpy Apr 26, 2024
cb3abb3
add profiler
weifengpy May 1, 2024
e68804a
use torchao copy_
weifengpy May 1, 2024
d6af9a2
enable saving checkpoint
weifengpy May 1, 2024
b616394
optimizer state dict: load on rank0 and broadcast
weifengpy May 1, 2024
a400497
import Optimizer
weifengpy May 1, 2024
e9de63c
resume training
weifengpy May 3, 2024
05d3895
prepare for full test
weifengpy May 3, 2024
7a5bb80
prepare for full test
weifengpy May 3, 2024
64bf49c
remove profiler
weifengpy May 3, 2024
cb1bba4
passed integration test
weifengpy May 4, 2024
ac516e9
remove uncesssary change
weifengpy May 4, 2024
bfde704
Merge branch 'main' into fsdp2
weifengpy May 4, 2024
102db31
bring back state dict validation
weifengpy May 4, 2024
0b66651
align indent on comment
weifengpy May 4, 2024
672aabb
remove unused import
weifengpy May 4, 2024
6af2723
switch to ptd state dict and keep self implemented in record
weifengpy May 8, 2024
42ad99c
clean unused code
weifengpy May 8, 2024
74f6175
remove cuda value error
weifengpy May 8, 2024
f1b8a5e
comment on to_empty
weifengpy May 8, 2024
36e6829
fix memory issues by switching model state dict api
weifengpy May 8, 2024
08cd1fd
clean for review
weifengpy May 8, 2024
559bc4d
Merge branch 'main' into fsdp2
weifengpy May 8, 2024
2333134
fix linter
weifengpy May 9, 2024
49a0364
fix checkpoint loading
weifengpy May 9, 2024
dc2ce02
expecttest CI depedency
weifengpy May 9, 2024
0a604aa
ci depdencecy
weifengpy May 9, 2024
fa83140
fix CI issue
weifengpy May 10, 2024
4b5a895
Merge branch 'pytorch:main' into fsdp2
weifengpy May 10, 2024
a2e34ec
support resuming training
weifengpy May 14, 2024
6142031
update docstring
weifengpy May 14, 2024
7607e14
remove depdency on broadcast_from_rank0
weifengpy May 14, 2024
1899beb
remove the need for model.to(device)
weifengpy May 15, 2024
c1cfabb
wrap lora and TransformerBlock
weifengpy May 17, 2024
d7382ae
require torch version 2.4.0
weifengpy May 17, 2024
d1ff53b
FSDP(CheckpointWrapper(model))
weifengpy May 22, 2024
1eb9e87
remove model.to()
weifengpy May 29, 2024
695e959
add docstrings and remove depdency on dcp
weifengpy May 31, 2024
e10f638
remove try...catch FSDPModule
weifengpy Jun 1, 2024
b1e3d30
Merge branch 'main' into fsdp2
weifengpy Jun 1, 2024
944a723
fsdp2 as dev recipe
weifengpy Jun 1, 2024
ac5f7aa
restore lora_finetune_distributed
weifengpy Jun 1, 2024
d769626
test cudnn ci error
weifengpy Jun 2, 2024
f90c3cc
test CI error
weifengpy Jun 3, 2024
42ef49a
address CI error for setting seed
weifengpy Jun 3, 2024
170de94
add back pytest
weifengpy Jun 3, 2024
f8a7018
add expecttest
weifengpy Jun 3, 2024
a3b2f3e
pytest 7.4.0
weifengpy Jun 3, 2024
1a692b3
add dev/recipe
weifengpy Jun 3, 2024
8fbbc4b
update yaml with lora_finetune_fsdp2
weifengpy Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 94 additions & 69 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
awgu marked this conversation as resolved.
Show resolved Hide resolved

import itertools
import os
import sys
import time
Expand All @@ -17,20 +18,16 @@

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.distributed._composable.fsdp import FSDP, fully_shard
from torch.distributed._tensor import distribute_tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
# validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -277,45 +274,98 @@ def _setup_model(
the correct device.
"""

if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
init_start = time.perf_counter()
if self._device.type != "cuda":
raise ValueError(
f'FSDP needs device="cuda" but found device={self._device.type}'
)

with utils.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)
with utils.set_default_dtype(self._dtype), torch.device("meta"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not able to comment above, but the docstring of this function should be updated since we're no longer initializing on CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring used to be Instantiating Model on CPU (left) and I removed the mention of CPU. I did not mention meta device because it measures meta init + checkpoing loading now. Happy to improve if you are referring to this docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh. Just got you point. Updated docstring for _setup_model

model = config.instantiate(cfg_model)

log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(m, modules.TransformerDecoderLayer): is equivalent of auto_wrap_policy in FSDP1

for m in model.modules():
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)
fully_shard(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the noob question, but can you help me understand what's going on here? Why do I need to full_shard the TransformerDecoderLayer and then call fully_shard on the model?

An unrelated question: if I have enough GPU memory, should I be thinking about using something similar to SHARD_GRAD_OP with FSDP2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FSDP1, we wrap each TransformerDecoderLayer and then root model as well. It's blackboxed in auto_wrap_policy=utils.lora_fsdp_wrap_policy(modules_to_wrap={modules.TransformerDecoderLayer})

In FSDP2, we un-blackboxed it to this for-loop. It you perfer, this can be factored into a util function in torchtune so user call util.fully_shard(model, modules_to_wrap)

Personally I have bias towards un-blackboxed approach since people can modify the for-loop to achieve different wrapping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the equivalence SHARD_GRAD_OP in FSDP2 is reshard_after_forward=False . Do you want it as a config in .yaml?

fully_shard(model, reshard_after_forward=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation! I love the un-blackboxed approach here - just needs more documentation and explanation :) After reading the FSDP2 RFC, this became a lot clearer.

meta_sharded_sd = model.state_dict()
sharded_sd = {}

if self._is_rank_zero:

# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)
# validate_state_dict_for_lora(
# lora_attn_modules=cfg_model.lora_attn_modules,
# apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
# apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
# full_model_state_dict_keys=model.state_dict().keys(),
# lora_state_dict_keys=(
# lora_weights_state_dict.keys()
# if lora_weights_state_dict is not None
# else None
# ),
# base_model_state_dict_keys=base_model_state_dict.keys(),
# )

# Load both the base model weights and (if available) the adapter weights. Both
# of this should happen only on Rank 0
model.load_state_dict(base_model_state_dict, strict=False)
if lora_weights_state_dict:
model.load_state_dict(lora_weights_state_dict, strict=False)

log.info("FSDP is enabled. Loading checkpoints for Rank 0 ...")
init_start = time.perf_counter()
for param_name, full_param in base_model_state_dict.items():
awgu marked this conversation as resolved.
Show resolved Hide resolved
sharded_meta_param = meta_sharded_sd.get(param_name)
full_param = full_param.detach().to(self._device)
mesh = sharded_meta_param.device_mesh
torch.distributed.broadcast(full_param, src=0, group=mesh.get_group(0))
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_meta_param.placements
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
log.info(
f"Loading checkpoints took {time.perf_counter() - init_start:.2f} secs"
)
# TODO: lora_weights_state_dict
# if lora_weights_state_dict:
# model.load_state_dict(lora_weights_state_dict, strict=False)
else:
# For non-zero ranks, load the model on meta device
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
for param_name, full_param in base_model_state_dict.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = torch.empty(
sharded_meta_param.size(),
device=self._device,
dtype=sharded_meta_param.dtype,
)
mesh = sharded_meta_param.device_mesh
torch.distributed.broadcast(full_tensor, src=0, group=mesh.get_group(0))
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_meta_param.placements
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)

model.load_state_dict(sharded_sd, strict=False, assign=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pros and cons of meta init. pros is 4.5x speed up during model init and thus shorter TTFB. cons is user need to call initialize_parameters on LoRALinear explicitly to move them from meta to gpu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because these params are not being loaded from checkpoint? Or do I misunderstand?

If this is indeed the reason, how do we handle this code block when the LoRA params are being loaded from checkpoint (eg: when resuming training)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. when finetuning from a original HF checkpoint, lora_weights_state_dict = None

for resuming training, lora_weights_state_dict is not None and we avoided calling m.initialize_parameters() again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:

  • the modules on which we call fully_shard
  • init on meta device
  • calling initialize_parameters and reset_parameters

Also I think there was a technical reason with FSDP1 to call the function reset_parameters. Is that still true? Or can we standardize this with initialize_parameters in the modules code? Happy to chat about this offline!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! will add comment to explain fully_shard, meta init, and reset/initialize_parameters

FSDP1 calls reset_parameters for the exact same reason FSDP2 call reset/initialize_parameters: RoPE are not covered in checkpoints, lora_a and lora_b are not covered in checkpoints for resume_training=False

It's just FSDP1 have a contract to call overrided nn.Module.reset_parameter through FSDP(model, param_init=), but FSDP2 does not impose overriding reset_parameter. now use can name it reset_parameter or initialize_parameters

# LoRALinear are sitll on meta if lora_weights_state_dict = False
# RotaryPositionalEmbeddings.theta is buffer and does not exists in checkpoints
for m in model.modules():
if isinstance(m, FSDP):
continue
param_is_meta = [
x.is_meta
for x in itertools.chain(
m.parameters(recurse=False), m.buffers(recurse=False)
)
]
if len(param_is_meta) > 0 and all(param_is_meta):
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
m.to_empty(device=self._device)
if not hasattr(m, "reset_parameters"):
raise ValueError(f"Need to implement reset_parameters in {m}")
m.reset_parameters()

if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
Expand All @@ -328,34 +378,9 @@ def _setup_model(
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

model = FSDP(
module=model,
auto_wrap_policy=utils.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerDecoderLayer}
),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if not self._is_rank_zero
else None
),
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -451,22 +476,22 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
if intermediate_checkpoint:
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved

sharded_sd = self._model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if self._is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
opt_state_dict = None
del full_param

# TODO: implement optim state dict
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
Expand Down
3 changes: 3 additions & 0 deletions torchtune/modules/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def initialize_parameters(self):
_lora_a_init_params(self.lora_a)
_lora_b_init_params(self.lora_b)

def reset_parameters(self):
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
self.initialize_parameters()

def _create_weight_and_bias(self):
"""
Creates a linear weight and bias tensor, using NF4 dtype if we're quantizing
Expand Down
11 changes: 6 additions & 5 deletions torchtune/modules/position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,21 @@ def __init__(
self.dim = dim
self.base = base
self.max_seq_len = max_seq_len
self._rope_init()
theta = self.get_theta()
self.register_buffer("theta", theta, persistent=False)
self.build_rope_cache(self.max_seq_len)

# We need to explicitly define reset_parameters for FSDP initialization, see
# https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885
def reset_parameters(self):
self._rope_init()
awgu marked this conversation as resolved.
Show resolved Hide resolved
self.theta.copy_(self.get_theta())

def _rope_init(self):
def get_theta(self):
theta = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
)
self.register_buffer("theta", theta, persistent=False)
self.build_rope_cache(self.max_seq_len)
return theta

def build_rope_cache(self, max_seq_len: int = 4096) -> None:
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
Expand Down
Loading