Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper support #5964

Closed
wants to merge 16 commits into from
38 changes: 38 additions & 0 deletions examples/whisper_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import vllm
import torch
import requests
from vllm import LLM
from datasets import Audio


def main():
sr = 16000
audio = Audio(sampling_rate=sr)
llm = LLM(
model="openai/whisper-large-v3",
max_num_seqs = 1,
max_model_len = 448,
gpu_memory_utilization = 0.4,
dtype = 'bfloat16',
)

r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/7021-79759-0004.wav')
y = audio.decode_example(audio.encode_example(r.content))['array']

output_lang = llm.generate({
"prompt_token_ids": [50258],
"whisper_data": y,
}, sampling_params = SamplingParams(max_tokens = 1, temperature = 0))

outputs = llm.generate({
"prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360],
"whisper_data": y,
}, sampling_params = SamplingParams(max_tokens = 100, temperature = 0))

# ' without going to any such extreme as this we can easily see on reflection how vast an influence on the'
print(outputs[0].outputs[0].text)



if __name__ == "__main__":
main()
40 changes: 40 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

_GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_WHISPER_MAX_NUM_BATCHED_TOKENS = 448


class ModelConfig:
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
self._verify_whisper_mode()
self._verify_quantization()
self._verify_cuda_graph()

Expand All @@ -165,6 +167,11 @@ def _verify_embedding_mode(self) -> None:
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)

def _verify_whisper_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.whisper_mode = any(
ModelRegistry.is_whisper_model(arch) for arch in architectures)

def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
Expand Down Expand Up @@ -682,6 +689,7 @@ class SchedulerConfig:
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
whisper_mode: Whether the running model is for whisper.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
Expand All @@ -699,6 +707,7 @@ def __init__(self,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
whisper_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -711,6 +720,9 @@ def __init__(self,
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
elif whisper_mode:
self.max_num_batched_tokens = max(
max_model_len, _WHISPER_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
Expand All @@ -725,6 +737,7 @@ def __init__(self,
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.whisper_mode = whisper_mode
self.preemption_mode = preemption_mode

self._verify_args()
Expand Down Expand Up @@ -1218,6 +1231,30 @@ def as_cli_args_dict(self) -> Dict[str, Any]:

return result

@dataclass
class WhisperConfig:
whisper_input_type: Optional[str] = 'input_features'
whisper_processor: Optional[str] = 'openai/whisper-large-v3'
whisper_processor_revision: Optional[str] = 'openai/whisper-large-v3'
sample_rate: Optional[int] = 16000

def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.

Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value

return result


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
Expand Down Expand Up @@ -1299,6 +1336,8 @@ def _get_and_verify_max_len(
"max_sequence_length",
"max_seq_length",
"seq_len",
# Whisper
"max_length",
]
# Choose the smallest "max_length" from the possible keys.
max_len_key = None
Expand Down Expand Up @@ -1435,6 +1474,7 @@ class EngineConfig:
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
whisper_config: Optional[WhisperConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
whisper_data=seq_group.whisper_data
)
seq_group_metadata_list.append(seq_group_metadata)

Expand Down
55 changes: 54 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
VisionLanguageConfig, WhisperConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple

Expand Down Expand Up @@ -88,6 +88,12 @@ class EngineArgs:
image_processor_revision: Optional[str] = None
disable_image_processor: bool = False

# Related to Whisper
whisper_input_type: Optional[str] = None
whisper_processor: Optional[str] = None
whisper_processor_revision: Optional[str] = None
sample_rate: Optional[int] = 16000

scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False

Expand Down Expand Up @@ -156,6 +162,38 @@ def add_cli_args_for_vlm(

return parser

@staticmethod
def add_cli_args_for_whisper(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--whisper-input-type',
type=nullable_str,
default=EngineArgs.whisper_input_type,
choices=[
'input_features'
],
help=('The audio input type for whisper passed into vLLM.'))
parser.add_argument(
'--whisper-processor',
type=str,
default=EngineArgs.whisper_processor,
help='Name or path of the huggingface whisper processor to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--whisper-processor-revision',
type=str,
default=None,
help='Revision of the huggingface whisper processor version to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--sample-rate',
type=int,
default=EngineArgs.sample_rate,
help='sample rate for whisper processor')

return parser

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
Expand Down Expand Up @@ -513,6 +551,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

# Related to Vision-language models such as llava
parser = EngineArgs.add_cli_args_for_vlm(parser)
parser = EngineArgs.add_cli_args_for_whisper(parser)

parser.add_argument(
'--scheduler-delay-factor',
Expand Down Expand Up @@ -717,6 +756,7 @@ def create_engine_config(self, ) -> EngineConfig:
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
whisper_mode=model_config.whisper_mode,
preemption_mode=self.preemption_mode,
)
lora_config = LoRAConfig(
Expand Down Expand Up @@ -772,6 +812,18 @@ def create_engine_config(self, ) -> EngineConfig:
)
else:
vision_language_config = None

if self.whisper_input_type:
if self.whisper_processor is None:
self.whisper_processor = self.model
whisper_config = WhisperConfig(
self.whisper_input_type,
self.whisper_processor,
self.whisper_processor_revision,
self.sample_rate,
)
else:
whisper_config = None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
Expand All @@ -794,6 +846,7 @@ def create_engine_config(self, ) -> EngineConfig:
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
whisper_config=whisper_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config,
Expand Down
17 changes: 16 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,24 @@ async def process_model_inputs_async(
else:
prompt_token_ids = inputs["prompt_token_ids"]

if 'whisper_data' in inputs:
if self.whisper_config is None:
raise ValueError(f"Whisper config is None, must initialize a Whisper model.")
if self.whisper_processor is None:
raise ValueError(f"Whisper Processor is not initialized.")
whisper_data = self.whisper_processor(
inputs['whisper_data'],
sampling_rate = self.whisper_config.sample_rate,
return_tensors = 'pt',
)
whisper_data = whisper_data.to(self.model_config.dtype).input_features[0]
else:
whisper_data = None

return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
multi_modal_data=inputs.get("multi_modal_data"),
whisper_data=whisper_data)

async def add_request_async(
self,
Expand Down
36 changes: 32 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
VisionLanguageConfig, WhisperConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
Expand All @@ -36,6 +36,7 @@
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
whisper_config: Optional[WhisperConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
Expand Down Expand Up @@ -206,6 +208,7 @@ def __init__(
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.whisper_config = whisper_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
Expand All @@ -223,10 +226,17 @@ def __init__(
self.tokenizer = None
self.detokenizer = None

if self.whisper_config is not None:
self.whisper_processor = cached_get_whisper_processor(
self.whisper_config.whisper_processor
)
else:
self.whisper_processor = None

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
Expand All @@ -235,6 +245,7 @@ def __init__(
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
whisper_config=whisper_config,
speculative_config=speculative_config,
load_config=load_config,
)
Expand Down Expand Up @@ -501,19 +512,36 @@ def process_model_inputs(
if isinstance(inputs, str):
inputs = {"prompt": inputs}

if 'whisper_data' in inputs:
if self.whisper_config is None:
raise ValueError(f"Whisper config is None, must initialize a Whisper model.")
if self.whisper_processor is None:
raise ValueError(f"Whisper Processor is not initialized.")
whisper_data = self.whisper_processor(
inputs['whisper_data'],
sampling_rate = self.whisper_config.sample_rate,
return_tensors = 'pt',
)
whisper_data = whisper_data.to(self.model_config.dtype).input_features[0]
else:
whisper_data = None

if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")

prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
lora_request=lora_request,
add_special_tokens=self.whisper_processor is None)
else:
prompt_token_ids = inputs["prompt_token_ids"]


return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
multi_modal_data=inputs.get("multi_modal_data"),
whisper_data=whisper_data)

def add_request(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,4 +575,4 @@ def _run_engine(
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
return sorted(outputs, key=lambda x: int(x.request_id))
Loading