Skip to content

Commit 045b396

Browse files
authored
[Bugfix][CI/Build] Fix failing Mteb CI (#26638)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 7685201 commit 045b396

File tree

5 files changed

+13
-2
lines changed

5 files changed

+13
-2
lines changed

tests/models/language/pooling_mteb_test/mteb_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def mteb_test_embed_models(
191191
with vllm_runner(
192192
model_info.name,
193193
runner="pooling",
194-
max_model_len=None,
194+
max_model_len=model_info.max_model_len,
195195
**vllm_extra_kwargs,
196196
) as vllm_model:
197197
model_config = vllm_model.llm.llm_engine.model_config

tests/models/language/pooling_mteb_test/test_jina.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
mteb_score=0.824413164,
2626
architecture="XLMRobertaModel",
2727
is_matryoshka=True,
28+
# The default max length of the model is 8194, which will crash
29+
# CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid
30+
# avoid this issue.
31+
max_model_len=8192,
32+
dtype="float32",
2833
)
2934
]
3035

tests/models/language/pooling_mteb_test/test_st_projector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
architecture="Gemma3TextModel",
2424
mteb_score=0.7473819294684156,
2525
enable_test=True,
26+
dtype="float32",
2627
),
2728
]
2829

tests/models/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ class ModelInfo:
369369
name: str
370370
architecture: str = ""
371371
dtype: str = "auto"
372+
max_model_len: Optional[int] = None
372373
hf_dtype: str = "float32"
373374
hf_overrides: Optional[dict[str, Any]] = None
374375
default_pooling_type: str = ""

vllm/model_executor/layers/layernorm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,11 @@ def forward_static(
318318
"""PyTorch-native implementation equivalent to forward()."""
319319
orig_dtype = x.dtype
320320
if residual is not None:
321-
x = x + residual.float() if orig_dtype == torch.float16 else x + residual
321+
x = (
322+
x.float() + residual.float()
323+
if orig_dtype == torch.float16
324+
else x + residual
325+
)
322326
residual = x
323327

324328
x = x.float()

0 commit comments

Comments
 (0)