-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable QLoRA + FSDP2 #909
enable QLoRA + FSDP2 #909
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
use torchao copy_
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
enable saving checkpoint
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/909
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cbb3da8 with merge base 71741df (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@@ -17,26 +17,3 @@ def clone(func, *args, **kwargs): | |||
in precision. | |||
""" | |||
return to_nf4(args[0][0].get_original_weight()) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
starting from TorchAO==0.2.0, we implemented NF4.copy_
. It's the superset of inplace_copy
. We cover both NF4.copy_(bf16)
and NF4.copy_(NF4)
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
_component_: torch.nn.CrossEntropyLoss | ||
|
||
fsdp: | ||
cpu_offload: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comparing with 7B_qlora_single_device.yaml
, cpu_offload
is the new feature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking you probably do not need to put cpu_offload
config under fsdp
(unless you think it helps for clarity, or anticipate other kwargs that we'd wanna pass through to fully_shard
). No strong preference either way here though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know Less has been working on cpu_offload
for activations, which might be a root config. It will be complete different thing than cpu_offload
for parameters in fsdp2
another things is, other fully_shard
kwargs like reshard_after_forward=False
can boost QPS
Does it make sense to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah this is a good point, makes sense to me. Thanks for clarifying
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN> | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config llama2/7B_qlora_single_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will replace lora_finetune_single_device
to lora_finetune_fsdp2
before landing. saving CI to publish PR earlier now
# Before loading the state dict, ensure the state dict keys for the base | ||
# model and adapters (if available) match the keys in the full LoRA model | ||
# This is a good sanity check to prevent silent errors | ||
validate_state_dict_for_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validate_state_dict_for_lora
needs some refactor since FSDP2 use clean FQNs.
FSDP2 also check param on meta and throw Runtime error if any https://fburl.com/wrhxykn3
"FSDP parameters should be materialized from meta device before training, "
f"but the following were still on meta device: {param_names_on_meta}\n"
"For example, call module.to_empty(device) to materialize to device and "
"call module.reset_parameters() on each module to initialize values."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah this is what I was alluding to in the previous PR. You can try validate_missing_and_unexpected_for_lora as in the single-device recipe now that we have the correct FQN. If it doesn't work out of the box feel free to leave this as-is, we can come back to refactor after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validate_missing_and_unexpected_for_lora
works out of the box. added it in this PR. finger acrossed!
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
# | ||
|
||
# Model Arguments | ||
model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comparing with 70B_lora.yaml
, added output_proj
, apply_lora_to_mlp: True
the difference is similar to 7B_lora.yaml
vs 7B_qlora_single_device.yaml
@@ -0,0 +1,91 @@ | |||
# Config for single device QLoRA with lora_finetune_single_device.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename this config as 7B_qlora_fsdp2.yaml
to match the LoRA one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! updating now
# Before loading the state dict, ensure the state dict keys for the base | ||
# model and adapters (if available) match the keys in the full LoRA model | ||
# This is a good sanity check to prevent silent errors | ||
validate_state_dict_for_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah this is what I was alluding to in the previous PR. You can try validate_missing_and_unexpected_for_lora as in the single-device recipe now that we have the correct FQN. If it doesn't work out of the box feel free to leave this as-is, we can come back to refactor after.
_component_: torch.nn.CrossEntropyLoss | ||
|
||
fsdp: | ||
cpu_offload: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking you probably do not need to put cpu_offload
config under fsdp
(unless you think it helps for clarity, or anticipate other kwargs that we'd wanna pass through to fully_shard
). No strong preference either way here though
) | ||
else: | ||
sharded_tensor = distribute_tensor( | ||
full_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer need to convert dtype as before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's moved to line 284 full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
. The extra .to(device)
is for NF4 since quantization on cpu is prohibitively slow
if isinstance(sharded_meta_param._local_tensor, NF4Tensor): | ||
full_tensor = to_nf4(full_tensor) | ||
# replicating logic from `_fsdp_param.py`` `_init_sharded_param` | ||
# otherwise `distribute_tensor(DTensor(local=NF4))` | ||
# requires dispatching `c10d.scatter_`` | ||
# long-term solution is `swap_tensor` | ||
mesh = sharded_meta_param.device_mesh | ||
if mesh.ndim > 1: | ||
raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}") | ||
shard_mesh_dim = 0 | ||
shard_world_size = mesh.size(shard_mesh_dim) | ||
shard_rank = cast( | ||
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) | ||
).rank() | ||
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank] | ||
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
sharded_tensor = DTensor( | ||
sharded_param, | ||
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
shape=sharded_meta_param.size(), | ||
dtype=sharded_meta_param.dtype, | ||
requires_grad=sharded_meta_param.requires_grad, | ||
stride=sharded_meta_param.stride(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's not too much work, is it possible to extend the unit test you added in #855 to cover this case as well? As is I find it a bit hard to follow and want to make sure we have a reliable sanity check in case anything breaks in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right. will try to come up with a unit test to cover NF4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers I added pytest -s tests/torchtune/utils/test_distributed.py -k test_qlora_state_dict
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good! As discussed offline we can separate out the checkpoint save support into a follow-up PR to unblock other ongoing work. Once CI is green this one should be good to land. Thank you for enabling a huge huge feature for our users!
merged after confirming CI's all green. Will follow up on saving state dict |
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com> Co-authored-by: ebsmothers <ebs@meta.com> Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com> Co-authored-by: Joe Cummings <jrcummings27@gmail.com> Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com> Co-authored-by: Rohan Varma <rvarm1@fb.com> Co-authored-by: Optimox <sebastien.fischman@gmail.com> Co-authored-by: Tanish Ambulkar <tanish.ambulkar99@gmail.com> Co-authored-by: Botao Chen <markchen1015@meta.com> Co-authored-by: solitude-alive <44771751+solitude-alive@users.noreply.github.com> Co-authored-by: christobill <christobill@users.noreply.github.com> Co-authored-by: Philip Bontrager <pbontrager@gmail.com> Co-authored-by: Evan Smothers <ebs@fb.com>
this PR is built on top of
unit test:
pytest -s tests/torchtune/utils/test_distributed.py -k test_qlora_state_dict
attaching snapshot on 2 runs
tune run --nnodes 1 --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/7B_qlora
tune run lora_finetune_single_device --config llama2/7B_qlora_single_device
numerics
tokens_per_second_per_gpu: 1st is A100, 2nd is H100. H100 scales better with faster memory read/write
memory