diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5b8a118280da..5df9fcb7b3e3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1662,8 +1662,24 @@ def to_pooling_params(self): EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest -PoolingCompletionRequest = EmbeddingCompletionRequest -PoolingChatRequest = EmbeddingChatRequest + +class PoolingCompletionRequest(EmbeddingCompletionRequest): + task: str | None = None + + def to_pooling_params(self): + params = super().to_pooling_params() + params.task = self.task + return params + + +class PoolingChatRequest(EmbeddingChatRequest): + task: str | None = None + + def to_pooling_params(self): + params = super().to_pooling_params() + params.task = self.task + return params + T = TypeVar("T") diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 102a29fe35cd..9b02f2814f31 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -181,19 +181,21 @@ async def create_pooling( try: pooling_params = request.to_pooling_params() - if "token_embed" in self.supported_tasks: - pooling_task = "token_embed" - elif "token_classify" in self.supported_tasks: - pooling_task = "token_classify" + if pooling_params.task is None: + if "token_embed" in self.supported_tasks: + pooling_task = "token_embed" + elif "token_classify" in self.supported_tasks: + pooling_task = "token_classify" else: + pooling_task = pooling_params.task + + if pooling_task not in self.supported_tasks: return self.create_error_response( - f"pooling_task must be one of {self.supported_tasks}." + f"Task {pooling_task} is not supported, it" + f" must be one of {self.supported_tasks}." ) - try: - pooling_params.verify(pooling_task, self.model_config) - except ValueError as e: - return self.create_error_response(str(e)) + pooling_params.verify(pooling_task, self.model_config) for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 84e176f0ea89..82f6196ddd90 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -803,3 +803,48 @@ def forward( def extra_repr(self) -> str: s = f"supported_task={self.get_supported_tasks()}" return s + + +class BOSEOSFilter(Pooler): + """Filters the BOS and EOS token results from outputs.""" + + def __init__( + self, + pooler: Pooler, + bos_token_id: int, + eos_token_id: int, + ) -> None: + super().__init__() + + self.pooler = pooler + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooler.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_outputs = self.pooler(hidden_states, pooling_metadata) + assert isinstance(pooled_outputs, list) + + for i, prompt_len in enumerate(pooling_metadata.prompt_lens): + pooled_data = pooled_outputs[i] + assert ( + isinstance(pooled_data, torch.Tensor) + and pooled_data.shape[0] == prompt_len + ) + token_ids = pooling_metadata.prompt_token_ids[i] + if token_ids[0] == self.bos_token_id: + pooled_data = pooled_data[1:] + if token_ids[-1] == self.eos_token_id: + pooled_data = pooled_data[:-1] + pooled_outputs[i] = pooled_data.squeeze() + + return pooled_outputs diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d119c161f6b3..b4f7f3f7dbed 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -201,6 +201,7 @@ "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), + "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), # [Multimodal] "CLIPModel": ("clip", "CLIPEmbeddingModel"), "LlavaNextForConditionalGeneration": ( diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index cfccb904f46c..b97a0fe2331e 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,20 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools from collections.abc import Iterable import torch from torch import nn from transformers import RobertaConfig -from vllm.config import ModelConfig, VllmConfig +from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.model_executor.layers.pooler import ( + BOSEOSFilter, ClassifierPooler, CLSPool, DispatchPooler, Pooler, ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import ( TOKEN_TYPE_SHIFT, BertEmbeddingModel, @@ -160,6 +164,83 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return loader.load_weights(weights_list, mapper=mapper) +def filter_secondary_weights( + all_weights: Iterable[tuple[str, torch.Tensor]], + secondary_weights: list[str], +) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: + all_weights1, all_weights2 = itertools.tee(all_weights) + + def filtered(n): + return any(n.startswith(f) for f in secondary_weights) + + return ((n, w) for n, w in all_weights1 if filtered(n)), ( + (n, w) for n, w in all_weights2 if not filtered(n) + ) + + +class BgeM3EmbeddingModel(RobertaEmbeddingModel): + """A model that extends RobertaEmbeddingModel with sparse embeddings. + + This class supports loading an additional sparse_linear.pt file + to create sparse embeddings as described in https://arxiv.org/abs/2402.03216 + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + + model_config = vllm_config.model_config + self.head_dtype = model_config.head_dtype + self.bos_token_id = model_config.hf_config.bos_token_id + self.eos_token_id = model_config.hf_config.eos_token_id + + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.secondary_weight_prefix = "sparse_linear." + + self.secondary_weights = [ + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + revision=None, + prefix=self.secondary_weight_prefix, + allow_patterns_overrides=["sparse_linear.pt"], + ) + ] + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype) + return DispatchPooler( + { + "embed": Pooler.for_embed(pooler_config), + "token_embed": BOSEOSFilter( + Pooler.for_token_embed(pooler_config), + self.bos_token_id, + self.eos_token_id, + ), + "token_classify": BOSEOSFilter( + Pooler.for_token_classify( + pooler_config, classifier=self.sparse_linear, act_fn=torch.relu + ), + self.bos_token_id, + self.eos_token_id, + ), + } + ) + + def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]): + secondary, weights = filter_secondary_weights( + all_weights, [self.secondary_weight_prefix] + ) + + super().load_weights(weights) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in secondary: + if name.startswith(self.secondary_weight_prefix): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + @default_pooling_type("CLS") class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities.