-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encoder-decoder models: move embedding scale to nn.Module #30410
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for fixing 🙌
There are more models where this inconsistency happens (e.g. MVP, NllbMoe, ...), would you be able to propagate the pattern?
@zucchini-nlp also, can you:
We'll tag the core maintainer after we ensure there's a test! |
Hmm, okay, I thought the test for generation with inputs-embeds is enough. There is no such test, yes. If that's needed I can add it and that can trigger handling other models/cases if there are peculiarities :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for giving you tons of extra work here 😅 The result will be very cool in the long run, though 💪
Seems like it's ready for core maintainer's review. Reverted the previous commit with skips and tested that checking signature is working for most cases. In cases where it does not work, model-specific test skipping is kept. |
@unittest.skip(reason="""Bridge Tower does not have input/output embeddings. Thus this test is not applicable.""") | ||
def test_inputs_embeds_matches_input_ids(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this (and most other skips) be caught in the skip conditions of the main test? What's missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some models have all the things implemented, but do not use "ipnuts_embeds" any where, even though accept it in forward. I dod not remove unused args like this due to BC but can be cleaned up in another PR if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO -- instead of adding custom test code to handle a silent failure, let's properly handle the failure (if the user passes inputs_embeds
to forward
, raise a NotImplementedError
, which can be caught in the test)
You've opened a Pandora's box of work in this PR, but it will be great in the long run! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so in cases when the model actually accepts but not uses inputs_embeds, we should remove "get_input_embeddings()" method? Or raise error inside forward after the input_embeds are passed?
I would like to entirely remove inputs_embeds from the forward argument list, but this leads to my question: if we should have the same list of arguments in all forwards in all models, regardless of usage? And will it break anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we should have the same list of arguments in all forwards in all models
Not necessarily e.g. vision models don't accept input_ids. All models which are grouped together should have a common subset of inputs, which enable a full forward pass of the model. Most important is that the inputs have standardised names e.g. we don't have token_ids
for one model, input_ids
for another and that their behaviour is consistent e.g. input_embeds
"means" the same thing across models.
In the case of text models, I'd say yes, if they accept input_ids, then they should also accept input_embeds, and then throw an error within the forward pass after input_embeds are passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okey, thanks for clarifying. I added NotImplemented errors where possible, leaving skip on some of the tests with explanations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, now the "test_inputs_embeds" can also be cleaned-up and all model-specific skips removed. I can do it later in another PR :)
Failing tests are passing for me locally, and are not related to the changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this - definitely a behaviour we want.
V. nicely handled and tested - just a few small things to address before merge
tests/test_modeling_common.py
Outdated
|
||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) | ||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 | ||
print(inputs.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print(inputs.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't have any print statements in tests
tests/test_modeling_common.py
Outdated
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 | ||
print(inputs.keys()) | ||
|
||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't have try/except patterns in tests: either we're deliberately triggering an error or we're not. Raising exceptions is for within code, allowing us to elegantly handle errors at runtime. Instead, models which don't use input_embeds
should explicitly skip this test, and the try/except block be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I thought we are adding all the NotImplementedErrors to get rid of many copies "skipTests", as discussed with @gante above:
IMO -- instead of adding custom test code to handle a silent failure, let's properly handle the failure (if the user passes inputs_embeds to forward, raise a NotImplementedError, which can be caught in the test)
I can bring back all skips if that's needed for tests-consistency but if we are specifically catching only "NotImplementedErrors" isn't it okay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, OK, sorry, I missed the bit about handling in the tests.
We should raise NotImplementedError on the model side, but let users handle that however they want, and then explicitly skip in the tests using unittest.skip
for the specific models. This avoids accidentally skipping because a different NotImplementedError is raised. It's true we end up with more code, but it's better for tests to be DAMP i.e. very clear and explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hehe the opposite of what I've been telling @zucchini-nlp in this PR 🙈 My bad :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not at all - the test suite is one of the most inconsistent places in our codebase 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, okay, TIL about the DAMP principle
For the musicgen tests - a fix was pushed to main - rebasing should resolve |
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@amyeroberts done, all the NotImplementedError models are skipped in their modeling_test files and the CI is green |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful - thanks for fixing this tricky issue and making sure it's well tested!
What does this PR do?
This PR moves embedding scale to nn.Module in encoder decoder models, so that users who want to pass in
inputs_embeds
to theforward
will get the same results as if they passedinput_ids
viamodel.get_input_embeddings()(input_ids)
The generation from embeds is not supported for these models, that's why we did not see the inconsistency. I think we do not need specific tests and here I can add support for generation from embeds in another PR if needed.
All the tests (+slow) for the changes models are passing on my end