Skip to content

Commit 99955db

Browse files
noooopamitm02
authored andcommitted
[CI] improve embed testing (vllm-project#18747)
Signed-off-by: amit <amit.man@gmail.com>
1 parent 6721806 commit 99955db

File tree

13 files changed

+244
-174
lines changed

13 files changed

+244
-174
lines changed

tests/entrypoints/openai/correctness/test_mteb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS,
7+
MTEB_EMBED_TOL,
78
OpenAIClientMtebEncoder,
89
run_mteb_embed_task,
910
run_mteb_embed_task_st)
@@ -38,4 +39,4 @@ def test_mteb(server):
3839
print("SentenceTransformer main score: ", st_main_score)
3940
print("Difference: ", st_main_score - vllm_main_score)
4041

41-
assert st_main_score == pytest.approx(vllm_main_score, rel=1e-4)
42+
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

tests/entrypoints/openai/test_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.entrypoints.openai.protocol import EmbeddingResponse
1212
from vllm.transformers_utils.tokenizer import get_tokenizer
1313

14-
from ...models.utils import run_embedding_correctness_test
14+
from ...models.language.pooling.embed_utils import (
15+
run_embedding_correctness_test)
1516
from ...utils import RemoteOpenAIServer
1617

1718
MODEL_NAME = "intfloat/multilingual-e5-small"

tests/entrypoints/openai/test_embedding_dimensions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from vllm.entrypoints.openai.protocol import EmbeddingResponse
1212

1313
from ...conftest import HfRunner
14-
from ...models.utils import EmbedModelInfo, run_embedding_correctness_test
14+
from ...models.language.pooling.embed_utils import (
15+
run_embedding_correctness_test)
16+
from ...models.utils import EmbedModelInfo
1517
from ...utils import RemoteOpenAIServer
1618

1719
MODELS = [
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from collections.abc import Sequence
3+
from typing import Optional
4+
5+
import pytest
6+
7+
from tests.conftest import HfRunner
8+
from tests.models.utils import (EmbedModelInfo, check_embeddings_close,
9+
matryoshka_fy)
10+
11+
12+
def run_embedding_correctness_test(
13+
hf_model: "HfRunner",
14+
inputs: list[str],
15+
vllm_outputs: Sequence[list[float]],
16+
dimensions: Optional[int] = None,
17+
):
18+
hf_outputs = hf_model.encode(inputs)
19+
if dimensions:
20+
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
21+
22+
check_embeddings_close(
23+
embeddings_0_lst=hf_outputs,
24+
embeddings_1_lst=vllm_outputs,
25+
name_0="hf",
26+
name_1="vllm",
27+
tol=1e-2,
28+
)
29+
30+
31+
def correctness_test_embed_models(hf_runner,
32+
vllm_runner,
33+
model_info: EmbedModelInfo,
34+
example_prompts,
35+
vllm_extra_kwargs=None,
36+
hf_model_callback=None):
37+
if not model_info.enable_test:
38+
# A model family has many models with the same architecture,
39+
# and we don't need to test each one.
40+
pytest.skip("Skipping test.")
41+
42+
# The example_prompts has ending "\n", for example:
43+
# "Write a short story about a robot that dreams for the first time.\n"
44+
# sentence_transformers will strip the input texts, see:
45+
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
46+
# This makes the input_ids different between hf_model and vllm_model.
47+
# So we need to strip the input texts to avoid test failing.
48+
example_prompts = [str(s).strip() for s in example_prompts]
49+
50+
vllm_extra_kwargs = vllm_extra_kwargs or {}
51+
vllm_extra_kwargs["dtype"] = model_info.dtype
52+
53+
with vllm_runner(model_info.name,
54+
task="embed",
55+
max_model_len=None,
56+
**vllm_extra_kwargs) as vllm_model:
57+
vllm_outputs = vllm_model.encode(example_prompts)
58+
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
59+
model_dtype = getattr(
60+
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
61+
vllm_dtype)
62+
63+
with hf_runner(
64+
model_info.name,
65+
dtype=model_dtype,
66+
is_sentence_transformer=True,
67+
) as hf_model:
68+
69+
if hf_model_callback is not None:
70+
hf_model_callback(hf_model)
71+
72+
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

tests/models/language/pooling/mteb_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,19 @@ def run_mteb_embed_task_st(model_name, tasks):
8080
def mteb_test_embed_models(hf_runner,
8181
vllm_runner,
8282
model_info: EmbedModelInfo,
83-
vllm_extra_kwargs=None):
83+
vllm_extra_kwargs=None,
84+
hf_model_callback=None):
8485
if not model_info.enable_test:
8586
# A model family has many models with the same architecture,
8687
# and we don't need to test each one.
8788
pytest.skip("Skipping test.")
8889

8990
vllm_extra_kwargs = vllm_extra_kwargs or {}
91+
vllm_extra_kwargs["dtype"] = model_info.dtype
9092

9193
with vllm_runner(model_info.name,
9294
task="embed",
9395
max_model_len=None,
94-
dtype=model_info.dtype,
9596
**vllm_extra_kwargs) as vllm_model:
9697

9798
if model_info.architecture:
@@ -108,10 +109,14 @@ def mteb_test_embed_models(hf_runner,
108109
with set_default_torch_dtype(model_dtype) and hf_runner(
109110
model_info.name, is_sentence_transformer=True,
110111
dtype=model_dtype) as hf_model:
112+
113+
if hf_model_callback is not None:
114+
hf_model_callback(hf_model)
115+
111116
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
112117

113118
print("VLLM:", vllm_dtype, vllm_main_score)
114119
print("SentenceTransformer:", model_dtype, st_main_score)
115120
print("Difference:", st_main_score - vllm_main_score)
116121

117-
assert st_main_score == pytest.approx(vllm_main_score, rel=MTEB_EMBED_TOL)
122+
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
5+
from .mteb_utils import mteb_test_embed_models
6+
7+
MODELS = [
8+
########## BertModel
9+
EmbedModelInfo("BAAI/bge-base-en",
10+
architecture="BertModel",
11+
enable_test=True),
12+
EmbedModelInfo("BAAI/bge-base-zh",
13+
architecture="BertModel",
14+
enable_test=False),
15+
EmbedModelInfo("BAAI/bge-small-en",
16+
architecture="BertModel",
17+
enable_test=False),
18+
EmbedModelInfo("BAAI/bge-small-zh",
19+
architecture="BertModel",
20+
enable_test=False),
21+
EmbedModelInfo("BAAI/bge-large-en",
22+
architecture="BertModel",
23+
enable_test=False),
24+
EmbedModelInfo("BAAI/bge-large-zh",
25+
architecture="BertModel",
26+
enable_test=False),
27+
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
28+
architecture="BertModel",
29+
enable_test=False),
30+
EmbedModelInfo("BAAI/bge-base-en-v1.5",
31+
architecture="BertModel",
32+
enable_test=False),
33+
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
34+
architecture="BertModel",
35+
enable_test=False),
36+
EmbedModelInfo("BAAI/bge-small-en-v1.5",
37+
architecture="BertModel",
38+
enable_test=False),
39+
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
40+
architecture="BertModel",
41+
enable_test=False),
42+
EmbedModelInfo("BAAI/bge-large-en-v1.5",
43+
architecture="BertModel",
44+
enable_test=False),
45+
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
46+
architecture="BertModel",
47+
enable_test=False),
48+
########## XLMRobertaModel
49+
EmbedModelInfo("BAAI/bge-m3",
50+
architecture="XLMRobertaModel",
51+
enable_test=True),
52+
########## Qwen2Model
53+
EmbedModelInfo("BAAI/bge-code-v1",
54+
architecture="Qwen2Model",
55+
dtype="float32",
56+
enable_test=True),
57+
]
58+
59+
60+
@pytest.mark.parametrize("model_info", MODELS)
61+
def test_embed_models_mteb(hf_runner, vllm_runner,
62+
model_info: EmbedModelInfo) -> None:
63+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
64+
65+
66+
@pytest.mark.parametrize("model_info", MODELS)
67+
def test_embed_models_correctness(hf_runner, vllm_runner,
68+
model_info: EmbedModelInfo,
69+
example_prompts) -> None:
70+
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
71+
example_prompts)

tests/models/language/pooling/test_gte.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import pytest
55

6-
from ...utils import EmbedModelInfo, run_embedding_correctness_test
6+
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
7+
from .mteb_utils import mteb_test_embed_models
78

89
MODELS = [
910
########## BertModel
@@ -53,9 +54,8 @@
5354

5455

5556
@pytest.mark.parametrize("model_info", MODELS)
56-
def test_models_mteb(hf_runner, vllm_runner,
57-
model_info: EmbedModelInfo) -> None:
58-
from .mteb_utils import mteb_test_embed_models
57+
def test_embed_models_mteb(hf_runner, vllm_runner,
58+
model_info: EmbedModelInfo) -> None:
5959

6060
vllm_extra_kwargs: dict[str, Any] = {}
6161
if model_info.architecture == "GteNewModel":
@@ -66,28 +66,13 @@ def test_models_mteb(hf_runner, vllm_runner,
6666

6767

6868
@pytest.mark.parametrize("model_info", MODELS)
69-
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
70-
example_prompts) -> None:
71-
if not model_info.enable_test:
72-
pytest.skip("Skipping test.")
73-
74-
# ST will strip the input texts, see test_embedding.py
75-
example_prompts = [str(s).strip() for s in example_prompts]
69+
def test_embed_models_correctness(hf_runner, vllm_runner,
70+
model_info: EmbedModelInfo,
71+
example_prompts) -> None:
7672

7773
vllm_extra_kwargs: dict[str, Any] = {}
7874
if model_info.architecture == "GteNewModel":
7975
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
8076

81-
with vllm_runner(model_info.name,
82-
task="embed",
83-
dtype=model_info.dtype,
84-
max_model_len=None,
85-
**vllm_extra_kwargs) as vllm_model:
86-
vllm_outputs = vllm_model.encode(example_prompts)
87-
88-
with hf_runner(
89-
model_info.name,
90-
dtype=model_info.dtype,
91-
is_sentence_transformer=True,
92-
) as hf_model:
93-
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
77+
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
78+
example_prompts, vllm_extra_kwargs)

0 commit comments

Comments
 (0)