Skip to content

Commit 374b19f

Browse files
noooop0xrushi
authored andcommitted
[Model][0/N] Improve all pooling task | clean up (vllm-project#25817)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent d456373 commit 374b19f

File tree

19 files changed

+197
-188
lines changed

19 files changed

+197
-188
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
581581
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
582582

583583
!!! note
584-
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
584+
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner_client.py>.
585585

586586
[](){ #supported-mm-models }
587587

examples/online_serving/pooling/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ python examples/online_serving/pooling/jinaai_rerank_client.py
1515
## Named Entity Recognition (NER) usage
1616

1717
```bash
18-
python examples/online_serving/pooling/ner.py
18+
python examples/online_serving/pooling/ner_client.py
1919
```
2020

2121
## Openai chat embedding for multimodal usage

tests/ci_envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from collections.abc import Callable
99
from typing import TYPE_CHECKING, Any
1010

11+
from vllm.envs import maybe_convert_bool
12+
1113
if TYPE_CHECKING:
1214
VLLM_CI_NO_SKIP: bool = False
1315
VLLM_CI_DTYPE: str | None = None
@@ -25,6 +27,10 @@
2527
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
2628
# Allow changing the head dtype used by transformers in tests
2729
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
30+
# Allow control over whether tests use enforce_eager
31+
"VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool(
32+
os.getenv("VLLM_CI_ENFORCE_EAGER", None)
33+
),
2834
}
2935

3036

tests/entrypoints/pooling/llm/test_classify.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def get_outputs(activation):
5858
)
5959

6060

61+
@pytest.mark.skip_global_cleanup
6162
def test_encode_api(llm: LLM):
63+
# chunked prefill does not support all pooling
6264
err_msg = "pooling_task must be one of.+"
6365
with pytest.raises(ValueError, match=err_msg):
6466
llm.encode(prompts, use_tqdm=False)

tests/entrypoints/pooling/llm/test_embedding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def llm():
3535
cleanup_dist_env_and_memory()
3636

3737

38-
@pytest.mark.skip_global_cleanup
3938
def test_pooling_params(llm: LLM):
4039
def get_outputs(normalize):
4140
outputs = llm.embed(

tests/entrypoints/pooling/llm/test_encode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def test_multiple_pooling_params(llm: LLM):
7474
assert len(PROMPTS) == len(outputs)
7575

7676

77-
@pytest.mark.skip_global_cleanup
7877
def test_right_side_truncation(llm: LLM):
7978
# Embeddings models should truncate the end of the prompt
8079
tokenizer = llm.get_tokenizer()

tests/entrypoints/pooling/llm/test_score.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def llm():
3333
cleanup_dist_env_and_memory()
3434

3535

36-
@pytest.mark.skip_global_cleanup
3736
def test_pooling_params(llm: LLM):
3837
def get_outputs(activation):
3938
text_1 = "What is the capital of France?"

tests/models/language/generation_ppl_test/ppl_utils.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
# Adapted from https://huggingface.co/docs/transformers/perplexity
44
from typing import cast
55

6-
import pytest
76
import torch
87
from datasets import load_dataset
98

109
import tests.ci_envs as ci_envs
11-
from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs
10+
from tests.models.utils import (
11+
GenerateModelInfo,
12+
TokensTextLogprobsPromptLogprobs,
13+
get_vllm_extra_kwargs,
14+
)
1215
from vllm.logprobs import Logprob
1316

1417
# See #24485
@@ -25,27 +28,10 @@ def wikitext_ppl_test(
2528
vllm_extra_kwargs=None,
2629
atol=PPL_TOL,
2730
):
28-
# A model family has many models with the same architecture,
29-
# and we don't need to test each one.
30-
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
31-
pytest.skip("Skipping test.")
31+
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
3232

3333
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
3434

35-
# Allow vllm to test using the given dtype, such as float32
36-
vllm_extra_kwargs = vllm_extra_kwargs or {}
37-
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
38-
39-
# Allow vllm to test using hf_overrides
40-
if model_info.hf_overrides is not None:
41-
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
42-
43-
# Allow changing the head dtype used by vllm in tests
44-
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
45-
if "hf_overrides" not in vllm_extra_kwargs:
46-
vllm_extra_kwargs["hf_overrides"] = {}
47-
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
48-
4935
with vllm_runner(
5036
model_info.name,
5137
gpu_memory_utilization=0.7,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from transformers import AutoModelForSequenceClassification
6+
7+
8+
@pytest.mark.parametrize(
9+
"model",
10+
["nie3e/sentiment-polish-gpt2-small"],
11+
)
12+
@pytest.mark.parametrize("dtype", ["half"])
13+
def test_classify_models(
14+
hf_runner,
15+
vllm_runner,
16+
example_prompts,
17+
model: str,
18+
dtype: str,
19+
) -> None:
20+
with hf_runner(
21+
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
22+
) as hf_model:
23+
hf_outputs = hf_model.classify(example_prompts)
24+
25+
for head_dtype_str in ["float32", "model"]:
26+
with vllm_runner(
27+
model,
28+
max_model_len=512,
29+
dtype=dtype,
30+
hf_overrides={"head_dtype": head_dtype_str},
31+
) as vllm_model:
32+
model_config = vllm_model.llm.llm_engine.model_config
33+
model_dtype = model_config.dtype
34+
head_dtype = model_config.head_dtype
35+
36+
if head_dtype_str == "float32":
37+
assert head_dtype == torch.float32
38+
elif head_dtype_str == "model":
39+
assert head_dtype == model_dtype
40+
41+
vllm_outputs = vllm_model.classify(example_prompts)
42+
43+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
44+
hf_output = torch.tensor(hf_output).float()
45+
vllm_output = torch.tensor(vllm_output).float()
46+
47+
assert torch.allclose(hf_output, vllm_output, atol=1e-2)

0 commit comments

Comments
 (0)