-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Full Finetune for MM #1548
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1548
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fe1a781 with merge base c5db813 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1548 +/- ##
==========================================
- Coverage 73.36% 71.25% -2.11%
==========================================
Files 287 290 +3
Lines 14142 14233 +91
==========================================
- Hits 10375 10142 -233
- Misses 3767 4091 +324 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we need to expose collate_fn
for multimodal, but it feels odd that users need to remember to set a collater (or even know what that is) if they want to train multimodal. can we do an alternative approach, like a multimodal flag, or infer it from the dataset?
@@ -240,10 +242,12 @@ def setup(self, cfg: DictConfig) -> None: | |||
|
|||
# sampler and dataloader depend on the tokenizer and loss_fn and should be | |||
# setup after both of these are initialized | |||
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a consideration for this PR, but I really don't like the proliferation of all these config defaults. It makes the yaml config stray further from being the source of truth because there are hidden parameters. We should figure out some process of adding these to all our configs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are planning to make collate_fn configurable from this point on, it might be worth actually updating all our configs to show this. but also not sure if we anticipate this being configured often
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point. I'd be interested in compiling a list of guildelines for recipes and deciding on what are standard is going forward.
@@ -423,6 +427,7 @@ def _setup_data( | |||
cfg_dataset: DictConfig, | |||
shuffle: bool, | |||
batch_size: int, | |||
collate_fn: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no documentation anywhere on how to specify the collate_fn. since it's not surfaced on our configs it's quite obscure how to customize this. I know we don't have docstrings for these setup methods but I would say collate_fn should have a quick explanation, especially it does some dotpath magic
else partial( | ||
padded_collate_packed, | ||
), | ||
else padded_collate_packed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should also explain somewhere that the collate_fn is ignored if packed=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't like this if/else construct here. I think after this we need to do an update and standardization of our padding function offerings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, we have a bunch of collate functions now and they are all added in a very adhoc fashion. I think @RdoubleA's comments still hold though.. the change in how we use collate_fn can potentially cause problems for people. Also I hate to say it but all our usage of _get_component_from_path
is starting to feel like we are just hacking in a registry.. (not saying we shouldn't do it here btw, I actually think it's the best approach rn)
for i, message in enumerate(sample[self._column_map["texts"]]): | ||
user_content = [{"type": "text", "content": message["user"]}] | ||
if i == 0: | ||
user_content = img_content + user_content | ||
messages.append( | ||
Message( | ||
role="user", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not every message you loop through here would be a user message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every "message" here is a user/message pair
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just commenting here for convenience, but is the example given in the docstring actually correct? Seems to me like it isn't
|
||
""" | ||
fusion_params = {} | ||
for k, v in model.named_modules(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make the for loop variables more descriptive? it's hard to follow what's happening
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a direct port of the get_peft_params function, so I'd prefer to keep both of those implementations in sync for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proliferation of my laziness..
current_fusion_params.remove(n) | ||
assert ( | ||
current_fusion_params == [] | ||
), f"Fusion params {current_adapter_params} not converted" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe "Fusion params retrieved but not found in model's named parameters?"
torchtune/utils/_device.py
Outdated
elif isinstance(v, torch.Tensor): | ||
batch[k] = v.to(device) | ||
else: | ||
raise AttributeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a ValueError or TypeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 1 to all Rafi's comments
@@ -34,8 +34,8 @@ Multimodal datasets | |||
:toctree: generated/ | |||
:nosignatures: | |||
|
|||
llava_instruct_dataset | |||
the_cauldron_dataset | |||
multimodal.llava_instruct_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank u
@@ -37,3 +37,33 @@ def fusion_params(self) -> List[str]: | |||
return [k for k, v in self.named_parameters()] | |||
|
|||
module.fusion_params = functools.partial(fusion_params, module) | |||
|
|||
|
|||
def get_fusion_params(model: nn.Module) -> Dict[str, nn.Parameter]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this function? Why would you need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It allows you to control which parameters you want to freeze. It's similar to get_peft_params. For an example of this, look in the DeepFusionModel init for an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this deal with the kv cache changes?
(it doesnt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No huge concerns from me, just a bunch of small comments
"SFTDataset", | ||
"hh_rlhf_helpful_dataset", | ||
"llava_instruct_dataset", | ||
"multimodal", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just expose the APIs directly in datasets/multimodal/init.py? Feels a bit weird to list a folder as a public API like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried not including multimodal in this init and the import path stopped working
@@ -100,6 +100,7 @@ def __init__( | |||
self.stop_tokens = self.tokenizer.stop_tokens | |||
self.max_seq_len = max_seq_len | |||
self.prompt_template = prompt_template | |||
self.pad_id = self.tokenizer.pad_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know there's not a clear way around it, but I don't love that we have to do this. It'll be a gotcha for anyone trying to write their own transforms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a way around this, but it involves inheriting from the tokenizer
out = out.masked_scatter(mask, embeds) | ||
out = out.masked_scatter(~mask, fusion_embeds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this cause we need to track the grads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because using the inplace masked_scatter was giving bad gradients
if encoder_trainable: | ||
trainable_params |= { | ||
f"encoder.{n}" for n, p in self.encoder.named_parameters() | ||
} | ||
if decoder_trainable: | ||
trainable_params |= { | ||
f"decoder.{n}" for n, p in self.decoder.named_parameters() | ||
} | ||
if fusion_trainable: | ||
trainable_params |= set(get_fusion_params(self)) | ||
else: | ||
trainable_params -= set(get_fusion_params(self)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes a default state of (decoder_trainable, encoder_trainable, fusion_trainable) = (False, False, True), right? It's a bit unintuitive to me that we remove only the fusion params explicitly but not the other ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because fusion params are not their own separate module but a part of the encoder and decoder. The encoder = pretrained_encoder + fusion_params and decoder = pretrained_decoder + fusion_params. So you add the encoder and/or decoder. Then you can either remove all fusion_params if fusion_trainable = False or you add them all, since if the encoder/decoder is missing you might be missing some.
|
||
""" | ||
fusion_params = {} | ||
for k, v in model.named_modules(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proliferation of my laziness..
torchtune/modules/transformer.py
Outdated
@@ -261,22 +261,27 @@ def forward( | |||
|
|||
# A mask of tokens (x) with no encoder_input | |||
skip_mask = self._skip_mask(encoder_mask) | |||
if encoder_mask is not None: | |||
# TODO: remove after PyTorch 2.5 is released | |||
# This unmasks the skipped rows to avoid NaNs in SPDA Softmax backward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# This unmasks the skipped rows to avoid NaNs in SPDA Softmax backward | |
# This unmasks the skipped rows to avoid NaNs in SDPA Softmax backward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also is this todo only related to encoder_mask? Or can we remove the usage of skip_mask altogether when 2.5 is released? I had assumed it was the latter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need skip_mask, we just don't need this extra step of updating the encoder_mask. The update would make skip_mask optional as you wouldn't have to worry about NaN's but you still get different behavior as you're not masking out the ffwd or output in attention.
else partial( | ||
padded_collate_packed, | ||
), | ||
else padded_collate_packed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, we have a bunch of collate functions now and they are all added in a very adhoc fashion. I think @RdoubleA's comments still hold though.. the change in how we use collate_fn can potentially cause problems for people. Also I hate to say it but all our usage of _get_component_from_path
is starting to feel like we are just hacking in a registry.. (not saying we shouldn't do it here btw, I actually think it's the best approach rn)
|
||
logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
logits = self._model(**batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we're fully in dictionary land here, we should think about doing some key validation down the line (or at least giving very clear documentation on what the expected fields are)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't strictly need the dictionary unpacking here since we have a standard input for TransformerDecoder, but I think this is much cleaner and makes future changes easier.
for i, message in enumerate(sample[self._column_map["texts"]]): | ||
user_content = [{"type": "text", "content": message["user"]}] | ||
if i == 0: | ||
user_content = img_content + user_content | ||
messages.append( | ||
Message( | ||
role="user", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just commenting here for convenience, but is the example given in the docstring actually correct? Seems to me like it isn't
Context
What is the purpose of this PR? Is it to
Generalize full_finetune_single_device.py to be compatible with Flamingo. The changes in the recipe are small but required a number of bug fixes for multimodal datasets and fusion models. This PR is a grab all of bug fixes that were required to get the recipe working with Flamingo.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models