-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setup loss_type in config at model init time #34616
Conversation
ensures no additional graph break introduced when torch.compile'ed fixes huggingface#34615 Signed-off-by: ChanderG <mail@chandergovind.org>
If this approach is fine - I will extend the PR to cover a wide set of models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ChanderG! Thanks for opening an issue and PR! Nice catch with the graph breaks!
Re the PR it is usually not the best decision to modify class attribute inplace, however this pattern is widely used across classification models losses
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
do you have other ideas on how to handle it without modifying the config attribute? Is there a way to fix re.findall
instead of introducing the change for all models?
Alternatively we can have a model attribute, smth like
self.loss_type = "ForCausalLM" if self.config.loss_type is None else self.config.loss_type
but not feel strong re this :)
A more generic option may be to run the regex call at model init time, instead of hardcoding. I need to check, but it may still be efficient placed there (ie, won't cause a graph break from init). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! We should rather do this in the call of super().__init__(config)
as otherwise we have to monkey patch each and every model!
Signed-off-by: ChanderG <mail@chandergovind.org>
@ArthurZucker Updated to generic lookup. Tested with the repro here #34615 (comment) and no Graph Breaks with this latest version as well. |
Waiting on this to put in the upcoming release !! 🤗 |
Okay this is failing tests because the |
Pushed directly, hope it's not a problem 🤗 wanted to get this fix in fast! |
Thanks for your patience! I thought this had been merged a while ago! |
ensures no additional graph break introduced when torch.compile'ed
fixes #34615
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.