Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Full Finetune for MM #1548

Merged
merged 15 commits into from
Sep 19, 2024
Merged

Update Full Finetune for MM #1548

merged 15 commits into from
Sep 19, 2024

Conversation

pbontrager
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Generalize full_finetune_single_device.py to be compatible with Flamingo. The changes in the recipe are small but required a number of bug fixes for multimodal datasets and fusion models. This PR is a grab all of bug fixes that were required to get the recipe working with Flamingo.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1548

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fe1a781 with merge base c5db813 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 11, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 87.12871% with 13 lines in your changes missing coverage. Please review.

Project coverage is 71.25%. Comparing base (221031a) to head (360108c).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_single_device.py 0.00% 9 Missing ⚠️
torchtune/datasets/multimodal/_llava_instruct.py 75.00% 1 Missing ⚠️
torchtune/datasets/multimodal/_the_cauldron.py 90.00% 1 Missing ⚠️
torchtune/models/flamingo/_transform.py 0.00% 1 Missing ⚠️
torchtune/modules/model_fusion/_fusion.py 92.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1548      +/-   ##
==========================================
- Coverage   73.36%   71.25%   -2.11%     
==========================================
  Files         287      290       +3     
  Lines       14142    14233      +91     
==========================================
- Hits        10375    10142     -233     
- Misses       3767     4091     +324     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we need to expose collate_fn for multimodal, but it feels odd that users need to remember to set a collater (or even know what that is) if they want to train multimodal. can we do an alternative approach, like a multimodal flag, or infer it from the dataset?

@@ -240,10 +242,12 @@ def setup(self, cfg: DictConfig) -> None:

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a consideration for this PR, but I really don't like the proliferation of all these config defaults. It makes the yaml config stray further from being the source of truth because there are hidden parameters. We should figure out some process of adding these to all our configs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are planning to make collate_fn configurable from this point on, it might be worth actually updating all our configs to show this. but also not sure if we anticipate this being configured often

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I'd be interested in compiling a list of guildelines for recipes and deciding on what are standard is going forward.

