Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/models/pooling_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model

### Manually enable Matryoshka Embeddings

There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, we simply check the existence of the fields `is_matryoshka` or `matryoshka_dimensions` inside `config.json`.
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.

For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}` (offline) or `--hf_overrides '{"is_matryoshka": true}'` (online).
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online).

Here is an example to serve a model with Matryoshka Embeddings enabled.

```text
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"is_matryoshka":true}'
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}'
```

### Offline Inference
Expand Down Expand Up @@ -204,14 +204,14 @@ curl http://127.0.0.1:8000/v1/embeddings \
"input": "Follow the white rabbit.",
"model": "jinaai/jina-embeddings-v3",
"encoding_format": "float",
"dimensions": 1
"dimensions": 32
}'
```

Expected output:

```json
{"id":"embd-0aab28c384d348c3b8f0eb783109dc5f","object":"list","created":1744195454,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-1.0]}],"usage":{"prompt_tokens":10,"total_tokens":10,"completion_tokens":0,"prompt_tokens_details":null}}
{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
```

A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>
4 changes: 2 additions & 2 deletions examples/online_serving/openai_embedding_matryoshka_fy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def main():
responses = client.embeddings.create(
input=["Follow the white rabbit."],
model=model,
dimensions=1,
dimensions=32,
)

for data in responses.data:
print(data.embedding) # List of float of len 1
print(data.embedding) # List of float of len 32


if __name__ == "__main__":
Expand Down
42 changes: 24 additions & 18 deletions tests/entrypoints/openai/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer

from ...models.embedding.utils import check_embeddings_close
from ...models.embedding.utils import correctness_test
from ...utils import RemoteOpenAIServer

MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
DTYPE = "bfloat16"


@pytest.fixture(scope="module")
Expand All @@ -25,7 +26,7 @@ def server():
"embed",
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
DTYPE,
"--enforce-eager",
"--max-model-len",
"512",
Expand All @@ -43,9 +44,17 @@ async def client(server):
yield async_client


@pytest.fixture(scope="module")
def hf_model(hf_runner):
with hf_runner(MODEL_NAME, dtype=DTYPE,
is_sentence_transformer=True) as hf_model:
yield hf_model


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"The chef prepared a delicious meal.",
]
Expand All @@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
assert embeddings.usage.prompt_tokens == 11
assert embeddings.usage.total_tokens == 11

vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs)

# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
embedding_response = await client.embeddings.create(
Expand All @@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):

@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
# test list[str]
input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.",
Expand All @@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
assert embeddings.usage.prompt_tokens == 33
assert embeddings.usage.total_tokens == 33

vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs)

# test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
Expand Down Expand Up @@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,

@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Hello my name is",
Expand All @@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
model=model_name,
encoding_format="float")
float_data = [d.embedding for d in responses_float.data]
correctness_test(hf_model, input_texts, float_data)

responses_base64 = await client.embeddings.create(input=input_texts,
model=model_name,
Expand All @@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
np.frombuffer(base64.b64decode(data.embedding),
dtype="float32").tolist())

check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float",
name_1="base64",
)
correctness_test(hf_model, input_texts, base64_data)

# Default response is float32 decoded from base64 by OpenAI Client
responses_default = await client.embeddings.create(input=input_texts,
model=model_name)
default_data = [d.embedding for d in responses_default.data]

check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=default_data,
name_0="float",
name_1="default",
)
correctness_test(hf_model, input_texts, default_data)


@pytest.mark.asyncio
Expand Down
134 changes: 91 additions & 43 deletions tests/entrypoints/openai/test_embedding_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,121 @@
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
"""

from typing import Optional

import openai
import pytest

from vllm.entrypoints.openai.protocol import EmbeddingResponse

from ...models.embedding.utils import EmbedModelInfo
from ...conftest import HfRunner
from ...models.embedding.utils import EmbedModelInfo, correctness_test
from ...utils import RemoteOpenAIServer

MODELS = [
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
matryoshka_dimensions=[256]),
]

input_texts = [
"The chef prepared a delicious meal.",
] * 3
]


@pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS)
async def test_validating_dimensions(model: EmbedModelInfo):
@pytest.fixture(scope="module", params=MODELS)
def model_info(request):
return request.param


@pytest.fixture(scope="module", params=["bfloat16"])
def dtype(request):
return request.param


@pytest.fixture(scope="module")
def server(model_info, dtype: str):
args = [
"--task",
"embed",
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
dtype,
"--enforce-eager",
"--max-model-len",
"512",
"--trust_remote_code"
"512"
]
with RemoteOpenAIServer(model.name, args) as remote_server:
client = remote_server.get_async_client()

async def make_request(dimensions):
embedding_response = await client.embeddings.create(
model=model.name,
input=input_texts,
dimensions=dimensions,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))

assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) > 0
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens > 0
assert embeddings.usage.total_tokens > 0

if dimensions is not None:
assert len(embeddings.data[0].embedding) == dimensions

if model.is_matryoshka:
for dimensions in [None, 16]:
await make_request(dimensions)

if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5":
# Manually enable Matryoshka Embeddings
args.extend([
"--trust_remote_code", "--hf_overrides",
'{"matryoshka_dimensions":[256]}'
])

with RemoteOpenAIServer(model_info.name, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def hf_model(hf_runner, model_info, dtype: str):
with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
yield hf_model


@pytest.mark.asyncio
async def test_matryoshka(model_info: EmbedModelInfo,
server: RemoteOpenAIServer, hf_model: HfRunner):
client = server.get_async_client()

async def make_request_and_correctness_test(dimensions):
prompts = input_texts * 3

embedding_response = await client.embeddings.create(
model=model_info.name,
input=prompts,
dimensions=dimensions,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))

assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) > 0
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens > 0
assert embeddings.usage.total_tokens > 0

if dimensions is not None:
assert len(embeddings.data[0].embedding) == dimensions

vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, prompts, vllm_outputs, dimensions)

if model_info.is_matryoshka:
valid_dimensions: list[Optional[int]] = [None]
if model_info.matryoshka_dimensions is not None:
valid_dimensions += model_info.matryoshka_dimensions[:2]

for dimensions in valid_dimensions:
await make_request_and_correctness_test(dimensions)

invalid_dimensions: list[Optional[int]] = [-1]
if model_info.matryoshka_dimensions is not None:
assert 5 not in model_info.matryoshka_dimensions
invalid_dimensions.append(5)

for dimensions in invalid_dimensions:
with pytest.raises(openai.BadRequestError):
for dimensions in [-1]:
await make_request(dimensions)
await make_request_and_correctness_test(dimensions)

else:
for dimensions in [None]:
await make_request(dimensions)
else:
for dimensions in [None]:
await make_request_and_correctness_test(dimensions)

for dimensions in [-1, 16]:
with pytest.raises(openai.BadRequestError):
for dimensions in [-1, 16]:
await make_request(dimensions)
await make_request_and_correctness_test(dimensions)
32 changes: 21 additions & 11 deletions tests/models/embedding/language/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,24 @@ def test_matryoshka(

with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
matryoshka_dimensions = (
vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
assert matryoshka_dimensions is not None

if dimensions not in matryoshka_dimensions:
with pytest.raises(ValueError):
vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
else:
vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
21 changes: 20 additions & 1 deletion tests/models/embedding/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence
from typing import NamedTuple
from typing import NamedTuple, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -43,5 +43,24 @@ def matryoshka_fy(tensor, dimensions):
class EmbedModelInfo(NamedTuple):
name: str
is_matryoshka: bool
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
enable_test: bool = True


def correctness_test(hf_model,
inputs,
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None):

hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
Loading