Skip to content

Commit 657860b

Browse files
committed
Add SPLADE sparse embedding model and tests- Removed unnecessary torch.no_grad() (handled by vLLM framework)- Added model loading entry to tests/models/registry.py- Added SPLADESparsePooler functional + smoke tests to ensure future stability
Signed-off-by: gjgjos <gjgjos@naver.com>
1 parent 3827c27 commit 657860b

File tree

3 files changed

+127
-8
lines changed

3 files changed

+127
-8
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import types
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
import torch.nn as nn
10+
11+
from vllm.model_executor.models.bert import (
12+
BertMLMHead,
13+
SPLADESparsePooler,
14+
)
15+
16+
# ---------------------------------------------------------------------
17+
# 1) Functional test: SPLADE formula correctness (no HF download needed)
18+
# ---------------------------------------------------------------------
19+
20+
21+
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
22+
def test_splade_pooler_matches_reference_formula(B, T, H, V):
23+
"""Ensure SPLADESparsePooler forward() matches the mathematical formula:
24+
log1p(relu(logits)) -> max over sequence length (after masking)."""
25+
torch.manual_seed(0)
26+
27+
# Prepare [B] sequences of shape [T, H]
28+
hs_list = [torch.randn(T, H) for _ in range(B)]
29+
30+
# Simulate PoolingMetadata (only required fields)
31+
prompt_lens = [T, T - 1]
32+
token_ids = torch.tensor(
33+
[
34+
[101, 5, 102], # Batch 0: [CLS], token, [SEP]
35+
[101, 6, 6], # Batch 1: [CLS], token, token (last token ignored)
36+
],
37+
dtype=torch.long,
38+
)
39+
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids)
40+
41+
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
42+
try:
43+
mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12)
44+
except Exception:
45+
mlm_head = nn.Linear(H, V, bias=True)
46+
47+
# Forward pass through SPLADE pooler
48+
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
49+
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V]
50+
51+
# Basic output checks
52+
assert isinstance(pooled, list) and len(pooled) == B
53+
for vec in pooled:
54+
assert vec.shape == (V,)
55+
assert torch.isfinite(vec).all()
56+
assert (vec >= 0).all(), "SPLADE outputs must be non-negative."
57+
58+
# Reference implementation for comparison
59+
def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor:
60+
keep = torch.ones(L, dtype=torch.bool)
61+
if L > 0 and tid_row[0].item() == 101: # remove CLS
62+
keep[0] = False
63+
if L > 0 and tid_row[L - 1].item() == 102: # remove SEP
64+
keep[L - 1] = False
65+
66+
valid = hs[:L][keep[:L]]
67+
if valid.numel() == 0:
68+
return torch.zeros(V, dtype=torch.float32)
69+
70+
logits = mlm_head(valid) # [L', V]
71+
scores = torch.log1p(torch.relu(logits)) # [L', V]
72+
return scores.max(dim=0).values.to(torch.float32)
73+
74+
torch.testing.assert_close(
75+
pooled[0],
76+
ref_one(hs_list[0], prompt_lens[0], token_ids[0]),
77+
rtol=1e-4,
78+
atol=1e-4,
79+
)
80+
torch.testing.assert_close(
81+
pooled[1],
82+
ref_one(hs_list[1], prompt_lens[1], token_ids[1]),
83+
rtol=1e-4,
84+
atol=1e-4,
85+
)
86+
87+
88+
# ---------------------------------------------------------------------
89+
# 2) Integration smoke test: end-to-end embedding path wiring
90+
# ---------------------------------------------------------------------
91+
92+
93+
@pytest.mark.cpu_model
94+
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
95+
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
96+
from transformers import AutoTokenizer
97+
98+
MODEL_ID = "hf-internal-testing/tiny-random-bert"
99+
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}
100+
101+
# Enforce CPU-only execution (optional)
102+
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
103+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
104+
105+
tok = AutoTokenizer.from_pretrained(MODEL_ID)
106+
vocab_size = tok.vocab_size
107+
108+
# The embed path should route through SPLADESparsePooler
109+
with vllm_runner(
110+
MODEL_ID,
111+
runner="pooling",
112+
max_model_len=64,
113+
hf_overrides=hf_overrides,
114+
) as vm:
115+
outs = vm.embed(["hello world", "splade sparse test"])
116+
117+
# Basic sanity checks
118+
assert len(outs) == 2
119+
assert outs[0].shape[0] == vocab_size
120+
assert outs[1].shape[0] == vocab_size
121+
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
122+
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ def check_available_online(
483483
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
484484
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
485485
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
486+
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
487+
"naver/splade-v3", is_available_online=False
488+
),
486489
# [Multimodal]
487490
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
488491
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),

vllm/model_executor/models/bert.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,6 @@ def get_supported_tasks(self) -> Set[PoolingTask]:
629629
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
630630
return PoolingParamsUpdate(requires_token_ids=True)
631631

632-
@torch.no_grad()
633632
def forward(
634633
self,
635634
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@@ -649,7 +648,7 @@ def forward(
649648
B = len(hs_list)
650649
H = hs_list[0].size(-1)
651650

652-
raw_lens = getattr(pooling_metadata, "prompt_lens", None)
651+
raw_lens = pooling_metadata.prompt_lens
653652

654653
def _fallback_lens_from_hs():
655654
return [int(h.size(0)) for h in hs_list]
@@ -724,12 +723,7 @@ def _fallback_lens_from_hs():
724723
torch.isneginf(pooled), torch.zeros_like(pooled), pooled
725724
)
726725

727-
outs: list[torch.Tensor] = []
728-
for i in range(B):
729-
vec = pooled[i].to(torch.float32).contiguous().view(-1) # [V]
730-
outs.append(vec)
731-
732-
return outs
726+
return pooled.contiguous()
733727

734728

735729
@default_pooling_type("CLS")

0 commit comments

Comments
 (0)