@@ -423,6 +427,7 @@ def _setup_data(
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
collate_fn: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no documentation anywhere on how to specify the collate_fn. since it's not surfaced on our configs it's quite obscure how to customize this. I know we don't have docstrings for these setup methods but I would say collate_fn should have a quick explanation, especially it does some dotpath magic

else partial(
padded_collate_packed,
),
else padded_collate_packed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also explain somewhere that the collate_fn is ignored if packed=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't like this if/else construct here. I think after this we need to do an update and standardization of our padding function offerings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, we have a bunch of collate functions now and they are all added in a very adhoc fashion. I think @RdoubleA's comments still hold though.. the change in how we use collate_fn can potentially cause problems for people. Also I hate to say it but all our usage of _get_component_from_path is starting to feel like we are just hacking in a registry.. (not saying we shouldn't do it here btw, I actually think it's the best approach rn)

for i, message in enumerate(sample[self._column_map["texts"]]):
user_content = [{"type": "text", "content": message["user"]}]
if i == 0:
user_content = img_content + user_content
messages.append(
Message(
role="user",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not every message you loop through here would be a user message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every "message" here is a user/message pair

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just commenting here for convenience, but is the example given in the docstring actually correct? Seems to me like it isn't


"""
fusion_params = {}
for k, v in model.named_modules():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make the for loop variables more descriptive? it's hard to follow what's happening

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a direct port of the get_peft_params function, so I'd prefer to keep both of those implementations in sync for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proliferation of my laziness..

current_fusion_params.remove(n)
assert (
current_fusion_params == []
), f"Fusion params {current_adapter_params} not converted"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe "Fusion params retrieved but not found in model's named parameters?"

elif isinstance(v, torch.Tensor):
batch[k] = v.to(device)
else:
raise AttributeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a ValueError or TypeError

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 1 to all Rafi's comments

@@ -34,8 +34,8 @@ Multimodal datasets
:toctree: generated/
:nosignatures:

llava_instruct_dataset
the_cauldron_dataset
multimodal.llava_instruct_dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank u

@@ -37,3 +37,33 @@ def fusion_params(self) -> List[str]:
return [k for k, v in self.named_parameters()]

module.fusion_params = functools.partial(fusion_params, module)


def get_fusion_params(model: nn.Module) -> Dict[str, nn.Parameter]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this function? Why would you need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It allows you to control which parameters you want to freeze. It's similar to get_peft_params. For an example of this, look in the DeepFusionModel init for an example.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this deal with the kv cache changes?

(it doesnt)

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No huge concerns from me, just a bunch of small comments

"SFTDataset",
"hh_rlhf_helpful_dataset",
"llava_instruct_dataset",
"multimodal",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just expose the APIs directly in datasets/multimodal/init.py? Feels a bit weird to list a folder as a public API like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried not including multimodal in this init and the import path stopped working

@@ -100,6 +100,7 @@ def __init__(
self.stop_tokens = self.tokenizer.stop_tokens
self.max_seq_len = max_seq_len
self.prompt_template = prompt_template
self.pad_id = self.tokenizer.pad_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know there's not a clear way around it, but I don't love that we have to do this. It'll be a gotcha for anyone trying to write their own transforms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a way around this, but it involves inheriting from the tokenizer

Comment on lines +274 to +275
out = out.masked_scatter(mask, embeds)
out = out.masked_scatter(~mask, fusion_embeds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cause we need to track the grads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because using the inplace masked_scatter was giving bad gradients

Comment on lines +340 to +351
if encoder_trainable:
trainable_params |= {
f"encoder.{n}" for n, p in self.encoder.named_parameters()
}
if decoder_trainable:
trainable_params |= {
f"decoder.{n}" for n, p in self.decoder.named_parameters()
}
if fusion_trainable:
trainable_params |= set(get_fusion_params(self))
else:
trainable_params -= set(get_fusion_params(self))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes a default state of (decoder_trainable, encoder_trainable, fusion_trainable) = (False, False, True), right? It's a bit unintuitive to me that we remove only the fusion params explicitly but not the other ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because fusion params are not their own separate module but a part of the encoder and decoder. The encoder = pretrained_encoder + fusion_params and decoder = pretrained_decoder + fusion_params. So you add the encoder and/or decoder. Then you can either remove all fusion_params if fusion_trainable = False or you add them all, since if the encoder/decoder is missing you might be missing some.


"""
fusion_params = {}
for k, v in model.named_modules():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proliferation of my laziness..

@@ -261,22 +261,27 @@ def forward(

# A mask of tokens (x) with no encoder_input
skip_mask = self._skip_mask(encoder_mask)
if encoder_mask is not None:
# TODO: remove after PyTorch 2.5 is released
# This unmasks the skipped rows to avoid NaNs in SPDA Softmax backward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This unmasks the skipped rows to avoid NaNs in SPDA Softmax backward
# This unmasks the skipped rows to avoid NaNs in SDPA Softmax backward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is this todo only related to encoder_mask? Or can we remove the usage of skip_mask altogether when 2.5 is released? I had assumed it was the latter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need skip_mask, we just don't need this extra step of updating the encoder_mask. The update would make skip_mask optional as you wouldn't have to worry about NaN's but you still get different behavior as you're not masking out the ffwd or output in attention.

else partial(
padded_collate_packed,
),
else padded_collate_packed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, we have a bunch of collate functions now and they are all added in a very adhoc fashion. I think @RdoubleA's comments still hold though.. the change in how we use collate_fn can potentially cause problems for people. Also I hate to say it but all our usage of _get_component_from_path is starting to feel like we are just hacking in a registry.. (not saying we shouldn't do it here btw, I actually think it's the best approach rn)


logits = self._model(tokens, mask=mask, input_pos=input_pos)
logits = self._model(**batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're fully in dictionary land here, we should think about doing some key validation down the line (or at least giving very clear documentation on what the expected fields are)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't strictly need the dictionary unpacking here since we have a standard input for TransformerDecoder, but I think this is much cleaner and makes future changes easier.

for i, message in enumerate(sample[self._column_map["texts"]]):
user_content = [{"type": "text", "content": message["user"]}]
if i == 0:
user_content = img_content + user_content
messages.append(
Message(
role="user",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just commenting here for convenience, but is the example given in the docstring actually correct? Seems to me like it isn't

@pbontrager pbontrager merged commit 63208c6 into pytorch:main Sep 19, 2024
17 checks passed
@pbontrager pbontrager deleted the mm_recipe branch September 19, 2024 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants