-
Notifications
You must be signed in to change notification settings - Fork 30k
Add TokenClassification for Mistral, Mixtral and Qwen2 #29878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TokenClassification for Mistral, Mixtral and Qwen2 #29878
Conversation
…ephenguehard/transformers into add-mistral-for-token-classification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks for the PR, could you explain the motivation behind this? No checkpoints are publicly available, and we usually first oopen a feature request and then integrate it if the community shows interest if there are no official checkpoints!
Thank you @josephenguehard for
since I've already released https://huggingface.co/KoichiYasuoka/Swallow-MS-7b-upos (using
from #26521 for my https://huggingface.co/KoichiYasuoka/Swallow-7b-plus-upos ? |
Hi Both! Thank you for your replies! I will open a separate issue to discuss adding these classes. @KoichiYasuoka I'd be happy to add |
The feature request issue linked to this PR is #29940 |
…ephenguehard/transformers into add-mistral-for-token-classification
@KoichiYasuoka I've added Llama support as well as its extensions (models where code from Llama was copied), and I think the PR is ready for review. One point where I'd like to have a discussion is on the Dropout. I see in your implementation that you include a Dropout layer, which follows the BertForTokenClassification. |
Well, Dropout layer... My Implementation follows GPT2ForTokenClassification which was released with #13290 at Transformers v4.10.0. But I'm vague how GPT2 and LLaMA differ at that point... |
It's very easy for me to add Dropout if there is a consensus on it. |
Yesterday @SeanLee97 of @WhereIsAI released two models:
They use their own package BiLLM with LlamaForTokenClassification and MistralForTokenClassification. |
Thank you for mentioning our released models. Happy to see that official 🤗 transformers supports token classification. |
Thank you @SeanLee97 but bi-directional seems not causal... IMHO we need |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I'm waiting for your review @ArthurZucker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Let's just add copied from for the test and good to go.
Seems to have decent traction, which is what I was waiting for!
@@ -367,6 +372,21 @@ def test_Gemma_sequence_classification_model_for_multi_label(self): | |||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) | |||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) | |||
|
|||
def test_Gemma_token_classification_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests can also be copied from!
also make sure to resolve the conflics |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker I've added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot ! LGTM with only one nit - can you propagate these changes to all models with make fix-copies
(once the suggestion accepted)
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Done @younesbelkada 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean, thanks !
What does this PR do?
This PR adds token classification for Llama and its extensions, where code from Llama was copied : Gemma, Mistral, Mixtral, Persimmon, Qwen2, Qwen2MoE, StableLm and StarCoder2:
GemmaForTokenClassification
LlamaForTokenClassification
MistralForTokenClassification
MixtralForTokenClassification
PersimmonForTokenClassification
Qwen2ForTokenClassification
Qwen2MoeForTokenClassification
StableLmForTokenClassification
StarCoder2ForTokenClassification
Note: The original PR was intended for Mistral only. Llama has been added as a request from @KoichiYasuoka, and the others models have been added as copying most of Llama setting.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?