-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[Model] Add AutoWeightsLoader support for BERT, RoBERTa #20534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Add AutoWeightsLoader support for BERT, RoBERTa #20534
Conversation
Signed-off-by: Jennifer He <islandhe@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @panyuhe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the BERT model implementation by adding support for AutoWeightsLoader to streamline the weight loading process. It also addresses a critical bug that prevented successful checkpoint loading due to an incorrect parameter registration, ensuring the model can be loaded and used correctly.
Highlights
- Model Weights Loading: Adds
AutoWeightsLoadersupport for the BERT model, enabling more flexible and automated weight loading. - Bug Fix: Fixes a checkpoint loading error by registering
position_idsas a buffer instead of a parameter to avoidValueError.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a correctness fix by changing position_ids to a buffer, resolving a model loading error. It also refactors the weight loading logic for BertForSequenceClassification to use AutoWeightsLoader, enhancing code maintainability.
vllm/model_executor/models/bert.py
Outdated
| self.register_buffer( | ||
| "position_ids", | ||
| torch.arange(config.max_position_embeddings).expand((1, -1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
| bert_weights = [] | ||
| classifier_weights = [] | ||
|
|
||
| self_weights = [] | ||
|
|
||
| def weight_filter(): | ||
| for name, weight in weights: | ||
| if name.startswith("bert."): | ||
| yield (name[len("bert."):], weight) | ||
| else: | ||
| self_weights.append((name, weight)) | ||
| for name, weight in weights: | ||
| if name.startswith("bert."): | ||
| bert_weights.append((name, weight)) | ||
| else: | ||
| classifier_weights.append((name, weight)) | ||
|
|
||
| self.bert.load_weights(weight_filter()) | ||
| loader = AutoWeightsLoader(self) | ||
| loaded_params = loader.load_weights(bert_weights) | ||
|
|
||
| params_dict = dict(self.named_parameters()) | ||
|
|
||
| for name, loaded_weight in self_weights: | ||
| if name.startswith("classifier"): | ||
| for name, loaded_weight in classifier_weights: | ||
| if name in params_dict: | ||
| param = params_dict[name] | ||
| weight_loader = getattr(param, "weight_loader", | ||
| default_weight_loader) | ||
| weight_loader(param, loaded_weight) | ||
| loaded_params.add(name) | ||
|
|
||
| return loaded_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self._pooler = ClassifierPooler(vllm_config.model_config, | ||
| self.classifier, self.bert.pooler) | ||
|
|
||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to consider the Roberta model in roberta.py as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer! I've refactored the RoBERTa models, and ran the following tests to verify that the models load successfully (also updated the PR description):
RobertaForSequenceClassification:
cardiffnlp/twitter-roberta-base-sentiment-latestjinaai/jina-embeddings-v3
RobertaEmbeddingModel:
FacebookAI/roberta-basesentence-transformers/stsb-roberta-base-v2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should verify all the tests under /tests/models/language/pooling.
I am doing testing locally.
This code involves too many models.
|
Thanks for taking this on. I was thinking if we could also simplify the logic in BertModel. There are several special cases such as: I think they can be refactored using AutoWeightsLoader but it could be tricky to get all models right. |
Thanks for the comments! Looking at the bias checks like:
The main challenge for fully migrating BertModel.load_weights() to AutoWeightsLoader is the QKV fusion logic, which requires the 3-parameter signature |
RobertaForSequenceClassification: - python3 -m vllm.entrypoints.cli.main serve cardiffnlp/twitter-roberta-base-sentiment-latest --served-model-name roberta-sentiment --trust-remote-code - python3 -m vllm.entrypoints.cli.main serve jinaai/jina-embeddings-v3 --served-model-name jina-v3 --trust-remote-code RobertaEmbeddingMode: - python3 -m vllm.entrypoints.cli.main serve FacebookAI/roberta-base --served-model-name roberta-base --trust-remote-code - python3 -m vllm.entrypoints.cli.main serve sentence-transformers/stsb-roberta-base-v2 --served-model-name stsb-roberta --trust-remote-code BertEmbeddingModel: - python3 -m vllm.entrypoints.cli.main serve sentence-transformers/all-MiniLM-L6-v2 --served-model-name bert-embeddings --trust-remote-code Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com>
a7769db to
ed9c1ae
Compare
noooop
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thank you for simplifying this code.
|
Help enable the Language Models Test (Extended Pooling) I'm not sure if all models are loaded correctly. |
|
Can we pause CI testing first and wait for my local testing? I actually just said that I think the code is OK, but not ready. |
|
It's too complicated to cancel the tests except for the extended pooling test, I'll just keep them running |
|
@panyuhe Please click so that each submission will not trigger the full process CI, which will run for more than 3 hours. Wait until we have solved most of the problems in local testing before starting it |
|
I'm testing locally and will send you the failed tests |
|
Overall, most models were loaded successfully, only a few models fails ValueError: Following weights were not initialized from checkpoint: {'model.embeddings.position_ids'} FAILED tests/models/language/pooling/test_scoring.py::test_cross_encoder_1_to_1[BAAI/bge-reranker-v2-m3] ERROR tests/entrypoints/openai/test_embedding_dimensions.py::test_matryoshka[model_info1-bfloat16] FAILED tests/models/language/pooling/test_snowflake_arctic_embed.py::test_embed_models_mteb[model_info0] roberta_task_weights_filter in roberta.py may no longer be used |
|
@panyuhe It took me a long time to learn how to run these tests. (╯‵□′)╯︵┻━┻ vllm is really too complicated |
Thank you for helping to run the tests. I'm able to reproduce the same error messages using vllm.entrypoints.cli.main on the failing models. I'm still working on setting up the test environment to run the actual pooling tests locally, and will reach out if I run into difficulties. Thank you! |
|
You may need to install mteb[bm25s]>=1.38.11, <2 and pytest-asyncio to run the tests For all dependencies, please refer to
But installing everything will be very slow Then you can launch the failing tests using pytest, e.g. pytest -s -vvv tests/models/language/pooling/test_snowflake_arctic_embed.py::test_embed_models_mteb[model_info0] |
|
@noooop All of the originally failing tests are passing now and I pasted the results of running the full pooling test suite below. Thank you for your help! Making an amendment to my previous statement, AutoWeightsLoader does not actually check if the weight is in the model's named parameters, so I added back the check for that. Full pooling test results (some failing due to HW) |
Also, switch position_ids to be initialized as a buffer and clean up unused code. Signed-off-by: Jen H <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com>
d3a7a66 to
216f147
Compare
|
buildkite/ci/pr/async-engine-inputs-utils-worker-test failure is unrelated to this PR All tests passed, my local test also passed. @maxdebayser Is there anything else to add? |
|
Merge from main branch should fix the failing async engine test. |
|
Done! |
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com> Signed-off-by: x22x22 <wadeking@qq.com>
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com>
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com>
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…#20534) Signed-off-by: Jennifer He <islandhe@gmail.com> Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com>
Purpose
FIX (partial) #15697
Also, fixes a checkpoint loading error by adding
requires_grad=Falsetoposition_ids. This avoids the following exception when loading weights:Test Plan
BertEmbeddingModel:
BertForSequenceClassification:
RobertaForSequenceClassification
RobertaEmbeddingModel:
Test Result
Models load successfully