Skip to content

Commit 897654d

Browse files
committed
support hybrid dtype
1 parent b9f61e1 commit 897654d

File tree

29 files changed

+302
-84
lines changed

29 files changed

+302
-84
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ steps:
475475
- pytest -v -s models/test_utils.py
476476
- pytest -v -s models/test_vision.py
477477
- pytest -v -s models/test_initialization.py
478+
- pytest -v -s models/test_hybrid_dtype.py
478479

479480
- label: Language Models Test (Standard)
480481
mirror_hardwares: [amdexperimental]

tests/models/language/pooling/embed_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def correctness_test_embed_models(hf_runner,
4848
example_prompts = [str(s).strip() for s in example_prompts]
4949

5050
vllm_extra_kwargs = vllm_extra_kwargs or {}
51-
vllm_extra_kwargs["dtype"] = model_info.dtype
51+
52+
if isinstance(model_info.dtype, str):
53+
vllm_extra_kwargs["dtype"] = model_info.dtype
54+
else:
55+
vllm_extra_kwargs["dtype"] = model_info.dtype.dtype
56+
vllm_extra_kwargs["attn_dtype"] = model_info.dtype.attn_dtype
5257

5358
with vllm_runner(model_info.name,
5459
task="embed",

tests/models/language/pooling/mteb_utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# ruff: noqa: SIM117
23
from collections.abc import Sequence
34

45
import mteb
@@ -88,7 +89,13 @@ def mteb_test_embed_models(hf_runner,
8889
pytest.skip("Skipping test.")
8990

9091
vllm_extra_kwargs = vllm_extra_kwargs or {}
91-
vllm_extra_kwargs["dtype"] = model_info.dtype
92+
93+
if isinstance(model_info.dtype, str):
94+
vllm_extra_kwargs["dtype"] = model_info.dtype
95+
96+
else:
97+
vllm_extra_kwargs["dtype"] = model_info.dtype.dtype
98+
vllm_extra_kwargs["attn_dtype"] = model_info.dtype.attn_dtype
9299

93100
with vllm_runner(model_info.name,
94101
task="embed",
@@ -102,18 +109,23 @@ def mteb_test_embed_models(hf_runner,
102109
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
103110
MTEB_EMBED_TASKS)
104111
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
112+
model_dtype = getattr(
113+
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
114+
vllm_dtype)
105115

106-
with set_default_torch_dtype(vllm_dtype) and hf_runner(
107-
model_info.name, is_sentence_transformer=True,
108-
dtype=vllm_dtype) as hf_model:
116+
with set_default_torch_dtype(model_dtype):
117+
with hf_runner(model_info.name,
118+
is_sentence_transformer=True,
119+
dtype=model_dtype) as hf_model:
109120

110-
if hf_model_callback is not None:
111-
hf_model_callback(hf_model)
121+
if hf_model_callback is not None:
122+
hf_model_callback(hf_model)
112123

113-
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
124+
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
125+
st_dtype = next(hf_model.model.parameters()).dtype
114126

115-
print("VLLM:", vllm_main_score)
116-
print("SentenceTransformers:", st_main_score)
127+
print("VLLM:", vllm_dtype, vllm_main_score)
128+
print("SentenceTransformers:", model_dtype, st_dtype, st_main_score)
117129
print("Difference:", st_main_score - vllm_main_score)
118130

119131
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

tests/models/language/pooling/test_gte.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,24 @@
1010
########## BertModel
1111
EmbedModelInfo("thenlper/gte-large",
1212
architecture="BertModel",
13-
dtype="float32",
13+
dtype="hybrid",
1414
enable_test=True),
1515
EmbedModelInfo("thenlper/gte-base",
1616
architecture="BertModel",
17-
dtype="float32",
17+
dtype="hybrid",
1818
enable_test=False),
1919
EmbedModelInfo("thenlper/gte-small",
2020
architecture="BertModel",
21-
dtype="float32",
21+
dtype="hybrid",
2222
enable_test=False),
2323
EmbedModelInfo("thenlper/gte-large-zh",
2424
architecture="BertModel",
25-
dtype="float32",
2625
enable_test=False),
2726
EmbedModelInfo("thenlper/gte-base-zh",
2827
architecture="BertModel",
29-
dtype="float32",
3028
enable_test=False),
3129
EmbedModelInfo("thenlper/gte-small-zh",
3230
architecture="BertModel",
33-
dtype="float32",
3431
enable_test=False),
3532
########### NewModel
3633
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
@@ -45,7 +42,6 @@
4542
########### Qwen2ForCausalLM
4643
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
4744
architecture="Qwen2ForCausalLM",
48-
dtype="float32",
4945
enable_test=True),
5046
########## ModernBertModel
5147
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
5+
from .mteb_utils import mteb_test_embed_models
6+
7+
MODELS = [
8+
########## BertModel
9+
EmbedModelInfo("intfloat/e5-small",
10+
architecture="BertModel",
11+
enable_test=True),
12+
EmbedModelInfo("intfloat/e5-base",
13+
architecture="BertModel",
14+
enable_test=False),
15+
EmbedModelInfo("intfloat/e5-large",
16+
architecture="BertModel",
17+
enable_test=False),
18+
EmbedModelInfo("intfloat/multilingual-e5-small",
19+
architecture="BertModel",
20+
enable_test=False),
21+
########## XLMRobertaModel
22+
EmbedModelInfo("intfloat/multilingual-e5-base",
23+
architecture="XLMRobertaModel",
24+
enable_test=True),
25+
EmbedModelInfo("intfloat/multilingual-e5-large",
26+
architecture="XLMRobertaModel",
27+
enable_test=False),
28+
EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
29+
architecture="XLMRobertaModel",
30+
dtype="hybrid",
31+
enable_test=False),
32+
]
33+
34+
35+
@pytest.mark.parametrize("model_info", MODELS)
36+
def test_embed_models_mteb(hf_runner, vllm_runner,
37+
model_info: EmbedModelInfo) -> None:
38+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
39+
40+
41+
@pytest.mark.parametrize("model_info", MODELS)
42+
def test_embed_models_correctness(hf_runner, vllm_runner,
43+
model_info: EmbedModelInfo,
44+
example_prompts) -> None:
45+
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
46+
example_prompts)

tests/models/language/pooling/test_jina.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
EMBEDDING_MODELS = [
3232
EmbedModelInfo("jinaai/jina-embeddings-v3",
3333
architecture="XLMRobertaModel",
34-
is_matryoshka=True,
35-
dtype="float32")
34+
dtype="hybrid",
35+
is_matryoshka=True)
3636
]
3737

3838

tests/models/language/pooling/test_nomic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,15 @@
88
MODELS = [
99
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
1010
architecture="NomicBertModel",
11-
dtype="float32",
1211
enable_test=True),
1312
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
1413
architecture="NomicBertModel",
15-
dtype="float32",
1614
enable_test=False),
1715
EmbedModelInfo("nomic-ai/CodeRankEmbed",
1816
architecture="NomicBertModel",
1917
enable_test=False),
2018
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
2119
architecture="NomicBertModel",
22-
dtype="float32",
2320
enable_test=True)
2421
]
2522

tests/models/test_hybrid_dtype.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# ruff: noqa: SIM117
3+
import pytest
4+
import torch
5+
6+
from tests.models.language.pooling.mteb_utils import mteb_test_embed_models
7+
from tests.models.utils import Dtype, EmbedModelInfo
8+
from vllm.config import _STR_DTYPE_TO_TORCH_DTYPE
9+
10+
high_precision_data_types = [
11+
Dtype(dtype="auto"), # hybrid
12+
Dtype(dtype="float32"),
13+
Dtype(dtype="hybrid"),
14+
Dtype(dtype="float32", attn_dtype="float16"),
15+
Dtype(dtype="float32", attn_dtype="bfloat16")
16+
]
17+
low_precision_data_types = [Dtype(dtype="float16"), Dtype(dtype="bfloat16")]
18+
data_types = high_precision_data_types + low_precision_data_types
19+
embed_model = "intfloat/e5-small"
20+
generate_model = "EleutherAI/pythia-70m"
21+
22+
23+
@pytest.mark.parametrize("dtype", data_types)
24+
def test_dtype(vllm_runner, dtype: Dtype):
25+
with vllm_runner(embed_model,
26+
dtype=dtype.dtype,
27+
max_model_len=None,
28+
attn_dtype=dtype.attn_dtype) as vllm_model:
29+
model_config = vllm_model.model.llm_engine.model_config
30+
if dtype.dtype == "hybrid" or dtype.dtype == "auto":
31+
assert model_config.dtype == torch.float32
32+
assert model_config.attn_dtype == torch.float16
33+
elif dtype.attn_dtype == "auto":
34+
assert model_config.dtype == model_config.attn_dtype
35+
else:
36+
assert model_config.dtype == _STR_DTYPE_TO_TORCH_DTYPE[dtype.dtype]
37+
assert model_config.attn_dtype == _STR_DTYPE_TO_TORCH_DTYPE[
38+
dtype.attn_dtype]
39+
40+
41+
@pytest.mark.parametrize("dtype", data_types)
42+
def test_embed_models_mteb(hf_runner, vllm_runner, dtype: Dtype):
43+
model_info = EmbedModelInfo(embed_model,
44+
architecture="BertModel",
45+
dtype=dtype)
46+
47+
if model_info.dtype in high_precision_data_types:
48+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
49+
else:
50+
with pytest.raises(AssertionError):
51+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
52+
53+
54+
@pytest.mark.parametrize("model", [generate_model])
55+
@pytest.mark.parametrize("dtype", data_types)
56+
@pytest.mark.parametrize("max_tokens", [32])
57+
@pytest.mark.parametrize("num_logprobs", [4])
58+
def test_generate_models(hf_runner, vllm_runner, example_prompts, model: str,
59+
dtype: Dtype, max_tokens: int,
60+
num_logprobs: int) -> None:
61+
if dtype.attn_dtype == "auto" and dtype.dtype != "hybrid":
62+
with vllm_runner(model, dtype=dtype.dtype,
63+
attn_dtype=dtype.attn_dtype):
64+
pass
65+
else:
66+
with pytest.raises(ValueError):
67+
with vllm_runner(model,
68+
dtype=dtype.dtype,
69+
attn_dtype=dtype.attn_dtype):
70+
pass

tests/models/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn.functional as F
99

10-
from vllm.config import ModelConfig, TaskOption
10+
from vllm.config import AttnDType, ModelConfig, ModelDType, TaskOption
1111
from vllm.inputs import InputContext
1212
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
1313

@@ -328,10 +328,15 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
328328
return tensor
329329

330330

331+
class Dtype(NamedTuple):
332+
dtype: ModelDType
333+
attn_dtype: AttnDType = "auto"
334+
335+
331336
class EmbedModelInfo(NamedTuple):
332337
name: str
333338
is_matryoshka: bool = False
334339
matryoshka_dimensions: Optional[list[int]] = None
335340
architecture: str = ""
336-
dtype: str = "auto"
341+
dtype: Union[str, Dtype] = "auto"
337342
enable_test: bool = True

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
class DummyPlatform(CudaPlatform):
77
device_name = "DummyDevice"
88

9-
def get_attn_backend_cls(self, backend_name, head_size, dtype,
9+
def get_attn_backend_cls(self, backend_name, head_size, attn_dtype,
1010
kv_cache_dtype, block_size, use_v1, use_mla):
1111
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

0 commit comments

Comments
 (0)