-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix CLIP pos embedding interpolation to work on DTensors #1739
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1739
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f98351a with merge base fc0249d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
local_device = inpt_local_pos_embed.device | ||
if isinstance(inpt_local_pos_embed, DTensor): | ||
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete "local_device" since its not used
put this inside of the "if inpt_local_pos_embed is not None:"
@@ -159,6 +164,13 @@ def _load_state_dict_hook( | |||
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)), | |||
) | |||
|
|||
if isinstance(inpt_local_pos_embed, DTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inpt_local_pos_embed is not a DTensor, because its full now
need to chech if self.local_token_positional_embedding is DTensor. We should probably do the same in the previous check (there you check if inpt_local_pos_embed is DTensor)
local_device = inpt_local_pos_embed.device | ||
if isinstance(inpt_local_pos_embed, DTensor): | ||
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor() | ||
|
||
if inpt_local_pos_embed is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably add here:
if inpt_local_pos_embed is not None and inpt_local_pos_embed.shape != self. local_token_positional_embedding
But testing becomes a bit trickier. Maybe for now its better to not add it until testing is completed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah will leave out for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should add before the PR is finalized though. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you say testing becomes trickier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After thinking about it more, I don't think we should add it. The DTensor fix resolves the issue and no need to add extra logic on top of this that was not present before
global_device = inpt_global_pos_embed.device | ||
if isinstance(inpt_global_pos_embed, DTensor): | ||
inpt_global_pos_embed = inpt_global_pos_embed.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments as above
if isinstance(embedding, DTensor): | ||
embedding = embedding.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to test it, please either use the pretrained version, since it has a different shape for local and global. The other pos embedding is the same, or change the max_num_tiles and patch size in the configs.
local_embed_device_mesh = inpt_local_pos_embed.device_mesh | ||
local_embed_placements = inpt_local_pos_embed.placements | ||
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i thought that we had to use the device_mech and placements from self.local_token_positional_embedding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they should be the same, which is why it was working before. But in my mind this is the more "correct" thing to do. We apply some operation to a DTensor, then restore it to its original state after the fact
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this increase chances of OOMs, or is the pos_embed small enough where this is not a concern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RdoubleA well as of today it doesn't work so I guess the current chance of OOM is NaN. This is supposed to be a no-op for single device (hence wrapping everything in isinstance(..., DTensor)
checks), so no memory implications there.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1739 +/- ##
==========================================
- Coverage 25.99% 25.75% -0.24%
==========================================
Files 304 305 +1
Lines 15627 15901 +274
==========================================
+ Hits 4062 4096 +34
- Misses 11565 11805 +240
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
thanks Evan! I am a bit confused: I am not sure how single device is a good test, since it doesnt require FSDP I think that the best test is to try to run the pretrained Model on single device, save the interpolated embeddings (i sahred them with you), run it in distributed and confirm that the interpolated embeddings are the same produced in single device |
@felipemello1 the purpose of single device is to show that we haven't broken anything there. I can run the test you're describing as well |
the unit tests should cover the single device. I added this to the functions, where i saved with single device, and loaded with distributed. It passed. We are good to merge
|
Currently our state dict hooks to resize CLIP positional embeddings do not work with FSDP2. This is because some of the ops used by interpolate (e.g.
aten.upsample_bilinear2d.default
) are not natively defined onDTensor
, and as can be seen here, we shard the state dict tensors into DTensors before callingload_state_dict
(where the interpolation hook is triggered).Until these ops are defined on DTensors, we need to do something like this. This PR adds logic to the CLIP positional embedding state dict hooks to handle the case that tensors in the state dict are already sharded. For each relevant tensor, we:
.full_tensor()
. If it isn't, we don't do anything.distribute_tensor
using the original mesh and placements. Otherwise we don't do anything.Test plan:
Compare loss curves on recipe before and after fix on base model. First, update 11B_lora_single_device.yaml to match this
Before fix (on main):
After fix (this PR)
Test that distributed recipe now runs: change 11B_lora config to use base model: here. Then run: