Skip to content

Conversation

@maxdebayser
Copy link
Contributor

@maxdebayser maxdebayser commented Mar 9, 2025

FIX #13609
FIX #15384
FIX #18469

Here I'm loading the extra sparse_linear.pt file using the secondary_weights loading introduced in the ultravox model when I detect that the model name is BAAI/bge-m3. It's a bit ugly but I don't know if there is a more generic way to do this.

Currently, since the only permissible pooling return type is torch.tensor, I'm just returning the token weights tensor directly. If the use wants to match tokens to the weights they have to call tokenize and remove the bos and eos token and then the indices of both vectors should match.

To request sparse vectors the use has to pass
"additional_data": {"sparse_embeddings": true} in the request. This means that all sequences in that request will be treated as sparse. If the user wants to mix, separate calls have to be made for each type of embedding.

The FlagEmbedding API allows to return more then one type of embedding at the same time, but currently, due to the limitation of the pooling return type we can only return a single tensor per sequence.

To show that this PoC is already returning the correct results, consider the code below:

from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel('BAAI/bge-m3',  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

sentences_1 = ["What is BGE M3?", "Defination of BM25"]

output_1 = model.encode(sentences_1, return_dense=True, return_sparse=True, return_colbert_vecs=False)
print(model.convert_id_to_token(output_1['lexical_weights']))

This code prints

[{'What': 0.08344, 'is': 0.08136, 'B': 0.1295, 'GE': 0.252, 'M': 0.1702, '3': 0.2695, '?': 0.04086}, {'De': 0.05023, 'fin': 0.1368, 'ation': 0.0452, 'of': 0.0635, 'BM': 0.2515, '25': 0.3337}]

With vllm we get the following:

$ curl -s http://localhost:8000/v1/embeddings    -H "Content-Type: application/json"    -d '{
     "model": "BAAI/bge-m3",
     "input": ["What is BGE M3?", "Defination of BM25"],
     "additional_data": {"sparse_embeddings": true}
}' | jq
{
  "id": "embd-38ce076880b94d41b206ae99caae7b19",
  "object": "list",
  "created": 1741555561,
  "model": "BAAI/bge-m3",
  "data": [
    {
      "index": 0,
      "object": "embedding",
      "embedding": [
        0.0836181640625,
        0.08148193359375,
        0.1295166015625,
        0.251708984375,
        0.1700439453125,
        0.269775390625,
        0.040924072265625
      ]
    },
    {
      "index": 1,
      "object": "embedding",
      "embedding": [
        0.050201416015625,
        0.136962890625,
        0.04510498046875,
        0.0633544921875,
        0.25146484375,
        0.333740234375
      ]
    }
  ],
  "usage": {
    "prompt_tokens": 17,
    "total_tokens": 17,
    "completion_tokens": 0,
    "prompt_tokens_details": null
  }
}

@github-actions
Copy link

github-actions bot commented Mar 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

DarkLight1337 commented Mar 10, 2025

To support sparse+dense together, we need to actually implement #12249. I still don't have time to implement this though.

@maxdebayser
Copy link
Contributor Author

I've changed the implementation so that now the user has to add --hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}' to the command line to activate this mode. But I agree that we need to implement #12249 to properly support this and other models like ibm-granite/granite-embedding-30m-sparse. Let's keep this PR in draft state for now.

@243006306
Copy link

This is great, looking forward to the launch of this feature, how long will it take for this feature to be available?

@IllyaPysarchuk
Copy link

+1, waiting for this feature.

@mergify
Copy link

mergify bot commented Apr 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maxdebayser.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 1, 2025
@arjunasuresh300
Copy link

+1

@Sam120204
Copy link

any update?

@maxdebayser
Copy link
Contributor Author

The V1 embedding PR is already approved but is now blocked by other unrelated test failures: #16188 . The next step will be to add support for encoder models as they have been left out of the embedding model PR to make it simpler.

@fufenghua
Copy link

现在还不支持哇

@mergify mergify bot added the new-model Requests to new models label Jul 11, 2025
@DarkLight1337
Copy link
Member

I think this should be possible now that we support multiple poolers

@maxdebayser
Copy link
Contributor Author

I think this should be possible now that we support multiple poolers
We can select the embedding types per request, right? But can we have multiple pooling strategies applied on the same request?
Anyway, I'll revive this PR to work for one pooling type per request already.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jul 24, 2025

We can support different task per request in the model runner, but this isn't exposed in the API server yet

Now with the pooling task framework

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@maxdebayser
Copy link
Contributor Author

@DarkLight1337 , I've updated the PR now that we have V1 embeddings and the new task refactoring. The new request form is:

curl -s http://localhost:8000/pooling    -H "Content-Type: application/json"    -d '{
     "model": "BAAI/bge-m3",
     "task": "embed-sparse",
     "input": ["What is BGE M3?", "Defination of BM25"]
}' | jq
{
  "id": "pool-f3ea25d3e28d4b40b686092badd99f91",
  "object": "list",
  "created": 1755018267,
  "model": "BAAI/bge-m3",
  "data": [
    {
      "index": 0,
      "object": "pooling",
      "data": [
        0.08349609375,
        0.0814208984375,
        0.1295166015625,
        0.251708984375,
        0.1700439453125,
        0.26953125,
        0.04083251953125
      ]
    },
    {
      "index": 1,
      "object": "pooling",
      "data": [
        0.05010986328125,
        0.136962890625,
        0.045013427734375,
        0.06341552734375,
        0.25146484375,
        0.33349609375
      ]
    }
  ],
  "usage": {
    "prompt_tokens": 17,
    "total_tokens": 17,
    "completion_tokens": 0,
    "prompt_tokens_details": null
  }
}

As a PoC, I created a new task "embed-sparse", but I'm not 100% happy with it, I don't think it will scale if we have to add many different new tasks. Maybe we should add sub-tasks that are model-defined that the dispatcher can use to route the requests.

Another point is that the output is not very expressive. To get the tokens the user would have to have to call tokenize and match the tokens with the embeddings by position. I think we should make the PoolingResponse more generic to add task-specific outputs. This is related to the discussion #21621

Finally, I'm not sure what the best way to test this model is. We could test it against the outputs of the FlagEmbedding library, but that means that we would have to add yet another dependency, which I think we already have to many of. Maybe we could just test a request against a known output.

@DarkLight1337
Copy link
Member

I'm not 100% happy with it, I don't think it will scale if we have to add many different new tasks

Agreed. Currently we allow the Pooler to define their own list of supported tasks but in order for those tasks to work, we also have to update the PoolingParams checking and request dispatching, which could be quite complicated. Having subtask would allow us to keep using the existing logic for the base task.

@DarkLight1337
Copy link
Member

Another point is that the output is not very expressive. To get the tokens the user would have to have to call tokenize and match the tokens with the embeddings by position. I think we should make the PoolingResponse more generic to add task-specific outputs.

Yeah, I see now the need for having a registry for each task to override how to transform the response. This would greatly improve the user experience when using encode method.

@DarkLight1337
Copy link
Member

Finally, I'm not sure what the best way to test this model is.

We can generate the ground truth locally using FlagEmbedding (set up a helper function so it is easy for us to update the result in case of version changes), and then inside the CI we compare our impl to those generated results.

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@maxdebayser maxdebayser marked this pull request as ready for review October 16, 2025 20:09
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR.

Comment on lines 1666 to 1680
class PoolingCompletionRequest(EmbeddingCompletionRequest):
task: str | None = None

def to_pooling_params(self):
return PoolingParams(
dimensions=self.dimensions, normalize=self.normalize, task=self.task
)


class PoolingChatRequest(EmbeddingChatRequest):
task: str | None = None

def to_pooling_params(self):
return PoolingParams(
dimensions=self.dimensions, normalize=self.normalize, task=self.task

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve truncate_prompt_tokens in pooling requests

The new PoolingCompletionRequest.to_pooling_params() and PoolingChatRequest.to_pooling_params() no longer pass truncate_prompt_tokens to PoolingParams. Prior to this change, callers could limit the prompt length by setting truncate_prompt_tokens and the value was forwarded in Embedding*Request.to_pooling_params. After the refactor, any truncate_prompt_tokens sent with a pooling request is silently ignored, so long prompts will no longer be truncated even though the API accepts the parameter. This can lead to unexpectedly long contexts or failure when inputs exceed the model’s max length.

Useful? React with 👍 / 👎.

Comment on lines 181 to 197
try:
pooling_params = request.to_pooling_params()

if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
if pooling_params.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)

try:
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
else:
if pooling_params.task not in self.supported_tasks:
raise ValueError(f"Task {pooling_params.task} is not supported")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Validate pooling params even when task provided

When the client now supplies task in the pooling request, the server only checks membership in supported_tasks and skips pooling_params.verify. That verification step previously filled in default values (e.g. normalize embeddings, apply classification activations) and rejected incompatible parameters. With the new branch, normalize/activation stay None and no validation runs, so explicit task requests return un‑normalized embeddings and token classifications without the configured activation (e.g. the ReLU for sparse weights), and invalid parameter combinations are never rejected. PoolingParams.verify(pooling_params.task, …) still needs to run in this path.

Useful? React with 👍 / 👎.

@maxdebayser
Copy link
Contributor Author

Now that @noooop has added support mulit-vector retrieval with the token_embed and token_classify tasks, I've refactored this PR in terms of these tasks.

To start the server, the architecture has to be overriden because otherwise the extra weight file won't be loaded for sparse embeddings (lexical weight).

vllm serve BAAI/bge-m3 --hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}'

With this setting, the server supports regular dense embedding, token_embed and token_classify:

curl -s http://localhost:8000/pooling    -H "Content-Type: application/json"    -d '{
     "model": "BAAI/bge-m3",
     "task": "token_classify", # this returns the lexical weights
     "input": ["What is BGE M3?", "Defination of BM25"]
}'
curl -s http://localhost:8000/pooling    -H "Content-Type: application/json"    -d '{
     "model": "BAAI/bge-m3",
     "task": "token_embed",
     "input": ["What is BGE M3?", "Defination of BM25"]
}'

Please note that the token_classify request will return an array if scores and not a dict of decoded tokens to their scores. The API currently doesn't support rich formats like that.

The lexical weights can also be retrieved with the offline API:

llm = LLM(
    model="BAAI/bge-m3",
    runner="pooling",
    enforce_eager=True,
    hf_overrides={"architectures": ["BgeM3EmbeddingModel"]})

outputs = llm.encode(prompts, pooling_task="token_classify")

cc: @DarkLight1337

@maxdebayser maxdebayser changed the title First working PoC for bge-m3 sparse embeddings Support bge-m3 sparse embeddings (lexical weights) Oct 16, 2025
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Comment on lines +1666 to +1681
class PoolingCompletionRequest(EmbeddingCompletionRequest):
task: str | None = None

def to_pooling_params(self):
params = super().to_pooling_params()
params.task = self.task
return params


class PoolingChatRequest(EmbeddingChatRequest):
task: str | None = None

def to_pooling_params(self):
params = super().to_pooling_params()
params.task = self.task
return params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to add the task parameter in #25524 and make it required. Thank for adding it now.

Comment on lines -184 to 196

if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
if pooling_params.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
pooling_task = pooling_params.task

if pooling_task not in self.supported_tasks:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to make the task parameter required in #25524, which can simplify this logic.

Comment on lines +210 to +226
return DispatchPooler(
{
"embed": Pooler.for_embed(pooler_config),
"token_embed": BOSEOSFilter(
Pooler.for_token_embed(pooler_config),
self.bos_token_id,
self.eos_token_id,
),
"token_classify": BOSEOSFilter(
Pooler.for_token_classify(
pooler_config, classifier=self.sparse_linear, act_fn=torch.relu
),
self.bos_token_id,
self.eos_token_id,
),
}
)
Copy link
Collaborator

@noooop noooop Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool !

BGE-M3 Multi-Functionality:

  • embed for dense retrieval
  • token_embed for multi-vector retrieval
  • token_classify for sparse retrieval

Nothing stops us from using a plugin task to output everything at once. (after #26973 landing)

This way, BGE-M3 will be the best demonstration of the flexibility of our new pooler API.

@DarkLight1337 You must come and see this


Please add examples to demonstrate how users can use it. As well as adding tests to guard this feature

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best is to use a plugin task to output everything all at once. This is more efficient.

This may need to coordinate with #26973

I think a separate PR is still needed to inform everyone that the plugin pooling task has been added, although this PR makes few code changes

Please feel free to modify anything in #26973, as well as any PR of mine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend new-model Requests to new models

Projects

None yet

8 participants