diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index fb64a7bb9c8f..2f4b17f191a5 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples. ## Using an IO Processor plugin diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 40651be1d449..18bb645ea9a9 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -30,11 +30,11 @@ If `--runner pooling` has been set (manually or automatically) but the model doe vLLM will attempt to automatically convert the model according to the architecture names shown in the table below. -| Architecture | `--convert` | Supported pooling tasks | -|-------------------------------------------------|-------------|-------------------------------| -| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `encode`, `embed` | -| `*For*Classification`, `*ClassificationModel` | `classify` | `encode`, `classify`, `score` | -| `*ForRewardModeling`, `*RewardModel` | `reward` | `encode` | +| Architecture | `--convert` | Supported pooling tasks | +|-------------------------------------------------|-------------|---------------------------------------| +| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` | +| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` | +| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` | !!! tip You can explicitly set `--convert ` to specify how to convert the model. @@ -45,12 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks], enabling the corresponding APIs: -| Task | APIs | -|------------|--------------------------------------| -| `encode` | `LLM.reward(...)` | -| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* | -| `classify` | `LLM.classify(...)` | -| `score` | `LLM.score(...)` | +| Task | APIs | +|------------------|-------------------------------------------------------------------------------| +| `embed` | `LLM.embed(...)`, `LLM.score(...)`\*, `LLM.encode(..., pooling_task="embed")` | +| `classify` | `LLM.classify(...)`, `LLM.encode(..., pooling_task="classify")` | +| `score` | `LLM.score(...)` | +| `token_classify` | `LLM.reward(...)`, `LLM.encode(..., pooling_task="token_classify")` | +| `token_embed` | `LLM.encode(..., pooling_task="token_embed")` | +| `plugin` | `LLM.encode(..., pooling_task="plugin")` | \* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task. @@ -144,7 +146,6 @@ A code example can be found here: [examples/offline_inference/basic/score.py](.. ### `LLM.reward` The [reward][vllm.LLM.reward] method is available to all reward models in vLLM. -It returns the extracted hidden states directly. ```python from vllm import LLM @@ -161,15 +162,17 @@ A code example can be found here: [examples/offline_inference/basic/reward.py](. ### `LLM.encode` The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. -It returns the extracted hidden states directly. !!! note Please use one of the more specific methods or set the task directly when using `LLM.encode`: - For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`. - For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`. - - For rewards, use `LLM.reward(...)` or `pooling_task="reward"`. - For similarity scores, use `LLM.score(...)`. + - For rewards, use `LLM.reward(...)` or `pooling_task="token_classify"`. + - For token classification, use `pooling_task="token_classify"`. + - For multi-vector retrieval, use `pooling_task="token_embed"` + - For IO Processor Plugins , use `pooling_task="plugin"` ```python from vllm import LLM @@ -185,10 +188,47 @@ print(f"Data: {data!r}") Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: -- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models. - [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models. - [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models. - [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models. +- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models. + +!!! note + Please use one of the more specific methods or set the task directly when using [Pooling API](../serving/openai_compatible_server.md#pooling-api) api.: + + - For embeddings, use [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) or `"task":"embed"`. + - For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `task":"classify"`. + - For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api). + - For rewards, `task":"token_classify"`. + - For token classification, use `task":"token_classify"`. + - For multi-vector retrieval, use `task":"token_embed"` + - For IO Processor Plugins , use `task":"plugin"` + +```python +# start a supported embeddings model server with `vllm serve`, e.g. +# vllm serve intfloat/e5-small +import requests + +host = "localhost" +port = "8000" +model_name = "intfloat/e5-small" + +api_url = f"http://{host}:{port}/pooling" + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +prompt = {"model": model_name, "input": prompts, "task": "embed"} + +response = requests.post(api_url, json=prompt) + +for output in response.json()["data"]: + data = output["data"] + print(f"Data: {data!r} (size={len(data)})") +``` ## Matryoshka Embeddings @@ -265,3 +305,16 @@ Expected output: ``` An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py) + +## Deprecated Features + +### Encode task + +We have split the `encode` task into two more specific token wise tasks: `token_embed` and `token_classify`: + +- `token_embed` is the same as embed, using normalize as activation. +- `token_classify` is the same as classify, default using softmax as activation. + +### Remove softmax from PoolingParams + +We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, you should set `use_activation`, since we actually allow `classify` and `token_classify` to use any activation function. diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 1414718a697d..e331b3422ea6 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -638,7 +638,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). -Code example: [examples/online_serving/openai_cross_encoder_score.py](../../examples/online_serving/openai_cross_encoder_score.py) +Code example: [examples/online_serving/pooling/openai_cross_encoder_score.py](../../examples/online_serving/pooling/openai_cross_encoder_score.py) #### Single inference @@ -819,7 +819,7 @@ You can pass multi-modal inputs to scoring models by passing `content` including print("Scoring output:", response_json["data"][0]["score"]) print("Scoring output:", response_json["data"][1]["score"]) ``` -Full example: [examples/online_serving/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/openai_cross_encoder_score_for_multimodal.py) +Full example: [examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py) #### Extra parameters diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md index cd9717122b16..ad78be38716b 100644 --- a/examples/offline_inference/pooling/README.md +++ b/examples/offline_inference/pooling/README.md @@ -38,6 +38,18 @@ python examples/offline_inference/pooling/multi_vector_retrieval.py python examples/offline_inference/pooling/ner.py ``` +## Prithvi Geospatial MAE usage + +```bash +python examples/offline_inference/pooling/prithvi_geospatial_mae.py +``` + +## IO Processor Plugins for Prithvi Geospatial MAE + +```bash +python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py +``` + ## Qwen3 reranker usage ```bash diff --git a/examples/offline_inference/pooling/ner.py b/examples/offline_inference/pooling/ner.py index b2dffdd6c5ee..34c80e7ccffd 100644 --- a/examples/offline_inference/pooling/ner.py +++ b/examples/offline_inference/pooling/ner.py @@ -33,7 +33,7 @@ def main(args: Namespace): label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label # Run inference - outputs = llm.encode(prompts) + outputs = llm.encode(prompts, pooling_task="token_classify") for prompt, output in zip(prompts, outputs): logits = output.outputs.data diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/pooling/prithvi_geospatial_mae.py similarity index 100% rename from examples/offline_inference/prithvi_geospatial_mae.py rename to examples/offline_inference/pooling/prithvi_geospatial_mae.py diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py similarity index 100% rename from examples/offline_inference/prithvi_geospatial_mae_io_processor.py rename to examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md index 3b6da20d5f0f..b76ad21f0481 100644 --- a/examples/online_serving/pooling/README.md +++ b/examples/online_serving/pooling/README.md @@ -3,65 +3,95 @@ ## Cohere rerank usage ```bash +# vllm serve BAAI/bge-reranker-base python examples/online_serving/pooling/cohere_rerank_client.py ``` ## Embedding requests base64 encoding_format usage ```bash +# vllm serve intfloat/e5-small python examples/online_serving/pooling/embedding_requests_base64_client.py ``` ## Embedding requests bytes encoding_format usage ```bash +# vllm serve intfloat/e5-small python examples/online_serving/pooling/embedding_requests_bytes_client.py ``` ## Jinaai rerank usage ```bash +# vllm serve BAAI/bge-reranker-base python examples/online_serving/pooling/jinaai_rerank_client.py ``` ## Multi vector retrieval usage ```bash +# vllm serve BAAI/bge-m3 python examples/online_serving/pooling/multi_vector_retrieval_client.py ``` ## Named Entity Recognition (NER) usage ```bash +# vllm serve boltuix/NeuroBERT-NER python examples/online_serving/pooling/ner_client.py ``` -## Openai chat embedding for multimodal usage +## OpenAI chat embedding for multimodal usage ```bash python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py ``` -## Openai classification usage +## OpenAI classification usage ```bash +# vllm serve jason9693/Qwen2.5-1.5B-apeach python examples/online_serving/pooling/openai_classification_client.py ``` -## Openai embedding usage +## OpenAI cross_encoder score usage ```bash +# vllm serve BAAI/bge-reranker-v2-m3 +python examples/online_serving/pooling/openai_cross_encoder_score.py +``` + +## OpenAI cross_encoder score for multimodal usage + +```bash +# vllm serve jinaai/jina-reranker-m0 +python examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py +``` + +## OpenAI embedding usage + +```bash +# vllm serve intfloat/e5-small python examples/online_serving/pooling/openai_embedding_client.py ``` -## Openai embedding matryoshka dimensions usage +## OpenAI embedding matryoshka dimensions usage ```bash +# vllm serve jinaai/jina-embeddings-v3 --trust-remote-code python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py ``` -## Openai pooling usage +## OpenAI pooling usage ```bash +# vllm serve internlm/internlm2-1_8b-reward --trust-remote-code python examples/online_serving/pooling/openai_pooling_client.py ``` + +## Online Prithvi Geospatial MAE usage + +```bash +python examples/online_serving/pooling/prithvi_geospatial_mae.py +``` diff --git a/examples/online_serving/openai_cross_encoder_score.py b/examples/online_serving/pooling/openai_cross_encoder_score.py similarity index 100% rename from examples/online_serving/openai_cross_encoder_score.py rename to examples/online_serving/pooling/openai_cross_encoder_score.py diff --git a/examples/online_serving/openai_cross_encoder_score_for_multimodal.py b/examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py similarity index 100% rename from examples/online_serving/openai_cross_encoder_score_for_multimodal.py rename to examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/pooling/prithvi_geospatial_mae.py similarity index 100% rename from examples/online_serving/prithvi_geospatial_mae.py rename to examples/online_serving/pooling/prithvi_geospatial_mae.py diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py index 96f634ee0a8c..1063c3b6b755 100644 --- a/tests/entrypoints/pooling/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -37,15 +37,17 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(activation): + def get_outputs(use_activation): outputs = llm.classify( - prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False + prompts, + pooling_params=PoolingParams(use_activation=use_activation), + use_tqdm=False, ) return torch.tensor([x.outputs.probs for x in outputs]) - default = get_outputs(activation=None) - w_activation = get_outputs(activation=True) - wo_activation = get_outputs(activation=False) + default = get_outputs(use_activation=None) + w_activation = get_outputs(use_activation=True) + wo_activation = get_outputs(use_activation=False) assert torch.allclose(default, w_activation, atol=1e-2), ( "Default should use activation." diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py index 81058dbad891..0255704cecd9 100644 --- a/tests/entrypoints/pooling/llm/test_reward.py +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -37,15 +37,17 @@ def llm(): def test_pooling_params(llm: LLM): - def get_outputs(activation): + def get_outputs(use_activation): outputs = llm.reward( - prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False + prompts, + pooling_params=PoolingParams(use_activation=use_activation), + use_tqdm=False, ) return torch.cat([x.outputs.data for x in outputs]) - default = get_outputs(activation=None) - w_activation = get_outputs(activation=True) - wo_activation = get_outputs(activation=False) + default = get_outputs(use_activation=None) + w_activation = get_outputs(use_activation=True) + wo_activation = get_outputs(use_activation=False) assert torch.allclose(default, w_activation, atol=1e-2), ( "Default should use activation." diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py index 2df973dd7863..b69c6a47c191 100644 --- a/tests/entrypoints/pooling/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -34,21 +34,21 @@ def llm(): def test_pooling_params(llm: LLM): - def get_outputs(activation): + def get_outputs(use_activation): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." outputs = llm.score( text_1, text_2, - pooling_params=PoolingParams(activation=activation), + pooling_params=PoolingParams(use_activation=use_activation), use_tqdm=False, ) return torch.tensor([x.outputs.score for x in outputs]) - default = get_outputs(activation=None) - w_activation = get_outputs(activation=True) - wo_activation = get_outputs(activation=False) + default = get_outputs(use_activation=None) + w_activation = get_outputs(use_activation=True) + wo_activation = get_outputs(use_activation=False) assert torch.allclose(default, w_activation, atol=1e-2), ( "Default should use activation." diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py index 92d40efad21c..671bb948780a 100644 --- a/tests/entrypoints/pooling/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from tests.utils import RemoteOpenAIServer -from vllm.entrypoints.openai.protocol import ClassificationResponse +from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" DTYPE = "float32" # Use float32 to avoid NaN issue @@ -163,20 +163,24 @@ async def test_invocations(server: RemoteOpenAIServer): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_activation(server: RemoteOpenAIServer, model_name: str): +async def test_use_activation(server: RemoteOpenAIServer, model_name: str): input_text = ["This product was excellent and exceeded my expectations"] - async def get_outputs(activation): + async def get_outputs(use_activation): response = requests.post( server.url_for("classify"), - json={"model": model_name, "input": input_text, "activation": activation}, + json={ + "model": model_name, + "input": input_text, + "use_activation": use_activation, + }, ) outputs = response.json() return torch.tensor([x["probs"] for x in outputs["data"]]) - default = await get_outputs(activation=None) - w_activation = await get_outputs(activation=True) - wo_activation = await get_outputs(activation=False) + default = await get_outputs(use_activation=None) + w_activation = await get_outputs(use_activation=True) + wo_activation = await get_outputs(use_activation=False) assert torch.allclose(default, w_activation, atol=1e-2), ( "Default should use activation." @@ -191,18 +195,7 @@ async def get_outputs(activation): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_pooling(server: RemoteOpenAIServer, model_name: str): - # pooling api uses ALL pooling, which does not support chunked prefill. - response = requests.post( - server.url_for("pooling"), - json={"model": model_name, "input": "test", "encoding_format": "float"}, - ) - assert response.json()["error"]["type"] == "BadRequestError" - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_score(server: RemoteOpenAIServer, model_name: str): +async def test_score(server: RemoteOpenAIServer, model_name: str): # score api is only enabled for num_labels == 1. response = requests.post( server.url_for("score"), @@ -217,7 +210,7 @@ def test_score(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_rerank(server: RemoteOpenAIServer, model_name: str): +async def test_rerank(server: RemoteOpenAIServer, model_name: str): # rerank api is only enabled for num_labels == 1. response = requests.post( server.url_for("rerank"), @@ -228,3 +221,62 @@ def test_rerank(server: RemoteOpenAIServer, model_name: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): + input_text = "This product was excellent and exceeded my expectations" + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": "classify", + }, + ) + poolings = PoolingResponse.model_validate(response.json()) + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): + # token_classify uses ALL pooling, which does not support chunked prefill. + task = "token_classify" + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float", + "task": task, + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + assert response.json()["error"]["message"].startswith( + f"Task {task} is not supported" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) +async def test_pooling_not_supported( + server: RemoteOpenAIServer, model_name: str, task: str +): + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float", + "task": task, + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + assert response.json()["error"]["message"].startswith( + f"Task {task} is not supported" + ) diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py index b3f12283fdbd..e971b23e8f1a 100644 --- a/tests/entrypoints/pooling/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -562,12 +562,40 @@ async def get_outputs(normalize): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_pooling(server: RemoteOpenAIServer, model_name: str): +async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str): + task = "embed" input_text = ["The chef prepared a delicious meal."] response = requests.post( server.url_for("pooling"), - json={"model": model_name, "input": input_text, "encoding_format": "float"}, + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": task, + }, + ) + + poolings = PoolingResponse.model_validate(response.json()) + + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 384 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): + task = "token_embed" + input_text = ["The chef prepared a delicious meal."] + + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": task, + }, ) poolings = PoolingResponse.model_validate(response.json()) @@ -575,3 +603,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str): assert len(poolings.data) == 1 assert len(poolings.data[0].data) == 11 assert len(poolings.data[0].data[0]) == 384 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"]) +async def test_pooling_not_supported( + server: RemoteOpenAIServer, model_name: str, task: str +): + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float", + "task": task, + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + assert response.json()["error"]["message"].startswith( + f"Task {task} is not supported" + ) diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py index e43148d25fee..1d85190c12a1 100644 --- a/tests/entrypoints/pooling/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -125,8 +125,8 @@ def test_invocations(server: RemoteOpenAIServer): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_activation(server: RemoteOpenAIServer, model_name: str): - async def get_outputs(activation): +async def test_use_activation(server: RemoteOpenAIServer, model_name: str): + async def get_outputs(use_activation): query = "What is the capital of France?" documents = [ "The capital of Brazil is Brasilia.", @@ -139,16 +139,16 @@ async def get_outputs(activation): "model": model_name, "query": query, "documents": documents, - "activation": activation, + "use_activation": use_activation, }, ) outputs = response.json() return torch.tensor([x["relevance_score"] for x in outputs["results"]]) - default = await get_outputs(activation=None) - w_activation = await get_outputs(activation=True) - wo_activation = await get_outputs(activation=False) + default = await get_outputs(use_activation=None) + w_activation = await get_outputs(use_activation=True) + wo_activation = await get_outputs(use_activation=False) assert torch.allclose(default, w_activation, atol=1e-2), ( "Default should use activation." @@ -163,7 +163,25 @@ async def get_outputs(activation): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_pooling(server: RemoteOpenAIServer, model_name: str): +async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): + input_text = "This product was excellent and exceeded my expectations" + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": "classify", + }, + ) + poolings = PoolingResponse.model_validate(response.json()) + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): input_text = ["The chef prepared a delicious meal."] response = requests.post( @@ -176,3 +194,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str): assert len(poolings.data) == 1 assert len(poolings.data[0].data) == 11 assert len(poolings.data[0].data[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) +async def test_pooling_not_supported( + server: RemoteOpenAIServer, model_name: str, task: str +): + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float", + "task": task, + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + assert response.json()["error"]["message"].startswith( + f"Task {task} is not supported" + ) diff --git a/tests/entrypoints/pooling/openai/test_score.py b/tests/entrypoints/pooling/openai/test_score.py index ef213ab0ea18..b8f796d47efa 100644 --- a/tests/entrypoints/pooling/openai/test_score.py +++ b/tests/entrypoints/pooling/openai/test_score.py @@ -218,8 +218,8 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]): # TODO: reset this tolerance to 0.01 once we find # an alternative to flash_attn with bfloat16 - def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): - def get_outputs(activation): + def test_use_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): + def get_outputs(use_activation): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." response = requests.post( @@ -228,7 +228,7 @@ def get_outputs(activation): "model": model["name"], "text_1": text_1, "text_2": text_2, - "activation": activation, + "use_activation": use_activation, }, ) if response.status_code != 200: @@ -238,9 +238,9 @@ def get_outputs(activation): return torch.tensor([x["score"] for x in outputs["data"]]) if model["is_cross_encoder"]: - default = get_outputs(activation=None) - w_activation = get_outputs(activation=True) - wo_activation = get_outputs(activation=False) + default = get_outputs(use_activation=None) + w_activation = get_outputs(use_activation=True) + wo_activation = get_outputs(use_activation=False) assert torch.allclose(default, w_activation, atol=1e-2), ( "Default should use activation." @@ -252,8 +252,8 @@ def get_outputs(activation): "w_activation should be close to activation(wo_activation)." ) else: - get_outputs(activation=None) + get_outputs(use_activation=None) # The activation parameter only works for the is_cross_encoder model - response = get_outputs(activation=True) + response = get_outputs(use_activation=True) assert response.status_code == 400 diff --git a/tests/models/language/pooling/test_pooler_config_init_behaviour.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py index 55663ee3f1b4..deb5de984d90 100644 --- a/tests/models/language/pooling/test_pooler_config_init_behaviour.py +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -24,7 +24,7 @@ def test_classify_models_using_activation( model, max_model_len=512, dtype=dtype, - pooler_config=PoolerConfig(activation=False), + pooler_config=PoolerConfig(use_activation=False), ) as vllm_model: wo_activation_out = vllm_model.classify(example_prompts) @@ -32,7 +32,7 @@ def test_classify_models_using_activation( model, max_model_len=512, dtype=dtype, - pooler_config=PoolerConfig(activation=True), + pooler_config=PoolerConfig(use_activation=True), ) as vllm_model: w_activation_out = vllm_model.classify(example_prompts) @@ -104,7 +104,7 @@ def test_reward_models_using_activation( model, max_model_len=1024, dtype=dtype, - pooler_config=PoolerConfig(activation=False), + pooler_config=PoolerConfig(use_activation=False), ) as vllm_model: wo_activation = vllm_model.reward(example_prompts) @@ -112,7 +112,7 @@ def test_reward_models_using_activation( model, max_model_len=1024, dtype=dtype, - pooler_config=PoolerConfig(activation=True), + pooler_config=PoolerConfig(use_activation=True), ) as vllm_model: w_activation = vllm_model.reward(example_prompts) diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index e73d7efc1483..7812562c8948 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -17,7 +17,7 @@ ), ] -classify_parameters = ["activation"] +classify_parameters = ["use_activation"] embed_parameters = ["dimensions", "normalize"] step_pooling_parameters = ["step_tag_id", "returned_token_ids"] @@ -88,13 +88,13 @@ def test_embed_dimensions(model_info: EmbedModelInfo): def test_classify(task): model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) - pooling_params = PoolingParams(activation=None) + pooling_params = PoolingParams(use_activation=None) pooling_params.verify(task=task, model_config=model_config) - pooling_params = PoolingParams(activation=True) + pooling_params = PoolingParams(use_activation=True) pooling_params.verify(task=task, model_config=model_config) - pooling_params = PoolingParams(activation=False) + pooling_params = PoolingParams(use_activation=False) pooling_params.verify(task=task, model_config=model_config) invalid_parameters = embed_parameters + step_pooling_parameters @@ -137,13 +137,13 @@ def test_token_classify(pooling_type: str): pooler_config=PoolerConfig(pooling_type=pooling_type) ) - pooling_params = PoolingParams(activation=None) + pooling_params = PoolingParams(use_activation=None) pooling_params.verify(task=task, model_config=model_config) - pooling_params = PoolingParams(activation=True) + pooling_params = PoolingParams(use_activation=True) pooling_params.verify(task=task, model_config=model_config) - pooling_params = PoolingParams(activation=False) + pooling_params = PoolingParams(use_activation=False) pooling_params.verify(task=task, model_config=model_config) invalid_parameters = embed_parameters diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 0590f74aa4c9..6bece8d0785b 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -7,6 +7,9 @@ from pydantic.dataclasses import dataclass from vllm.config.utils import config +from vllm.logger import init_logger + +logger = init_logger(__name__) @config @@ -48,7 +51,15 @@ class PoolerConfig: """ ## for classification models - activation: bool | None = None + softmax: float | None = None + """ + softmax will be deprecated, please use use_activation instead. + """ + activation: float | None = None + """ + activation will be deprecated, please use use_activation instead. + """ + use_activation: bool | None = None """ Whether to apply activation function to the classification outputs. Defaults to True. @@ -59,11 +70,6 @@ class PoolerConfig: """ ## for reward models - softmax: bool | None = None - """ - Whether to apply softmax to the reward outputs. - Defaults to True. - """ step_tag_id: int | None = None """ If set, only the score corresponding to the `step_tag_id` in the @@ -77,6 +83,10 @@ class PoolerConfig: `math-shepherd-mistral-7b-prm` model. """ + def __post_init__(self): + # raise deprecated warning for softmax and activation + self.use_activation = get_use_activation(self) + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -94,3 +104,19 @@ def compute_hash(self) -> str: factors: list[Any] = [] hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + + +def get_use_activation(o: object): + if softmax := getattr(o, "softmax", None) is not None: + logger.warning_once( + "softmax will be deprecated, please use use_activation instead." + ) + return softmax + + if activation := getattr(o, "activation", None) is not None: + logger.warning_once( + "activation will be deprecated, please use use_activation instead." + ) + return activation + + return getattr(o, "use_activation", None) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 71939d6c41df..f3aa5351e530 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -107,6 +107,7 @@ ) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager +from vllm.tasks import POOLING_TASKS from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.network_utils import is_valid_ipv6_address @@ -1748,12 +1749,7 @@ async def init_app_state( log_error_stack=args.log_error_stack, ) ) - if ( - any( - task in supported_tasks - for task in ["token_embed", "token_classify", "plugin"] - ) - ) + if any(task in POOLING_TASKS for task in supported_tasks) else None ) state.openai_serving_embedding = ( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0778e4d78790..d0061f9d5b40 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -49,6 +49,8 @@ ) from openai_harmony import Message as OpenAIHarmonyMessage +from vllm.config.pooler import get_use_activation +from vllm.tasks import PoolingTask from vllm.utils.serial_utils import ( EmbedDType, EncodingFormat, @@ -1669,8 +1671,58 @@ def to_pooling_params(self): EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest -PoolingCompletionRequest = EmbeddingCompletionRequest -PoolingChatRequest = EmbeddingChatRequest + +class PoolingCompletionRequest(EmbeddingCompletionRequest): + task: PoolingTask | None = None + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "If it is a classify or token_classify task, the default is True; " + "for other tasks, this value should be None.", + ) + + def to_pooling_params(self): + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize, + use_activation=get_use_activation(self), + ) + + +class PoolingChatRequest(EmbeddingChatRequest): + task: PoolingTask | None = None + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "If it is a classify or token_classify task, the default is True; " + "for other tasks, this value should be None.", + ) + + def to_pooling_params(self): + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize, + use_activation=get_use_activation(self), + ) + T = TypeVar("T") @@ -1686,6 +1738,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): """ data: T + task: PoolingTask = "plugin" encoding_format: EncodingFormat = "float" embed_dtype: EmbedDType = Field( default="float32", @@ -1749,14 +1802,27 @@ class ScoreRequest(OpenAIBaseModel): ), ) - activation: bool | None = None + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "Default is True.", + ) # --8<-- [end:score-extra-params] def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation, + use_activation=get_use_activation(self), ) @@ -1783,14 +1849,27 @@ class RerankRequest(OpenAIBaseModel): ), ) - activation: bool | None = None + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "Default is True.", + ) # --8<-- [end:rerank-extra-params] def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation, + use_activation=get_use_activation(self), ) @@ -1958,14 +2037,27 @@ class ClassificationRequest(OpenAIBaseModel): ), ) - activation: bool | None = None + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "Default is True.", + ) # --8<-- [end:classification-extra-params] def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation, + use_activation=get_use_activation(self), ) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 568896ccbf1b..0eade272111f 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -170,15 +170,24 @@ async def create_pooling( pooling_params = request.to_pooling_params() pooling_task: PoolingTask - if "token_embed" in self.supported_tasks: - pooling_task = "token_embed" - elif "token_classify" in self.supported_tasks: - pooling_task = "token_classify" - elif "plugin" in self.supported_tasks: - pooling_task = "plugin" + if request.task is None: + if "token_embed" in self.supported_tasks: + pooling_task = "token_embed" + elif "token_classify" in self.supported_tasks: + pooling_task = "token_classify" + elif "plugin" in self.supported_tasks: + pooling_task = "plugin" + else: + return self.create_error_response( + f"pooling_task must be one of {self.supported_tasks}." + ) else: + pooling_task = request.task + + if pooling_task not in self.supported_tasks: return self.create_error_response( - f"pooling_task must be one of {self.supported_tasks}." + f"Task {pooling_task} is not supported, it" + f" must be one of {self.supported_tasks}." ) try: diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 145f18f23566..7dd02e32ff21 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -607,7 +607,7 @@ def forward( pooled_data -= self.logit_bias pooling_params = get_pooling_params(pooling_metadata) - flags = [p.activation for p in pooling_params] + flags = [p.use_activation for p in pooling_params] if len(set(flags)) == 1: scores = self.act_fn(pooled_data) if flags[0] else pooled_data @@ -681,7 +681,7 @@ def forward( if self.logit_bias is not None: scores -= self.logit_bias - if pooling_param.activation: + if pooling_param.use_activation: scores = self.act_fn(scores) # scores shape: [n_token, num_labels] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ac5949cda9de..3bd02121f018 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -53,8 +53,8 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config - if pooler_config.activation is None: - pooler_config.activation = False + if pooler_config.use_activation is None: + pooler_config.use_activation = False class JinaRobertaModelConfig(VerifyAndUpdateConfig): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 090d92414465..72a8320cc1bf 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -2,16 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Annotated, Any, Optional +from typing import Annotated, Any, Optional import msgspec +from vllm.config import ModelConfig, PoolerConfig +from vllm.config.pooler import get_use_activation from vllm.sampling_params import RequestOutputKind from vllm.tasks import PoolingTask -if TYPE_CHECKING: - from vllm.config import ModelConfig, PoolerConfig - class PoolingParams( msgspec.Struct, @@ -25,10 +24,12 @@ class PoolingParams( Set to -1 to use the model's default truncation size. Set to k to keep only the last k tokens (left truncation). Set to None to disable truncation. - normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. - activation: Whether to apply activation function to + normalize: Whether to normalize the embeddings outputs. + softmax: softmax will be deprecated, please use use_activation instead. + activation: activation will be deprecated, please use use_activation instead. + use_activation: Whether to apply activation function to the classification outputs. """ @@ -44,7 +45,9 @@ class PoolingParams( ## for classification, scoring and rerank # --8<-- [start:classification-pooling-params] + softmax: bool | None = None activation: bool | None = None + use_activation: bool | None = None # --8<-- [end:classification-pooling-params] ## for step pooling models @@ -59,16 +62,16 @@ class PoolingParams( @property def all_parameters(self) -> list[str]: - return ["dimensions", "normalize", "activation"] + return ["dimensions", "normalize", "use_activation"] @property def valid_parameters(self): return { "embed": ["dimensions", "normalize"], - "classify": ["activation"], - "score": ["activation"], + "classify": ["use_activation"], + "score": ["use_activation"], "token_embed": ["dimensions", "normalize"], - "token_classify": ["activation"], + "token_classify": ["use_activation"], } def clone(self) -> "PoolingParams": @@ -84,6 +87,9 @@ def verify( msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" raise ValueError(msg) + # raise deprecated warning for softmax and activation + self.use_activation = get_use_activation(self) + # plugin task uses io_processor.parse_request to verify inputs, # skipping PoolingParams verify if self.task == "plugin": @@ -168,8 +174,8 @@ def _set_default_parameters(self, model_config: Optional["ModelConfig"]): raise ValueError("Dimensions must be greater than 0") elif self.task in ["classify", "score", "token_classify"]: - if self.activation is None: - self.activation = True + if self.use_activation is None: + self.use_activation = True else: raise ValueError(f"Unknown pooling task: {self.task}") @@ -197,7 +203,7 @@ def __repr__(self) -> str: f"task={self.task}, " f"normalize={self.normalize}, " f"dimensions={self.dimensions}, " - f"activation={self.activation}, " + f"use_activation={self.use_activation}, " f"step_tag_id={self.step_tag_id}, " f"returned_token_ids={self.returned_token_ids}, " f"requires_token_ids={self.requires_token_ids}, "