Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] full finetune: move state dict to cpu when cpu offloading #1495

Merged
merged 6 commits into from
Sep 5, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Sep 4, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

resolve: #1412

full finetune with cpu offload: tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=True

full finetune without cpu offload: tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False

lora: tune run --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_lora

qlora: tune run --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_qlora

when cpu offloading, we can move state dict to cpu to avoid peaking memory. As the snapshot shows, peak memory dropped from 7GB to 3GB:

memory behavior before
Screenshot 2024-09-04 at 3 29 52 PM

memory behavior after
Screenshot 2024-09-04 at 3 31 29 PM

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Sep 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1495

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fedaa32 with merge base 5fcb931 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 4, 2024
@@ -584,4 +587,4 @@ def shard_model(
fully_shard(m, **fsdp_kwargs)

# Finally shard the entire model to account for any stragglers
fully_shard(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be dropping **fsdp_kwargs by accident? it prevents cpu_offloading for root model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I had meant to fix this, thanks for adding it here!

@@ -338,6 +339,8 @@ def load_from_full_model_state_dict(
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharded_tensor have device=cuda because distribute_tensor/DTensor requires NCCL. For cpu offloading, we can move DTensor to device=cpu afterwards to avoid peaking memory

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy
Copy link
Contributor Author

this torchtune side fix alone can resolve the problem, because we move state dict to cpu when fsdp_cpu_offload=True

pytorch side fix covers cases when state dict are on gpu

@ebsmothers
Copy link
Contributor

Thank you @weifengpy for figuring this out and landing the fix!

@ebsmothers
Copy link
Contributor

Btw @weifengpy I assume for save_checkpoint we will need to do the inverse operation, right? Move from CPU back to current device prior to resharding?

@weifengpy
Copy link
Contributor Author

Btw @weifengpy I assume for save_checkpoint we will need to do the inverse operation, right? Move from CPU back to current device prior to resharding?

you're right. I need to update the PR to cover save_checkpoint

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy merged commit f437639 into pytorch:main Sep 5, 2024
17 checks passed
@@ -338,6 +339,8 @@ def load_from_full_model_state_dict(
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes sense to me. I am not sure if we should support the user trying to load a GPU state dict into an FSDP module that has CPU offloading enabled. We can provide a better error message though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. this reminds me of my overdued BE to check cpu device in lazy_init when cpu offloading is enabled

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Full finetune recipe not working with FSDP2 CPU offload
4 participants