Skip to content

Commit 843916c

Browse files
gjgjosDarkLight1337
authored andcommitted
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) (vllm-project#26339)
Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 6813726 commit 843916c

File tree

4 files changed

+340
-0
lines changed

4 files changed

+340
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import types
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
import torch.nn as nn
10+
11+
from vllm.model_executor.models.bert import (
12+
BertMLMHead,
13+
SPLADESparsePooler,
14+
)
15+
16+
# ---------------------------------------------------------------------
17+
# 1) Functional test: SPLADE formula correctness (no HF download needed)
18+
# ---------------------------------------------------------------------
19+
20+
21+
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
22+
def test_splade_pooler_matches_reference_formula(B, T, H, V):
23+
"""Ensure SPLADESparsePooler forward() matches the mathematical formula:
24+
log1p(relu(logits)) -> max over sequence length (after masking)."""
25+
torch.manual_seed(0)
26+
27+
# Prepare [B] sequences of shape [T, H]
28+
hs_list = [torch.randn(T, H) for _ in range(B)]
29+
30+
# Simulate PoolingMetadata (only required fields)
31+
prompt_lens = [T, T - 1]
32+
token_ids = torch.tensor(
33+
[
34+
[101, 5, 102], # Batch 0: [CLS], token, [SEP]
35+
[101, 6, 6], # Batch 1: [CLS], token, token (last token ignored)
36+
],
37+
dtype=torch.long,
38+
)
39+
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids)
40+
41+
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
42+
try:
43+
mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12)
44+
except Exception:
45+
mlm_head = nn.Linear(H, V, bias=True)
46+
47+
# Forward pass through SPLADE pooler
48+
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
49+
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V]
50+
51+
# Basic output checks
52+
assert isinstance(pooled, list) and len(pooled) == B
53+
for vec in pooled:
54+
assert vec.shape == (V,)
55+
assert torch.isfinite(vec).all()
56+
assert (vec >= 0).all(), "SPLADE outputs must be non-negative."
57+
58+
# Reference implementation for comparison
59+
def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor:
60+
keep = torch.ones(L, dtype=torch.bool)
61+
if L > 0 and tid_row[0].item() == 101: # remove CLS
62+
keep[0] = False
63+
if L > 0 and tid_row[L - 1].item() == 102: # remove SEP
64+
keep[L - 1] = False
65+
66+
valid = hs[:L][keep[:L]]
67+
if valid.numel() == 0:
68+
return torch.zeros(V, dtype=torch.float32)
69+
70+
logits = mlm_head(valid) # [L', V]
71+
scores = torch.log1p(torch.relu(logits)) # [L', V]
72+
return scores.max(dim=0).values.to(torch.float32)
73+
74+
torch.testing.assert_close(
75+
pooled[0],
76+
ref_one(hs_list[0], prompt_lens[0], token_ids[0]),
77+
rtol=1e-4,
78+
atol=1e-4,
79+
)
80+
torch.testing.assert_close(
81+
pooled[1],
82+
ref_one(hs_list[1], prompt_lens[1], token_ids[1]),
83+
rtol=1e-4,
84+
atol=1e-4,
85+
)
86+
87+
88+
# ---------------------------------------------------------------------
89+
# 2) Integration smoke test: end-to-end embedding path wiring
90+
# ---------------------------------------------------------------------
91+
92+
93+
@pytest.mark.cpu_model
94+
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
95+
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
96+
from transformers import AutoTokenizer
97+
98+
MODEL_ID = "hf-internal-testing/tiny-random-bert"
99+
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}
100+
101+
# Enforce CPU-only execution (optional)
102+
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
103+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
104+
105+
tok = AutoTokenizer.from_pretrained(MODEL_ID)
106+
vocab_size = tok.vocab_size
107+
108+
# The embed path should route through SPLADESparsePooler
109+
with vllm_runner(
110+
MODEL_ID,
111+
runner="pooling",
112+
max_model_len=64,
113+
hf_overrides=hf_overrides,
114+
) as vm:
115+
outs = vm.embed(["hello world", "splade sparse test"])
116+
117+
# Basic sanity checks
118+
assert len(outs) == 2
119+
assert outs[0].shape[0] == vocab_size
120+
assert outs[1].shape[0] == vocab_size
121+
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
122+
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ def check_available_online(
486486
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
487487
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
488488
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
489+
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
490+
"naver/splade-v3", is_available_online=False
491+
),
489492
# [Multimodal]
490493
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
491494
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),

vllm/model_executor/models/bert.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,220 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
572572
return token_type_ids
573573

574574

575+
class BertMLMHead(nn.Module):
576+
def __init__(
577+
self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12
578+
):
579+
super().__init__()
580+
self.dense = nn.Linear(hidden_size, hidden_size)
581+
self.activation = nn.GELU()
582+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
583+
self.decoder = nn.Linear(hidden_size, vocab_size, bias=True)
584+
585+
def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor):
586+
self.decoder.weight = embeddings_weight
587+
588+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
589+
x = self.dense(hidden_states)
590+
x = self.activation(x)
591+
x = self.layer_norm(x)
592+
logits = self.decoder(x)
593+
return logits
594+
595+
596+
class SPLADESparsePooler(Pooler):
597+
"""
598+
SPLADE sparse pooling:
599+
logits = mlm_head(hidden_states)
600+
-> log1p(relu(logits))
601+
-> (max|sum over L)
602+
-> [V]
603+
604+
Padding is masked with an attention mask,
605+
[CLS]/[SEP] is removed (selected),
606+
and then pooled.
607+
"""
608+
609+
def __init__(
610+
self,
611+
mlm_head: nn.Module,
612+
cls_token_id: Optional[int] = 101,
613+
sep_token_id: Optional[int] = 102,
614+
pooling: str = "max",
615+
remove_cls_sep: bool = True,
616+
):
617+
super().__init__()
618+
assert pooling in ("max", "sum")
619+
self.mlm_head = mlm_head
620+
self.cls_token_id = cls_token_id
621+
self.sep_token_id = sep_token_id
622+
self.pooling = pooling
623+
self.remove_cls_sep = remove_cls_sep
624+
625+
def get_supported_tasks(self) -> Set[PoolingTask]:
626+
return {"embed"}
627+
628+
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
629+
return PoolingParamsUpdate(requires_token_ids=True)
630+
631+
def forward(
632+
self,
633+
hidden_states: torch.Tensor,
634+
pooling_metadata: PoolingMetadata,
635+
) -> torch.Tensor:
636+
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2
637+
638+
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
639+
lens: list[int] = lens_tensor.tolist()
640+
B: int = len(lens)
641+
642+
token_ids = pooling_metadata.prompt_token_ids
643+
offset = 0
644+
pooled_list: list[torch.Tensor] = []
645+
646+
for i in range(B):
647+
L = int(lens[i])
648+
hs = hidden_states[offset : offset + L]
649+
650+
start_idx = 0
651+
end_idx = L
652+
if self.remove_cls_sep and token_ids is not None:
653+
if (
654+
self.cls_token_id is not None
655+
and token_ids[i, 0].item() == self.cls_token_id
656+
):
657+
start_idx = 1
658+
if (
659+
self.sep_token_id is not None
660+
and token_ids[i, L - 1].item() == self.sep_token_id
661+
):
662+
end_idx = max(start_idx, L - 1)
663+
664+
if end_idx <= start_idx:
665+
V = int(self.mlm_head.decoder.out_features)
666+
pooled_list.append(hs.new_zeros((V,)))
667+
offset += L
668+
continue
669+
670+
logits_i = self.mlm_head(hs[start_idx:end_idx])
671+
scores_i = torch.log1p(torch.relu(logits_i))
672+
673+
if self.pooling == "sum":
674+
pooled_i = scores_i.sum(dim=0)
675+
else: # "max"
676+
pooled_i = scores_i.max(dim=0).values
677+
678+
pooled_list.append(pooled_i.contiguous())
679+
offset += L
680+
681+
return torch.stack(pooled_list, dim=0).contiguous()
682+
683+
684+
@default_pooling_type("CLS")
685+
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
686+
"""
687+
BertEmbeddingModel + SPLADE sparse embedding.
688+
- Make logits by self.mlm_head
689+
- pooler: SPLADESparsePooler(mlm_head...)
690+
"""
691+
692+
def __init__(
693+
self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max"
694+
):
695+
super().__init__(vllm_config=vllm_config, prefix=prefix)
696+
cfg = vllm_config.model_config.hf_config
697+
698+
# MLM head
699+
self.mlm_head = BertMLMHead(
700+
hidden_size=cfg.hidden_size,
701+
vocab_size=cfg.vocab_size,
702+
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
703+
)
704+
705+
self._splade_pooling = splade_pooling
706+
pooler_config = vllm_config.model_config.pooler_config
707+
assert pooler_config is not None
708+
self.pooler = self._build_pooler(pooler_config)
709+
710+
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
711+
cfg = self.model.config
712+
713+
if not hasattr(self, "mlm_head"):
714+
self.mlm_head = BertMLMHead(
715+
hidden_size=cfg.hidden_size,
716+
vocab_size=cfg.vocab_size,
717+
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
718+
)
719+
720+
pooling_mode = getattr(self, "_splade_pooling", "max")
721+
722+
cls_id = getattr(cfg, "cls_token_id", None)
723+
sep_id = getattr(cfg, "sep_token_id", None)
724+
725+
return DispatchPooler(
726+
{
727+
"encode": Pooler.for_encode(pooler_config),
728+
"embed": SPLADESparsePooler(
729+
mlm_head=self.mlm_head,
730+
cls_token_id=cls_id,
731+
sep_token_id=sep_id,
732+
pooling=pooling_mode, # "max" or "sum"
733+
remove_cls_sep=True,
734+
),
735+
}
736+
)
737+
738+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
739+
if not hasattr(self, "mlm_head"):
740+
cfg = self.model.config
741+
self.mlm_head = BertMLMHead(
742+
hidden_size=cfg.hidden_size,
743+
vocab_size=cfg.vocab_size,
744+
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
745+
)
746+
747+
def _strip(name: str) -> str:
748+
for p in ("model.", "bert."):
749+
if name.startswith(p):
750+
name = name[len(p) :]
751+
return name
752+
753+
weights_list = list(weights)
754+
model_side: list[tuple[str, torch.Tensor]] = []
755+
mlm_side: list[tuple[str, torch.Tensor]] = []
756+
757+
for k, w in weights_list:
758+
name = _strip(k)
759+
if name.startswith("cls.predictions."):
760+
mlm_side.append((name, w))
761+
else:
762+
model_side.append((name, w))
763+
764+
loaded: set[str] = set()
765+
loaded_model = self.model.load_weights(model_side)
766+
loaded.update({"model." + n for n in loaded_model})
767+
768+
if mlm_side:
769+
name_map = {
770+
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
771+
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
772+
("cls.predictions.transform.LayerNorm.weight"): (
773+
"mlm_head.layer_norm.weight"
774+
),
775+
("cls.predictions.transform.LayerNorm.bias"): (
776+
"mlm_head.layer_norm.bias"
777+
),
778+
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
779+
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
780+
}
781+
remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
782+
if remapped:
783+
loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
784+
loaded.update(loaded_mlm)
785+
786+
return loaded
787+
788+
575789
@default_pooling_type("CLS")
576790
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
577791
"""A model that uses Bert to provide embedding functionalities.

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
_EMBEDDING_MODELS = {
173173
# [Text-only]
174174
"BertModel": ("bert", "BertEmbeddingModel"),
175+
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
175176
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
176177
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
177178
"Gemma3TextModel": ("gemma3", "Gemma3Model"),

0 commit comments

Comments
 (0)