Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable LoRA + FSDP2 #855

Merged
merged 58 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e5826a1
enable LoRA + FSDP2
weifengpy Apr 24, 2024
64fc870
reset params for lora weights and rope
weifengpy Apr 24, 2024
0cd21c6
support lora weights checkpoint and checkpoint utils
weifengpy Apr 24, 2024
589191e
fix lora meta device bug
weifengpy Apr 24, 2024
c801f26
save optim state dict
weifengpy Apr 25, 2024
19a2d70
mark TODO
weifengpy Apr 25, 2024
441da10
optimizer foreach=True for DTensor
weifengpy Apr 25, 2024
750b9e5
clip grad norm
weifengpy Apr 25, 2024
3d632d5
switch to ptd state dict api
weifengpy Apr 26, 2024
cb3abb3
add profiler
weifengpy May 1, 2024
e68804a
use torchao copy_
weifengpy May 1, 2024
d6af9a2
enable saving checkpoint
weifengpy May 1, 2024
b616394
optimizer state dict: load on rank0 and broadcast
weifengpy May 1, 2024
a400497
import Optimizer
weifengpy May 1, 2024
e9de63c
resume training
weifengpy May 3, 2024
05d3895
prepare for full test
weifengpy May 3, 2024
7a5bb80
prepare for full test
weifengpy May 3, 2024
64bf49c
remove profiler
weifengpy May 3, 2024
cb1bba4
passed integration test
weifengpy May 4, 2024
ac516e9
remove uncesssary change
weifengpy May 4, 2024
bfde704
Merge branch 'main' into fsdp2
weifengpy May 4, 2024
102db31
bring back state dict validation
weifengpy May 4, 2024
0b66651
align indent on comment
weifengpy May 4, 2024
672aabb
remove unused import
weifengpy May 4, 2024
6af2723
switch to ptd state dict and keep self implemented in record
weifengpy May 8, 2024
42ad99c
clean unused code
weifengpy May 8, 2024
74f6175
remove cuda value error
weifengpy May 8, 2024
f1b8a5e
comment on to_empty
weifengpy May 8, 2024
36e6829
fix memory issues by switching model state dict api
weifengpy May 8, 2024
08cd1fd
clean for review
weifengpy May 8, 2024
559bc4d
Merge branch 'main' into fsdp2
weifengpy May 8, 2024
2333134
fix linter
weifengpy May 9, 2024
49a0364
fix checkpoint loading
weifengpy May 9, 2024
dc2ce02
expecttest CI depedency
weifengpy May 9, 2024
0a604aa
ci depdencecy
weifengpy May 9, 2024
fa83140
fix CI issue
weifengpy May 10, 2024
4b5a895
Merge branch 'pytorch:main' into fsdp2
weifengpy May 10, 2024
a2e34ec
support resuming training
weifengpy May 14, 2024
6142031
update docstring
weifengpy May 14, 2024
7607e14
remove depdency on broadcast_from_rank0
weifengpy May 14, 2024
1899beb
remove the need for model.to(device)
weifengpy May 15, 2024
c1cfabb
wrap lora and TransformerBlock
weifengpy May 17, 2024
d7382ae
require torch version 2.4.0
weifengpy May 17, 2024
d1ff53b
FSDP(CheckpointWrapper(model))
weifengpy May 22, 2024
1eb9e87
remove model.to()
weifengpy May 29, 2024
695e959
add docstrings and remove depdency on dcp
weifengpy May 31, 2024
e10f638
remove try...catch FSDPModule
weifengpy Jun 1, 2024
b1e3d30
Merge branch 'main' into fsdp2
weifengpy Jun 1, 2024
944a723
fsdp2 as dev recipe
weifengpy Jun 1, 2024
ac5f7aa
restore lora_finetune_distributed
weifengpy Jun 1, 2024
d769626
test cudnn ci error
weifengpy Jun 2, 2024
f90c3cc
test CI error
weifengpy Jun 3, 2024
42ef49a
address CI error for setting seed
weifengpy Jun 3, 2024
170de94
add back pytest
weifengpy Jun 3, 2024
f8a7018
add expecttest
weifengpy Jun 3, 2024
a3b2f3e
pytest 7.4.0
weifengpy Jun 3, 2024
1a692b3
add dev/recipe
weifengpy Jun 3, 2024
8fbbc4b
update yaml with lora_finetune_fsdp2
weifengpy Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ tune = "torchtune._cli.tune:main"
dev = [
"bitsandbytes>=0.43.0",
"pre-commit",
"pytest",
"pytest==7.4.0",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from torch.testing._internal.common_utils import run_tests has a depdency on pytest==7.4.0 and expecttest, borrowed from pytorch repo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is run_tests strictly required for the usage of FSDPTest, or is it more used for convenience? (Either way not a huge issue)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's strictly required for the usage of FSDPTest

"pytest-cov",
"pytest-mock",
"pytest-integration",
"tensorboard",
"wandb",
"expecttest==0.1.6",
]

