-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added GPTNeoForTokenClassification #22908
added GPTNeoForTokenClassification #22908
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hey! Could you make sure the CI tests are green? Can review then! |
@ArthurZucker |
If the flax errors are not due to the PR, this is ready to be reviewed, @ArthurZucker and @younesbelkada :-) |
I just checked the logs for the remaining errors one more time. The errors are related to the import of the optax library, where jax.Array is used in a type. Apparently there is no name "Array" in the top-level namespace of the jax module. I cannot see how this could be related to my PR. |
The jax version used in the examples_flax test is 0.3.6: |
Figured out that optax <= 0.1.4 is needed. And found out that upstream/main has that change already 👍 Now everything should be cleared for review. |
Definitely ready for review, @ArthurZucker and @younesbelkada :-) |
Cool! Reviewing now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Left a few comments but looks good otherwise!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your work on this, I have the same comments as @ArthurZucker !
removed expected_outputs
All done and ready to be merged, @ArthurZucker and @younesbelkada 👍 |
I implemented the same change as for GPTNeoXForTokenClassification, i.e., I removed the hasattr etc. and just use config.classifier_dropout directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much again!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Can you just solve the conflicts so we can merge the PR?
@sgugger Ready to merge when the checks complete. Thanks for the fast action 👍 ... and more to come in the next weeks! |
* added GPTNeoForTokenClassification * add to top-level init * fixup * test * more fixup * add to gpt_neo.mdx * repo consistency * dummy copy * fix copies * optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6 * merge with main made this superfluous * added classifier_dropout * remove legacy code * removed fmt:on/off removed expected_outputs * doc style fix * classifier_dropout is always in config --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
* added GPTNeoForTokenClassification * add to top-level init * fixup * test * more fixup * add to gpt_neo.mdx * repo consistency * dummy copy * fix copies * optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6 * merge with main made this superfluous * added classifier_dropout * remove legacy code * removed fmt:on/off removed expected_outputs * doc style fix * classifier_dropout is always in config --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
What does this PR do?
It adds the class GPTNeoForTokenClassification, which allows using GPT Neo models for token classification tasks. The implementation follows the one for other models (such as GPT2) closely and simply adds a linear layer after the hidden states.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker
@younesbelkada