Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CLIP pos embedding interpolation to work on DTensors #1739

Merged
merged 7 commits into from
Oct 2, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 2, 2024

Currently our state dict hooks to resize CLIP positional embeddings do not work with FSDP2. This is because some of the ops used by interpolate (e.g. aten.upsample_bilinear2d.default) are not natively defined on DTensor, and as can be seen here, we shard the state dict tensors into DTensors before calling load_state_dict (where the interpolation hook is triggered).

Until these ops are defined on DTensors, we need to do something like this. This PR adds logic to the CLIP positional embedding state dict hooks to handle the case that tensors in the state dict are already sharded. For each relevant tensor, we:

  1. Check if the state dict tensor is a DTensor
  2. If it is, we gather via .full_tensor(). If it isn't, we don't do anything.
  3. Call interpolate (along with various reshapes) on the vanilla Tensor from (2)
  4. If the original state dict tensor was a DTensor, we call distribute_tensor using the original mesh and placements. Otherwise we don't do anything.

Test plan:

pytest tests/torchtune/models/clip/test_pos_embedding_interpolation.py
...
========== 15 passed in 0.66s ============

Compare loss curves on recipe before and after fix on base model. First, update 11B_lora_single_device.yaml to match this

Before fix (on main):

tune run lora_finetune_single_device --config llama3_2_vision/11B_lora_single_device max_steps_per_epoch=100 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-1739 \
metric_logger.name=pre-fix

After fix (this PR)

tune run lora_finetune_single_device --config llama3_2_vision/11B_lora_single_device max_steps_per_epoch=100 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-1739 \
metric_logger.name=post-fix
Screenshot 2024-10-02 at 10 15 15 AM

Test that distributed recipe now runs: change 11B_lora config to use base model: here. Then run:

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_lora
...
1|2|Loss: 1.3108230829238892:   0%|                                                                                                                         | 2/10359 [00:57<82:38:19, 28.72s/it]

Copy link

pytorch-bot bot commented Oct 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1739

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f98351a with merge base fc0249d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2024
@ebsmothers ebsmothers changed the title hacky fsdp interpolate Fix CLIP pos embedding interpolation to work on DTensors Oct 2, 2024
Comment on lines 141 to 143
local_device = inpt_local_pos_embed.device
if isinstance(inpt_local_pos_embed, DTensor):
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete "local_device" since its not used

put this inside of the "if inpt_local_pos_embed is not None:"

@@ -159,6 +164,13 @@ def _load_state_dict_hook(
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)

if isinstance(inpt_local_pos_embed, DTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inpt_local_pos_embed is not a DTensor, because its full now

need to chech if self.local_token_positional_embedding is DTensor. We should probably do the same in the previous check (there you check if inpt_local_pos_embed is DTensor)

local_device = inpt_local_pos_embed.device
if isinstance(inpt_local_pos_embed, DTensor):
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()

if inpt_local_pos_embed is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add here:

if inpt_local_pos_embed is not None and inpt_local_pos_embed.shape != self. local_token_positional_embedding

But testing becomes a bit trickier. Maybe for now its better to not add it until testing is completed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah will leave out for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should add before the PR is finalized though. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you say testing becomes trickier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about it more, I don't think we should add it. The DTensor fix resolves the issue and no need to add extra logic on top of this that was not present before

Comment on lines 191 to 193
global_device = inpt_global_pos_embed.device
if isinstance(inpt_global_pos_embed, DTensor):
inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as above

Comment on lines 524 to 525
if isinstance(embedding, DTensor):
embedding = embedding.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as above

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to test it, please either use the pretrained version, since it has a different shape for local and global. The other pos embedding is the same, or change the max_num_tiles and patch size in the configs.

@ebsmothers ebsmothers marked this pull request as ready for review October 2, 2024 16:21
Comment on lines +149 to +151
local_embed_device_mesh = inpt_local_pos_embed.device_mesh
local_embed_placements = inpt_local_pos_embed.placements
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought that we had to use the device_mech and placements from self.local_token_positional_embedding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be the same, which is why it was working before. But in my mind this is the more "correct" thing to do. We apply some operation to a DTensor, then restore it to its original state after the fact

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this increase chances of OOMs, or is the pos_embed small enough where this is not a concern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RdoubleA well as of today it doesn't work so I guess the current chance of OOM is NaN. This is supposed to be a no-op for single device (hence wrapping everything in isinstance(..., DTensor) checks), so no memory implications there.

@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2024

Codecov Report

Attention: Patch coverage is 4.00000% with 24 lines in your changes missing coverage. Please review.

Project coverage is 25.75%. Comparing base (10b02e0) to head (f98351a).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/models/clip/_position_embeddings.py 4.00% 24 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1739      +/-   ##
==========================================
- Coverage   25.99%   25.75%   -0.24%     
==========================================
  Files         304      305       +1     
  Lines       15627    15901     +274     
==========================================
+ Hits         4062     4096      +34     
- Misses      11565    11805     +240     
Flag Coverage Δ
?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1
Copy link
Contributor

thanks Evan! I am a bit confused: I am not sure how single device is a good test, since it doesnt require FSDP

I think that the best test is to try to run the pretrained Model on single device, save the interpolated embeddings (i sahred them with you), run it in distributed and confirm that the interpolated embeddings are the same produced in single device

@ebsmothers
Copy link
Contributor Author

@felipemello1 the purpose of single device is to show that we haven't broken anything there. I can run the test you're describing as well

@felipemello1
Copy link
Contributor

the unit tests should cover the single device. I added this to the functions, where i saved with single device, and loaded with distributed. It passed. We are good to merge

# torch.save(inpt_local_pos_embed, "inpt_local_pos_embed.pt")
inpt_local_pos_embed_single = torch.load("inpt_local_pos_embed.pt", map_location=torch.device('cpu'))
inpt_local_pos_embed_single = inpt_local_pos_embed_single.to(inpt_local_pos_embed.device)
assert torch.allclose(inpt_local_pos_embed, inpt_local_pos_embed_single, rtol=1e-03, atol=1e-03)
print("passed local")

# torch.save(inpt_global_pos_embed, "inpt_global_pos_embed.pt")
inpt_global_pos_embed_single = torch.load("inpt_global_pos_embed.pt", map_location=torch.device('cpu'))
inpt_global_pos_embed_single = inpt_global_pos_embed_single.to(inpt_global_pos_embed.device)
assert torch.allclose(inpt_global_pos_embed, inpt_global_pos_embed_single, rtol=1e-03, atol=1e-03)
print("passed global")

@RdoubleA RdoubleA merged commit 7cf656b into pytorch:main Oct 2, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants