Skip to content
8 changes: 2 additions & 6 deletions vllm/entrypoints/openai/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,7 @@ async def _preprocess_speech_to_text(
audio_data: bytes,
) -> tuple[list[PromptType], float]:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang = request.language or "en"
self.model_cls.validate_language(lang)
language = self.model_cls.validate_language(request.language)

if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
Expand All @@ -112,7 +108,7 @@ async def _preprocess_speech_to_text(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=lang,
language=language,
task_type=self.task_type,
request_prompt=request.prompt)
prompts.append(prompt)
Expand Down
53 changes: 47 additions & 6 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable, MutableSequence
from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)

import numpy as np
import torch
from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs

from vllm.config import ModelConfig, SpeechToTextConfig
Expand Down Expand Up @@ -685,6 +686,8 @@ def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""
# Mapping from ISO639_1 language codes: language names
supported_languages: ClassVar[Mapping[str, str]]

supports_transcription: ClassVar[Literal[True]] = True

Expand All @@ -694,21 +697,59 @@ class SupportsTranscription(Protocol):
`True`.
"""

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# language codes in supported_languages
# that don't exist in the full language map
invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
if invalid:
raise ValueError(
f"{cls.__name__}.supported_languages contains invalid "
f"language codes: {sorted(invalid)}\n. "
f"Valid choices are: {sorted(LANGUAGES.keys())}")

@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig, language: str,
task_type: str,
model_config: ModelConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
returns a valid PromptType."""
...

@classmethod
def validate_language(cls, language: str) -> bool:
"""Check if the model supports a specific ISO639_1 language."""
...
def get_other_languages(cls) -> Mapping[str, str]:
# other possible language codes from the whisper map
return {
k: v
for k, v in LANGUAGES.items() if k not in cls.supported_languages
}

@classmethod
def validate_language(cls, language: Optional[str]) -> Optional[str]:
"""
Ensure the language specified in the transcription request
is a valid ISO 639-1 language code. If the request language is
valid, but not natively supported by the model, trigger a
warning (but not an exception).
"""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
logger.warning(
"Language %r is not natively supported by %s; "
"results may be less accurate. Supported languages: %r",
language,
cls.__name__,
list(cls.supported_languages.keys()),
)
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. Must be one of "
f"{list(cls.supported_languages.keys())}.")

@classmethod
def get_speech_to_text_config(
Expand Down
25 changes: 16 additions & 9 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
# yapf: disable
from vllm.model_executor.models.whisper import (
WhisperEncoder, WhisperForConditionalGeneration)
from vllm.model_executor.models.whisper import WhisperEncoder
# yapf: enable
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -50,6 +49,18 @@

logger = init_logger(__name__)

ISO639_1_SUPPORTED_LANGS = {
"ar": "Arabic",
"nl": "Dutch",
"en": "English",
"fr": "French",
"de": "German",
"hi": "Hindi",
"it": "Italian",
"pt": "Portuguese",
"es": "Spanish",
}


class VoxtralProcessorAdapter:
"""
Expand Down Expand Up @@ -301,6 +312,7 @@ def _get_data_parser(self) -> MultiModalDataParser:
dummy_inputs=VoxtralDummyInputsBuilder)
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -441,8 +453,8 @@ def get_speech_to_text_config(cls, model_config: ModelConfig,
# for speech-to-text transcription
def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
stt_config: SpeechToTextConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate),
Expand All @@ -457,11 +469,6 @@ def get_generation_prompt(cls, audio: np.ndarray,
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)

@classmethod
def validate_language(cls, language: str) -> bool:
# same as whisper
return WhisperForConditionalGeneration.validate_language(language)

@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,
Expand Down
74 changes: 15 additions & 59 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,6 @@
"vi": "Vietnamese",
"cy": "Welsh"
}
ISO639_1_OTHER_LANGS = {
"lo": "Lao",
"jw": "Javanese",
"tk": "Turkmen",
"yi": "Yiddish",
"so": "Somali",
"bn": "Bengali",
"nn": "Norwegian Nynorsk",
"si": "Sinhala",
"yo": "Yoruba",
"sa": "Sanskrit",
"mi": "Māori",
"fo": "Faroese", # codespell:ignore
"mt": "Maltese",
"tg": "Tajik",
"mg": "Malagasy",
"haw": "Hawaiian",
"km": "Khmer",
"br": "Breton",
"ps": "Pashto",
"ln": "Lingala",
"la": "Latin",
"ml": "Malayalam",
"sq": "Albanian",
"su": "Sundanese",
"eu": "Basque",
"ka": "Georgian",
"uz": "Uzbek",
"sn": "Shona",
"ht": "Haitian",
"as": "Assamese",
"mn": "Mongolian",
"te": "Telugu",
"pa": "Panjabi",
"tt": "Tatar",
"gu": "Gujarati",
"oc": "Occitan",
"ha": "Hausa",
"ba": "Bashkir",
"my": "Burmese",
"sd": "Sindhi",
"am": "Amharic",
"lb": "Luxembourgish",
"bo": "Tibetan"
}


class WhisperAudioInputs(TypedDict):
Expand Down Expand Up @@ -807,32 +762,33 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,

# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
supported_languages = ISO639_1_SUPPORTED_LANGS

@classmethod
def validate_language(cls, language: str) -> bool:
if language in ISO639_1_SUPPORTED_LANGS:
return True
elif language in ISO639_1_OTHER_LANGS:
def validate_language(cls, language: Optional[str]) -> Optional[str]:
if language is None:
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking perhaps this should be a logger.warning_once case.

Copy link
Contributor Author

@sanchit-gandhi sanchit-gandhi Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty big assumption that's being made in creating the token ids, e.g. it completely nullifies the model's ability to do multilingual transcription.

Until the TODO is resolved, the best practice when using Whisper is to always specify the language. Otherwise, you end up with undefined behaviour (e.g. audio in Spanish, task set to "transcribe", lang token set to "en" → mis-match with how the model was trained!). This would happen silently if you passed a mix of transcription requests, some with and some without the language field.

Ideally, we would enforce this best practice by throwing an exception if the language is not specified. However, since that's not backwards compatible, a persistent warning is the best we can do. So I'd be in favour of keeping it as logger.warning!

"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice.", language)
return True
else:
raise ValueError(f"Unsupported language: {language}."
"Language should be one of:" +
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
f"or {list(ISO639_1_OTHER_LANGS.values())}")
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
"in the TranscriptionRequest.")
language = "en"
return super().validate_language(language)

@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str,
language: Optional[str],
task_type: str,
request_prompt: str) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt")
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
Expand Down