-
-
Couldn't load subscription status.
- Fork 10.9k
Support bge-m3 sparse embeddings (lexical weights) #14526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3d443c6
925836f
25817d5
7d69a03
589e143
715ad4c
b0aa6a4
69a721a
de8f381
12455b2
daba294
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -181,19 +181,21 @@ async def create_pooling( | |
| try: | ||
| pooling_params = request.to_pooling_params() | ||
|
|
||
| if "token_embed" in self.supported_tasks: | ||
| pooling_task = "token_embed" | ||
| elif "token_classify" in self.supported_tasks: | ||
| pooling_task = "token_classify" | ||
| if pooling_params.task is None: | ||
| if "token_embed" in self.supported_tasks: | ||
| pooling_task = "token_embed" | ||
| elif "token_classify" in self.supported_tasks: | ||
| pooling_task = "token_classify" | ||
| else: | ||
| pooling_task = pooling_params.task | ||
|
|
||
| if pooling_task not in self.supported_tasks: | ||
| return self.create_error_response( | ||
| f"pooling_task must be one of {self.supported_tasks}." | ||
| f"Task {pooling_task} is not supported, it" | ||
| f" must be one of {self.supported_tasks}." | ||
| ) | ||
|
Comment on lines
-184
to
196
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I plan to make the task parameter required in #25524, which can simplify this logic. |
||
|
|
||
| try: | ||
| pooling_params.verify(pooling_task, self.model_config) | ||
| except ValueError as e: | ||
| return self.create_error_response(str(e)) | ||
| pooling_params.verify(pooling_task, self.model_config) | ||
|
|
||
| for i, engine_prompt in enumerate(engine_prompts): | ||
| request_id_item = f"{request_id}-{i}" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,24 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import itertools | ||
| from collections.abc import Iterable | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers import RobertaConfig | ||
|
|
||
| from vllm.config import ModelConfig, VllmConfig | ||
| from vllm.config import ModelConfig, PoolerConfig, VllmConfig | ||
| from vllm.model_executor.layers.pooler import ( | ||
| BOSEOSFilter, | ||
| ClassifierPooler, | ||
| CLSPool, | ||
| DispatchPooler, | ||
| Pooler, | ||
| ) | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding | ||
| from vllm.model_executor.model_loader.default_loader import DefaultModelLoader | ||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
| from vllm.model_executor.models.bert import ( | ||
| TOKEN_TYPE_SHIFT, | ||
| BertEmbeddingModel, | ||
|
|
@@ -160,6 +164,83 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | |
| return loader.load_weights(weights_list, mapper=mapper) | ||
|
|
||
|
|
||
| def filter_secondary_weights( | ||
| all_weights: Iterable[tuple[str, torch.Tensor]], | ||
| secondary_weights: list[str], | ||
| ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: | ||
| all_weights1, all_weights2 = itertools.tee(all_weights) | ||
|
|
||
| def filtered(n): | ||
| return any(n.startswith(f) for f in secondary_weights) | ||
|
|
||
| return ((n, w) for n, w in all_weights1 if filtered(n)), ( | ||
| (n, w) for n, w in all_weights2 if not filtered(n) | ||
| ) | ||
|
|
||
|
|
||
| class BgeM3EmbeddingModel(RobertaEmbeddingModel): | ||
| """A model that extends RobertaEmbeddingModel with sparse embeddings. | ||
|
|
||
| This class supports loading an additional sparse_linear.pt file | ||
| to create sparse embeddings as described in https://arxiv.org/abs/2402.03216 | ||
| """ | ||
|
|
||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| self.hidden_size = vllm_config.model_config.hf_config.hidden_size | ||
|
|
||
| model_config = vllm_config.model_config | ||
| self.head_dtype = model_config.head_dtype | ||
| self.bos_token_id = model_config.hf_config.bos_token_id | ||
| self.eos_token_id = model_config.hf_config.eos_token_id | ||
|
|
||
| super().__init__(vllm_config=vllm_config, prefix=prefix) | ||
| self.secondary_weight_prefix = "sparse_linear." | ||
|
|
||
| self.secondary_weights = [ | ||
| DefaultModelLoader.Source( | ||
| model_or_path=vllm_config.model_config.model, | ||
| revision=None, | ||
| prefix=self.secondary_weight_prefix, | ||
| allow_patterns_overrides=["sparse_linear.pt"], | ||
| ) | ||
| ] | ||
|
|
||
| def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: | ||
| self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype) | ||
| return DispatchPooler( | ||
| { | ||
| "embed": Pooler.for_embed(pooler_config), | ||
| "token_embed": BOSEOSFilter( | ||
| Pooler.for_token_embed(pooler_config), | ||
| self.bos_token_id, | ||
| self.eos_token_id, | ||
| ), | ||
| "token_classify": BOSEOSFilter( | ||
| Pooler.for_token_classify( | ||
| pooler_config, classifier=self.sparse_linear, act_fn=torch.relu | ||
| ), | ||
| self.bos_token_id, | ||
| self.eos_token_id, | ||
| ), | ||
| } | ||
| ) | ||
|
Comment on lines
+210
to
+226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool ! BGE-M3 Multi-Functionality:
Nothing stops us from using a plugin task to output everything at once. (after #26973 landing) This way, BGE-M3 will be the best demonstration of the flexibility of our new pooler API. @DarkLight1337 You must come and see this Please add examples to demonstrate how users can use it. As well as adding tests to guard this feature There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the best is to use a plugin task to output everything all at once. This is more efficient. This may need to coordinate with #26973 I think a separate PR is still needed to inform everyone that the plugin pooling task has been added, although this PR makes few code changes Please feel free to modify anything in #26973, as well as any PR of mine. |
||
|
|
||
| def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]): | ||
| secondary, weights = filter_secondary_weights( | ||
| all_weights, [self.secondary_weight_prefix] | ||
| ) | ||
|
|
||
| super().load_weights(weights) | ||
|
|
||
| params_dict = dict(self.named_parameters()) | ||
|
|
||
| for name, loaded_weight in secondary: | ||
| if name.startswith(self.secondary_weight_prefix): | ||
| param = params_dict[name] | ||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
| weight_loader(param, loaded_weight) | ||
|
|
||
|
|
||
| @default_pooling_type("CLS") | ||
| class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): | ||
| """A model that uses Roberta to provide embedding functionalities. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I plan to add the task parameter in #25524 and make it required. Thank for adding it now.