Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added GPTNeoForTokenClassification #22908

Merged
merged 20 commits into from
Apr 27, 2023

Conversation

peter-sk
Copy link
Contributor

@peter-sk peter-sk commented Apr 21, 2023

What does this PR do?

It adds the class GPTNeoForTokenClassification, which allows using GPT Neo models for token classification tasks. The implementation follows the one for other models (such as GPT2) closely and simply adds a linear layer after the hidden states.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker
@younesbelkada

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 21, 2023

The documentation is not available anymore as the PR was closed or merged.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Apr 21, 2023

Hey! Could you make sure the CI tests are green? Can review then!

@peter-sk
Copy link
Contributor Author

@ArthurZucker
Sure. I'm getting the hang of it. Now, the only failing tests are connected to flax and seem unrelated to this pull request.

@peter-sk
Copy link
Contributor Author

If the flax errors are not due to the PR, this is ready to be reviewed, @ArthurZucker and @younesbelkada :-)

@peter-sk
Copy link
Contributor Author

I just checked the logs for the remaining errors one more time. The errors are related to the import of the optax library, where jax.Array is used in a type. Apparently there is no name "Array" in the top-level namespace of the jax module.

I cannot see how this could be related to my PR.

@peter-sk
Copy link
Contributor Author

The jax version used in the examples_flax test is 0.3.6:
Collecting jax!=0.3.2,<=0.3.6,>=0.2.8 (from transformers==4.28.0.dev0)
Using cached jax-0.3.6-py3-none-any.whl
This version clearly has no Array class.
I am unsure why such an old version should be used?

@peter-sk
Copy link
Contributor Author

Figured out that optax <= 0.1.4 is needed. And found out that upstream/main has that change already 👍 Now everything should be cleared for review.

@peter-sk
Copy link
Contributor Author

Definitely ready for review, @ArthurZucker and @younesbelkada :-)

@ArthurZucker
Copy link
Collaborator

Cool! Reviewing now

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Left a few comments but looks good otherwise!

src/transformers/models/gpt_neo/modeling_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/modeling_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/modeling_gpt_neo.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your work on this, I have the same comments as @ArthurZucker !

@peter-sk
Copy link
Contributor Author

All done and ready to be merged, @ArthurZucker and @younesbelkada 👍

@peter-sk
Copy link
Contributor Author

I implemented the same change as for GPTNeoXForTokenClassification, i.e., I removed the hasattr etc. and just use config.classifier_dropout directly.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much again!

@younesbelkada younesbelkada requested a review from sgugger April 27, 2023 15:06
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Can you just solve the conflicts so we can merge the PR?

@peter-sk
Copy link
Contributor Author

@sgugger Ready to merge when the checks complete. Thanks for the fast action 👍

... and more to come in the next weeks!

@sgugger sgugger merged commit d65b14e into huggingface:main Apr 27, 2023
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* added GPTNeoForTokenClassification

* add to top-level init

* fixup

* test

* more fixup

* add to gpt_neo.mdx

* repo consistency

* dummy copy

* fix copies

* optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6

* merge with main made this superfluous

* added classifier_dropout

* remove legacy code

* removed fmt:on/off
removed expected_outputs

* doc style fix

* classifier_dropout is always in config

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* added GPTNeoForTokenClassification

* add to top-level init

* fixup

* test

* more fixup

* add to gpt_neo.mdx

* repo consistency

* dummy copy

* fix copies

* optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6

* merge with main made this superfluous

* added classifier_dropout

* remove legacy code

* removed fmt:on/off
removed expected_outputs

* doc style fix

* classifier_dropout is always in config

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants