Skip to content

Conversation

@noooop
Copy link
Collaborator

@noooop noooop commented May 30, 2025

RFC #18342

Summary

Most models can maintain their original precision, but a few models require float32. But Flash attn does not support float32.

For models where precision drops significantly at float16, hybrid dtypemight be a better choice, especially models that support long context.

Hybrid dtype is used by default for the float32 pooling model to make the default parameters more robust.

Proposed Change.

  • hybrid dtype only supports Pooling Models, does not support draft and generate tasks.
  • add attn_dtype
  • Platform.get_attn_backend_cls
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:

    dtype -> attn_dtype

    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             attn_dtype: torch.dtype, kv_cache_dtype: Optional[str],
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
  • vllm.attention.selector:get_attn_backend
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    is_attention_free: bool,
    is_blocksparse: bool = False,
    use_mla: bool = False,
) -> Type[AttentionBackend]:

dtype -> attn_dtype

def get_attn_backend(
    head_size: int,
    attn_dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    is_attention_free: bool,
    is_blocksparse: bool = False,
    use_mla: bool = False,
) -> Type[AttentionBackend]:
  • add "hybrid dtype" for dtype
  • For pooling models, ff config_dtype is float32, use hybrid dtype by default.
        # For pooling models
        config_dtype = _find_dtype(self.model,
                                   self.hf_config,
                                   revision=self.revision)

        # If config_dtype is float32, use hybrid dtype by default.
        if config_dtype == torch.float32 and self.dtype == "auto":
            self.dtype = "hybrid"

        if self.dtype == "hybrid":
            self.dtype = torch.float32
            if config_dtype == torch.bfloat16:
                self.attn_dtype = torch.bfloat16
            else:
                self.attn_dtype = torch.float16
            return

        self.dtype = _get_and_verify_dtype(
            self.model,
            self.hf_config,
            self.dtype,
            is_pooling_model=is_pooling_model,
            revision=self.revision,
        )
        if self.attn_dtype == "auto":
            self.attn_dtype = self.dtype
        else:
            self.attn_dtype = _get_and_verify_dtype(
            self.model,
            self.hf_config,
            self.attn_dtype,
            is_pooling_model=is_pooling_model,
            revision=self.revision,
        )

Improve Precision

There are total of 46 models (including skiped model) participate in ci test_embed_models_mteb.

Test related information PTAL #17175 #18747

https://github.com/vllm-project/vllm/blob/main/tests/models/language/pooling/mteb_utils.py

MTEB_EMBED_TOL = 1e-4
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

Among which 15 models fail the test with main auto dtype strategy and can be divided into two categories: float32 and float16/bfloat16.

  • For the 10 models with dtype as float32, this PR uses hybrid dtype by default for the float32 pooling model. 9 of them pass the tests.
    - nomic-ai/nomic-embed-text-v1
    - nomic-ai/nomic-embed-text-v1.5
    - nomic-ai/nomic-embed-text-v2-moe
    - intfloat/e5-small
    - intfloat/e5-base
    - intfloat/e5-large
    - intfloat/multilingual-e5-small
    - intfloat/multilingual-e5-base
    - intfloat/multilingual-e5-large

BAAI/bge-code-v1 must use float32 to pass the test.

This indicates that hybrid dtype is used by default for the float32 pooling model to make the default parameters more robust.

  • For the 5 models with dtype as float16/bfloat16, manually set the dtype to use hybrid dtype. all of them pass the tests.
    - jinaai/jina-embeddings-v3 (bfloat16)
    - thenlper/gte-large (float16)
    - thenlper/gte-base (float16)
    - thenlper/gte-small (float16)
    - intfloat/multilingual-e5-large-instruct (float16)

For some models, SentenceTransformers will default to using float32 format and ignore the torch_dtype parameter.
This is why the results using vllm with torch_dtype differ from those of SentenceTransformers.

mteb/STS12 Result

https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_mteb.py

  • intfloat/multilingual-e5-small

SentenceTransformer torch.float32 0.7805425596252846
float16 diff:-0.2749311085815237 std:0.006216913108536066
bfloat16 diff:-0.2910827032299663 std:0.013966753081608172
float32 diff:-3.215640069775816e-07 std:8.464281838713533e-06
attn_dtype: float16 diff:2.298023865754395e-06 std:7.7257068905813e-06
attn_dtype: bfloat16 diff:-1.7649813619513566e-06 std:1.2307026385810296e-05

  • nomic-ai/nomic-embed-text-v1

SentenceTransformer torch.float32 0.7375691474332452
float16 diff:-0.0023990889416417582 std:0.0005159950361333374
bfloat16 diff:-0.08114092824625219 std:0.005852545461415668
float32 diff:-2.484723391815713e-06 std:7.641653300452288e-06
attn_dtype: float16 diff:-2.1050127918531558e-07 std:8.164661176164923e-06
attn_dtype: bfloat16 diff:-3.4379344918678e-06 std:1.1283457501932804e-05

Precision-Efficiency trade-off

mteb/T2Reranking Result

https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_speed.py
https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_speed_long.py

  • intfloat/multilingual-e5-small

512
float32 0.6562587376880025 1670.16it/s 100%
hybrid 0.6562531966067329 2180.74it/s 130%
float16 0.763608731060288 4130.90it/s 247%
bfloat16 0.6508541895430957 4146.73it/s 248%

  • nomic-ai/nomic-embed-text-v1

512
float32 0.6217034721042816 563.92it/s 100%
hybrid 0.6216922978501486 621.49it/s 110%
float16 0.6835500899468472 1169.96it/s 207%

2048
float32 0.6192403917441829 261.67it/s 100%
hybrid 0.6191907625942774 344.23it/s 131%
float16 0.68911642552594 633.15it/s 242%

8192
float32 0.6174638370330193 161.90it/s 100%
hybrid 0.6174442576960775 264.05it/s 163%
float16 Nan 461.94it/s 285%
bfloat16 0.6532416687333537 465.72it/s 285%

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented Jun 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 1, 2025
@noooop noooop closed this Jun 2, 2025
@noooop noooop reopened this Jun 2, 2025
@mergify
Copy link

mergify bot commented Jun 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Jun 2, 2025
@noooop
Copy link
Collaborator Author

noooop commented Jun 3, 2025

@DarkLight1337

Please open Language Models Test (Extended Pooling) and Language Models Test (Extended Generation),

see what I've broken.

@noooop
Copy link
Collaborator Author

noooop commented Jun 3, 2025

Please open Language Models Test (Extended Pooling) and Language Models Test (Extended Generation),

see what I've broken.

nothing broken. luck

https://buildkite.com/vllm/fastcheck/builds/26074

@mergify mergify bot added the frontend label Jun 3, 2025
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
dtype="hybrid",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to change these to use "auto"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto cannot pass the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM: torch.float16 0.5363271844116764
SentenceTransformers: torch.float16 0.7680302125398653
Difference: 0.23170302812818888

SentenceTransformers dtype makes me confused

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test pass if float32 is specified for both cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a HybridDType = DTypeInfo(dtype="float32", attn_dtype="float16") to make the test pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you manually check the parameter types in vLLM vs ST impl?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am studying the code of SentenceTransformers to see how it loads the models.

Copy link
Collaborator Author

@noooop noooop Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sentence_transformer and vllm's models are exactly the same, but the results are different.

import torch

from tests.conftest import HfRunner
from tests.models.language.pooling.mteb_utils import run_mteb_embed_task, MTEB_EMBED_TASKS, VllmMtebEncoder

model = "thenlper/gte-large"


hf_model = HfRunner(
        model,
        dtype="float16",
        is_sentence_transformer=True,
)

st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
# 0.7680302125398653

print(hf_model.model[0].auto_model.encoder.layer[0].intermediate.dense.weight)


"""
hf_model.model[0].auto_model.encoder.layer[0].intermediate.dense.weight
Parameter containing:
tensor([[ 0.0728, -0.0249,  0.0490,  ...,  0.0595,  0.0363,  0.0181],
        [ 0.0327,  0.0181,  0.0223,  ...,  0.0534, -0.0257,  0.0275],
        [-0.0246,  0.0504,  0.0197,  ...,  0.0159, -0.0805,  0.0060],
        ...,
        [ 0.0048,  0.0043, -0.0147,  ...,  0.0363, -0.0344,  0.0045],
        [ 0.0347,  0.0642, -0.0393,  ...,  0.0238,  0.0541,  0.0966],
        [-0.0337,  0.0023,  0.0086,  ..., -0.0003, -0.0105, -0.0164]],
       device='cuda:0', dtype=torch.float16, requires_grad=True)

"""

from tests.conftest import VllmRunner


vllm_model = VllmRunner(model,
                 task="embed",
                 max_model_len=None, dtype="float16",)



vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
                              MTEB_EMBED_TASKS)

# 0.5363271844116764
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype


m = vllm_model.model.llm_engine.model_executor.driver_worker.worker.get_model()

"""
m.model.encoder.layer[0].intermediate.dense.weight
Parameter containing:
tensor([[ 0.0728, -0.0249,  0.0490,  ...,  0.0595,  0.0363,  0.0181],
        [ 0.0327,  0.0181,  0.0223,  ...,  0.0534, -0.0257,  0.0275],
        [-0.0246,  0.0504,  0.0197,  ...,  0.0159, -0.0805,  0.0060],
        ...,
        [ 0.0048,  0.0043, -0.0147,  ...,  0.0363, -0.0344,  0.0045],
        [ 0.0347,  0.0642, -0.0393,  ...,  0.0238,  0.0541,  0.0966],
        [-0.0337,  0.0023,  0.0086,  ..., -0.0003, -0.0105, -0.0164]],
       device='cuda:0', dtype=torch.float16)
"""

vllm_weight = list(m.named_parameters())
st_weight = dict(hf_model.model[0].named_parameters())

for name, weight in vllm_weight:
    try:
        print(name, torch.mean(weight-st_weight["auto_"+name]).item())
    except Exception:
        pass

"""
model.embeddings.position_embeddings.weight 0.0
model.embeddings.LayerNorm.weight 0.0
model.embeddings.LayerNorm.bias 0.0
model.encoder.layer.0.attention.output.dense.weight 0.0
model.encoder.layer.0.attention.output.dense.bias 0.0
model.encoder.layer.0.attention.output.LayerNorm.weight 0.0
model.encoder.layer.0.attention.output.LayerNorm.bias 0.0
model.encoder.layer.0.intermediate.dense.weight 0.0
model.encoder.layer.0.intermediate.dense.bias 0.0
model.encoder.layer.0.output.dense.weight 0.0
model.encoder.layer.0.output.dense.bias 0.0
model.encoder.layer.0.output.LayerNorm.weight 0.0
model.encoder.layer.0.output.LayerNorm.bias 0.0
"""

Set Vllm with attn_dtype="float32", but the results are still different from st. The difference is probably due to the mlp layer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like a bug, I'll raise an issue and ask experts for help.

Copy link
Collaborator Author

@noooop noooop Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337

(╯‵□′)╯︵┻━┻

Converting the output embedding dtype to float32 can fix this issue.

Perhaps we don't even need hybrid dtype.

Tomorrow I will do further testing.

@noooop noooop closed this Jun 3, 2025
@noooop noooop deleted the hybrid4pooling branch July 10, 2025 04:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build frontend tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants