Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable QLoRA + FSDP2 #909

Merged
merged 100 commits into from
Jun 5, 2024
Merged

enable QLoRA + FSDP2 #909

merged 100 commits into from
Jun 5, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented May 1, 2024

this PR is built on top of

  • TorchAO nightly that contains NF4Tensor FSDP2 ops PR1 PR2
  • Pytorch nightly that contains meta init + cpu offloading PR

unit test: pytest -s tests/torchtune/utils/test_distributed.py -k test_qlora_state_dict

attaching snapshot on 2 runs

  • QLoRA + FSDP2 on 8 GPUs: tune run --nnodes 1 --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/7B_qlora
  • QLoRA on 1 GPU: tune run lora_finetune_single_device --config llama2/7B_qlora_single_device

numerics
Screenshot 2024-06-04 at 3 16 52 PM

tokens_per_second_per_gpu: 1st is A100, 2nd is H100. H100 scales better with faster memory read/write
Screenshot 2024-06-04 at 3 10 20 PM

Screenshot 2024-06-04 at 3 30 55 PM

memory
Screenshot 2024-06-04 at 3 10 23 PM

weifengpy and others added 15 commits April 23, 2024 17:45
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
enable saving checkpoint
Copy link

pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/909

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cbb3da8 with merge base 71741df (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2024
@weifengpy weifengpy marked this pull request as draft May 1, 2024 06:45
weifengpy and others added 11 commits May 1, 2024 00:33
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@@ -17,26 +17,3 @@ def clone(func, *args, **kwargs):
in precision.
"""
return to_nf4(args[0][0].get_original_weight())


Copy link
Contributor Author

@weifengpy weifengpy Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starting from TorchAO==0.2.0, we implemented NF4.copy_. It's the superset of inplace_copy. We cover both NF4.copy_(bf16) and NF4.copy_(NF4)

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comparing with 7B_qlora_single_device.yaml, cpu_offload is the new feature

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking you probably do not need to put cpu_offload config under fsdp (unless you think it helps for clarity, or anticipate other kwargs that we'd wanna pass through to fully_shard). No strong preference either way here though

Copy link
Contributor Author

@weifengpy weifengpy Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know Less has been working on cpu_offload for activations, which might be a root config. It will be complete different thing than cpu_offload for parameters in fsdp2

another things is, other fully_shard kwargs like reshard_after_forward=False can boost QPS

Does it make sense to you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah this is a good point, makes sense to me. Thanks for clarifying

# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_single_device --config llama2/7B_qlora_single_device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will replace lora_finetune_single_device to lora_finetune_fsdp2 before landing. saving CI to publish PR earlier now

# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
Copy link
Contributor Author

@weifengpy weifengpy Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_state_dict_for_lora needs some refactor since FSDP2 use clean FQNs.
FSDP2 also check param on meta and throw Runtime error if any https://fburl.com/wrhxykn3

"FSDP parameters should be materialized from meta device before training, "
f"but the following were still on meta device: {param_names_on_meta}\n"
"For example, call module.to_empty(device) to materialize to device and "
"call module.reset_parameters() on each module to initialize values."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this is what I was alluding to in the previous PR. You can try validate_missing_and_unexpected_for_lora as in the single-device recipe now that we have the correct FQN. If it doesn't work out of the box feel free to leave this as-is, we can come back to refactor after.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_missing_and_unexpected_for_lora works out of the box. added it in this PR. finger acrossed!

@weifengpy weifengpy marked this pull request as ready for review June 4, 2024 22:32
@weifengpy weifengpy requested a review from ebsmothers June 4, 2024 22:33
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
#

# Model Arguments
model:
Copy link
Contributor Author

@weifengpy weifengpy Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comparing with 70B_lora.yaml, added output_proj, apply_lora_to_mlp: True

the difference is similar to 7B_lora.yaml vs 7B_qlora_single_device.yaml

@@ -0,0 +1,91 @@
# Config for single device QLoRA with lora_finetune_single_device.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename this config as 7B_qlora_fsdp2.yaml to match the LoRA one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! updating now

# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this is what I was alluding to in the previous PR. You can try validate_missing_and_unexpected_for_lora as in the single-device recipe now that we have the correct FQN. If it doesn't work out of the box feel free to leave this as-is, we can come back to refactor after.

recipes/dev/lora_finetune_fsdp2.py Show resolved Hide resolved
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking you probably do not need to put cpu_offload config under fsdp (unless you think it helps for clarity, or anticipate other kwargs that we'd wanna pass through to fully_shard). No strong preference either way here though

)
else:
sharded_tensor = distribute_tensor(
full_tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer need to convert dtype as before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's moved to line 284 full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device). The extra .to(device) is for NF4 since quantization on cpu is prohibitively slow

Comment on lines +285 to +310
if isinstance(sharded_meta_param._local_tensor, NF4Tensor):
full_tensor = to_nf4(full_tensor)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}")
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)
sharded_tensor = DTensor(
sharded_param,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
requires_grad=sharded_meta_param.requires_grad,
stride=sharded_meta_param.stride(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not too much work, is it possible to extend the unit test you added in #855 to cover this case as well? As is I find it a bit hard to follow and want to make sure we have a reliable sanity check in case anything breaks in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. will try to come up with a unit test to cover NF4

Copy link
Contributor Author

@weifengpy weifengpy Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers I added pytest -s tests/torchtune/utils/test_distributed.py -k test_qlora_state_dict

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy changed the title [WIP] enable QLoRA + FSDP2 enable QLoRA + FSDP2 Jun 5, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good! As discussed offline we can separate out the checkpoint save support into a follow-up PR to unblock other ongoing work. Once CI is green this one should be good to land. Thank you for enabling a huge huge feature for our users!

@weifengpy weifengpy merged commit f9cb9e6 into pytorch:main Jun 5, 2024
29 checks passed
@weifengpy
Copy link
Contributor Author

merged after confirming CI's all green. Will follow up on saving state dict

maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
Co-authored-by: Rohan Varma <rvarm1@fb.com>
Co-authored-by: Optimox <sebastien.fischman@gmail.com>
Co-authored-by: Tanish Ambulkar <tanish.ambulkar99@gmail.com>
Co-authored-by: Botao Chen <markchen1015@meta.com>
Co-authored-by: solitude-alive <44771751+solitude-alive@users.noreply.github.com>
Co-authored-by: christobill <christobill@users.noreply.github.com>
Co-authored-by: Philip Bontrager <pbontrager@gmail.com>
Co-authored-by: Evan Smothers <ebs@fb.com>
@RdoubleA RdoubleA mentioned this pull request Jul 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.