diff --git a/docs/source/models/pooling_models.md b/docs/source/models/pooling_models.md index 8c8d1832d382..3fd35e2e8bd1 100644 --- a/docs/source/models/pooling_models.md +++ b/docs/source/models/pooling_models.md @@ -140,6 +140,7 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha - [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models. - [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models. +- [Classification API](#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models. - [Score API](#score-api) is similar to `LLM.score` for cross-encoder models. ## Matryoshka Embeddings diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 34382c87a484..07bd211c2375 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -61,6 +61,8 @@ In addition, we have the following custom APIs: - Applicable to any model with a tokenizer. - [Pooling API](#pooling-api) (`/pooling`) - Applicable to all [pooling models](../models/pooling_models.md). +- [Classification API](#classification-api) (`/classify`) + - Only applicable to [classification models](../models/pooling_models.md) (`--task classify`). - [Score API](#score-api) (`/score`) - Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`). - [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) @@ -443,6 +445,130 @@ The input format is the same as [Embeddings API](#embeddings-api), but the outpu Code example: +(classification-api)= + +### Classification API + +Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach). + +We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. + +Code example: + +#### Example Requests + +You can classify multiple texts by passing an array of strings: + +Request: + +```bash +curl -v "http://127.0.0.1:8000/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": [ + "Loved the new café—coffee was great.", + "This update broke everything. Frustrating." + ] + }' +``` + +Response: + +```bash +{ + "id": "classify-7c87cac407b749a6935d8c7ce2a8fba2", + "object": "list", + "created": 1745383065, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [ + 0.565970778465271, + 0.4340292513370514 + ], + "num_classes": 2 + }, + { + "index": 1, + "label": "Spoiled", + "probs": [ + 0.26448777318000793, + 0.7355121970176697 + ], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 20, + "total_tokens": 20, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +You can also pass a string directly to the `input` field: + +Request: + +```bash +curl -v "http://127.0.0.1:8000/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new café—coffee was great." + }' +``` + +Response: + +```bash +{ + "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682", + "object": "list", + "created": 1745383213, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [ + 0.565970778465271, + 0.4340292513370514 + ], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 10, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +#### Extra parameters + +The following [pooling parameters](#pooling-params) are supported. + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-classification-pooling-params +:end-before: end-classification-pooling-params +::: + +The following extra parameters are supported: + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-classification-extra-params +:end-before: end-classification-extra-params +::: + (score-api)= ### Score API diff --git a/examples/online_serving/openai_classification_client.py b/examples/online_serving/openai_classification_client.py new file mode 100644 index 000000000000..99241346373e --- /dev/null +++ b/examples/online_serving/openai_classification_client.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import pprint + +import requests + + +def post_http_request(payload: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=payload) + return response + + +def parse_args(): + parse = argparse.ArgumentParser() + parse.add_argument("--host", type=str, default="localhost") + parse.add_argument("--port", type=int, default=8000) + parse.add_argument("--model", + type=str, + default="jason9693/Qwen2.5-1.5B-apeach") + return parse.parse_args() + + +def main(args): + host = args.host + port = args.port + model_name = args.model + + api_url = f"http://{host}:{port}/classify" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + payload = { + "model": model_name, + "input": prompts, + } + + classify_response = post_http_request(payload=payload, api_url=api_url) + pprint.pprint(classify_response.json()) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py new file mode 100644 index 000000000000..97124c85e0d3 --- /dev/null +++ b/tests/entrypoints/openai/test_classification.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import requests + +from vllm.entrypoints.openai.protocol import ClassificationResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" +DTYPE = "float32" # Use float32 to avoid NaN issue + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + "--max-model-len", + "512", + "--dtype", + DTYPE, + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_single_input_classification(server: RemoteOpenAIServer, + model_name: str): + input_text = "This product was excellent and exceeded my expectations" + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": input_text + }, + ) + + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_multiple_inputs_classification(server: RemoteOpenAIServer, + model_name: str): + input_texts = [ + "The product arrived on time and works perfectly", + "I'm very satisfied with my purchase, would buy again", + "The customer service was helpful and resolved my issue quickly", + "This product broke after one week, terrible quality", + "I'm very disappointed with this purchase, complete waste of money", + "The customer service was rude and unhelpful", + ] + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": input_texts + }, + ) + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert len(output.data) == len(input_texts) + for i, item in enumerate(output.data): + assert item.index == i + assert hasattr(item, "label") + assert hasattr(item, "probs") + assert len(item.probs) == item.num_classes + assert item.label in ["Default", "Spoiled"] + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): + long_text = "hello " * 600 + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": long_text, + "truncate_prompt_tokens": 5 + }, + ) + + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert len(output.data) == 1 + assert output.data[0].index == 0 + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 5 + assert output.usage.total_tokens == 5 + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, + model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": "test", + "truncate_prompt_tokens": 513 + }, + ) + + error = classification_response.json() + assert classification_response.status_code == 400 + assert error["object"] == "error" + assert "truncate_prompt_tokens" in error["message"] + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": "" + }, + ) + + error = classification_response.json() + assert classification_response.status_code == 400 + assert error["object"] == "error" + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_batch_classification_empty_list(server: RemoteOpenAIServer, + model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": [] + }, + ) + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert output.object == "list" + assert isinstance(output.data, list) + assert len(output.data) == 0 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e034eacb24ef..3c6852b41b98 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -48,6 +48,8 @@ # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, CompletionRequest, CompletionResponse, DetokenizeRequest, @@ -71,6 +73,8 @@ UnloadLoRAAdapterRequest) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_classification import ( + ServingClassification) from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import OpenAIServing @@ -373,6 +377,10 @@ def score(request: Request) -> Optional[ServingScores]: return request.app.state.openai_serving_scores +def classify(request: Request) -> Optional[ServingClassification]: + return request.app.state.openai_serving_classification + + def rerank(request: Request) -> Optional[ServingScores]: return request.app.state.openai_serving_scores @@ -405,6 +413,7 @@ async def get_server_load_metrics(request: Request): # - /v1/audio/transcriptions # - /v1/embeddings # - /pooling + # - /classify # - /score # - /v1/score # - /rerank @@ -572,6 +581,27 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) +@router.post("/classify", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_classify(request: ClassificationRequest, + raw_request: Request): + handler = classify(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Classification API") + + generator = await handler.create_classify(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ClassificationResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + @router.post("/score", dependencies=[Depends(validate_json_request)]) @with_cancellation @load_aware_call @@ -1001,6 +1031,12 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger) if model_config.task in ( "score", "embed", "pooling") else None + state.openai_serving_classification = ServingClassification( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.task == "classify" else None state.jinaai_serving_reranking = ServingScores( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index aa01e785f21a..4e09240f23af 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1292,6 +1292,47 @@ class ScoreResponse(OpenAIBaseModel): usage: UsageInfo +class ClassificationRequest(OpenAIBaseModel): + model: Optional[str] = None + input: Union[list[str], str] + truncate_prompt_tokens: Optional[int] = None + user: Optional[str] = None + + # doc: begin-classification-pooling-params + additional_data: Optional[Any] = None + # doc: end-classification-pooling-params + + # doc: begin-classification-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # doc: end-classification-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class ClassificationData(OpenAIBaseModel): + index: int + label: Optional[str] + probs: list[float] + num_classes: int + + +class ClassificationResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"classify-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[ClassificationData] + usage: UsageInfo + + class FunctionCall(OpenAIBaseModel): name: str arguments: str diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py new file mode 100644 index 000000000000..90cdd389d59f --- /dev/null +++ b/vllm/entrypoints/openai/serving_classification.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 + +from http import HTTPStatus +from typing import Optional, Union, cast + +import numpy as np +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ClassificationData, + ClassificationRequest, + ClassificationResponse, + ErrorResponse, UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, + OpenAIServing, + ServeContext) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import ClassificationOutput, PoolingRequestOutput + +logger = init_logger(__name__) + + +class ClassificationMixin(OpenAIServing): + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """ + Process classification inputs: tokenize text, resolve adapters, + and prepare model-specific inputs. + """ + ctx = cast(ClassificationServeContext, ctx) + if isinstance(ctx.request.input, str) and not ctx.request.input: + return self.create_error_response( + "Input cannot be empty for classification", + status_code=HTTPStatus.BAD_REQUEST, + ) + + if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0: + return None + + try: + ( + ctx.lora_request, + ctx.prompt_adapter_request, + ) = self._maybe_get_adapters(ctx.request) + + ctx.tokenizer = await self.engine_client.get_tokenizer( + ctx.lora_request) + + if ctx.prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported for classification models" + ) + + ( + ctx.request_prompts, + ctx.engine_prompts, + ) = await self._preprocess_completion( + ctx.request, + ctx.tokenizer, + ctx.request.input, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, + ) + + return None + + except (ValueError, TypeError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[ClassificationResponse, ErrorResponse]: + """ + Convert model outputs to a formatted classification response + with probabilities and labels. + """ + ctx = cast(ClassificationServeContext, ctx) + items: list[ClassificationData] = [] + num_prompt_tokens = 0 + + final_res_batch_checked = cast(list[PoolingRequestOutput], + ctx.final_res_batch) + + for idx, final_res in enumerate(final_res_batch_checked): + classify_res = ClassificationOutput.from_base(final_res.outputs) + + probs = classify_res.probs + predicted_index = int(np.argmax(probs)) + label = getattr(self.model_config.hf_config, "id2label", + {}).get(predicted_index) + + item = ClassificationData( + index=idx, + label=label, + probs=probs, + num_classes=len(probs), + ) + + items.append(item) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ClassificationResponse( + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, + data=items, + usage=usage, + ) + + +class ServingClassification(ClassificationMixin): + request_id_prefix = "classify" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + ) + + async def create_classify( + self, + request: ClassificationRequest, + raw_request: Request, + ) -> Union[ClassificationResponse, ErrorResponse]: + model_name = self._get_model_name(request.model) + request_id = (f"{self.request_id_prefix}-" + f"{self._base_request_id(raw_request)}") + + ctx = ClassificationServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) + + return await super().handle(ctx) # type: ignore diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4b4d2d8b76f4..3785d2642f9d 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import asyncio import base64 -import time -from collections.abc import AsyncGenerator from typing import Final, Literal, Optional, Union, cast import numpy as np from fastapi import Request -from typing_extensions import assert_never +from typing_extensions import assert_never, override from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -19,13 +16,13 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, + OpenAIServing, + ServeContext) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) -from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -45,180 +42,77 @@ def _get_embedding( assert_never(encoding_format) -class OpenAIServingEmbedding(OpenAIServing): +class EmbeddingMixin(OpenAIServing): - def __init__( - self, - engine_client: EngineClient, - model_config: ModelConfig, - models: OpenAIServingModels, - *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - chat_template_content_format: ChatTemplateContentFormatOption, - ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) - - self.chat_template = chat_template - self.chat_template_content_format: Final = chat_template_content_format - - async def create_embedding( + async def _preprocess( self, - request: EmbeddingRequest, - raw_request: Optional[Request] = None, - ) -> Union[EmbeddingResponse, ErrorResponse]: - """ - Embedding API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/embeddings/create - for the API specification. This API mimics the OpenAI Embedding API. - """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - encoding_format = request.encoding_format - - model_name = self._get_model_name(request.model) - request_id = f"embd-{self._base_request_id(raw_request)}" - created_time = int(time.time()) - - truncate_prompt_tokens = request.truncate_prompt_tokens - - pooling_params = request.to_pooling_params() - - try: - pooling_params.verify(self.model_config) - except ValueError as e: - return self.create_error_response(str(e)) - + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + ctx = cast(EmbeddingServeContext, ctx) try: - truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens) ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + ctx.lora_request, + ctx.prompt_adapter_request, + ) = self._maybe_get_adapters(ctx.request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request + ) - if prompt_adapter_request is not None: + if ctx.prompt_adapter_request is not None: raise NotImplementedError("Prompt adapter is not supported " "for embedding models") - if isinstance(request, EmbeddingChatRequest): + if isinstance(ctx.request, EmbeddingChatRequest): ( _, - request_prompts, - engine_prompts, + ctx.request_prompts, + ctx.engine_prompts, ) = await self._preprocess_chat( - request, + ctx.request, tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. + ctx.request.messages, + chat_template=ctx.request.chat_template + or ctx.chat_template, + chat_template_content_format=ctx. chat_template_content_format, # In embedding requests, we are not generating tokens, # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, + truncate_prompt_tokens=ctx.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, + (ctx.request_prompts, + ctx.engine_prompts) = await self._preprocess_completion( + ctx.request, tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, + ctx.request.input, + truncate_prompt_tokens=ctx.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, ) + return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - # Schedule the request and get the result generator. - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - try: - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - - self._log_inputs(request_id_item, - request_prompts[i], - params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) - - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) - - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - ) - - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - result_generator = merge_async_iterators(*generators) - - num_prompts = len(engine_prompts) - - # Non-streaming response - final_res_batch: list[Optional[PoolingRequestOutput]] - final_res_batch = [None] * num_prompts - try: - async for i, res in result_generator: - final_res_batch[i] = res - - assert all(final_res is not None for final_res in final_res_batch) - - final_res_batch_checked = cast(list[PoolingRequestOutput], - final_res_batch) - - response = self.request_output_to_embedding_response( - final_res_batch_checked, - request_id, - created_time, - model_name, - encoding_format, - ) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - return response - - def request_output_to_embedding_response( + def _build_response( self, - final_res_batch: list[PoolingRequestOutput], - request_id: str, - created_time: int, - model_name: str, - encoding_format: Literal["float", "base64"], - ) -> EmbeddingResponse: + ctx: ServeContext, + ) -> Union[EmbeddingResponse, ErrorResponse]: items: list[EmbeddingResponseData] = [] num_prompt_tokens = 0 - for idx, final_res in enumerate(final_res_batch): + final_res_batch_checked = cast(list[PoolingRequestOutput], + ctx.final_res_batch) + + for idx, final_res in enumerate(final_res_batch_checked): embedding_res = EmbeddingRequestOutput.from_base(final_res) item = EmbeddingResponseData( index=idx, embedding=_get_embedding(embedding_res.outputs, - encoding_format), + ctx.request.encoding_format), ) prompt_token_ids = final_res.prompt_token_ids @@ -231,9 +125,76 @@ def request_output_to_embedding_response( ) return EmbeddingResponse( - id=request_id, - created=created_time, - model=model_name, + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, data=items, usage=usage, ) + + +class OpenAIServingEmbedding(EmbeddingMixin): + request_id_prefix = "embd" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """ + Embedding API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + model_name = self._get_model_name(request.model) + request_id = (f"{self.request_id_prefix}-" + f"{self._base_request_id(raw_request)}") + + ctx = EmbeddingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + + return await super().handle(ctx) # type: ignore + + @override + def _validate_request( + self, + ctx: ServeContext[EmbeddingRequest], + ) -> Optional[ErrorResponse]: + if error := super()._validate_request(ctx): + return error + + ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens + + pooling_params = ctx.request.to_pooling_params() + + try: + pooling_params.verify(self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + return None diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index bb11650815ec..37134cfb3da3 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import json -from collections.abc import Iterable, Iterator, Mapping, Sequence +import time +from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, + Sequence) from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus -from typing import Annotated, Any, Callable, Optional, TypedDict, Union +from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, + TypedDict, TypeVar, Union) from fastapi import Request -from pydantic import Field +from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers import vllm.envs as envs @@ -24,15 +27,23 @@ resolve_chat_template_content_format) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, CompletionRequest, + CompletionResponse, DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, - ErrorResponse, RerankRequest, - ScoreRequest, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + PoolingResponse, RerankRequest, + ScoreRequest, ScoreResponse, TokenizeChatRequest, TokenizeCompletionRequest, - TranscriptionRequest) + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -40,6 +51,9 @@ from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin + MultiModalDataDict) +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -47,13 +61,15 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import is_list_of, make_async, random_uuid +from vllm.utils import (is_list_of, make_async, merge_async_iterators, + random_uuid) logger = init_logger(__name__) CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, EmbeddingCompletionRequest, RerankRequest, - ScoreRequest, TokenizeCompletionRequest] + ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] @@ -61,6 +77,17 @@ AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, TranscriptionRequest] +AnyResponse = Union[ + CompletionResponse, + ChatCompletionResponse, + EmbeddingResponse, + TranscriptionResponse, + TokenizeResponse, + PoolingResponse, + ClassificationResponse, + ScoreResponse, +] + class TextTokensPrompt(TypedDict): prompt: str @@ -69,8 +96,79 @@ class TextTokensPrompt(TypedDict): RequestPrompt = Union[list[int], str, TextTokensPrompt] +RequestT = TypeVar("RequestT", bound=AnyRequest) + + +class RequestProcessingMixin(BaseModel): + """ + Mixin for request processing, + handling prompt preparation and engine input. + """ + request_prompts: Optional[Sequence[RequestPrompt]] = \ + Field(default_factory=list) + engine_prompts: Optional[list[TokensPrompt]] = \ + Field(default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ResponseGenerationMixin(BaseModel): + """ + Mixin for response generation, + managing result generators and final batch results. + """ + result_generator: Optional[AsyncGenerator[tuple[int, Union[ + RequestOutput, PoolingRequestOutput]], None]] = None + final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( + default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, + Generic[RequestT]): + # Shared across all requests + request: RequestT + raw_request: Optional[Request] = None + model_name: str + request_id: str + created_time: int = Field(default_factory=lambda: int(time.time())) + lora_request: Optional[LoRARequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + # Shared across most requests + tokenizer: Optional[AnyTokenizer] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # `protected_namespaces` resolves Pydantic v2's warning + # on conflict with protected namespace "model_" + model_config = ConfigDict( + protected_namespaces=(), + arbitrary_types_allowed=True, + ) + + +ClassificationServeContext = ServeContext[ClassificationRequest] + + +class EmbeddingServeContext(ServeContext[EmbeddingRequest]): + chat_template: Optional[str] = None + chat_template_content_format: ChatTemplateContentFormatOption + + +# Used to resolve the Pydantic error related to +# forward reference of MultiModalDataDict in TokensPrompt +RequestProcessingMixin.model_rebuild() +ServeContext.model_rebuild() +ClassificationServeContext.model_rebuild() +EmbeddingServeContext.model_rebuild() + class OpenAIServing: + request_id_prefix: ClassVar[str] = """ + A short string prepended to every request’s ID (e.g. "embd", "classify") + so you can easily tell “this ID came from Embedding vs Classification.” + """ def __init__( self, @@ -100,6 +198,167 @@ def __init__( self._tokenize_prompt_input_or_inputs, executor=self._tokenizer_executor) + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """ + Default preprocessing hook. Subclasses may override + to prepare `ctx` (classification, embedding, etc.). + """ + return None + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[AnyResponse, ErrorResponse]: + """ + Default response builder. Subclass may override this method + to return the appropriate response object. + """ + return self.create_error_response("unimplemented endpoint") + + async def handle( + self, + ctx: ServeContext, + ) -> Union[AnyResponse, ErrorResponse]: + generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None] + generation = self._pipeline(ctx) + + async for response in generation: + return response + + return self.create_error_response("No response yielded from pipeline") + + async def _pipeline( + self, + ctx: ServeContext, + ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]: + """Execute the request processing pipeline yielding responses.""" + if error := await self._check_model(ctx.request): + yield error + if error := self._validate_request(ctx): + yield error + + preprocess_ret = await self._preprocess(ctx) + if isinstance(preprocess_ret, ErrorResponse): + yield preprocess_ret + + generators_ret = await self._prepare_generators(ctx) + if isinstance(generators_ret, ErrorResponse): + yield generators_ret + + collect_ret = await self._collect_batch(ctx) + if isinstance(collect_ret, ErrorResponse): + yield collect_ret + + yield self._build_response(ctx) + + def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", + None) + + if truncate_prompt_tokens is not None: + if truncate_prompt_tokens <= self.max_model_len: + ctx.truncate_prompt_tokens = truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + return None + + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Schedule the request and get the result generator.""" + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + pooling_params = ctx.request.to_pooling_params() + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + self._log_inputs( + request_id_item, + ctx.request_prompts[i], + params=pooling_params, + lora_request=ctx.lora_request, + prompt_adapter_request=ctx.prompt_adapter_request) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect batch results from the result generator.""" + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[Optional[Union[RequestOutput, + PoolingRequestOutput]]] + final_res_batch = [None] * num_prompts + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts") + + ctx.final_res_batch = [ + res for res in final_res_batch if res is not None + ] + + return None + + except Exception as e: + return self.create_error_response(str(e)) + def create_error_response( self, message: str, @@ -183,6 +442,12 @@ def _normalize_prompt_text_to_input( if truncate_prompt_tokens is None: encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + elif truncate_prompt_tokens < 0: + # Negative means we cap at the model's max length + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=self.max_model_len) else: encoded = tokenizer(prompt, add_special_tokens=add_special_tokens, @@ -204,6 +469,8 @@ def _normalize_prompt_tokens_to_input( ) -> TextTokensPrompt: if truncate_prompt_tokens is None: input_ids = prompt_ids + elif truncate_prompt_tokens < 0: + input_ids = prompt_ids[-self.max_model_len:] else: input_ids = prompt_ids[-truncate_prompt_tokens:] @@ -219,13 +486,16 @@ def _validate_input( ) -> TextTokensPrompt: token_num = len(input_ids) - # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens + # Note: EmbeddingRequest, ClassificationRequest, + # and ScoreRequest doesn't have max_tokens if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest, - ScoreRequest, RerankRequest)): + ScoreRequest, RerankRequest, ClassificationRequest)): + operation = { + ScoreRequest: "score", + ClassificationRequest: "classification" + }.get(type(request), "embedding generation") - operation = "score" if isinstance(request, ScoreRequest) \ - else "embedding generation" if token_num > self.max_model_len: raise ValueError( f"This model's maximum context length is " @@ -247,7 +517,7 @@ def _validate_input( # TODO(#9845): remove max_tokens when field dropped from OpenAI API max_tokens = request.max_completion_tokens or request.max_tokens else: - max_tokens = request.max_tokens + max_tokens = getattr(request, "max_tokens", None) if max_tokens is None: if token_num >= self.max_model_len: raise ValueError(