-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: add model class validation #18902
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
src/transformers/generation_utils.py
Outdated
if not hasattr(self, "prepare_inputs_for_generation"): | ||
model_class = self.__class__.__name__ | ||
raise TypeError( | ||
f"The current model class ({model_class}) is not compatible with `.generate()`, as it doesn't have a ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Wondering if it could make sense to make a really nifty error message here where we check if any class of the model type of the given class in in one of the auto classes below - maybe over-engineered though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Wondering if it could make sense to make the error message a bit more detailed E.g. if someone uses BertModel
and gets an error that one should use one of the following AutoModel classes, I'm not sure how helpful this is). Could be super cool that if one uses BertModel
and gets as an answer that BertForCausalLM
supports generate
it could be a nicer UX. But maybe over engineered.
The documentation is not available anymore as the PR was closed or merged. |
@sgugger @patrickvonplaten I've requested a re-review of this PR. As per @patrickvonplaten's suggestion, the PR was upgraded to contain the exact class the user should use in the exception (as opposed to pointing to all generate-compatible auto classes). In the process of building it, I've noticed that the previous version of this PR was incorrect anyways -- PT and TF had a default Here's an example with the current version of the PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I haven't tested thoroughly but from a few quick tests it seems to work well.
Thanks @gante!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh missed the notification that re-requested the review. Still LGTM!
What does this PR do?
Fixes #18210
This PR adds model class validation at the start of generate (all model classes inherit
GenerationMixin
, but few can usegenerate()
). It also adds an exception that attempts to redirect the users to the right classes.