[tool.setuptools.dynamic]
Expand Down
127 changes: 62 additions & 65 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import (
get_optimizer_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_merged_lora_ckpt,
Expand Down Expand Up @@ -279,16 +281,13 @@ def _setup_model(
"""

if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
log.info("FSDP is enabled. Model init and checkpoint loading on Rank 0 ...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.info("FSDP is enabled. Model init and checkpoint loading on Rank 0 ...")
log.info("FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...")

init_start = time.perf_counter()

with utils.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)

log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
)
with utils.set_default_dtype(self._dtype), torch.device("meta"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not able to comment above, but the docstring of this function should be updated since we're no longer initializing on CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring used to be Instantiating Model on CPU (left) and I removed the mention of CPU. I did not mention meta device because it measures meta init + checkpoing loading now. Happy to improve if you are referring to this docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh. Just got you point. Updated docstring for _setup_model

model = config.instantiate(cfg_model)

if self._is_rank_zero:
# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
Expand All @@ -307,16 +306,36 @@ def _setup_model(
base_model_state_dict_keys=base_model_state_dict.keys(),
)

# Load both the base model weights and (if available) the adapter weights. Both
# of this should happen only on Rank 0
model.load_state_dict(base_model_state_dict, strict=False)
if lora_weights_state_dict:
model.load_state_dict(lora_weights_state_dict, strict=False)
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

else:
# For non-zero ranks, load the model on meta device
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(m, modules.TransformerDecoderLayer): is equivalent of auto_wrap_policy in FSDP1

for m in model.modules():
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)
fully_shard(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the noob question, but can you help me understand what's going on here? Why do I need to full_shard the TransformerDecoderLayer and then call fully_shard on the model?

An unrelated question: if I have enough GPU memory, should I be thinking about using something similar to SHARD_GRAD_OP with FSDP2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FSDP1, we wrap each TransformerDecoderLayer and then root model as well. It's blackboxed in auto_wrap_policy=utils.lora_fsdp_wrap_policy(modules_to_wrap={modules.TransformerDecoderLayer})

In FSDP2, we un-blackboxed it to this for-loop. It you perfer, this can be factored into a util function in torchtune so user call util.fully_shard(model, modules_to_wrap)

Personally I have bias towards un-blackboxed approach since people can modify the for-loop to achieve different wrapping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the equivalence SHARD_GRAD_OP in FSDP2 is reshard_after_forward=False . Do you want it as a config in .yaml?

fully_shard(model, reshard_after_forward=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation! I love the un-blackboxed approach here - just needs more documentation and explanation :) After reading the FSDP2 RFC, this became a lot clearer.


utils.load_from_full_model_state_dict(
model, base_model_state_dict, self._device, self._is_rank_zero
)
if lora_weights_state_dict:
utils.load_from_full_model_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pros and cons of meta init. pros is 4.5x speed up during model init and thus shorter TTFB. cons is user need to call initialize_parameters on LoRALinear explicitly to move them from meta to gpu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because these params are not being loaded from checkpoint? Or do I misunderstand?

If this is indeed the reason, how do we handle this code block when the LoRA params are being loaded from checkpoint (eg: when resuming training)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. when finetuning from a original HF checkpoint, lora_weights_state_dict = None

for resuming training, lora_weights_state_dict is not None and we avoided calling m.initialize_parameters() again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:

  • the modules on which we call fully_shard
  • init on meta device
  • calling initialize_parameters and reset_parameters

Also I think there was a technical reason with FSDP1 to call the function reset_parameters. Is that still true? Or can we standardize this with initialize_parameters in the modules code? Happy to chat about this offline!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! will add comment to explain fully_shard, meta init, and reset/initialize_parameters

FSDP1 calls reset_parameters for the exact same reason FSDP2 call reset/initialize_parameters: RoPE are not covered in checkpoints, lora_a and lora_b are not covered in checkpoints for resume_training=False

It's just FSDP1 have a contract to call overrided nn.Module.reset_parameter through FSDP(model, param_init=), but FSDP2 does not impose overriding reset_parameter. now use can name it reset_parameter or initialize_parameters

with utils.set_default_dtype(self._dtype), self._device:
for m in model.modules():
if isinstance(m, LoRALinear) and not lora_weights_state_dict:
m.lora_a.to_empty(device=self._device)
m.lora_b.to_empty(device=self._device)
m.initialize_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, we special handle RoPE because the buffer is not being loaded from a state dict, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately.

m.reset_parameters()
model = model.to(self._device)

if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
Expand All @@ -325,39 +344,13 @@ def _setup_model(
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha

# Note: this needs to be set before wrapping with FSDP
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

model = FSDP(
module=model,
auto_wrap_policy=utils.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerDecoderLayer}
),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if not self._is_rank_zero
else None
),
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
log.info(
f"Model init and checkpoint loading took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

Expand All @@ -371,12 +364,14 @@ def _setup_optimizer(
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
# Note: technically we should check _contains_fsdp for
# just the state dict of the adapter cfg, but should be equivalent
opt_state_dict = utils.transform_opt_state_dict(
opt_state_dict, self._model, optimizer
set_optimizer_state_dict(
self._model,
optimizer,
optim_state_dict=opt_state_dict,
options=StateDictOptions(
broadcast_from_rank0=True, full_state_dict=True
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help explain what this is doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for both FSDP1/FSDP2, parameters are shared into 1/GPUs. Each optimizer only see 1/GPUs portion of the parameters. We start with full optimizer state dict on rank0, and each rank should only load 1/GPUs of it.

In FSDP2, set_optimizer_state_dict serves the purpose. In FSDP1, utils.transform_opt_state_dict serves the purpose

)
optimizer.load_state_dict(opt_state_dict)

if self._is_rank_zero:
log.info("Optimizer and loss are initialized.")
Expand Down Expand Up @@ -461,17 +456,19 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
cpu_state_dict = utils.get_full_model_state_dict(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
if intermediate_checkpoint:
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
opt_state_dict = None
self._is_rank_zero,
)

if intermediate_checkpoint:
opt_state_dict = get_optimizer_state_dict(
self._model,
self._optimizer,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
Expand Down
10 changes: 10 additions & 0 deletions tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_loss_values_from_metric_logger,
gpu_test,
)
from torch.distributed.checkpoint.state_dict import StateDictOptions
from torchtune import config


Expand All @@ -51,6 +52,10 @@ def _fetch_expected_loss_values(self):

@pytest.mark.integration_test
@gpu_test(gpu_count=2)
@pytest.mark.skipif(
not hasattr(StateDictOptions, "broadcast_from_rank0"),
reason="need latest pytorch nightly",
)
def test_loss(self, tmpdir, monkeypatch):
ckpt = "small_test_ckpt_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
Expand Down Expand Up @@ -87,6 +92,7 @@ def test_loss(self, tmpdir, monkeypatch):

@pytest.mark.integration_test
@gpu_test(gpu_count=2)
@pytest.mark.skipif(True, reason="resolve FSDP2 optimizer state dict and enable")
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
def test_training_state_on_resume(self, tmpdir, monkeypatch):
"""Test whether the recipe state is correctly updated on resume. Since this
is model agnostic, we should run this on the small model only. The test
Expand Down Expand Up @@ -161,6 +167,10 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):

@pytest.mark.integration_test
@gpu_test(gpu_count=2)
@pytest.mark.skipif(
not hasattr(StateDictOptions, "broadcast_from_rank0"),
reason="need latest pytorch nightly",
)
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
ckpt = "small_test_ckpt_tune"

Expand Down
Loading
Loading