Skip to content

Add TokenClassification for Mistral, Mixtral and Qwen2 #29878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

josephenguehard
Copy link
Contributor

@josephenguehard josephenguehard commented Mar 26, 2024

What does this PR do?

This PR adds token classification for Llama and its extensions, where code from Llama was copied : Gemma, Mistral, Mixtral, Persimmon, Qwen2, Qwen2MoE, StableLm and StarCoder2:

  • GemmaForTokenClassification
  • LlamaForTokenClassification
  • MistralForTokenClassification
  • MixtralForTokenClassification
  • PersimmonForTokenClassification
  • Qwen2ForTokenClassification
  • Qwen2MoeForTokenClassification
  • StableLmForTokenClassification
  • StarCoder2ForTokenClassification

Note: The original PR was intended for Mistral only. Llama has been added as a request from @KoichiYasuoka, and the others models have been added as copying most of Llama setting.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@josephenguehard josephenguehard changed the title Add MistralForTokenClassification Add TokenClassification for Mistral, Mixtral and Qwen2 Mar 26, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the PR, could you explain the motivation behind this? No checkpoints are publicly available, and we usually first oopen a feature request and then integrate it if the community shows interest if there are no official checkpoints!

@KoichiYasuoka
Copy link
Contributor

Thank you @josephenguehard for

  • MistralForTokenClassification

since I've already released https://huggingface.co/KoichiYasuoka/Swallow-MS-7b-upos (using trust_remote_code=True). And would you kindly include

  • LlamaForTokenClassification

from #26521 for my https://huggingface.co/KoichiYasuoka/Swallow-7b-plus-upos ?

@josephenguehard
Copy link
Contributor Author

Hi Both!

Thank you for your replies! I will open a separate issue to discuss adding these classes.

@KoichiYasuoka I'd be happy to add LlamaForTokenClassification! I'll try to have this done by the end of the week.

@josephenguehard
Copy link
Contributor Author

The feature request issue linked to this PR is #29940

@josephenguehard
Copy link
Contributor Author

@KoichiYasuoka I've added Llama support as well as its extensions (models where code from Llama was copied), and I think the PR is ready for review.

One point where I'd like to have a discussion is on the Dropout. I see in your implementation that you include a Dropout layer, which follows the BertForTokenClassification.
However, I argue that this implementation actually follows BertForSequenceClassification, which also has a Dropout layer. Moreover, LlamaForSequenceClassification does not have such Dropout layer, so I think we should not include it for LlamaForTokenClassification either. But happy to get other's opinions!

@KoichiYasuoka
Copy link
Contributor

Well, Dropout layer... My Implementation follows GPT2ForTokenClassification which was released with #13290 at Transformers v4.10.0. But I'm vague how GPT2 and LLaMA differ at that point...

@josephenguehard
Copy link
Contributor Author

It's very easy for me to add Dropout if there is a consensus on it.
FYI I'm away and won't have access to my laptop until April 15th, will pick this up then.

@KoichiYasuoka
Copy link
Contributor

@SeanLee97
Copy link

Yesterday @SeanLee97 of @WhereIsAI released two models:

They use their own package BiLLM with LlamaForTokenClassification and MistralForTokenClassification.

Thank you for mentioning our released models. Happy to see that official 🤗 transformers supports token classification.
However, based on our experiments in the paper (https://arxiv.org/abs/2310.01208), LLMs with the causal mask cannot perform well on token classification tasks. Thus, we convert the causal attention mask from uni- to bi-directional. This change can improve the token classification performance significantly. Hopefully this feature can be added to the official transformers in the future.

@KoichiYasuoka
Copy link
Contributor

Thank you @SeanLee97 but bi-directional seems not causal... IMHO we need causal option in LlamaConfig as used in XLMConfig, where causal=True (default) means uni-directional attention in contrast of causal=False bi-directional.

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@KoichiYasuoka
Copy link
Contributor

I'm waiting for your review @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's just add copied from for the test and good to go.
Seems to have decent traction, which is what I was waiting for!

@@ -367,6 +372,21 @@ def test_Gemma_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

def test_Gemma_token_classification_model(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests can also be copied from!

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 13, 2024

also make sure to resolve the conflics

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@josephenguehard
Copy link
Contributor Author

@ArthurZucker I've added Copied from for the tests and resolved the conflicts.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! LGTM with only one nit - can you propagate these changes to all models with make fix-copies (once the suggestion accepted)

@josephenguehard
Copy link
Contributor Author

Done @younesbelkada 👍

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean, thanks !

@younesbelkada younesbelkada merged commit 07bf2df into huggingface:main May 20, 2024
@josephenguehard josephenguehard deleted the add-mistral-for-token-classification branch May 20, 2024 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants