- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) #26339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) #26339
Conversation
Signed-off-by: gjgjos <gjgjos@naver.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the naver/splade-v3 sparse embedding model by introducing BertSpladeSparseEmbeddingModel and SPLADESparsePooler. The implementation is well-tested and demonstrates correctness against Hugging Face and TEI frameworks. My review focuses on improving the robustness and maintainability of the new BertSpladeSparseEmbeddingModel class, particularly in the load_weights method, where I've identified opportunities for optimization and safer error handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
        
          
                vllm/model_executor/models/bert.py
              
                Outdated
          
        
      | @torch.no_grad() | ||
| def forward( | ||
| self, | ||
| hidden_states: Union[torch.Tensor, list[torch.Tensor]], | ||
| pooling_metadata: PoolingMetadata, | ||
| ) -> Union[torch.Tensor, list[torch.Tensor]]: | ||
| if isinstance(hidden_states, torch.Tensor): | ||
| hs_list = [hidden_states] | ||
| else: | ||
| hs_list = list(hidden_states) | ||
|  | ||
| for i, hs in enumerate(hs_list): | ||
| if hs.dim() == 3 and hs.size(0) == 1: | ||
| hs_list[i] = hs.squeeze(0) # [L, H] | ||
| elif hs.dim() != 2: | ||
| raise ValueError(f"Expected [L,H] or [1,L,H], got {tuple(hs.shape)}") | ||
|  | ||
| B = len(hs_list) | ||
| H = hs_list[0].size(-1) | ||
|  | ||
| raw_lens = getattr(pooling_metadata, "prompt_lens", None) | ||
|  | ||
| def _fallback_lens_from_hs(): | ||
| return [int(h.size(0)) for h in hs_list] | ||
|  | ||
| if raw_lens is None: | ||
| lens = _fallback_lens_from_hs() | ||
| elif isinstance(raw_lens, int): | ||
| lens = [int(raw_lens)] * B | ||
| else: | ||
| try: | ||
| tmp = list(raw_lens) | ||
| if len(tmp) == B: | ||
| lens = [int(x) for x in tmp] | ||
| elif len(tmp) == 1: | ||
| lens = [int(tmp[0])] * B | ||
| else: | ||
| lens = _fallback_lens_from_hs() | ||
| except TypeError: | ||
| lens = _fallback_lens_from_hs() | ||
|  | ||
| max_len = max(int(h.size(0)) for h in hs_list) | ||
| device = hs_list[0].device | ||
|  | ||
| # pad to [B, T, H] | ||
| padded = hs_list[0].new_zeros((B, max_len, H)) # zeros | ||
| attn_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device) | ||
|  | ||
| for i, (hs, L) in enumerate(zip(hs_list, lens)): | ||
| L = int(L) | ||
| L = min(L, max_len) | ||
| padded[i, :L] = hs[:L] | ||
| attn_mask[i, :L] = True | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
  Pooler ignores batching layout and drops extra requests
The new SPLADESparsePooler.forward wraps the incoming hidden_states tensor into a single item whenever it is a 2‑D tensor (lines 638‑649) and never consults the pooling_metadata.pooling_cursor that encodes how multiple requests are concatenated. In the vLLM runner, embeddings are pooled from a single [total_tokens, hidden] tensor containing all prompts in a batch. With the current logic only the first prompt in the batch is padded and pooled while the remaining prompts are silently ignored, causing incorrect or missing embeddings whenever more than one request is processed together. The pooler should use pooling_cursor (as done in SimplePooler) to split the tensor per request before applying the MLM head.
Useful? React with 👍 / 👎.
| @hmellor I guess transformers backend can't really handle custom poolers based on the current design, right? | 
| 
 Right now no there is no way to register custom poolers. It wouldn't be too hard to add a  Or do you mean a mechanism to register custom poolers in the Transformers backend with no upstream changes? | 
| 
 Yeah that's what I'm thinking. I guess implementing this in vLLM is the most reasonable solution without upstream changes then. | 
2799f7f    to
    3106979      
    Compare
  
    | 
 Oh that wouldn't require any upstream changes. These changes would be made in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/transformers_pooling.py | 
| The only caveat is that it would mean users have to install Transformers from source because the Transformers side refactor that enables the Transformers backend for BERT models is not in a release yet. | 
…h.no_grad() (handled by vLLM framework)- Added model loading entry to tests/models/registry.py- Added SPLADESparsePooler functional + smoke tests to ensure future stability Signed-off-by: gjgjos <gjgjos@naver.com>
3ab178a    to
    657860b      
    Compare
  
    | /gemini review | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the naver/splade-v3 sparse embedding model. The implementation is well-structured, introducing BertSpladeSparseEmbeddingModel and SPLADESparsePooler. The accompanying tests are thorough, covering both functional correctness and integration with vLLM's serving capabilities.
My review identifies two high-severity issues. First, a broad except Exception: pass in the weight loading logic could mask critical errors and lead to silent failures. Second, the SPLADE pooling method is hardcoded to 'max', preventing users from selecting the 'sum' method, which is mentioned as supported. Addressing these points will improve the robustness and configurability of the new model support.
        
          
                vllm/model_executor/models/bert.py
              
                Outdated
          
        
      | try: | ||
| emb_w = self.model.embeddings.word_embeddings.weight | ||
| dec_w = self.mlm_head.decoder.weight | ||
| if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | ||
| self.mlm_head.decoder.weight = emb_w | ||
| except Exception: | ||
| pass | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The try...except Exception: pass block is too broad and can hide important errors during weight loading. For instance, if self.model.embeddings or other attributes do not exist due to a model structure mismatch, an AttributeError would be silently ignored, making debugging difficult. This could lead to weights not being tied when they should be, resulting in incorrect model behavior. It's better to catch more specific exceptions, like AttributeError, or at least log a warning if an exception occurs.
| try: | |
| emb_w = self.model.embeddings.word_embeddings.weight | |
| dec_w = self.mlm_head.decoder.weight | |
| if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | |
| self.mlm_head.decoder.weight = emb_w | |
| except Exception: | |
| pass | |
| try: | |
| emb_w = self.model.embeddings.word_embeddings.weight | |
| dec_w = self.mlm_head.decoder.weight | |
| if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | |
| self.mlm_head.decoder.weight = emb_w | |
| except AttributeError: | |
| # It's possible that some BERT variants may not have this structure. | |
| # If we can't find the weights to tie, it's not a critical | |
| # error, as the model can still function with untied weights. | |
| pass | 
58d045b    to
    83c2b7d      
    Compare
  
    Signed-off-by: gjgjos <gjgjos@naver.com>
83c2b7d    to
    706a735      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the refactoring, @gjgjos , it looks very clean now. I've left a few extra comments based on the fact that we no longer need to handle the case where hidden_states is a list because that was deprecated.
Signed-off-by: gjgjos <gjgjos@naver.com>
| 
 Thanks! I’ve cleaned up the remaining list-handling logic as suggested — the code now fully assumes a single concatenated tensor. | 
| Stamping, I assume you have tested this model already | 
| 
 Yes, I’ve already tested it — everything works as expected. Thank you for your help!! | 
| Hi @gjgjos , I noticed the new added two tests are failing on the main branch, for both of CUDA and CPU. (EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]   File "/workspace/vllm/vllm/v1/worker/gpu_worker.py", line 229, in load_model
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]   File "/workspace/vllm/vllm/v1/worker/cpu_model_runner.py", line 68, in load_model
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]     self.model = get_model(vllm_config=self.vllm_config)
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]   File "/workspace/vllm/vllm/model_executor/model_loader/__init__.py", line 130, in get_model
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]     return loader.load_model(vllm_config=vllm_config, model_config=model_config)
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]   File "/workspace/vllm/vllm/model_executor/model_loader/base_loader.py", line 55, in load_model
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]     self.load_weights(model, model_config)
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]   File "/workspace/vllm/vllm/model_executor/model_loader/default_loader.py", line 323, in load_weights
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790]     raise ValueError(
(EngineCore_DP0 pid=1338945) ERROR 10-13 05:39:04 [core.py:790] ValueError: Following weights were not initialized from checkpoint: {'mlm_head.decoder.bias'}Any idea about this? | 
| 
 will fix in #25817 | 
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: 1994 <1994@users.noreply.github.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: bbartels <benjamin@bartels.dev>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…g model) (vllm-project#26339) Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Purpose
This PR adds official support for the
naver/splade-v3model, a BERT-based sparse retrieval model utilizing the SPLADE pooling mechanism.The implementation introduces the
BertSpladeSparseEmbeddingModelclass, extendingBertEmbeddingModelto generate sparse lexical embeddings from the MLM head output (log1p(ReLU(logits))), fully compatible with vLLM’s embedding API (/v1/embeddingsand/poolingendpoints).This enables users to serve SPLADE models via vLLM with high performance and verified consistency against Hugging Face’s
SparseEncoderand TEI (Text Embeddings Inference) frameworks.Implementation Details
New model registration
Architecture
Backbone:
bertHead: MLM head (
cls.predictions.*)Pooling:
SPLADESparsePooler(supportsmaxorsum)Output: sparse lexical embedding vector (dimension = vocab size ≈ 30k)
Modified files
bert.py→ addedBertSpladeSparseEmbeddingModelregistry.py→ registered model under"bert"familyTest Plan
1️⃣ vLLM-based Docker serving
Run script
Server log highlights
✅ Successfully initialized with torch.compile graph caching and KVCache disabled (sparse embedding mode).
The
/v1/embeddingsroute was available for inference.2️⃣ vLLM Inference Test (Python Client) — Actual response & parsed preview
Request
Actual response JSON (shape)
{ "id": "embd-c1899570dd224953adf527b49be8120e", "object": "list", "created": 1759815423, "model": "splade-v3", "data": { "embeddings": [ /* ... dense array of size ~30k, mostly zeros, e.g. 0, ..., 1.08984375, 0.55126953125, 0.0, 0.16845703125, 0.0, 0.0, 0.308837890625, 0.0, 0.0, 1.689453125, 0.0, 0.671875, 0.0, 1.255859375, ... */ ] }, "usage": { "prompt_tokens": 9, "total_tokens": 9, "completion_tokens": 0, "prompt_tokens_details": null } }Parsing helper & preview
Observed output
3️⃣ Hugging Face
SparseEncoderVerificationResult
✅ The vLLM and Hugging Face results are numerically identical (within 1e-4 float tolerance) across all nonzero indices and values.
4️⃣ TEI (Text Embeddings Inference) Consistency Test
Container launch
Test via curl
Response
✅ The TEI server’s output is functionally equivalent to the vLLM response, confirming correct sparse pooling and alignment of activation magnitudes.
Test Result Summary
All three implementations produce identical sparse activation patterns and values, demonstrating full correctness and interoperability.
Notes
No regression for existing
BertEmbeddingModelor dense embedding workflows.Sparse embedding fully integrated with
PoolingTask.embed.Works with FlashAttention backend and torch.compile graph caching.
TEI consistency ensures vLLM can serve SPLADE models interchangeably in hybrid retrieval systems.
Clearly described purpose (add SPLADE support for
naver/splade-v3)Test plan included (vLLM, HF, TEI parity)
Verified consistent outputs across frameworks
Registry and pooling code updated
No backward-compatibility issues introduced
(Optional) Update
supported_models.md(Optional) Add release note entry
✅ Summary:
This PR adds end-to-end integration of the BERT-based
naver/splade-v3sparse embedding model into vLLM.