Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds PreTrainedModel.framework attribute #13817

Merged
merged 20 commits into from
Oct 8, 2021
Merged

Adds PreTrainedModel.framework attribute #13817

merged 20 commits into from
Oct 8, 2021

Conversation

StellaAthena
Copy link
Contributor

What does this PR do?

This PR introduces an attribute called framework in PreTrainedModel, FlaxPreTrainedModel, and TFPreTrainedModel. The purpose of this attribute is to allow a user to know what framework a provided model is in, as that information is not currently very accessible.

I'm a little confused as to whether this is correctly implemented. I was basing it off of the implementation of base_model_prefix, which doesn't have a getattr in FlaxPretrainedModel and TFPretrainedModel despite those not (AFAICT) inheriting from PreTrainedModel.

Who can review?

@patil-suraj @LysandreJik

@LysandreJik
Copy link
Member

Great, thank you @StellaAthena! I'm not sure I see the full picture of adding that argument - but I'm definitely not opposed to it if it's helpful for your use case. It's more robust than relying on the class name.

I believe the property implemented in PyTorch (and not implemented in Flax and TensorFlow) isn't voluntary - the former was implemented early (two years ago), and the latter was overlooked.

For this property in particular (framework), I believe having it as a simple attribute should be enough for all three frameworks.

Thanks for offering a PR!

@patil-suraj
Copy link
Contributor

Thanks for the PR!

The base_model_prefix serves a different purpose here, it indicates the module name used for the base module in a model with a specific head on top. For example, the base_model_prefix for bert is bert, which is used by the head models as the module name for the base model

class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config)

This attribute is useful when loading a base model weights into a model with a head. And the reason base_model property is only added in PT PreTrainedModel and not in FlaxPreTrainedModel is because in pt it's possible to return a submodule and using this the user can access the base model if he needs (for example to freeze the base).
This is not possible for example in flax, because flax modules are stateless, and returning base_model will return a reference to the module without weights. Hope this makes it clear.

And for this property framework, IMO we could simply add it as a getter property and return the framework string, adding it just as a getter will also prevent users from accidentally setting it.

@StellaAthena
Copy link
Contributor Author

Thank you both for the explication, it makes understanding why the transformers code is the way it is.

Great, thank you @StellaAthena! I'm not sure I see the full picture of adding that argument - but I'm definitely not opposed to it if it's helpful for your use case. It's more robust than relying on the class name.

When writing code that takes a user-defined transformers model as an input there are a lot of weird gotchas. The impetus for this PR was my attempt to generalize Google's BIG Bench to work with arbitrary transformer models, but I suspect it'll also be useful to EleutherAI's LM Eval Harness and other similar projects. Unfortunately, there are important properties of models that are impossible to derive from the config file. Another example of this is the fact that some tokenizers auto-append to the end of generations while others do not.

And for this property framework, IMO we could simply add it as a getter property and return the framework string, adding it just as a getter will also prevent users from accidentally setting it.

That's an interesting idea. My thought was that this approach would cause it to be encoded in config files, which seems like a good best practice to follow.

@StellaAthena
Copy link
Contributor Author

StellaAthena commented Oct 4, 2021

@patil-suraj I have updated the code to follow your suggestion. The failing tests seem to have to do with an indentation error that I cannot work out. I even copied an existing function rather than write my own, in case there was something funky about how my keyboard was registering!

Edit: it looks like I was being fooled by a misleading error message! Changing string to str solved the problem.

@StellaAthena StellaAthena marked this pull request as ready for review October 4, 2021 03:46
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation and working on this @StellaAthena, this looks good to me! There's an issue with code quality, do you mind running the code quality script? You can do so like this, from the root of your transformers clone:

pip install -e .[quality]
make fixup

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@StellaAthena
Copy link
Contributor Author

@LysandreJik @patil-suraj I don't think I can do any more. I'm having trouble installing Jax, which may be the blocker? IDK. The below image shows me running make fixup and then the verification test that the readout asks me to run.

Screen Shot 2021-10-07 at 3 21 29 PM

@patil-suraj
Copy link
Contributor

Hi @StellaAthena no problem. I could take care of this, would it be okay if I push to your branch?

@StellaAthena
Copy link
Contributor Author

Hi @StellaAthena no problem. I could take care of this, would it be okay if I push to your branch?

Absolutely! Thanks

@patil-suraj patil-suraj merged commit de34481 into huggingface:master Oct 8, 2021
@LysandreJik
Copy link
Member

Thanks for working on the PR @StellaAthena!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants