|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | 
|  | 3 | + | 
|  | 4 | +import types | 
|  | 5 | + | 
|  | 6 | +import numpy as np | 
|  | 7 | +import pytest | 
|  | 8 | +import torch | 
|  | 9 | +import torch.nn as nn | 
|  | 10 | + | 
|  | 11 | +from vllm.model_executor.models.bert import ( | 
|  | 12 | +    BertMLMHead, | 
|  | 13 | +    SPLADESparsePooler, | 
|  | 14 | +) | 
|  | 15 | + | 
|  | 16 | +# --------------------------------------------------------------------- | 
|  | 17 | +# 1) Functional test: SPLADE formula correctness (no HF download needed) | 
|  | 18 | +# --------------------------------------------------------------------- | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)]) | 
|  | 22 | +def test_splade_pooler_matches_reference_formula(B, T, H, V): | 
|  | 23 | +    """Ensure SPLADESparsePooler forward() matches the mathematical formula: | 
|  | 24 | +    log1p(relu(logits)) -> max over sequence length (after masking).""" | 
|  | 25 | +    torch.manual_seed(0) | 
|  | 26 | + | 
|  | 27 | +    # Prepare [B] sequences of shape [T, H] | 
|  | 28 | +    hs_list = [torch.randn(T, H) for _ in range(B)] | 
|  | 29 | + | 
|  | 30 | +    # Simulate PoolingMetadata (only required fields) | 
|  | 31 | +    prompt_lens = [T, T - 1] | 
|  | 32 | +    token_ids = torch.tensor( | 
|  | 33 | +        [ | 
|  | 34 | +            [101, 5, 102],  # Batch 0: [CLS], token, [SEP] | 
|  | 35 | +            [101, 6, 6],  # Batch 1: [CLS], token, token (last token ignored) | 
|  | 36 | +        ], | 
|  | 37 | +        dtype=torch.long, | 
|  | 38 | +    ) | 
|  | 39 | +    meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids) | 
|  | 40 | + | 
|  | 41 | +    # MLM head (prefer BertMLMHead, fallback to Linear if unavailable) | 
|  | 42 | +    try: | 
|  | 43 | +        mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12) | 
|  | 44 | +    except Exception: | 
|  | 45 | +        mlm_head = nn.Linear(H, V, bias=True) | 
|  | 46 | + | 
|  | 47 | +    # Forward pass through SPLADE pooler | 
|  | 48 | +    pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True) | 
|  | 49 | +    pooled = pooler(hidden_states=hs_list, pooling_metadata=meta)  # list of [V] | 
|  | 50 | + | 
|  | 51 | +    # Basic output checks | 
|  | 52 | +    assert isinstance(pooled, list) and len(pooled) == B | 
|  | 53 | +    for vec in pooled: | 
|  | 54 | +        assert vec.shape == (V,) | 
|  | 55 | +        assert torch.isfinite(vec).all() | 
|  | 56 | +        assert (vec >= 0).all(), "SPLADE outputs must be non-negative." | 
|  | 57 | + | 
|  | 58 | +    # Reference implementation for comparison | 
|  | 59 | +    def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor: | 
|  | 60 | +        keep = torch.ones(L, dtype=torch.bool) | 
|  | 61 | +        if L > 0 and tid_row[0].item() == 101:  # remove CLS | 
|  | 62 | +            keep[0] = False | 
|  | 63 | +        if L > 0 and tid_row[L - 1].item() == 102:  # remove SEP | 
|  | 64 | +            keep[L - 1] = False | 
|  | 65 | + | 
|  | 66 | +        valid = hs[:L][keep[:L]] | 
|  | 67 | +        if valid.numel() == 0: | 
|  | 68 | +            return torch.zeros(V, dtype=torch.float32) | 
|  | 69 | + | 
|  | 70 | +        logits = mlm_head(valid)  # [L', V] | 
|  | 71 | +        scores = torch.log1p(torch.relu(logits))  # [L', V] | 
|  | 72 | +        return scores.max(dim=0).values.to(torch.float32) | 
|  | 73 | + | 
|  | 74 | +    torch.testing.assert_close( | 
|  | 75 | +        pooled[0], | 
|  | 76 | +        ref_one(hs_list[0], prompt_lens[0], token_ids[0]), | 
|  | 77 | +        rtol=1e-4, | 
|  | 78 | +        atol=1e-4, | 
|  | 79 | +    ) | 
|  | 80 | +    torch.testing.assert_close( | 
|  | 81 | +        pooled[1], | 
|  | 82 | +        ref_one(hs_list[1], prompt_lens[1], token_ids[1]), | 
|  | 83 | +        rtol=1e-4, | 
|  | 84 | +        atol=1e-4, | 
|  | 85 | +    ) | 
|  | 86 | + | 
|  | 87 | + | 
|  | 88 | +# --------------------------------------------------------------------- | 
|  | 89 | +# 2) Integration smoke test: end-to-end embedding path wiring | 
|  | 90 | +# --------------------------------------------------------------------- | 
|  | 91 | + | 
|  | 92 | + | 
|  | 93 | +@pytest.mark.cpu_model | 
|  | 94 | +def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch): | 
|  | 95 | +    """Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings.""" | 
|  | 96 | +    from transformers import AutoTokenizer | 
|  | 97 | + | 
|  | 98 | +    MODEL_ID = "hf-internal-testing/tiny-random-bert" | 
|  | 99 | +    hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]} | 
|  | 100 | + | 
|  | 101 | +    # Enforce CPU-only execution (optional) | 
|  | 102 | +    monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "") | 
|  | 103 | +    monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") | 
|  | 104 | + | 
|  | 105 | +    tok = AutoTokenizer.from_pretrained(MODEL_ID) | 
|  | 106 | +    vocab_size = tok.vocab_size | 
|  | 107 | + | 
|  | 108 | +    # The embed path should route through SPLADESparsePooler | 
|  | 109 | +    with vllm_runner( | 
|  | 110 | +        MODEL_ID, | 
|  | 111 | +        runner="pooling", | 
|  | 112 | +        max_model_len=64, | 
|  | 113 | +        hf_overrides=hf_overrides, | 
|  | 114 | +    ) as vm: | 
|  | 115 | +        outs = vm.embed(["hello world", "splade sparse test"]) | 
|  | 116 | + | 
|  | 117 | +        # Basic sanity checks | 
|  | 118 | +        assert len(outs) == 2 | 
|  | 119 | +        assert outs[0].shape[0] == vocab_size | 
|  | 120 | +        assert outs[1].shape[0] == vocab_size | 
|  | 121 | +        assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all() | 
|  | 122 | +        assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all() | 
0 commit comments