-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] full finetune: move state dict to cpu when cpu offloading #1495
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1495
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fedaa32 with merge base 5fcb931 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -584,4 +587,4 @@ def shard_model( | |||
fully_shard(m, **fsdp_kwargs) | |||
|
|||
# Finally shard the entire model to account for any stragglers | |||
fully_shard(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be dropping **fsdp_kwargs
by accident? it prevents cpu_offloading for root model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I had meant to fix this, thanks for adding it here!
@@ -338,6 +339,8 @@ def load_from_full_model_state_dict( | |||
sharded_meta_param.device_mesh, | |||
sharded_meta_param.placements, | |||
) | |||
if cpu_offload: | |||
sharded_tensor = sharded_tensor.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharded_tensor
have device=cuda
because distribute_tensor
/DTensor requires NCCL. For cpu offloading, we can move DTensor to device=cpu
afterwards to avoid peaking memory
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
this torchtune side fix alone can resolve the problem, because we move state dict to cpu when pytorch side fix covers cases when state dict are on gpu |
Thank you @weifengpy for figuring this out and landing the fix! |
Btw @weifengpy I assume for save_checkpoint we will need to do the inverse operation, right? Move from CPU back to current device prior to resharding? |
you're right. I need to update the PR to cover save_checkpoint |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@@ -338,6 +339,8 @@ def load_from_full_model_state_dict( | |||
sharded_meta_param.device_mesh, | |||
sharded_meta_param.placements, | |||
) | |||
if cpu_offload: | |||
sharded_tensor = sharded_tensor.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change makes sense to me. I am not sure if we should support the user trying to load a GPU state dict into an FSDP module that has CPU offloading enabled. We can provide a better error message though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got you. this reminds me of my overdued BE to check cpu device in lazy_init when cpu offloading is enabled
Context
What is the purpose of this PR? Is it to
resolve: #1412
full finetune with cpu offload:
tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=True
full finetune without cpu offload:
tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False
lora:
tune run --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_lora
qlora:
tune run --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_qlora
when cpu offloading, we can move state dict to cpu to avoid peaking memory. As the snapshot shows, peak memory dropped from 7GB to 3GB:
memory behavior before
memory behavior after