-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds PreTrainedModel.framework
attribute
#13817
Conversation
Great, thank you @StellaAthena! I'm not sure I see the full picture of adding that argument - but I'm definitely not opposed to it if it's helpful for your use case. It's more robust than relying on the class name. I believe the property implemented in PyTorch (and not implemented in Flax and TensorFlow) isn't voluntary - the former was implemented early (two years ago), and the latter was overlooked. For this property in particular ( Thanks for offering a PR! |
Thanks for the PR! The transformers/src/transformers/models/bert/modeling_bert.py Lines 1486 to 1492 in 8bbb53e
This attribute is useful when loading a base model weights into a model with a head. And the reason And for this property |
Thank you both for the explication, it makes understanding why the
When writing code that takes a user-defined
That's an interesting idea. My thought was that this approach would cause it to be encoded in |
@patil-suraj I have updated the code to follow your suggestion. The failing tests seem to have to do with an indentation error that I cannot work out. I even copied an existing function rather than write my own, in case there was something funky about how my keyboard was registering! Edit: it looks like I was being fooled by a misleading error message! Changing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation and working on this @StellaAthena, this looks good to me! There's an issue with code quality, do you mind running the code quality script? You can do so like this, from the root of your transformers
clone:
pip install -e .[quality]
make fixup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@LysandreJik @patil-suraj I don't think I can do any more. I'm having trouble installing Jax, which may be the blocker? IDK. The below image shows me running |
Hi @StellaAthena no problem. I could take care of this, would it be okay if I push to your branch? |
Absolutely! Thanks |
Thanks for working on the PR @StellaAthena! |
What does this PR do?
This PR introduces an attribute called
framework
inPreTrainedModel
,FlaxPreTrainedModel
, andTFPreTrainedModel
. The purpose of this attribute is to allow a user to know what framework a provided model is in, as that information is not currently very accessible.I'm a little confused as to whether this is correctly implemented. I was basing it off of the implementation of
base_model_prefix
, which doesn't have a getattr inFlaxPretrainedModel
andTFPretrainedModel
despite those not (AFAICT) inheriting fromPreTrainedModel
.Who can review?
@patil-suraj @LysandreJik