Skip to content
20 changes: 18 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,8 +1662,24 @@ def to_pooling_params(self):

EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest

PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest

class PoolingCompletionRequest(EmbeddingCompletionRequest):
task: str | None = None

def to_pooling_params(self):
params = super().to_pooling_params()
params.task = self.task
return params


class PoolingChatRequest(EmbeddingChatRequest):
task: str | None = None

def to_pooling_params(self):
params = super().to_pooling_params()
params.task = self.task
return params
Comment on lines +1666 to +1681
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to add the task parameter in #25524 and make it required. Thank for adding it now.



T = TypeVar("T")

Expand Down
20 changes: 11 additions & 9 deletions vllm/entrypoints/openai/serving_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,21 @@ async def create_pooling(
try:
pooling_params = request.to_pooling_params()

if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
if pooling_params.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
pooling_task = pooling_params.task

if pooling_task not in self.supported_tasks:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
)
Comment on lines -184 to 196
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to make the task parameter required in #25524, which can simplify this logic.


try:
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
pooling_params.verify(pooling_task, self.model_config)

for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
Expand Down
45 changes: 45 additions & 0 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,48 @@ def forward(
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s


class BOSEOSFilter(Pooler):
"""Filters the BOS and EOS token results from outputs."""

def __init__(
self,
pooler: Pooler,
bos_token_id: int,
eos_token_id: int,
) -> None:
super().__init__()

self.pooler = pooler
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooler.get_supported_tasks()

def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)

def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list)

for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
pooled_data = pooled_outputs[i]
assert (
isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len
)
token_ids = pooling_metadata.prompt_token_ids[i]
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze()

return pooled_outputs
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"LlavaNextForConditionalGeneration": (
Expand Down
83 changes: 82 additions & 1 deletion vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import itertools
from collections.abc import Iterable

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
BOSEOSFilter,
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT,
BertEmbeddingModel,
Expand Down Expand Up @@ -160,6 +164,83 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return loader.load_weights(weights_list, mapper=mapper)


def filter_secondary_weights(
all_weights: Iterable[tuple[str, torch.Tensor]],
secondary_weights: list[str],
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
all_weights1, all_weights2 = itertools.tee(all_weights)

def filtered(n):
return any(n.startswith(f) for f in secondary_weights)

return ((n, w) for n, w in all_weights1 if filtered(n)), (
(n, w) for n, w in all_weights2 if not filtered(n)
)


class BgeM3EmbeddingModel(RobertaEmbeddingModel):
"""A model that extends RobertaEmbeddingModel with sparse embeddings.

This class supports loading an additional sparse_linear.pt file
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.hidden_size = vllm_config.model_config.hf_config.hidden_size

model_config = vllm_config.model_config
self.head_dtype = model_config.head_dtype
self.bos_token_id = model_config.hf_config.bos_token_id
self.eos_token_id = model_config.hf_config.eos_token_id

super().__init__(vllm_config=vllm_config, prefix=prefix)
self.secondary_weight_prefix = "sparse_linear."

self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
revision=None,
prefix=self.secondary_weight_prefix,
allow_patterns_overrides=["sparse_linear.pt"],
)
]

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
return DispatchPooler(
{
"embed": Pooler.for_embed(pooler_config),
"token_embed": BOSEOSFilter(
Pooler.for_token_embed(pooler_config),
self.bos_token_id,
self.eos_token_id,
),
"token_classify": BOSEOSFilter(
Pooler.for_token_classify(
pooler_config, classifier=self.sparse_linear, act_fn=torch.relu
),
self.bos_token_id,
self.eos_token_id,
),
}
)
Comment on lines +210 to +226
Copy link
Collaborator

@noooop noooop Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool !

BGE-M3 Multi-Functionality:

  • embed for dense retrieval
  • token_embed for multi-vector retrieval
  • token_classify for sparse retrieval

Nothing stops us from using a plugin task to output everything at once. (after #26973 landing)

This way, BGE-M3 will be the best demonstration of the flexibility of our new pooler API.

@DarkLight1337 You must come and see this


Please add examples to demonstrate how users can use it. As well as adding tests to guard this feature

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best is to use a plugin task to output everything all at once. This is more efficient.

This may need to coordinate with #26973

I think a separate PR is still needed to inform everyone that the plugin pooling task has been added, although this PR makes few code changes

Please feel free to modify anything in #26973, as well as any PR of mine.


def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
secondary, weights = filter_secondary_weights(
all_weights, [self.secondary_weight_prefix]
)

super().load_weights(weights)

params_dict = dict(self.named_parameters())

for name, loaded_weight in secondary:
if name.startswith(self.secondary_weight_prefix):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)


@default_pooling_type("CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
Expand Down