From e4aee908d0b4a1ab2608f14c30c1cde2bca4613d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 23 Jan 2025 14:06:50 -0500 Subject: [PATCH 01/27] whisper-async working poc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- query_transcription.py | 24 +++ vllm/entrypoints/openai/api_server.py | 38 ++++- vllm/entrypoints/openai/protocol.py | 160 +++++++++++++++++- .../openai/serving_transcription.py | 137 +++++++++++++++ 4 files changed, 356 insertions(+), 3 deletions(-) create mode 100644 query_transcription.py create mode 100644 vllm/entrypoints/openai/serving_transcription.py diff --git a/query_transcription.py b/query_transcription.py new file mode 100644 index 000000000000..8132958a077b --- /dev/null +++ b/query_transcription.py @@ -0,0 +1,24 @@ +from openai import OpenAI +from openai.types.audio import TranscriptionCreateParams +from pathlib import Path +import io + +mary_had_lamb = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/mary_had_lamb.ogg') +winning_call = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/winning_call.ogg') + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + prompt="<|startoftranscript|>", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8f54d6c7804..7d4811073129 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,10 +17,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union, List, Annotated import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Request, Form from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -76,6 +76,7 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import OpenAIServingTranscription from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -318,6 +319,9 @@ def rerank(request: Request) -> Optional[JinaAIServingRerank]: def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -510,6 +514,29 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/v1/audio/transcriptions") +@with_cancellation +async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()], + raw_request: Request): + + audio_data = await request.file.read() + + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + generator = await handler.create_transcription(audio_data, request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/rerank") @with_cancellation @@ -734,6 +761,7 @@ async def init_app_state( state: State, args: Namespace, ) -> None: + if args.served_model_name is not None: served_model_names = args.served_model_name else: @@ -821,6 +849,12 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) state.task = model_config.task diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83b841826231..5bcf9f324260 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,12 +5,14 @@ import re import time from argparse import Namespace -from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union +from os import PathLike +from typing import Any, Dict, List, Literal, Optional, Union, TypeAlias, TYPE_CHECKING, Tuple, Mapping import torch from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) from typing_extensions import Annotated +from fastapi import UploadFile from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1426,3 +1428,159 @@ class LoadLoraAdapterRequest(BaseModel): class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + #https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these formats: + flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: str + """ID of the model to use. + """ + + language: str + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will + improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. If set to 0, the model will use + [log probability](https://en.wikipedia.org/wiki/Log_probability) to + automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: List[Literal["word", "segment"]] = Field(alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: There + is no additional latency for segment timestamps, but generating word timestamps + incurs additional latency. + """ + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + # TODO (varun) : ATM the max_tokens are set to the max-model-len - len(prompt_ids). + # Tbis makes sense. but is this okay ? + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens) + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this + segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" + diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 000000000000..c1d46ca0581a --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,137 @@ +import asyncio +import io +#import time +from typing import Any, AsyncGenerator, Dict, Optional, Union + +## TODO (varun) : This is used for testing.. use pydub instead ????? +import librosa +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput + +logger = init_logger(__name__) + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + # TODO (varun) : pass in a tokenizer and return tokenized values !! + async def _preprocess_transcription( + self, audio_data: bytes, + request: TranscriptionRequest) -> Dict[Any, Any]: + return { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": librosa.load(io.BytesIO(audio_data)), + }, + }, + # TODO (Varun) : Should this instead be encoder prompt ??? + "decoder_prompt": f"{request.prompt}", + } + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI completion API. + """ + + assert request.response_format in ['text', 'json'] + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + request_id = f"cmpl-{self._base_request_id(raw_request)}" + # TODO (varun) : other serving_* files use this -- we should use + # it as well. + #created_time = int(time.time()) + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # TODO (varun) : Does Whisper have LoRA ? + #tokenizer = await self.engine_client.get_tokenizer(None) + + prompt = await self._preprocess_transcription(audio_data, request) + + default_sampling_params = (self.model_config.get_diff_sampling_param()) + + # TODO (Varun) : figure out default_max_tokens by tokenizing first + default_max_tokens = 200 + sampling_params = request.to_sampling_params(default_max_tokens, + default_sampling_params) + + self._log_inputs( + request_id, + prompt['decoder_prompt'], + params=sampling_params, + lora_request=None, + prompt_adapter_request=None, + ) + + generator: AsyncGenerator[RequestOutput, None] = None + try: + generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # Non-streaming response + result: Optional[RequestOutput] = None + + try: + async for op in generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) From 025084bc76d15b09f78381e48a9151741ef35c4d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Feb 2025 02:16:28 +0000 Subject: [PATCH 02/27] updated MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: rshaw@neuralmagic.com Signed-off-by: Daniele Trifirò --- query_transcription.py | 20 ++--- vllm/assets/audio.py | 5 ++ vllm/entrypoints/openai/api_server.py | 16 ++-- vllm/entrypoints/openai/protocol.py | 55 ++++++------ .../openai/serving_transcription.py | 88 +++++++++++-------- 5 files changed, 106 insertions(+), 78 deletions(-) diff --git a/query_transcription.py b/query_transcription.py index 8132958a077b..b0c9c95c8818 100644 --- a/query_transcription.py +++ b/query_transcription.py @@ -1,10 +1,9 @@ from openai import OpenAI -from openai.types.audio import TranscriptionCreateParams -from pathlib import Path -import io -mary_had_lamb = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/mary_had_lamb.ogg') -winning_call = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/winning_call.ogg') +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path() +winning_call = AudioAsset('winning_call').get_asset_path() # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" @@ -15,10 +14,9 @@ ) with open(str(mary_had_lamb), "rb") as f: transcription = client.audio.transcriptions.create( - file=f, - model="openai/whisper-large-v3", - language="en", - prompt="<|startoftranscript|>", - response_format="text", - temperature=0.0) + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) print("transcription result:", transcription) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index d9e51082e6ca..c465527b4e89 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from pathlib import Path from typing import Literal from urllib.parse import urljoin @@ -28,6 +29,10 @@ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) + def get_asset_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + @property def url(self) -> str: return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7d4811073129..fe5c37e70244 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,10 +17,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union, List, Annotated +from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request, Form +from fastapi import APIRouter, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -76,7 +76,8 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) -from vllm.entrypoints.openai.serving_transcription import OpenAIServingTranscription +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -319,6 +320,7 @@ def rerank(request: Request) -> Optional[JinaAIServingRerank]: def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization + def transcription(request: Request) -> OpenAIServingTranscription: return request.app.state.openai_serving_transcription @@ -514,9 +516,11 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) + @router.post("/v1/audio/transcriptions") @with_cancellation -async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()], +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], raw_request: Request): audio_data = await request.file.read() @@ -526,7 +530,8 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()] return base(raw_request).create_error_response( message="The model does not support Transcriptions API") - generator = await handler.create_transcription(audio_data, request, raw_request) + generator = await handler.create_transcription(audio_data, request, + raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -761,7 +766,6 @@ async def init_app_state( state: State, args: Namespace, ) -> None: - if args.served_model_name is not None: served_model_names = args.served_model_name else: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5bcf9f324260..c6991f21b78e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,14 +5,13 @@ import re import time from argparse import Namespace -from os import PathLike -from typing import Any, Dict, List, Literal, Optional, Union, TypeAlias, TYPE_CHECKING, Tuple, Mapping +from typing import Any, Dict, List, Literal, Optional, TypeAlias, Union import torch from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from typing_extensions import Annotated from fastapi import UploadFile +from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1429,17 +1428,20 @@ class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + ## Protocols for Audio -AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + class TranscriptionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation #https://platform.openai.com/docs/api-reference/audio/createTranscription - file: UploadFile + file: UploadFile """ - The audio file object (not file name) to transcribe, in one of these formats: - flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. """ model: str @@ -1450,8 +1452,8 @@ class TranscriptionRequest(OpenAIBaseModel): """The language of the input audio. Supplying the input language in - [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will - improve accuracy and latency. + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. """ prompt: str = Field(default="") @@ -1468,23 +1470,24 @@ class TranscriptionRequest(OpenAIBaseModel): `verbose_json`, or `vtt`. """ - ## TODO (varun) : Support if set to 0, certain thresholds are met !! + ## TODO (varun) : Support if set to 0, certain thresholds are met !! temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. - Higher values like 0.8 will make the output more random, while lower values like - 0.2 will make it more focused and deterministic. If set to 0, the model will use - [log probability](https://en.wikipedia.org/wiki/Log_probability) to - automatically increase the temperature until certain thresholds are hit. + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. """ - timestamp_granularities: List[Literal["word", "segment"]] = Field(alias="timestamp_granularities[]", default=[]) + timestamp_granularities: List[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) """The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. - Either or both of these options are supported: `word`, or `segment`. Note: There - is no additional latency for segment timestamps, but generating word timestamps - incurs additional latency. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. """ # Default sampling parameters for transcription requests. @@ -1506,17 +1509,16 @@ def to_sampling_params( temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) - # TODO (varun) : ATM the max_tokens are set to the max-model-len - len(prompt_ids). - # Tbis makes sense. but is this okay ? - return SamplingParams.from_optional( - temperature=temperature, - max_tokens=max_tokens) + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens) + # Transcription response objects class TranscriptionResponse(OpenAIBaseModel): text: str """The transcribed text.""" + class TranscriptionWord(OpenAIBaseModel): end: float """End time of the word in seconds.""" @@ -1527,6 +1529,7 @@ class TranscriptionWord(OpenAIBaseModel): word: str """The text content of the word.""" + class TranscriptionSegment(OpenAIBaseModel): id: int """Unique identifier of the segment.""" @@ -1549,8 +1552,8 @@ class TranscriptionSegment(OpenAIBaseModel): no_speech_prob: float """Probability of no speech in the segment. - If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this - segment silent. + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. """ seek: int @@ -1568,6 +1571,7 @@ class TranscriptionSegment(OpenAIBaseModel): tokens: List[int] """Array of token IDs for the text content.""" + class TranscriptionResponseVerbose(OpenAIBaseModel): duration: str """The duration of the input audio.""" @@ -1583,4 +1587,3 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): words: Optional[List[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" - diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index c1d46ca0581a..572a6d1ae73c 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,9 +1,7 @@ import asyncio import io -#import time from typing import Any, AsyncGenerator, Dict, Optional, Union -## TODO (varun) : This is used for testing.. use pydub instead ????? import librosa from fastapi import Request @@ -45,10 +43,11 @@ def __init__( "Overwriting default completion sampling param with: %s", diff_sampling_param) - # TODO (varun) : pass in a tokenizer and return tokenized values !! async def _preprocess_transcription( - self, audio_data: bytes, - request: TranscriptionRequest) -> Dict[Any, Any]: + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> Dict[Any, Any]: return { "encoder_prompt": { "prompt": "", @@ -56,8 +55,10 @@ async def _preprocess_transcription( "audio": librosa.load(io.BytesIO(audio_data)), }, }, - # TODO (Varun) : Should this instead be encoder prompt ??? - "decoder_prompt": f"{request.prompt}", + # TODO(rob): tokenize here. + "decoder_prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" + # "decoder_prompt": f"{request.prompt}", } # TODO (varun) : Make verbose response work ! @@ -71,8 +72,6 @@ async def create_transcription( for the API specification. This API mimics the OpenAI completion API. """ - assert request.response_format in ['text', 'json'] - error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -83,38 +82,54 @@ async def create_transcription( if self.engine_client.errored: raise self.engine_client.dead_error + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + request_id = f"cmpl-{self._base_request_id(raw_request)}" - # TODO (varun) : other serving_* files use this -- we should use - # it as well. - #created_time = int(time.time()) request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: raw_request.state.request_metadata = request_metadata - # TODO (varun) : Does Whisper have LoRA ? - #tokenizer = await self.engine_client.get_tokenizer(None) - - prompt = await self._preprocess_transcription(audio_data, request) - - default_sampling_params = (self.model_config.get_diff_sampling_param()) - - # TODO (Varun) : figure out default_max_tokens by tokenizing first - default_max_tokens = 200 - sampling_params = request.to_sampling_params(default_max_tokens, - default_sampling_params) + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) - self._log_inputs( - request_id, - prompt['decoder_prompt'], - params=sampling_params, - lora_request=None, - prompt_adapter_request=None, - ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) - generator: AsyncGenerator[RequestOutput, None] = None + result_generator: AsyncGenerator[RequestOutput, None] = None try: - generator = self.engine_client.generate( + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + default_params = self.model_config.get_diff_sampling_param() + sampling_params = request.to_sampling_params( + default_max_tokens, default_params) + + self._log_inputs(request_id, + prompt['decoder_prompt'], + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( prompt, sampling_params, request_id, @@ -123,11 +138,14 @@ async def create_transcription( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - # Non-streaming response - result: Optional[RequestOutput] = None + # TODO(rob): figure out a way to pipe streaming in. + stream = False + if stream: + return None + # Non-streaming response. try: - async for op in generator: + async for op in result_generator: result = op return TranscriptionResponse(text=result.outputs[0].text) except asyncio.CancelledError: From e37ea61c72332192491e8538efb86f54714863c0 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 5 Feb 2025 18:10:52 +0000 Subject: [PATCH 03/27] language+prompt+validation and first tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .../openai/test_transcription_validation.py | 70 ++++++++ vllm/entrypoints/openai/protocol.py | 2 +- .../openai/serving_transcription.py | 149 +++++++++++++++++- 3 files changed, 214 insertions(+), 7 deletions(-) create mode 100644 tests/entrypoints/openai/test_transcription_validation.py diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py new file mode 100644 index 000000000000..7b15b3fc628a --- /dev/null +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 + +# imports for guided decoding tests +import re + +import openai +import pytest + +from ...utils import RemoteOpenAIServer +from vllm.assets.audio import AudioAsset +import json + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset('mary_had_lamb').get_asset_path() + with open(str(path), "rb") as f: + yield f + +@pytest.fixture +def winning_call(): + path = AudioAsset('winning_call').get_asset_path() + with open(str(path), "rb") as f: + yield f + +@pytest.mark.asyncio +async def test_basic_audio(mary_had_lamb): + model_name = "openai/whisper-large-v3-turbo" + server_args = ["--enforce-eager"] + # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. + prompt="THE FIRST WORDS I SPOKE" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Mary had a little lamb," in out + # This should "force" whisper to continue prompt in all caps + transcription_wprompt = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_capital = json.loads(transcription_wprompt)['text'] + assert prompt not in out_capital + print(out_capital.capitalize(), out_capital) + + +@pytest.mark.asyncio +async def test_bad_requests(mary_had_lamb): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + # invalid language + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="hh", + response_format="text", + temperature=0.0) + + # TODO audio too long diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c6991f21b78e..0427c542b199 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1448,7 +1448,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ID of the model to use. """ - language: str + language: Optional[str] = None """The language of the input audio. Supplying the input language in diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 572a6d1ae73c..7677867688a7 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -19,6 +19,120 @@ logger = init_logger(__name__) +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", +"ar": "Arabic", +"hy": "Armenian", +"az": "Azerbaijani", +"be": "Belarusian", +"bs": "Bosnian", +"bg": "Bulgarian", +"ca": "Catalan", +"zh": "Chinese", +"hr": "Croatian", +"cs": "Czech", +"da": "Danish", +"nl": "Dutch", +"en": "English", +"et": "Estonian", +"fi": "Finnish", +"fr": "French", +"gl": "Galician", +"de": "German", +"el": "Greek", +"he": "Hebrew", +"hi": "Hindi", +"hu": "Hungarian", +"is": "Icelandic", +"id": "Indonesian", +"it": "Italian", +"ja": "Japanese", +"kn": "Kannada", +"kk": "Kazakh", +"ko": "Korean", +"lv": "Latvian", +"lt": "Lithuanian", +"mk": "Macedonian", +"ms": "Malay", +"mr": "Marathi", +"mi": "Maori", +"ne": "Nepali", +"no": "Norwegian", +"fa": "Persian", +"pl": "Polish", +"pt": "Portuguese", +"ro": "Romanian", +"ru": "Russian", +"sr": "Serbian", +"sk": "Slovak", +"sl": "Slovenian", +"es": "Spanish", +"sw": "Swahili", +"sv": "Swedish", +"tl": "Tagalog", +"ta": "Tamil", +"th": "Thai", +"tr": "Turkish", +"uk": "Ukrainian", +"ur": "Urdu", +"vi": "Vietnamese", +"cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +# TODO get from processor.feature_extractor.chunk_length +MAX_AUDIO_CLIP_DURATION_S = 30 + class OpenAIServingTranscription(OpenAIServing): @@ -48,17 +162,39 @@ async def _preprocess_transcription( request: TranscriptionRequest, audio_data: bytes, ) -> Dict[Any, Any]: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "" #"<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning(f"The selected language {request.language} has"+ + " limited accuracy with reported WER>=0.5."+ + " Results may be less accurate for this choice.") + else: + raise ValueError(f"Unsupported language: {request.language}." + f"Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())} or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data)/1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("") + + y, sr = librosa.load(io.BytesIO(audio_data)) + if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: + raise ValueError("") + return { "encoder_prompt": { "prompt": "", "multi_modal_data": { - "audio": librosa.load(io.BytesIO(audio_data)), + "audio": (y, sr), }, }, - # TODO(rob): tokenize here. "decoder_prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" - # "decoder_prompt": f"{request.prompt}", + f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" } # TODO (varun) : Make verbose response work ! @@ -66,10 +202,10 @@ async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose]: - """Completion API similar to OpenAI's API. + """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription - for the API specification. This API mimics the OpenAI completion API. + for the API specification. This API mimics the OpenAI transcription API. """ error_check_ret = await self._check_model(request) @@ -86,6 +222,7 @@ async def create_transcription( return self.create_error_response( "Currently only support response_format `text` or `json`") + # TODO cmpl->transcription? request_id = f"cmpl-{self._base_request_id(raw_request)}" request_metadata = RequestResponseMetadata(request_id=request_id) From 2edbd0e30cc7e20047e7bd88cf2732bee371e075 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 6 Feb 2025 15:00:59 +0000 Subject: [PATCH 04/27] error msgs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .../openai/serving_transcription.py | 146 +++++++++--------- 1 file changed, 76 insertions(+), 70 deletions(-) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 7677867688a7..2a84a4692cc6 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -22,64 +22,65 @@ # From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages # TODO these configs should live somewhere with the model so we can support # additional ones + ISO639_1_SUPPORTED_LANGS = { "af": "Afrikaans", -"ar": "Arabic", -"hy": "Armenian", -"az": "Azerbaijani", -"be": "Belarusian", -"bs": "Bosnian", -"bg": "Bulgarian", -"ca": "Catalan", -"zh": "Chinese", -"hr": "Croatian", -"cs": "Czech", -"da": "Danish", -"nl": "Dutch", -"en": "English", -"et": "Estonian", -"fi": "Finnish", -"fr": "French", -"gl": "Galician", -"de": "German", -"el": "Greek", -"he": "Hebrew", -"hi": "Hindi", -"hu": "Hungarian", -"is": "Icelandic", -"id": "Indonesian", -"it": "Italian", -"ja": "Japanese", -"kn": "Kannada", -"kk": "Kazakh", -"ko": "Korean", -"lv": "Latvian", -"lt": "Lithuanian", -"mk": "Macedonian", -"ms": "Malay", -"mr": "Marathi", -"mi": "Maori", -"ne": "Nepali", -"no": "Norwegian", -"fa": "Persian", -"pl": "Polish", -"pt": "Portuguese", -"ro": "Romanian", -"ru": "Russian", -"sr": "Serbian", -"sk": "Slovak", -"sl": "Slovenian", -"es": "Spanish", -"sw": "Swahili", -"sv": "Swedish", -"tl": "Tagalog", -"ta": "Tamil", -"th": "Thai", -"tr": "Turkish", -"uk": "Ukrainian", -"ur": "Urdu", -"vi": "Vietnamese", -"cy": "Welsh" + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" } ISO639_1_OTHER_LANGS = { "lo": "Lao", @@ -129,9 +130,9 @@ # As per https://platform.openai.com/docs/guides/speech-to-text#overview. # TODO configurable -MAX_AUDIO_CLIP_FILESIZE_MB = 25 +MAX_AUDIO_CLIP_FILESIZE_MB = 25 # TODO get from processor.feature_extractor.chunk_length -MAX_AUDIO_CLIP_DURATION_S = 30 +MAX_AUDIO_CLIP_DURATION_S = 30 class OpenAIServingTranscription(OpenAIServing): @@ -163,28 +164,33 @@ async def _preprocess_transcription( audio_data: bytes, ) -> Dict[Any, Any]: # Validate request - # TODO language should be optional and can be guessed. + # TODO language should be optional and can be guessed. # For now we default to en. See # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 - lang_token = f"<|{request.language}|>" if request.language else "" #"<|en|>" + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" if request.language: if request.language in ISO639_1_SUPPORTED_LANGS: pass elif request.language in ISO639_1_OTHER_LANGS: - logger.warning(f"The selected language {request.language} has"+ - " limited accuracy with reported WER>=0.5."+ - " Results may be less accurate for this choice.") + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) else: - raise ValueError(f"Unsupported language: {request.language}." - f"Language should be one of:" + - f" {list(ISO639_1_SUPPORTED_LANGS.values())} or {list(ISO639_1_OTHER_LANGS.values())}") + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") - if len(audio_data)/1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: - raise ValueError("") - y, sr = librosa.load(io.BytesIO(audio_data)) if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: - raise ValueError("") + raise ValueError( + f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s)\ + exceeded.") return { "encoder_prompt": { From 150a32a9e857adc499b38c4685c13a05a240ce47 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 13:24:13 +0000 Subject: [PATCH 05/27] more tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- tests/entrypoints/openai/test_audio.py | 1 + .../openai/test_transcription_validation.py | 44 +++++++++++++------ .../openai/serving_transcription.py | 6 +-- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index fe7299a48e6f..f85a1ae45888 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -17,6 +17,7 @@ ] +# TODO rename to test chat audio or multimodal @pytest.fixture(scope="module") def server(): args = [ diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 7b15b3fc628a..98a1a8d74940 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -1,14 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # imports for guided decoding tests -import re +import io +import json +import librosa +import numpy as np import openai import pytest +import soundfile as sf -from ...utils import RemoteOpenAIServer from vllm.assets.audio import AudioAsset -import json + +from ...utils import RemoteOpenAIServer + @pytest.fixture def mary_had_lamb(): @@ -16,18 +21,20 @@ def mary_had_lamb(): with open(str(path), "rb") as f: yield f + @pytest.fixture def winning_call(): path = AudioAsset('winning_call').get_asset_path() with open(str(path), "rb") as f: yield f + @pytest.mark.asyncio async def test_basic_audio(mary_had_lamb): model_name = "openai/whisper-large-v3-turbo" server_args = ["--enforce-eager"] # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. - prompt="THE FIRST WORDS I SPOKE" + prompt = "THE FIRST WORDS I SPOKE" with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() transcription = await client.audio.transcriptions.create( @@ -48,8 +55,7 @@ async def test_basic_audio(mary_had_lamb): temperature=0.0) out_capital = json.loads(transcription_wprompt)['text'] assert prompt not in out_capital - print(out_capital.capitalize(), out_capital) - + @pytest.mark.asyncio async def test_bad_requests(mary_had_lamb): @@ -60,11 +66,21 @@ async def test_bad_requests(mary_had_lamb): # invalid language with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="hh", - response_format="text", - temperature=0.0) - - # TODO audio too long + await client.audio.transcriptions.create(model=model_name, + file=mary_had_lamb, + language="hh", + temperature=0.0) + + # Expect audio too long: repeat the timeseries + mary_had_lamb.seek(0) + audio, sr = librosa.load(mary_had_lamb) + repeated_audio = np.tile(audio, 10) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=buffer, + language="en", + temperature=0.0) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 2a84a4692cc6..fc82b480b7f9 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -94,7 +94,7 @@ "yo": "Yoruba", "sa": "Sanskrit", "mi": "Māori", - "fo": "Faroese", + "fo": "Faroese", # codespell:ignore "mt": "Maltese", "tg": "Tajik", "mg": "Malagasy", @@ -189,8 +189,8 @@ async def _preprocess_transcription( y, sr = librosa.load(io.BytesIO(audio_data)) if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: raise ValueError( - f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s)\ - exceeded.") + f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " + "exceeded.") return { "encoder_prompt": { From c3511a182c6ea542aeb37c0a218316dd636285b8 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 14:13:48 +0000 Subject: [PATCH 06/27] docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .../serving/openai_compatible_server.md | 11 +++++++++ .../openai_transcription_client.py | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 examples/online_serving/openai_transcription_client.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 82ef54c16daf..68f3a177d211 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -41,6 +41,8 @@ We currently support the following OpenAI APIs: - *Note: `parallel_tool_calls` and `user` parameters are ignored.* - [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). +- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) + - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). In addition, we have the following custom APIs: @@ -298,6 +300,15 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s (tokenizer-api)= +### Transcriptions API + +Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); +you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. + + + +Code example: + ### Tokenizer API Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py new file mode 100644 index 000000000000..3e50b5d52575 --- /dev/null +++ b/examples/online_serving/openai_transcription_client.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path() +winning_call = AudioAsset('winning_call').get_asset_path() + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) From 806a07e7ff46fcada391c6a9bdc2c431c5a857a3 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 14:23:35 +0000 Subject: [PATCH 07/27] cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- query_transcription.py | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 query_transcription.py diff --git a/query_transcription.py b/query_transcription.py deleted file mode 100644 index b0c9c95c8818..000000000000 --- a/query_transcription.py +++ /dev/null @@ -1,22 +0,0 @@ -from openai import OpenAI - -from vllm.assets.audio import AudioAsset - -mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path() -winning_call = AudioAsset('winning_call').get_asset_path() - -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) -with open(str(mary_had_lamb), "rb") as f: - transcription = client.audio.transcriptions.create( - file=f, - model="openai/whisper-large-v3", - language="en", - response_format="text", - temperature=0.0) - print("transcription result:", transcription) From 8a7d4d8505493d64f6fd162eadaaf94b5502ecd9 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 15:58:22 +0000 Subject: [PATCH 08/27] CI correctness tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .buildkite/asr-eval/run-tests.sh | 16 ++ .../test_transcription_api_correctness.py | 208 ++++++++++++++++++ .buildkite/test-pipeline.yaml | 10 + requirements-test.in | 1 + 4 files changed, 235 insertions(+) create mode 100644 .buildkite/asr-eval/run-tests.sh create mode 100644 .buildkite/asr-eval/test_transcription_api_correctness.py diff --git a/.buildkite/asr-eval/run-tests.sh b/.buildkite/asr-eval/run-tests.sh new file mode 100644 index 000000000000..9cbaf9f7c3fd --- /dev/null +++ b/.buildkite/asr-eval/run-tests.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -x + +# Start server +python3 -m vllm.entrypoints.openai.api_server --model openai/whisper-large-v3 $@ & +server_pid=$! + +# Wait for server to start, timeout after 600 seconds +timeout 180 bash -c 'until curl localhost:8000/v1/models; do sleep 4; done' || exit 1 + +# NOTE: Expected WER measured with hf.transformers equivalent model on same dataset. +# Original dataset split is about 23GB in size, hence we use a pre-filtered slice. +python test_transcription_api_correctness.py -m openai/whisper-large-v3 -dr D4nt3/esb-datasets-earnings22-validation-tiny-filtered --expected-wer 12.744980 + +# Wait for graceful exit +kill $server_pid diff --git a/.buildkite/asr-eval/test_transcription_api_correctness.py b/.buildkite/asr-eval/test_transcription_api_correctness.py new file mode 100644 index 000000000000..f3fb98527e3d --- /dev/null +++ b/.buildkite/asr-eval/test_transcription_api_correctness.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +""" +import asyncio +import io +import time +from argparse import ArgumentParser +from statistics import mean, median +from typing import List, Optional + +import librosa +import soundfile +import torch +from datasets import load_dataset +from evaluate import load +from openai import AsyncOpenAI +from transformers import AutoTokenizer + +openai_api_base = "http://localhost:8000/v1" +client = AsyncOpenAI(api_key="EMPTY", base_url=openai_api_base) + + +def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + +async def transcribe_audio(client, tokenizer, y, sr): + status = 200 + try: + # Send loaded audio directly instead of loading from disk, + # dont account for that time though + with to_bytes(y, sr) as f: + start_time = time.perf_counter() + transcription = await client.audio.transcriptions.create( + file=f, + model=tokenizer.name_or_path, + language="en", + temperature=0.0, + ) + end_time = time.perf_counter() + # NOTE there's no streaming in transcriptions, can't measure ttft + except Exception as e: + print(f"Error: {e}") + status = 500 + # Hard check on server working properly + assert status == 200 + latency = end_time - start_time + num_output_tokens = len( + tokenizer(transcription.text, add_special_tokens=False).input_ids) + return latency, num_output_tokens, transcription.text + + +async def bound_transcribe(model_name, sem, client, audio, reference): + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Use semaphore to limit concurrent requests. + async with sem: + result = await transcribe_audio(client, tokenizer, *audio) + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(result[2]) + ref = tokenizer.normalize(reference) + return result[:2] + (out, ref) + + +async def process_dataset(model, data, concurrent_request): + sem = asyncio.Semaphore(concurrent_request) + tasks: List[asyncio.Task] = [] + for sample in data: + audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + task = asyncio.create_task( + bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def load_hf_dataset(dataset_repo: str, + dataset_name: str, + split='validation', + **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, + dataset_name, + split=split, + **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + return dataset + + +def run_evaluation(model: str, + dataset, + n_examples: int = -1, + max_concurrent_reqs: Optional[int] = None, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + + # Warmup + _ = asyncio.run( + process_dataset(model, dataset.select(range(1)), max_concurrent_reqs)) + + start = time.perf_counter() + results = asyncio.run(process_dataset(model, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +if __name__ == "__main__": + args = ArgumentParser() + # alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. + args.add_argument("-m", + "--model-name", + type=str, + help="Name of the ASR model to evaluate.", + default="openai/whisper-large-v3") + args.add_argument("-dr", + "--dataset-repo", + type=str, + help="Path/repo of the hf asr dataset to test on.") + args.add_argument("-dn", + "--dataset-name", + type=str, + help="Name of the hf asr dataset to test on.") + args.add_argument("--n-examples", + type=int, + help="Limit the number of examples to evaluate on.", + default=-1) + args.add_argument( + "--max-concurrent-request", + type=int, + help="Limit the number of requests sent to the server at the same time" + ) + args.add_argument("--expected-wer", + type=float, + help="Expected WER to compare against.") + args.add_argument( + "--extra", + nargs="*", + help="Extra keyword arguments (key=value pairs) to be passed " + "to hf `load_dataset`") + args = args.parse_args() + + extra_kwargs = {} + if args.extra: + for item in args.extra: + key, value = item.split("=", 1) + extra_kwargs[key] = value + + print("Running evaluation with args", vars(args)) + dataset = load_hf_dataset(args.dataset_repo, args.dataset_name, + **extra_kwargs) + + if not args.max_concurrent_request: + # No max concurrency + args.max_concurrent_request = args.n_examples if args.n_examples > 0\ + else len(dataset) + + wer = run_evaluation(args.model_name, dataset, args.n_examples, + args.max_concurrent_request) + if args.expected_wer: + torch.testing.assert_close(wer, + args.expected_wer, + atol=1e-1, + rtol=1e-2) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e26b1bf3818e..271a8053a048 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -339,6 +339,16 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: Transcription API correctness + working_dir: "/vllm-workspace/.buildkite/asr-eval" + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/serving_transcription.py + - vllm/model_executor/models/whisper.py + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh + - label: Encoder Decoder tests # 5min source_file_dependencies: - vllm/ diff --git a/requirements-test.in b/requirements-test.in index 229d743ec802..ecf874ecc50f 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -19,6 +19,7 @@ pqdm ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests +jiwer # required for audio tests timm # required for internvl test torch==2.5.1 torchaudio==2.5.1 From 4ac9f43ccf574856850e4af16e49651a4a0412a5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 16:32:10 +0000 Subject: [PATCH 09/27] clean up MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- tests/entrypoints/openai/test_audio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index f85a1ae45888..fe7299a48e6f 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -17,7 +17,6 @@ ] -# TODO rename to test chat audio or multimodal @pytest.fixture(scope="module") def server(): args = [ From 7ef7a916a0ad503f41c4c07e2534543571cea257 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 17:55:06 +0000 Subject: [PATCH 10/27] rebase leftovers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- vllm/entrypoints/openai/api_server.py | 2 ++ vllm/entrypoints/openai/protocol.py | 5 +++-- vllm/entrypoints/openai/serving_transcription.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fe5c37e70244..f0205cec193a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -62,6 +62,8 @@ ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, UnloadLoraAdapterRequest) from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0427c542b199..af1cb8ad1bf3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,12 +5,13 @@ import re import time from argparse import Namespace -from typing import Any, Dict, List, Literal, Optional, TypeAlias, Union +from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, + TypeAlias, Union) import torch +from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from fastapi import UploadFile from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index fc82b480b7f9..ca1cc147ad4e 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import asyncio import io from typing import Any, AsyncGenerator, Dict, Optional, Union From 49165ad2c8e43e9e730f1af890d402c17ee66e78 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 09:51:10 +0000 Subject: [PATCH 11/27] move correctness to tests to use remoteopenaiserver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .buildkite/asr-eval/run-tests.sh | 16 -- .../test_transcription_api_correctness.py | 208 ------------------ .buildkite/test-pipeline.yaml | 4 +- .../openai/test_transcription_validation.py | 18 +- 4 files changed, 17 insertions(+), 229 deletions(-) delete mode 100644 .buildkite/asr-eval/run-tests.sh delete mode 100644 .buildkite/asr-eval/test_transcription_api_correctness.py diff --git a/.buildkite/asr-eval/run-tests.sh b/.buildkite/asr-eval/run-tests.sh deleted file mode 100644 index 9cbaf9f7c3fd..000000000000 --- a/.buildkite/asr-eval/run-tests.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -set -x - -# Start server -python3 -m vllm.entrypoints.openai.api_server --model openai/whisper-large-v3 $@ & -server_pid=$! - -# Wait for server to start, timeout after 600 seconds -timeout 180 bash -c 'until curl localhost:8000/v1/models; do sleep 4; done' || exit 1 - -# NOTE: Expected WER measured with hf.transformers equivalent model on same dataset. -# Original dataset split is about 23GB in size, hence we use a pre-filtered slice. -python test_transcription_api_correctness.py -m openai/whisper-large-v3 -dr D4nt3/esb-datasets-earnings22-validation-tiny-filtered --expected-wer 12.744980 - -# Wait for graceful exit -kill $server_pid diff --git a/.buildkite/asr-eval/test_transcription_api_correctness.py b/.buildkite/asr-eval/test_transcription_api_correctness.py deleted file mode 100644 index f3fb98527e3d..000000000000 --- a/.buildkite/asr-eval/test_transcription_api_correctness.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Evaluate Transcription API correctness by computing Word Error Rate (WER) -on a given ASR dataset. When provided, it will also compare the WER against -a baseline. -""" -import asyncio -import io -import time -from argparse import ArgumentParser -from statistics import mean, median -from typing import List, Optional - -import librosa -import soundfile -import torch -from datasets import load_dataset -from evaluate import load -from openai import AsyncOpenAI -from transformers import AutoTokenizer - -openai_api_base = "http://localhost:8000/v1" -client = AsyncOpenAI(api_key="EMPTY", base_url=openai_api_base) - - -def to_bytes(y, sr): - buffer = io.BytesIO() - soundfile.write(buffer, y, sr, format="WAV") - buffer.seek(0) - return buffer - - -async def transcribe_audio(client, tokenizer, y, sr): - status = 200 - try: - # Send loaded audio directly instead of loading from disk, - # dont account for that time though - with to_bytes(y, sr) as f: - start_time = time.perf_counter() - transcription = await client.audio.transcriptions.create( - file=f, - model=tokenizer.name_or_path, - language="en", - temperature=0.0, - ) - end_time = time.perf_counter() - # NOTE there's no streaming in transcriptions, can't measure ttft - except Exception as e: - print(f"Error: {e}") - status = 500 - # Hard check on server working properly - assert status == 200 - latency = end_time - start_time - num_output_tokens = len( - tokenizer(transcription.text, add_special_tokens=False).input_ids) - return latency, num_output_tokens, transcription.text - - -async def bound_transcribe(model_name, sem, client, audio, reference): - tokenizer = AutoTokenizer.from_pretrained(model_name) - # Use semaphore to limit concurrent requests. - async with sem: - result = await transcribe_audio(client, tokenizer, *audio) - # Normalize *english* output/reference for evaluation. - out = tokenizer.normalize(result[2]) - ref = tokenizer.normalize(reference) - return result[:2] + (out, ref) - - -async def process_dataset(model, data, concurrent_request): - sem = asyncio.Semaphore(concurrent_request) - tasks: List[asyncio.Task] = [] - for sample in data: - audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] - task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) - tasks.append(task) - return await asyncio.gather(*tasks) - - -def print_performance_metrics(results, total_time): - latencies = [res[0] for res in results] - total_tokens = sum([res[1] for res in results]) - - total = len(results) - print(f"Total Requests: {total}") - print(f"Successful Requests: {len(latencies)}") - print(f"Average Latency: {mean(latencies):.4f} seconds") - print(f"Median Latency: {median(latencies):.4f} seconds") - perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] - print(f"95th Percentile Latency: {perc:.4f} seconds") - # Throughput - req_throughput = len(latencies) / total_time - print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") - throughput = total_tokens / total_time - print(f"Estimated Throughput: {throughput:.2f} tok/s") - - -def add_duration(sample): - y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] - sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 - return sample - - -def load_hf_dataset(dataset_repo: str, - dataset_name: str, - split='validation', - **hf_kwargs): - ## Load and filter the dataset - dataset = load_dataset(dataset_repo, - dataset_name, - split=split, - **hf_kwargs) - if 'duration_ms' not in dataset[0]: - # compute duration to filter - dataset = dataset.map(add_duration) - - # Whisper max supported duration - dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) - return dataset - - -def run_evaluation(model: str, - dataset, - n_examples: int = -1, - max_concurrent_reqs: Optional[int] = None, - print_metrics: bool = True): - if n_examples > 0: - dataset = dataset.select(range(n_examples)) - - # Warmup - _ = asyncio.run( - process_dataset(model, dataset.select(range(1)), max_concurrent_reqs)) - - start = time.perf_counter() - results = asyncio.run(process_dataset(model, dataset, max_concurrent_reqs)) - end = time.perf_counter() - total_time = end - start - print(f"Total Test Time: {total_time:.4f} seconds") - if print_metrics: - print_performance_metrics(results, total_time) - # Compute WER - predictions = [res[2] for res in results] - references = [res[3] for res in results] - wer = load("wer") - wer_score = 100 * wer.compute(references=references, - predictions=predictions) - print("WER:", wer_score) - return wer_score - - -if __name__ == "__main__": - args = ArgumentParser() - # alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. - args.add_argument("-m", - "--model-name", - type=str, - help="Name of the ASR model to evaluate.", - default="openai/whisper-large-v3") - args.add_argument("-dr", - "--dataset-repo", - type=str, - help="Path/repo of the hf asr dataset to test on.") - args.add_argument("-dn", - "--dataset-name", - type=str, - help="Name of the hf asr dataset to test on.") - args.add_argument("--n-examples", - type=int, - help="Limit the number of examples to evaluate on.", - default=-1) - args.add_argument( - "--max-concurrent-request", - type=int, - help="Limit the number of requests sent to the server at the same time" - ) - args.add_argument("--expected-wer", - type=float, - help="Expected WER to compare against.") - args.add_argument( - "--extra", - nargs="*", - help="Extra keyword arguments (key=value pairs) to be passed " - "to hf `load_dataset`") - args = args.parse_args() - - extra_kwargs = {} - if args.extra: - for item in args.extra: - key, value = item.split("=", 1) - extra_kwargs[key] = value - - print("Running evaluation with args", vars(args)) - dataset = load_hf_dataset(args.dataset_repo, args.dataset_name, - **extra_kwargs) - - if not args.max_concurrent_request: - # No max concurrency - args.max_concurrent_request = args.n_examples if args.n_examples > 0\ - else len(dataset) - - wer = run_evaluation(args.model_name, dataset, args.n_examples, - args.max_concurrent_request) - if args.expected_wer: - torch.testing.assert_close(wer, - args.expected_wer, - atol=1e-1, - rtol=1e-2) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 271a8053a048..2044b8523a81 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -340,14 +340,12 @@ steps: - bash ./run-tests.sh -c configs/models-small.txt -t 1 - label: Transcription API correctness - working_dir: "/vllm-workspace/.buildkite/asr-eval" source_file_dependencies: - csrc/ - vllm/entrypoints/openai/serving_transcription.py - vllm/model_executor/models/whisper.py commands: - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh + - pytest -s entrypoints/openai/test_transcription_api_correctness.py - label: Encoder Decoder tests # 5min source_file_dependencies: diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 98a1a8d74940..1cf72be7ff41 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -17,14 +17,14 @@ @pytest.fixture def mary_had_lamb(): - path = AudioAsset('mary_had_lamb').get_asset_path() + path = AudioAsset('mary_had_lamb').get_local_path() with open(str(path), "rb") as f: yield f @pytest.fixture def winning_call(): - path = AudioAsset('winning_call').get_asset_path() + path = AudioAsset('winning_call').get_local_path() with open(str(path), "rb") as f: yield f @@ -84,3 +84,17 @@ async def test_bad_requests(mary_had_lamb): file=buffer, language="en", temperature=0.0) + + +@pytest.mark.asyncio +async def test_non_asr_model(winning_call): + # text to text model + model_name = "JackFram/llama-68m" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + # with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=winning_call, + language="hh", + temperature=0.0) From 1ecb9b4d842d4c7ab5d995541b904e3b2ecd2002 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 10:04:39 +0000 Subject: [PATCH 12/27] optional librosa import; get_local_path; transcription endpoint only for whisper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .../openai_transcription_client.py | 4 +- .../test_transcription_api_correctness.py | 164 ++++++++++++++++++ .../openai/test_transcription_validation.py | 11 +- vllm/assets/audio.py | 2 +- vllm/entrypoints/openai/api_server.py | 3 +- .../openai/serving_transcription.py | 7 +- 6 files changed, 181 insertions(+), 10 deletions(-) create mode 100644 tests/entrypoints/openai/test_transcription_api_correctness.py diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 3e50b5d52575..bd3c02a8a95e 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -3,8 +3,8 @@ from vllm.assets.audio import AudioAsset -mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path() -winning_call = AudioAsset('winning_call').get_asset_path() +mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() +winning_call = AudioAsset('winning_call').get_local_path() # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" diff --git a/tests/entrypoints/openai/test_transcription_api_correctness.py b/tests/entrypoints/openai/test_transcription_api_correctness.py new file mode 100644 index 000000000000..a895ad0fd42f --- /dev/null +++ b/tests/entrypoints/openai/test_transcription_api_correctness.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +""" +import asyncio +import io +import time +from statistics import mean, median +from typing import List + +import librosa +import pytest +import soundfile +import torch +from datasets import load_dataset +from evaluate import load +from transformers import AutoTokenizer + +from ...utils import RemoteOpenAIServer + + +def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + +async def transcribe_audio(client, tokenizer, y, sr): + # Send loaded audio directly instead of loading from disk, + # dont account for that time though + with to_bytes(y, sr) as f: + start_time = time.perf_counter() + transcription = await client.audio.transcriptions.create( + file=f, + model=tokenizer.name_or_path, + language="en", + temperature=0.0, + ) + end_time = time.perf_counter() + # NOTE there's no streaming in transcriptions, can't measure ttft + latency = end_time - start_time + num_output_tokens = len( + tokenizer(transcription.text, add_special_tokens=False).input_ids) + return latency, num_output_tokens, transcription.text + + +async def bound_transcribe(model_name, sem, client, audio, reference): + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Use semaphore to limit concurrent requests. + async with sem: + result = await transcribe_audio(client, tokenizer, *audio) + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(result[2]) + ref = tokenizer.normalize(reference) + return result[:2] + (out, ref) + + +async def process_dataset(model, client, data, concurrent_request): + sem = asyncio.Semaphore(concurrent_request) + + # Warmup call as the first `librosa.load` server-side is quite slow. + audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] + _ = await bound_transcribe(model, sem, client, (audio, sr), "") + + tasks: List[asyncio.Task] = [] + for sample in data: + audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + task = asyncio.create_task( + bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + return dataset + + +def run_evaluation(model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + start = time.perf_counter() + results = asyncio.run( + process_dataset(model, client, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. +@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) +# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. +@pytest.mark.parametrize( + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) +# NOTE: Expected WER measured with equivalent hf.transformers args: +# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. +@pytest.mark.parametrize("expected_wer", [12.744980]) +def test_wer_correctness(model_name, + dataset_repo, + expected_wer, + n_examples=-1, + max_concurrent_request=None): + with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + dataset = load_hf_dataset(dataset_repo) + + if not max_concurrent_request: + # No max concurrency + max_concurrent_request = n_examples if n_examples > 0\ + else len(dataset) + + client = remote_server.get_async_client() + wer = run_evaluation(model_name, client, dataset, + max_concurrent_request, n_examples) + if expected_wer: + torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 1cf72be7ff41..aac53c4ef0ce 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -93,8 +93,9 @@ async def test_non_asr_model(winning_call): server_args = ["--enforce-eager"] with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - # with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=model_name, - file=winning_call, - language="hh", - temperature=0.0) + res = await client.audio.transcriptions.create(model=model_name, + file=winning_call, + language="en", + temperature=0.0) + assert res.code == 400 and res.text == "" + assert res.message == "The model does not support Transcriptions API" diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index c465527b4e89..0203dc092a71 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -29,7 +29,7 @@ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) - def get_asset_path(self) -> Path: + def get_local_path(self) -> Path: return get_vllm_public_assets(filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f0205cec193a..c381b23c18da 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -860,7 +860,8 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, - ) + ) if ("WhisperForConditionalGeneration" + in model_config.hf_config.architectures) else None state.task = model_config.task diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index ca1cc147ad4e..f27b901c4814 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -3,7 +3,6 @@ import io from typing import Any, AsyncGenerator, Dict, Optional, Union -import librosa from fastapi import Request from vllm.config import ModelConfig @@ -17,6 +16,12 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] logger = init_logger(__name__) From d33647faf2fa8601cd2db946aad4c2637e1fbed5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 11:37:21 +0000 Subject: [PATCH 13/27] fix docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- docs/source/serving/openai_compatible_server.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 68f3a177d211..64439475fdb5 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -298,7 +298,7 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s :end-before: end-chat-embedding-extra-params ::: -(tokenizer-api)= +(transcriptions-api)= ### Transcriptions API @@ -309,6 +309,8 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai Code example: +(tokenizer-api)= + ### Tokenizer API Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). From 5e27adbd18b2e32bf5569e4625e9951301e884ac Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 12:35:53 +0000 Subject: [PATCH 14/27] fix multipart import issue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- vllm/entrypoints/openai/api_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c381b23c18da..c98256381c96 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -525,13 +525,12 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()], raw_request: Request): - audio_data = await request.file.read() - handler = transcription(raw_request) if handler is None: return base(raw_request).create_error_response( message="The model does not support Transcriptions API") + audio_data = await request.file.read() generator = await handler.create_transcription(audio_data, request, raw_request) From e418f26048f19fb5db7c1bc80537fe50ad12ec42 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 12:49:40 +0000 Subject: [PATCH 15/27] group openai api correctness tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NickLucche Signed-off-by: Daniele Trifirò --- .buildkite/test-pipeline.yaml | 10 +++++----- .../openai/{ => correctness}/test_accuracy.py | 0 .../test_transcription_api_correctness.py | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) rename tests/entrypoints/openai/{ => correctness}/test_accuracy.py (100%) rename tests/entrypoints/openai/{ => correctness}/test_transcription_api_correctness.py (98%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2044b8523a81..8c12b01e7a5f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -117,7 +117,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -339,13 +339,13 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 -- label: Transcription API correctness +- label: OpenAI API correctness source_file_dependencies: - csrc/ - - vllm/entrypoints/openai/serving_transcription.py + - vllm/entrypoints/openai/ - vllm/model_executor/models/whisper.py - commands: - - pytest -s entrypoints/openai/test_transcription_api_correctness.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ - label: Encoder Decoder tests # 5min source_file_dependencies: diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/correctness/test_accuracy.py similarity index 100% rename from tests/entrypoints/openai/test_accuracy.py rename to tests/entrypoints/openai/correctness/test_accuracy.py diff --git a/tests/entrypoints/openai/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py similarity index 98% rename from tests/entrypoints/openai/test_transcription_api_correctness.py rename to tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index a895ad0fd42f..628b47e66d67 100644 --- a/tests/entrypoints/openai/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -3,6 +3,8 @@ Evaluate Transcription API correctness by computing Word Error Rate (WER) on a given ASR dataset. When provided, it will also compare the WER against a baseline. +This simulates real work usage of the API and makes sure that the frontend and +AsyncLLMEngine are working correctly. """ import asyncio import io From 133e783c5ecc17e9b81eca6dd4093a2f4081e7f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Mon, 10 Feb 2025 16:15:00 +0100 Subject: [PATCH 16/27] deps: use fastapi[standard], remove redudant dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- requirements-common.txt | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index cfa02025629f..0b7253cc121d 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,12 +8,11 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' +fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) -uvicorn[standard] -pydantic >= 2.9 # Required for fastapi >= 0.113.0 +pydantic >= 2.9 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 From 7ef6932837550f05cb64219ed73e0617d4cbabcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Mon, 10 Feb 2025 16:39:27 +0100 Subject: [PATCH 17/27] audio transcription: use BytesIO context manager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- vllm/entrypoints/openai/serving_transcription.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index f27b901c4814..25307e841ab1 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -192,7 +192,8 @@ async def _preprocess_transcription( if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: raise ValueError("Maximum file size exceeded.") - y, sr = librosa.load(io.BytesIO(audio_data)) + with io.BytesIO(audio_data) as bytes_: + y, sr = librosa.load(bytes_) if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: raise ValueError( f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " From 70f83aa9e78fe69a54e4a25bd172b7ec770764c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Mon, 10 Feb 2025 16:54:15 +0100 Subject: [PATCH 18/27] serving transcription: fix type hints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- vllm/entrypoints/openai/serving_transcription.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 25307e841ab1..503122861d97 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -265,7 +265,7 @@ async def create_transcription( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - result_generator: AsyncGenerator[RequestOutput, None] = None + result_generator: AsyncGenerator[RequestOutput, None] | None = None try: # TODO(rob): subtract len of tokenized prompt. default_max_tokens = self.model_config.max_model_len From 3e6307d62d9c7f8db41c95f6ab1f019aa7b5b33e Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 17:09:07 +0000 Subject: [PATCH 19/27] make mypy happy Signed-off-by: NickLucche --- vllm/entrypoints/openai/serving_engine.py | 6 ++-- .../openai/serving_transcription.py | 30 +++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9efb5e6fa398..785117ca1d45 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -31,7 +31,8 @@ ErrorResponse, RerankRequest, ScoreRequest, TokenizeChatRequest, - TokenizeCompletionRequest) + TokenizeCompletionRequest, + TranscriptionRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -57,7 +58,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, + TranscriptionRequest] class TextTokensPrompt(TypedDict): diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 503122861d97..8e36070bb8c4 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,19 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import io -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import AsyncGenerator, Optional, Union, cast from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (RequestResponseMetadata, +from vllm.entrypoints.openai.protocol import (ErrorResponse, + RequestResponseMetadata, TranscriptionRequest, TranscriptionResponse, TranscriptionResponseVerbose) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.utils import PlaceholderModule @@ -168,7 +170,7 @@ async def _preprocess_transcription( self, request: TranscriptionRequest, audio_data: bytes, - ) -> Dict[Any, Any]: + ) -> PromptType: # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See @@ -199,7 +201,7 @@ async def _preprocess_transcription( f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " "exceeded.") - return { + prompt = { "encoder_prompt": { "prompt": "", "multi_modal_data": { @@ -209,18 +211,19 @@ async def _preprocess_transcription( "decoder_prompt": f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" } + return cast(PromptType, prompt) # TODO (varun) : Make verbose response work ! async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request - ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose]: + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose, + ErrorResponse]: """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription for the API specification. This API mimics the OpenAI transcription API. """ - error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -273,11 +276,12 @@ async def create_transcription( sampling_params = request.to_sampling_params( default_max_tokens, default_params) - self._log_inputs(request_id, - prompt['decoder_prompt'], - params=sampling_params, - lora_request=None, - prompt_adapter_request=None) + self._log_inputs( + request_id, + prompt['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) result_generator = self.engine_client.generate( prompt, @@ -289,10 +293,6 @@ async def create_transcription( return self.create_error_response(str(e)) # TODO(rob): figure out a way to pipe streaming in. - stream = False - if stream: - return None - # Non-streaming response. try: async for op in result_generator: From 4290464f7b7735d9c5c58460413a394cb709f9cb Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 17:31:03 +0000 Subject: [PATCH 20/27] make mypy happy Signed-off-by: NickLucche --- vllm/entrypoints/openai/protocol.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index af1cb8ad1bf3..2bcfdc235776 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,14 +5,13 @@ import re import time from argparse import Namespace -from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, - TypeAlias, Union) +from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union import torch from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger From 9ceb623312a8d05fee16df8ac22972b54d129483 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 10 Feb 2025 17:44:57 +0000 Subject: [PATCH 21/27] make mypy happy Signed-off-by: NickLucche --- .../openai/correctness/{test_accuracy.py => test_lmeval.py} | 0 vllm/entrypoints/openai/serving_transcription.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/entrypoints/openai/correctness/{test_accuracy.py => test_lmeval.py} (100%) diff --git a/tests/entrypoints/openai/correctness/test_accuracy.py b/tests/entrypoints/openai/correctness/test_lmeval.py similarity index 100% rename from tests/entrypoints/openai/correctness/test_accuracy.py rename to tests/entrypoints/openai/correctness/test_lmeval.py diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 8e36070bb8c4..da4930e0e2d8 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -268,7 +268,7 @@ async def create_transcription( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - result_generator: AsyncGenerator[RequestOutput, None] | None = None + result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None try: # TODO(rob): subtract len of tokenized prompt. default_max_tokens = self.model_config.max_model_len From 3fea4d6321e37ae6ecbeb90b424cf3e521adaae9 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 12 Feb 2025 08:40:25 +0000 Subject: [PATCH 22/27] fix test import Signed-off-by: NickLucche --- tests/entrypoints/openai/correctness/__init__.py | 0 tests/entrypoints/openai/correctness/test_lmeval.py | 2 +- .../openai/correctness/test_transcription_api_correctness.py | 2 +- tests/entrypoints/openai/test_transcription_validation.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 tests/entrypoints/openai/correctness/__init__.py diff --git a/tests/entrypoints/openai/correctness/__init__.py b/tests/entrypoints/openai/correctness/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index df25780cd0f4..ebb2ea4d9d14 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -13,7 +13,7 @@ from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer +from ....utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" NUM_CONCURRENT = 500 diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 628b47e66d67..19d4735b9dde 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -20,7 +20,7 @@ from evaluate import load from transformers import AutoTokenizer -from ...utils import RemoteOpenAIServer +from ....utils import RemoteOpenAIServer def to_bytes(y, sr): diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index aac53c4ef0ce..723677ca5c7f 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -97,5 +97,5 @@ async def test_non_asr_model(winning_call): file=winning_call, language="en", temperature=0.0) - assert res.code == 400 and res.text == "" + assert res.code == 400 and not res.text assert res.message == "The model does not support Transcriptions API" From 19c6ccbd4193992255a2b1cc414b074e494472e6 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 12 Feb 2025 10:27:54 +0000 Subject: [PATCH 23/27] requirements-test update Signed-off-by: NickLucche --- requirements-test.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/requirements-test.txt b/requirements-test.txt index e032aac710dd..648a2626c857 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -66,6 +66,7 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black + # jiwer # nltk # ray colorama==0.4.6 @@ -187,6 +188,8 @@ jinja2==3.1.4 # via # datamodel-code-generator # torch +jiwer==3.0.5 + # via -r requirements-test.in jmespath==1.0.1 # via # boto3 @@ -470,6 +473,8 @@ pyyaml==6.0.2 # timm # transformers # vocos +rapidfuzz==3.12.1 + # via jiwer ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 From cad9b305d23223a99e72d3c38db3be75c0d1b13b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 12 Feb 2025 11:51:41 +0000 Subject: [PATCH 24/27] v1 test update Signed-off-by: NickLucche --- .buildkite/test-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8c12b01e7a5f..9991060a3162 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -205,7 +205,7 @@ steps: - VLLM_USE_V1=1 pytest -v -s v1/e2e # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api - - pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine + - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min working_dir: "/vllm-workspace/examples" From b2323cc8e92578e7825c05ca596230a1cb25d3b8 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 12 Feb 2025 18:51:21 +0000 Subject: [PATCH 25/27] transcription endpoint disables the others: transcription task Signed-off-by: NickLucche --- .../openai/test_transcription_validation.py | 21 +++++++++++++++++++ tests/test_config.py | 1 + vllm/config.py | 11 +++++++--- vllm/entrypoints/openai/api_server.py | 3 +-- 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 723677ca5c7f..5d4a5de4badd 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -99,3 +99,24 @@ async def test_non_asr_model(winning_call): temperature=0.0) assert res.code == 400 and not res.text assert res.message == "The model does not support Transcriptions API" + + +@pytest.mark.asyncio +async def test_completion_endpoints(): + # text to text model + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }]) + assert res.code == 400 + assert res.message == "The model does not support Chat Completions API" + + res = await client.completions.create(model=model_name, prompt="Hello") + assert res.code == 400 + assert res.message == "The model does not support Completions API" diff --git a/tests/test_config.py b/tests/test_config.py index 2dfae218b47d..3fb83b4c0328 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,6 +17,7 @@ ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), + ("openai/whisper-small", "transcription", "transcription"), ], ) def test_auto_task(model_id, expected_runner_type, expected_task): diff --git a/vllm/config.py b/vllm/config.py index 1740871e7c10..211ba7f7c38c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,17 +54,18 @@ _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward"] + "score", "reward", "transcription"] _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", - "draft"] + "draft", "transcription"] -RunnerType = Literal["generate", "pooling", "draft"] +RunnerType = Literal["generate", "pooling", "draft", "transcription"] _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { "generate": ["generate"], "pooling": ["embed", "classify", "score", "reward"], "draft": ["draft"], + "transcription": ["transcription"], } _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { @@ -484,6 +485,8 @@ def _get_preferred_task( return "embed" if ModelRegistry.is_cross_encoder_model(architectures): return "score" + if "WhisperForConditionalGeneration" in architectures: + return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ # Other models follow this pattern @@ -516,6 +519,8 @@ def _resolve_task( runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them + "transcription": "WhisperForConditionalGeneration" in \ + hf_config.architectures, "generate": ModelRegistry.is_text_generation_model(architectures), "pooling": ModelRegistry.is_pooling_model(architectures), } diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c98256381c96..781cff350529 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -859,8 +859,7 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, - ) if ("WhisperForConditionalGeneration" - in model_config.hf_config.architectures) else None + ) if model_config.runner_type == "transcription" else None state.task = model_config.task From 7ffa0e0de7c48df5ccaaffad9439bd25864c7a87 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 13 Feb 2025 09:19:21 +0000 Subject: [PATCH 26/27] SupportsTranscription interface Signed-off-by: NickLucche --- vllm/config.py | 6 +++--- vllm/model_executor/models/interfaces.py | 27 ++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 12 +++++++++-- vllm/model_executor/models/whisper.py | 5 +++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 211ba7f7c38c..10004b8f6291 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -485,7 +485,7 @@ def _get_preferred_task( return "embed" if ModelRegistry.is_cross_encoder_model(architectures): return "score" - if "WhisperForConditionalGeneration" in architectures: + if ModelRegistry.is_transcription_model(architectures): return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ @@ -519,8 +519,8 @@ def _resolve_task( runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them - "transcription": "WhisperForConditionalGeneration" in \ - hf_config.architectures, + "transcription": + ModelRegistry.is_transcription_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures), "pooling": ModelRegistry.is_pooling_model(architectures), } diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0fc5c4db179c..a0a1b69ad502 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -441,3 +441,30 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + +@overload +def supports_transcription( + model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 198b6d134718..92797730256b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,7 +22,7 @@ from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, - supports_pp) + supports_pp, supports_transcription) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -225,6 +225,7 @@ class _ModelInfo: has_inner_state: bool is_attention_free: bool is_hybrid: bool + supports_transcription: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -238,7 +239,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), - ) + supports_transcription=supports_transcription(model)) class _BaseRegisteredModel(ABC): @@ -486,6 +487,13 @@ def is_hybrid_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_hybrid + def is_transcription_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription + ModelRegistry = _ModelRegistry({ model_arch: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0a3011d36101..0b506072094e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,7 @@ from vllm.sequence import SequenceData from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers logger = init_logger(__name__) @@ -637,7 +637,8 @@ def input_mapper_for_whisper( @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_whisper_audio_tokens) -class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): +class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", From 74027ce91cca6b772134867e0fa277b9a1e9805a Mon Sep 17 00:00:00 2001 From: Vaibhav Jain Date: Thu, 13 Feb 2025 20:22:22 +0530 Subject: [PATCH 27/27] resolve conflict with V Signed-off-by: NickLucche --- tests/entrypoints/openai/test_basic.py | 16 ++++++++++ vllm/entrypoints/openai/api_server.py | 42 +++++++++++++++++--------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 0d44a7611aed..a970981b7562 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer): max_tokens=10) assert len(response.choices) == 1 + + +@pytest.mark.asyncio +async def test_request_wrong_content_type(server: RemoteOpenAIServer): + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client() + + with pytest.raises(openai.APIStatusError): + await client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_headers={ + "Content-Type": "application/x-www-form-urlencoded" + }) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 781cff350529..167cc46893dd 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -20,7 +20,7 @@ from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, Form, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -257,6 +257,15 @@ def _cleanup_ipc_path(): multiprocess.mark_process_dead(engine_process.pid) +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + detail="Unsupported Media Type: Only 'application/json' is allowed" + ) + + router = APIRouter() @@ -344,7 +353,7 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize") +@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -359,7 +368,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): assert_never(generator) -@router.post("/detokenize") +@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -388,7 +397,8 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions", + dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): @@ -409,7 +419,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions") +@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) @@ -427,7 +437,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings") +@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) @@ -473,7 +483,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) -@router.post("/pooling") +@router.post("/pooling", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) @@ -491,7 +501,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/score") +@router.post("/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) @@ -509,7 +519,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/score") +@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( @@ -544,7 +554,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/rerank") +@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) @@ -561,7 +571,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/rerank") +@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( @@ -572,7 +582,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -@router.post("/v2/rerank") +@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) @@ -616,7 +626,7 @@ async def reset_prefix_cache(raw_request: Request): return Response(status_code=200) -@router.post("/invocations") +@router.post("/invocations", dependencies=[Depends(validate_json_request)]) async def invocations(raw_request: Request): """ For SageMaker, routes requests to other handlers based on model `task`. @@ -666,7 +676,8 @@ async def stop_profile(raw_request: Request): "Lora dynamic loading & unloading is enabled in the API server. " "This should ONLY be used for local development!") - @router.post("/v1/load_lora_adapter") + @router.post("/v1/load_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): handler = models(raw_request) @@ -677,7 +688,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest, return Response(status_code=200, content=response) - @router.post("/v1/unload_lora_adapter") + @router.post("/v1/unload_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): handler = models(raw_request)