-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable LoRA + FSDP2 #855
enable LoRA + FSDP2 #855
Changes from 37 commits
e5826a1
64fc870
0cd21c6
589191e
c801f26
19a2d70
441da10
750b9e5
3d632d5
cb3abb3
e68804a
d6af9a2
b616394
a400497
e9de63c
05d3895
7a5bb80
64bf49c
cb1bba4
ac516e9
bfde704
102db31
0b66651
672aabb
6af2723
42ad99c
74f6175
f1b8a5e
36e6829
08cd1fd
559bc4d
2333134
49a0364
dc2ce02
0a604aa
fa83140
4b5a895
a2e34ec
6142031
7607e14
1899beb
c1cfabb
d7382ae
d1ff53b
1eb9e87
695e959
e10f638
b1e3d30
944a723
ac5f7aa
d769626
f90c3cc
42ef49a
170de94
f8a7018
a3b2f3e
1a692b3
8fbbc4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,16 +17,18 @@ | |||||
|
||||||
from torch import nn | ||||||
from torch.distributed import destroy_process_group, init_process_group | ||||||
from torch.distributed.fsdp import ( | ||||||
FullOptimStateDictConfig, | ||||||
FullStateDictConfig, | ||||||
FullyShardedDataParallel as FSDP, | ||||||
StateDictType, | ||||||
from torch.distributed._composable.fsdp import fully_shard | ||||||
from torch.distributed.checkpoint.state_dict import ( | ||||||
get_optimizer_state_dict, | ||||||
set_optimizer_state_dict, | ||||||
StateDictOptions, | ||||||
) | ||||||
|
||||||
from torch.optim import Optimizer | ||||||
from torch.utils.data import DataLoader, DistributedSampler | ||||||
from torchtune import config, modules, utils | ||||||
from torchtune.datasets import ConcatDataset | ||||||
from torchtune.modules.peft import LoRALinear | ||||||
from torchtune.modules.peft.peft_utils import ( | ||||||
get_adapter_params, | ||||||
get_merged_lora_ckpt, | ||||||
|
@@ -279,16 +281,13 @@ def _setup_model( | |||||
""" | ||||||
|
||||||
if self._is_rank_zero: | ||||||
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") | ||||||
log.info("FSDP is enabled. Model init and checkpoint loading on Rank 0 ...") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
init_start = time.perf_counter() | ||||||
|
||||||
with utils.set_default_dtype(self._dtype): | ||||||
model = config.instantiate(cfg_model) | ||||||
|
||||||
log.info( | ||||||
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" | ||||||
) | ||||||
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry not able to comment above, but the docstring of this function should be updated since we're no longer initializing on CPU? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the docstring used to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh. Just got you point. Updated docstring for |
||||||
model = config.instantiate(cfg_model) | ||||||
|
||||||
if self._is_rank_zero: | ||||||
# The model contains LoRA params which won't have any matching keys in | ||||||
# the state dict. As a result, we need to load with strict=False. | ||||||
# Before loading the state dict, ensure the state dict keys for the base | ||||||
|
@@ -307,16 +306,36 @@ def _setup_model( | |||||
base_model_state_dict_keys=base_model_state_dict.keys(), | ||||||
) | ||||||
|
||||||
# Load both the base model weights and (if available) the adapter weights. Both | ||||||
# of this should happen only on Rank 0 | ||||||
model.load_state_dict(base_model_state_dict, strict=False) | ||||||
if lora_weights_state_dict: | ||||||
model.load_state_dict(lora_weights_state_dict, strict=False) | ||||||
self.adapter_params = get_adapter_params(model) | ||||||
set_trainable_params(model, self.adapter_params) | ||||||
|
||||||
else: | ||||||
# For non-zero ranks, load the model on meta device | ||||||
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||||||
model = config.instantiate(cfg_model) | ||||||
if enable_activation_checkpointing: | ||||||
utils.set_activation_checkpointing( | ||||||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||||||
) | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
for m in model.modules(): | ||||||
if isinstance(m, modules.TransformerDecoderLayer): | ||||||
fully_shard(m) | ||||||
fully_shard(model) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the noob question, but can you help me understand what's going on here? Why do I need to An unrelated question: if I have enough GPU memory, should I be thinking about using something similar to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In FSDP1, we wrap each In FSDP2, we un-blackboxed it to this for-loop. It you perfer, this can be factored into a util function in torchtune so user call Personally I have bias towards un-blackboxed approach since people can modify the for-loop to achieve different wrapping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the equivalence SHARD_GRAD_OP in FSDP2 is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the explanation! I love the un-blackboxed approach here - just needs more documentation and explanation :) After reading the FSDP2 RFC, this became a lot clearer. |
||||||
|
||||||
utils.load_from_full_model_state_dict( | ||||||
model, base_model_state_dict, self._device, self._is_rank_zero | ||||||
) | ||||||
if lora_weights_state_dict: | ||||||
utils.load_from_full_model_state_dict( | ||||||
model, lora_weights_state_dict, self._device, self._is_rank_zero | ||||||
) | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pros and cons of meta init. pros is 4.5x speed up during model init and thus shorter TTFB. cons is user need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this because these params are not being loaded from checkpoint? Or do I misunderstand? If this is indeed the reason, how do we handle this code block when the LoRA params are being loaded from checkpoint (eg: when resuming training)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right. when finetuning from a original HF checkpoint, for resuming training, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:
Also I think there was a technical reason with FSDP1 to call the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point! will add comment to explain FSDP1 calls It's just FSDP1 have a contract to call overrided |
||||||
with utils.set_default_dtype(self._dtype), self._device: | ||||||
for m in model.modules(): | ||||||
if isinstance(m, LoRALinear) and not lora_weights_state_dict: | ||||||
m.lora_a.to_empty(device=self._device) | ||||||
m.lora_b.to_empty(device=self._device) | ||||||
m.initialize_parameters() | ||||||
if isinstance(m, modules.RotaryPositionalEmbeddings): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to clarify, we special handle RoPE because the buffer is not being loaded from a state dict, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately. |
||||||
m.reset_parameters() | ||||||
model = model.to(self._device) | ||||||
|
||||||
if self._dtype == torch.bfloat16: | ||||||
model = model.to(torch.bfloat16) | ||||||
|
@@ -325,39 +344,13 @@ def _setup_model( | |||||
self._lora_rank = cfg_model.lora_rank | ||||||
self._lora_alpha = cfg_model.lora_alpha | ||||||
|
||||||
# Note: this needs to be set before wrapping with FSDP | ||||||
self.adapter_params = get_adapter_params(model) | ||||||
set_trainable_params(model, self.adapter_params) | ||||||
|
||||||
model = FSDP( | ||||||
module=model, | ||||||
auto_wrap_policy=utils.lora_fsdp_wrap_policy( | ||||||
modules_to_wrap={modules.TransformerDecoderLayer} | ||||||
), | ||||||
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, | ||||||
device_id=self._device, | ||||||
# this recipe does not currently support mixed precision training | ||||||
mixed_precision=None, | ||||||
# Ensure we broadcast params and buffers from rank 0 | ||||||
sync_module_states=True, | ||||||
# Initialize empty modules on all non-zero ranks | ||||||
param_init_fn=( | ||||||
lambda module: module.to_empty( | ||||||
device=torch.device("cuda"), recurse=False | ||||||
) | ||||||
if not self._is_rank_zero | ||||||
else None | ||||||
), | ||||||
) | ||||||
|
||||||
# Ensure no params and buffers are on meta device | ||||||
utils.validate_no_params_on_meta_device(model) | ||||||
|
||||||
if enable_activation_checkpointing: | ||||||
utils.set_activation_checkpointing( | ||||||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||||||
) | ||||||
if self._is_rank_zero: | ||||||
log.info( | ||||||
f"Model init and checkpoint loading took {time.perf_counter() - init_start:.2f} secs" | ||||||
) | ||||||
memory_stats = utils.get_memory_stats(device=self._device) | ||||||
utils.log_memory_stats(memory_stats) | ||||||
|
||||||
|
@@ -371,12 +364,14 @@ def _setup_optimizer( | |||||
) -> Optimizer: | ||||||
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) | ||||||
if opt_state_dict: | ||||||
# Note: technically we should check _contains_fsdp for | ||||||
# just the state dict of the adapter cfg, but should be equivalent | ||||||
opt_state_dict = utils.transform_opt_state_dict( | ||||||
opt_state_dict, self._model, optimizer | ||||||
set_optimizer_state_dict( | ||||||
self._model, | ||||||
optimizer, | ||||||
optim_state_dict=opt_state_dict, | ||||||
options=StateDictOptions( | ||||||
broadcast_from_rank0=True, full_state_dict=True | ||||||
), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you help explain what this is doing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for both FSDP1/FSDP2, parameters are shared into 1/GPUs. Each optimizer only see 1/GPUs portion of the parameters. We start with full optimizer state dict on rank0, and each rank should only load 1/GPUs of it. In FSDP2, |
||||||
) | ||||||
optimizer.load_state_dict(opt_state_dict) | ||||||
|
||||||
if self._is_rank_zero: | ||||||
log.info("Optimizer and loss are initialized.") | ||||||
|
@@ -461,17 +456,19 @@ def save_checkpoint( | |||||
intermediate_checkpoint = epoch + 1 < self.total_epochs | ||||||
# To prevent GPU memory from spiking during checkpoint save, | ||||||
# we consolidate the full model and optim state dicts on CPU for rank 0 | ||||||
with FSDP.state_dict_type( | ||||||
cpu_state_dict = utils.get_full_model_state_dict( | ||||||
self._model, | ||||||
StateDictType.FULL_STATE_DICT, | ||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | ||||||
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | ||||||
): | ||||||
cpu_state_dict = self._model.state_dict() | ||||||
if intermediate_checkpoint: | ||||||
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) | ||||||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
else: | ||||||
opt_state_dict = None | ||||||
self._is_rank_zero, | ||||||
) | ||||||
|
||||||
if intermediate_checkpoint: | ||||||
opt_state_dict = get_optimizer_state_dict( | ||||||
self._model, | ||||||
self._optimizer, | ||||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True), | ||||||
) | ||||||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
else: | ||||||
opt_state_dict = None | ||||||
|
||||||
# Now that we have the model and opt state dict, create the actual checkpoint dict | ||||||
# to be sent to the checkpointer and ultimately written to file | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from torch.testing._internal.common_utils import run_tests
has a depdency onpytest==7.4.0
andexpecttest
, borrowed from pytorch repoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
run_tests
strictly required for the usage ofFSDPTest
, or is it more used for convenience? (Either way not a huge issue)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's strictly required for the usage of
FSDPTest