diff --git a/examples/whisper_example.py b/examples/whisper_example.py new file mode 100644 index 0000000000000..d325870b20d51 --- /dev/null +++ b/examples/whisper_example.py @@ -0,0 +1,38 @@ +import vllm +import torch +import requests +from vllm import LLM +from datasets import Audio + + +def main(): + sr = 16000 + audio = Audio(sampling_rate=sr) + llm = LLM( + model="openai/whisper-large-v3", + max_num_seqs = 1, + max_model_len = 448, + gpu_memory_utilization = 0.4, + dtype = 'bfloat16', + ) + + r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/7021-79759-0004.wav') + y = audio.decode_example(audio.encode_example(r.content))['array'] + + output_lang = llm.generate({ + "prompt_token_ids": [50258], + "whisper_data": y, + }, sampling_params = SamplingParams(max_tokens = 1, temperature = 0)) + + outputs = llm.generate({ + "prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360], + "whisper_data": y, + }, sampling_params = SamplingParams(max_tokens = 100, temperature = 0)) + + # ' without going to any such extreme as this we can easily see on reflection how vast an influence on the' + print(outputs[0].outputs[0].text) + + + +if __name__ == "__main__": + main() diff --git a/vllm/config.py b/vllm/config.py index 0217a2b569928..3623adfe27b0e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -24,6 +24,7 @@ _GB = 1 << 30 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_WHISPER_MAX_NUM_BATCHED_TOKENS = 448 class ModelConfig: @@ -149,6 +150,7 @@ def __init__( if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_embedding_mode() + self._verify_whisper_mode() self._verify_quantization() self._verify_cuda_graph() @@ -165,6 +167,11 @@ def _verify_embedding_mode(self) -> None: self.embedding_mode = any( ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _verify_whisper_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.whisper_mode = any( + ModelRegistry.is_whisper_model(arch) for arch in architectures) + def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: @@ -682,6 +689,7 @@ class SchedulerConfig: enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. embedding_mode: Whether the running model is for embedding. + whisper_mode: Whether the running model is for whisper. preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than @@ -699,6 +707,7 @@ def __init__(self, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, + whisper_mode: Optional[bool] = False, preemption_mode: Optional[str] = None) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -711,6 +720,9 @@ def __init__(self, # For embedding, choose specific value for higher throughput self.max_num_batched_tokens = max( max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + elif whisper_mode: + self.max_num_batched_tokens = max( + max_model_len, _WHISPER_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. @@ -725,6 +737,7 @@ def __init__(self, self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode + self.whisper_mode = whisper_mode self.preemption_mode = preemption_mode self._verify_args() @@ -1218,6 +1231,30 @@ def as_cli_args_dict(self) -> Dict[str, Any]: return result +@dataclass +class WhisperConfig: + whisper_input_type: Optional[str] = 'input_features' + whisper_processor: Optional[str] = 'openai/whisper-large-v3' + whisper_processor_revision: Optional[str] = 'openai/whisper-large-v3' + sample_rate: Optional[int] = 16000 + + def as_cli_args_dict(self) -> Dict[str, Any]: + """Flatten vision language config to pure args. + + Compatible with what llm entrypoint expects. + """ + result: Dict[str, Any] = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, enum.Enum): + result[f.name] = value.name.lower() + elif isinstance(value, tuple): + result[f.name] = ",".join([str(item) for item in value]) + else: + result[f.name] = value + + return result + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, @@ -1299,6 +1336,8 @@ def _get_and_verify_max_len( "max_sequence_length", "max_seq_length", "seq_len", + # Whisper + "max_length", ] # Choose the smallest "max_length" from the possible keys. max_len_key = None @@ -1435,6 +1474,7 @@ class EngineConfig: load_config: LoadConfig lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] + whisper_config: Optional[WhisperConfig] speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08ae..c03fe7fffb0f6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1006,6 +1006,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + whisper_data=seq_group.whisper_data ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 16374098b23d4..3246e11203825 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,7 +9,7 @@ EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) + VisionLanguageConfig, WhisperConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser, str_to_int_tuple @@ -88,6 +88,12 @@ class EngineArgs: image_processor_revision: Optional[str] = None disable_image_processor: bool = False + # Related to Whisper + whisper_input_type: Optional[str] = None + whisper_processor: Optional[str] = None + whisper_processor_revision: Optional[str] = None + sample_rate: Optional[int] = 16000 + scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -156,6 +162,38 @@ def add_cli_args_for_vlm( return parser + @staticmethod + def add_cli_args_for_whisper( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + '--whisper-input-type', + type=nullable_str, + default=EngineArgs.whisper_input_type, + choices=[ + 'input_features' + ], + help=('The audio input type for whisper passed into vLLM.')) + parser.add_argument( + '--whisper-processor', + type=str, + default=EngineArgs.whisper_processor, + help='Name or path of the huggingface whisper processor to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--whisper-processor-revision', + type=str, + default=None, + help='Revision of the huggingface whisper processor version to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--sample-rate', + type=int, + default=EngineArgs.sample_rate, + help='sample rate for whisper processor') + + return parser + @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" @@ -513,6 +551,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) + parser = EngineArgs.add_cli_args_for_whisper(parser) parser.add_argument( '--scheduler-delay-factor', @@ -717,6 +756,7 @@ def create_engine_config(self, ) -> EngineConfig: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + whisper_mode=model_config.whisper_mode, preemption_mode=self.preemption_mode, ) lora_config = LoRAConfig( @@ -772,6 +812,18 @@ def create_engine_config(self, ) -> EngineConfig: ) else: vision_language_config = None + + if self.whisper_input_type: + if self.whisper_processor is None: + self.whisper_processor = self.model + whisper_config = WhisperConfig( + self.whisper_input_type, + self.whisper_processor, + self.whisper_processor_revision, + self.sample_rate, + ) + else: + whisper_config = None decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) @@ -794,6 +846,7 @@ def create_engine_config(self, ) -> EngineConfig: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, + whisper_config=whisper_config, speculative_config=speculative_config, load_config=load_config, decoding_config=decoding_config, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index df25eb111e87f..4e418e5964e58 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -278,9 +278,24 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if 'whisper_data' in inputs: + if self.whisper_config is None: + raise ValueError(f"Whisper config is None, must initialize a Whisper model.") + if self.whisper_processor is None: + raise ValueError(f"Whisper Processor is not initialized.") + whisper_data = self.whisper_processor( + inputs['whisper_data'], + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + else: + whisper_data = None + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + multi_modal_data=inputs.get("multi_modal_data"), + whisper_data=whisper_data) async def add_request_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7eae257fdd16..24a0ef2e7ee89 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + VisionLanguageConfig, WhisperConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -36,6 +36,7 @@ from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) +from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -154,6 +155,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + whisper_config: Optional[WhisperConfig], speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], observability_config: Optional[ObservabilityConfig], @@ -206,6 +208,7 @@ def __init__( self.cache_config = cache_config self.lora_config = lora_config self.vision_language_config = vision_language_config + self.whisper_config = whisper_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config @@ -223,10 +226,17 @@ def __init__( self.tokenizer = None self.detokenizer = None + if self.whisper_config is not None: + self.whisper_processor = cached_get_whisper_processor( + self.whisper_config.whisper_processor + ) + else: + self.whisper_processor = None + self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) - + self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, @@ -235,6 +245,7 @@ def __init__( device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, + whisper_config=whisper_config, speculative_config=speculative_config, load_config=load_config, ) @@ -501,19 +512,36 @@ def process_model_inputs( if isinstance(inputs, str): inputs = {"prompt": inputs} + if 'whisper_data' in inputs: + if self.whisper_config is None: + raise ValueError(f"Whisper config is None, must initialize a Whisper model.") + if self.whisper_processor is None: + raise ValueError(f"Whisper Processor is not initialized.") + whisper_data = self.whisper_processor( + inputs['whisper_data'], + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + else: + whisper_data = None + if "prompt_token_ids" not in inputs: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], - lora_request=lora_request) + lora_request=lora_request, + add_special_tokens=self.whisper_processor is None) else: prompt_token_ids = inputs["prompt_token_ids"] + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + multi_modal_data=inputs.get("multi_modal_data"), + whisper_data=whisper_data) def add_request( self, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e923493160ed..70651cad73e59 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -575,4 +575,4 @@ def _run_engine( # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - return sorted(outputs, key=lambda x: int(x.request_id)) + return sorted(outputs, key=lambda x: int(x.request_id)) \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ea6275920c79d..b47d95b855d99 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -12,6 +12,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi import File, Form, UploadFile from prometheus_client import make_asgi_app from starlette.routing import Mount @@ -22,10 +23,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, + TranscriptionVerboseJsonResponse, + TranscriptionJsonResponse, EmbeddingRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_whisper import OpenAIServingWhisper from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.version import __version__ as VLLM_VERSION @@ -35,6 +39,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +openai_serving_whisper: OpenAIServingWhisper logger = init_logger('vllm.entrypoints.openai.api_server') @@ -137,7 +142,38 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): else: return JSONResponse(content=generator.model_dump()) - +@app.post("/audio/transcriptions") +async def audio_transcriptions( + file: bytes = File(), + model: str = 'whisper', + language: str = Form(None), + response_format: str = Form('text'), + timestamp_granularities: str = Form('segment'), + repetition_penalty: float = Form(1.0), + stream: bool = Form(False), + raw_request: Request = None, +): + generator = await openai_serving_whisper.create_audio_transcriptions( + file=file, + language=language, + response_format=response_format, + timestamp_granularities=timestamp_granularities, + repetition_penalty=repetition_penalty, + stream=stream, + raw_request=raw_request, + ) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + if isinstance(generator, str): + return generator + else: + return JSONResponse(content=generator.model_dump()) + if __name__ == "__main__": args = parse_args() @@ -219,6 +255,8 @@ async def authentication(request: Request, call_next): engine, model_config, served_model_names, args.lora_modules) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) + openai_serving_whisper = OpenAIServingWhisper( + engine, model_config, served_model_names, args.max_size_mb_whisper) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 59ad73bf097c8..a134a0cbeae09 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -71,6 +71,10 @@ def make_arg_parser(): help="The file path to the chat template, " "or the template in single-line form " "for the specified model") + parser.add_argument("--max-size-mb-whisper", + type=int, + default=200, + help="max size of audio to transcribe using Whisper in term of MB.") parser.add_argument("--response-role", type=nullable_str, default="assistant", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b57d79859aec5..d1719998b8854 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -699,3 +699,25 @@ class BatchRequestOutput(OpenAIBaseModel): # For requests that failed with a non-HTTP error, this will contain more # information on the cause of the failure. error: Optional[Any] + +class Segment(BaseModel): + id: int + seek: int + start: float + end: float + text: str + tokens: list[int] + temperature: float + avg_logprob: float + compression_ratio: float + no_speech_prob: float + +class TranscriptionVerboseJsonResponse(BaseModel): + task: str = "transcribe" + language: str + duration: float + text: str + segments: list[Segment] + +class TranscriptionJsonResponse(BaseModel): + text: str \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_whisper.py b/vllm/entrypoints/openai/serving_whisper.py new file mode 100644 index 0000000000000..3f46816233c00 --- /dev/null +++ b/vllm/entrypoints/openai/serving_whisper.py @@ -0,0 +1,406 @@ +import codecs +import time +import re +import json +import torchaudio +import numpy as np +from torchaudio.io import StreamReader +from dataclasses import dataclass, field +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, + List, Optional) +from typing import Sequence as GenericSequence +from typing import TypedDict, Union, cast, final + +from fastapi import Request +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + Segment, TranscriptionVerboseJsonResponse, TranscriptionJsonResponse) +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.multimodal.image import ImagePixelData +from vllm.multimodal.utils import (async_get_and_parse_image, + get_full_image_text_prompt) +from vllm.outputs import RequestOutput +from vllm.sequence import Logprob +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +logger = init_logger(__name__) + +buffer_size = 4096 +sample_rate = 16000 +segment_length = sample_rate * 1 +maxlen = 30 +replaces = ['<|startoftranscript|>', '<|endoftext|>', '<|transcribe|>'] +pattern = r'<\|\-?\d+\.?\d*\|>' +pattern_pair = r'<\|(\d+\.\d+)\|>(.*?)<\|(\d+\.\d+)\|>' + +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + +class OpenAIServingWhisper(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + max_size_mb_whisper: 200, + ): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + + self._check_whisper_mode(model_config.whisper_mode) + self.model_config = model_config + self.max_size_mb_whisper = max_size_mb_whisper * 1024 * 1024 + + def _check_whisper_mode(self, whisper_mode: bool): + if not whisper_mode: + logger.warning( + "whisper_mode is False. Whisper API will not work.") + else: + logger.info("Activating the server engine with whisper enabled.") + + async def create_audio_transcriptions( + self, + file, + language, + response_format, + timestamp_granularities, + repetition_penalty, + stream, + raw_request: Optional[Request] = None + ): + + if len(file) > self.max_size_mb_whisper: + return self.create_error_response(f"maximum size for `file` is {self.max_size_mb_whisper}MB only") + + if timestamp_granularities.lower().strip() != 'segment': + return self.create_error_response("currently `timestamp_granularities` only support `segment`") + + if response_format.lower() not in {'text', 'json', 'verbose_json', 'srt'}: + return self.create_error_response( + 'currently `response_format` only support `text`, `json`, `verbose_json` and `srt`') + + request_id = f"cmpl-{random_uuid()}" + + sampling_params = SamplingParams( + max_tokens = self.model_config.max_model_len - 4, + temperature = 0.0, + skip_special_tokens = False, + stop_token_ids = [50257], + repetition_penalty = repetition_penalty + ) + + if isinstance(language, str) and language.lower() == 'null': + language = None + + streamer = StreamReader( + src=file, + format=None, + option=None, + buffer_size=buffer_size + ) + streamer.add_basic_audio_stream( + frames_per_chunk=segment_length, + sample_rate=sample_rate + ) + stream_iterator = streamer.stream() + + is_tracing_enabled = await self.engine.is_tracing_enabled() + trace_headers = None + if is_tracing_enabled and raw_request: + trace_headers = extract_trace_headers(raw_request.headers) + if not is_tracing_enabled and raw_request and contains_trace_headers( + raw_request.headers): + log_tracing_disabled_warning() + + # Streaming response + if stream: + return self.audio_transcription_stream_generator( + sampling_params, + stream_iterator, + language, + response_format, + request_id, + trace_headers, + raw_request, + ) + else: + try: + return await self.audio_transcription_full_generator( + sampling_params, + stream_iterator, + language, + response_format, + request_id, + trace_headers, + raw_request, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def generate( + self, + sampling_params, + language, + wav_data, + last_timestamp, + last_i, + response_format, + request_id, + trace_headers, + raw_request, + ): + if language is None: + prompt_ids = [50258] + inputs: PromptInputs = { + "prompt": None, + "prompt_token_ids": prompt_ids, + "whisper_data": wav_data + } + lang_sampling_params = SamplingParams( + max_tokens = 1, temperature = 0, skip_special_tokens = False) + + result_generator = self.engine.generate( + inputs, + lang_sampling_params, + request_id=request_id + '-predict-lang', + lora_request = None, + trace_headers=trace_headers, + ) + async for res in result_generator: + if raw_request is not None and await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + yield self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + + lang_token = final_res.outputs[0].token_ids[0] + language = self.tokenizer.decode([lang_token])[2:-2] + else: + lang_token = self.tokenizer.encode(f'<|{language}|>', add_special_tokens = False)[0] + + prompt_ids = [50258, lang_token, 50360, 50365] + inputs: PromptInputs = { + "prompt": None, + "prompt_token_ids": prompt_ids, + "whisper_data": wav_data + } + + result_generator = self.engine.generate( + inputs, + sampling_params, + request_id=request_id + f'-{last_i}', + lora_request = None, + trace_headers=trace_headers, + ) + + texts = f'<|{language}|><|{last_timestamp}|>' + + if response_format != 'srt': + text = texts + if response_format == 'json': + text = json.dumps({'token': texts}) + + yield text + + """ + [CompletionOutput(index=0, text=' and', token_ids=[293], cumulative_logprob=-1.7037980556488037, logprobs=None, finish_reason=None, stop_reason=None)] + """ + async for res in result_generator: + + token = self.tokenizer.convert_ids_to_tokens([res.outputs[0].token_ids[-1]]) + text = self.tokenizer.convert_tokens_to_string(token) + + for r in replaces: + text = text.replace(r, '') + matches = re.findall(pattern, text) + for match in matches: + timestamp = float(match.split('|')[1]) + timestamp += last_timestamp + timestamp = f'<|{timestamp}|>' + text = text.replace(match, timestamp) + if len(text): + texts += text + matches = re.findall(pattern_pair, texts) + if response_format == 'srt': + if len(matches): + match = matches[0] + if len(match[1]) > 2: + start = float(match[0]) + last_timestamp + end = float(match[-1]) + last_timestamp + text_between = match[1].strip() + ids = f"{last_i + 1}\n" + r = [ + ids, + f"{format_timestamp(start, always_include_hours=True, decimal_marker=',')} --> ", + f"{format_timestamp(end, always_include_hours=True, decimal_marker=',')}\n", + f"{text_between.replace('-->', '->')}\n"] + + combined = ''.join(r) + '\n' + last_i += 1 + yield combined + + texts = text.split('|>')[-2] + '|>' + else: + if response_format == 'json': + text = json.dumps({'token': text}) + + yield text + + + async def audio_transcription_stream_generator( + self, + sampling_params, + stream_iterator, + language, + response_format, + request_id: str, + trace_headers, + raw_request, + ) -> AsyncGenerator[str, None]: + wav_data = np.array([], dtype=np.float32) + last_i = 0 + last_timestamp = 0.0 + try: + for chunk in stream_iterator: + frame = chunk[0][:, 0].numpy() + wav_data = np.concatenate([wav_data, frame]) + audio_len = len(wav_data) / sample_rate + if audio_len >= maxlen: + async for t in self.generate( + sampling_params=sampling_params, + language=language, + wav_data=wav_data, + last_timestamp=last_timestamp, + last_i=last_i, + response_format=response_format, + request_id=request_id, + trace_headers=trace_headers, + raw_request=raw_request, + ): + yield f"data: {t}\n\n" + last_i += 1 + + last_timestamp += audio_len + wav_data = np.array([], dtype=np.float32) + + if len(wav_data): + async for t in self.generate( + sampling_params=sampling_params, + language=language, + wav_data=wav_data, + last_timestamp=last_timestamp, + last_i=last_i, + response_format=response_format, + request_id=request_id, + trace_headers=trace_headers, + raw_request=raw_request, + ): + yield f"data: {t}\n\n" + last_i += 1 + + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def audio_transcription_full_generator( + self, + sampling_params, + stream_iterator, + language, + response_format, + request_id: str, + trace_headers, + raw_request, + ): + tokens = [] + async for data in self.audio_transcription_stream_generator( + sampling_params=sampling_params, + stream_iterator=stream_iterator, + language=language, + response_format='json', + request_id=request_id, + trace_headers=trace_headers, + raw_request=raw_request, + ): + if isinstance(data, str): + if '[DONE]' in data: + break + data = json.loads(data.split('data:')[1].strip()) + tokens.append(data['token']) + + tokens = ''.join(tokens) + lang = tokens.split('|')[1] + matches = re.findall(pattern_pair, tokens) + print(matches) + segments = [] + all_texts = [] + for no, (start, substring, end) in enumerate(matches): + start_timestamp = float(start) + end_timestamp = float(end) + segment = Segment( + id=no, + seek=0, + start=start_timestamp, + end=end_timestamp, + text=substring.strip(), + tokens=self.tokenizer.encode(substring.strip(), add_special_tokens=False), + temperature=0.0, + avg_logprob=0.0, + compression_ratio=1.0, + no_speech_prob=0.0, + ) + segments.append(segment) + all_texts.append(substring) + + all_texts = ''.join(all_texts).strip() + if response_format == 'verbose_json': + return TranscriptionVerboseJsonResponse( + task='transcribe', + language=lang, + duration=segments[-1].end, + text=all_texts, + segments=segments + ) + elif response_format == 'json': + return TranscriptionJsonResponse( + text=all_texts + ) + else: + return all_texts \ No newline at end of file diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 7c2520b5a64f5..b0ce00448171b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -3,7 +3,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, WhisperConfig) from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -26,6 +26,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + whisper_config: Optional[WhisperConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config @@ -36,6 +37,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.whisper_config = whisper_config self.speculative_config = speculative_config self._init_executor() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 0a654200ed796..3f13f06e36cc7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -44,6 +44,7 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + whisper_config=self.whisper_config, speculative_config=self.speculative_config, is_driver_worker=rank == 0, ) diff --git a/vllm/inputs.py b/vllm/inputs.py index 026903e19a26e..12a60b6aa9af4 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -4,7 +4,7 @@ from typing_extensions import NotRequired if TYPE_CHECKING: - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalData, WhisperData class ParsedText(TypedDict): @@ -73,6 +73,8 @@ class TextPrompt(TypedDict): """The input text to be tokenized before passing to the model.""" multi_modal_data: NotRequired["MultiModalData"] + + whisper_data: NotRequired["WhisperData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -86,6 +88,8 @@ class TokensPrompt(TypedDict): """A list of token IDs to pass to the model.""" multi_modal_data: NotRequired["MultiModalData"] + + whisper_data: NotRequired["WhisperData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -105,6 +109,8 @@ class TextTokensPrompt(TypedDict): tokenizer to convert the prompts to token IDs.""" multi_modal_data: NotRequired["MultiModalData"] + + whisper_data: NotRequired["WhisperData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -128,3 +134,4 @@ class LLMInputs(TypedDict): prompt_token_ids: List[int] prompt: NotRequired[Optional[str]] multi_modal_data: NotRequired[Optional["MultiModalData"]] + whisper_data: NotRequired[Optional["WhisperData"]] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 5afb2e1d44d39..b9bb666a495ac 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -67,7 +67,11 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } -_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} +_WHISPER_MODELS = { + "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), +} + +_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS, **_WHISPER_MODELS} # Architecture -> type. # out of tree models @@ -129,6 +133,9 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): def is_embedding_model(model_arch: str) -> bool: return model_arch in _EMBEDDING_MODELS + @staticmethod + def is_whisper_model(model_arch: str) -> bool: + return model_arch in _WHISPER_MODELS __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py new file mode 100644 index 0000000000000..e543393cc03e9 --- /dev/null +++ b/vllm/model_executor/models/whisper.py @@ -0,0 +1,576 @@ +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union + +import math +import torch +from torch import nn +from transformers import WhisperConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import FastGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.utils import is_hip, print_warning_once +from xformers import ops as xops +from vllm.utils import is_hip + +def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: + """Returns sinusoids for positional embedding""" + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + +class WhisperPositionalEmbedding(nn.Embedding): + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__(num_positions, embedding_dim) + + def forward(self, position_ids): + return self.weight[position_ids] + +class WhisperAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = max(1, self.num_heads // tp_size) + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = False, + quant_config=quant_config, + ) + self.v_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, + quant_config=quant_config + ) + self.q_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, + quant_config=quant_config + ) + self.out_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, + quant_config=quant_config + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + +class WhisperEncoderAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + config=config, + quant_config=quant_config, + cache_config=cache_config + ) + + def forward( + self, + hidden_states: torch.Tensor, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + + attn_output = attn_output.reshape(-1, self.embed_dim) + output, _ = self.out_proj(attn_output) + return output + +class WhisperDecoderAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + config=config, + quant_config=quant_config, + cache_config=cache_config + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata: AttentionMetadata = None, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + output, _ = self.out_proj(attn_output) + + return output + +class WhisperDecoderCrossAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + config=config, + quant_config=quant_config, + cache_config=cache_config + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + + attn_output = attn_output.reshape(-1, self.embed_dim) + output, _ = self.out_proj(attn_output) + return output + +class WhisperEncoderLayer(nn.Module): + def __init__( + self, + config: WhisperConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WhisperEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + quant_config=quant_config, + cache_config=cache_config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = FastGELU() + self.fc1 = RowParallelLinear( + input_size = self.embed_dim, + output_size = config.encoder_ffn_dim, + bias = True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + input_size = config.encoder_ffn_dim, + output_size = self.embed_dim, + bias = True, + quant_config=quant_config, + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + +class WhisperDecoderLayer(nn.Module): + def __init__( + self, + config: WhisperConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WhisperDecoderAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + quant_config=quant_config, + cache_config=cache_config, + ) + self.activation_fn = FastGELU() + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = WhisperDecoderCrossAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + quant_config=quant_config, + cache_config=cache_config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = RowParallelLinear( + input_size = self.embed_dim, + output_size = config.decoder_ffn_dim, + bias = True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + input_size = config.decoder_ffn_dim, + output_size = self.embed_dim, + bias = True, + quant_config=quant_config, + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + +class WhisperEncoder(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + + self.layers = nn.ModuleList([WhisperEncoderLayer(config, quant_config=quant_config, cache_config=cache_config) + for layer_idx in range(config.encoder_layers)]) + + self.layer_norm = nn.LayerNorm(config.d_model) + + with torch.no_grad(): + self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape)) + + def forward( + self, + input_features, + ): + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + inputs_embeds = inputs_embeds.permute(1, 0) + + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer(hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + +class WhisperDecoder(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + + self.layers = nn.ModuleList([WhisperDecoderLayer(config, quant_config=quant_config, cache_config=cache_config) + for layer_idx in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_ids, + positions: torch.Tensor, + encoder_hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + past_key_values = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + positions = self.embed_positions(positions) + hidden_states = inputs_embeds + positions + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + +class WhisperModel(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.encoder = WhisperEncoder(config, cache_config=cache_config, quant_config=quant_config) + self.decoder = WhisperDecoder(config, cache_config=cache_config, quant_config=quant_config) + + def forward( + self, + input_features: torch.FloatTensor, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ): + encoder_outputs = self.encoder(input_features) + + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=encoder_outputs, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return decoder_outputs + +class WhisperForConditionalGeneration(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.model = WhisperModel(config, cache_config=cache_config, quant_config=quant_config) + self.unpadded_vocab_size = config.vocab_size + self.proj_out = RowParallelLinear( + input_size = config.d_model, + output_size = config.vocab_size, + bias = False, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward( + self, + whisper_data: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> SamplerOutput: + + decoder_outputs = self.model( + input_features=whisper_data, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return decoder_outputs + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.proj_out.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if name == 'model.decoder.embed_tokens.weight': + param = params_dict['proj_out.weight'] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + param = params_dict[name] \ No newline at end of file diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 270012e7d1c3b..df3b5d51e8ff9 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,7 +1,8 @@ from .base import MultiModalData, MultiModalPlugin +from .audio import WhisperData from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry __all__ = [ "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", - "MultiModalRegistry" + "MultiModalRegistry", "WhisperData" ] diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py new file mode 100644 index 0000000000000..bc3670caaf7f6 --- /dev/null +++ b/vllm/multimodal/audio.py @@ -0,0 +1,84 @@ +from typing import Dict, Tuple, Type, Union + +import torch +import numpy as np +from vllm.config import ModelConfig, WhisperConfig +from vllm.logger import init_logger +from vllm.sequence import SequenceData +from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor + +from .base import MultiModalData, MultiModalPlugin + +logger = init_logger(__name__) + +class WhisperData: + pass + + +def _get_dummy_seq_data(seq_len: int, + whisper_config: WhisperConfig) -> SequenceData: + + # '<|startoftranscript|><|en|><|transcribe|>' + token_ids = [50258, 50259, 50360] + return SequenceData(token_ids) + + +def _get_dummy_values(whisper_config: WhisperConfig) -> torch.Tensor: + values_dtype = torch.float16 + + return torch.zeros((30 * whisper_config.sample_rate), dtype=values_dtype) + + +def get_dummy_audio_data( + seq_len: int, + model_config: ModelConfig, + whisper_config: WhisperConfig, +) -> Tuple[SequenceData, MultiModalData]: + """Standard dummy data factory for image data (to be used in + :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`).""" + seq_data = _get_dummy_seq_data(seq_len, whisper_config) + values = _get_dummy_values(whisper_config) + + fake_mm_data = AudioData(values) + return seq_data, fake_mm_data + + +class AudioData(MultiModalData): + + def __init__(self, audio: Union[np.array, torch.Tensor]) -> None: + + self.audio = audio + + def __repr__(self) -> str: + return str(self.audio) + + +class AudioPlugin(MultiModalPlugin[AudioData]): + + def get_data_type(self) -> Type[AudioData]: + return AudioData + + def _get_hf_whisper_processor(self, model_config: ModelConfig, + whisper_config: WhisperConfig): + if whisper_config is None or whisper_config.whisper_processor is None: + return None + + return cached_get_whisper_processor( + whisper_config.whisper_processor, + revision=whisper_config.whisper_processor_revision, + ) + + def _default_input_processor( + self, data: AudioData, model_config: ModelConfig, + whisper_config: WhisperConfig) -> Dict[str, torch.Tensor]: + audio = data.audio + + processor = self._get_hf_whisper_processor(model_config, whisper_config) + if processor is None: + raise RuntimeError("No HuggingFace processor is available" + "to process the audio object") + try: + return processor(audio, return_tensors="pt", sampling_rate = 16000).to(model_config.dtype) + except Exception: + logger.error("Failed to process audio (%s)", audio) + raise diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 4789ce5ce4cfe..92cd484cea366 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,12 +2,13 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar) -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig, VisionLanguageConfig, WhisperConfig from vllm.logger import init_logger from .base import MultiModalData, MultiModalPlugin from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, ImagePixelPlugin) +from .audio import AudioData, AudioPlugin if TYPE_CHECKING: import torch @@ -20,9 +21,9 @@ D = TypeVar("D", bound=MultiModalData) N = TypeVar("N", bound=Type["nn.Module"]) -MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], +MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig, WhisperConfig], Dict[str, "torch.Tensor"]] -MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig], +MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig, WhisperConfig], Tuple["SequenceData", MultiModalData]] @@ -32,7 +33,7 @@ class MultiModalRegistry: according to its modality and the target model. """ - DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) + DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin(), AudioPlugin()) def __init__(self, *, @@ -119,6 +120,17 @@ def register_image_pixel_input( """ return self.register_input(ImagePixelData, processor) + def register_audio_input( + self, + processor: Optional[ + MultiModalInputProcessor[AudioData]] = None): + """ + Register an input processor for image pixel data to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self.register_input(AudioData, processor) + def register_image_feature_input( self, processor: Optional[ @@ -129,6 +141,7 @@ def register_image_feature_input( See :meth:`MultiModalPlugin.register_input_processor` for more details. """ return self.register_input(ImageFeatureData, processor) + def process_input(self, data: MultiModalData, model_config: ModelConfig, vlm_config: VisionLanguageConfig): diff --git a/vllm/sequence.py b/vllm/sequence.py index 287e1b9df6165..2e94a98449b2e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,7 +14,7 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalData, WhisperData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -260,6 +260,10 @@ def prompt_token_ids(self) -> List[int]: def multi_modal_data(self) -> Optional["MultiModalData"]: return self.inputs.get("multi_modal_data") + @property + def whisper_data(self) -> Optional["WhisperData"]: + return self.inputs.get("whisper_data") + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -463,6 +467,12 @@ def multi_modal_data(self) -> Optional["MultiModalData"]: # We use the multi-modal data of an arbitrary sequence. return next(iter(self.seqs_dict.values())).multi_modal_data + @property + def whisper_data(self) -> Optional["WhisperData"]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).whisper_data + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -641,6 +651,7 @@ def __init__( computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional["MultiModalData"] = None, + whisper_data: Optional["WhisperData"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: @@ -653,6 +664,7 @@ def __init__( self.lora_request = lora_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data + self.whisper_data = whisper_data self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data self.cross_block_table = cross_block_table diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 9614f01d2b955..4ab768f0e5ea8 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -49,9 +49,11 @@ def _raise_if_input_too_long(self, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: bool = True, + ) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = tokenizer.encode(prompt) + ret = tokenizer.encode(prompt, add_special_tokens = add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret diff --git a/vllm/transformers_utils/whisper_processor.py b/vllm/transformers_utils/whisper_processor.py new file mode 100644 index 0000000000000..3b0f20ca101dd --- /dev/null +++ b/vllm/transformers_utils/whisper_processor.py @@ -0,0 +1,44 @@ +from functools import lru_cache +from typing import Optional + +from transformers import WhisperProcessor + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_whisper_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +) -> WhisperProcessor: + """Gets an whisper processor for the given model name via HuggingFace.""" + try: + processor: WhisperProcessor = WhisperProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the whisper processor. If the whisper processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor + + +cached_get_whisper_processor = lru_cache(get_whisper_processor) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a321eafce1a2f..adda874e1c970 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,7 +11,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + VisionLanguageConfig, WhisperConfig) from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger @@ -26,6 +26,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) +from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor logger = init_logger(__name__) @@ -86,6 +87,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + whisper_config: Optional[WhisperConfig] = None, return_hidden_states: bool = False, ): self.model_config = model_config @@ -97,6 +99,7 @@ def __init__( self.load_config = load_config self.is_driver_worker = is_driver_worker self.vision_language_config = vision_language_config + self.whisper_config = whisper_config self.return_hidden_states = return_hidden_states self.device = self.device_config.device @@ -139,6 +142,13 @@ def __init__( ) else: self.multi_modal_input_processor = None + + if self.whisper_config is not None: + self.whisper_processor = cached_get_whisper_processor( + self.whisper_config.whisper_processor + ) + else: + self.whisper_processor = None # Lazy initialization self.model: nn.Module # Set after load_model @@ -750,6 +760,8 @@ def execute_model( else: model_executable = self.model + print(input_tokens, input_positions, attn_metadata) + hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, @@ -830,13 +842,13 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - - if vlm_config is None: - seq_data = SequenceData([0] * seq_len) - dummy_multi_modal_data = None - else: + + if vlm_config is not None: seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ .dummy_data_for_profiling(seq_len, model_config, vlm_config) + else: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None seq = SequenceGroupMetadata( request_id=str(group_id), diff --git a/vllm/worker/whisper_model_runner.py b/vllm/worker/whisper_model_runner.py new file mode 100644 index 0000000000000..e148fabb181ac --- /dev/null +++ b/vllm/worker/whisper_model_runner.py @@ -0,0 +1,523 @@ +import gc +import time +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig, WhisperConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.parallel_state import graph_capture +from vllm.model_executor import SamplingMetadata +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import ModelRunner + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 +LORA_WARMUP_RANK = 8 +_BATCH_SIZE_ALIGNMENT = 8 +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] +_NUM_WARMUP_ITERS = 2 + +class WhisperModelRunner(ModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + whisper_config: Optional[WhisperConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config, + whisper_config=whisper_config) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (whisper_data, input_tokens, input_positions, attn_metadata, sampling_metadata, + lora_requests, lora_mapping, multi_modal_kwargs + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + hidden_states = model_executable( + whisper_data=whisper_data, + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **multi_modal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + return output + + def _prepare_encoder_model_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + attn_metadata: AttentionMetadata): + + whisper_data_list = [] + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + for seq_id in seq_ids: + + whisper_data = seq_group_metadata.whisper_data + if whisper_data is not None: + # Process multi-modal data + if self.whisper_processor is None: + raise ValueError("Whisper Processor not initialized") + + if len(whisper_data.shape) == 1: + whisper_data = self.whisper_processor( + whisper_data, + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + + whisper_data_list.append(whisper_data) + + whisper_data = whisper_data_list[0].cuda() + + return whisper_data + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + lora_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + whisper_data = self._prepare_encoder_model_input(seq_group_metadata_list, attn_metadata) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) + + metadata_dict = { + 'whisper_data': whisper_data, + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + } + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + whisper_data = metadata_dict.pop("whisper_data") + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + attn_metadata = None + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + + return (whisper_data, input_tokens, input_positions, attn_metadata, + sampling_metadata, lora_requests, lora_mapping, + multi_modal_kwargs) + + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + model_config = self.model_config + vlm_config = self.vision_language_config + whisper_config = self.whisper_config + + if vlm_config: + max_num_seqs = min( + max_num_seqs, + int(max_num_batched_tokens / vlm_config.image_feature_size)) + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + + whisper_data = None + if vlm_config is not None: + seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ + .dummy_data_for_profiling(seq_len, model_config, vlm_config) + else: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + + if whisper_config is not None: + whisper_data = torch.zeros( + (30 * whisper_config.sample_rate), + dtype=torch.float16) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_multi_modal_data, + whisper_data=whisper_data + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() + return + + @torch.inference_mode() + def capture_model(self, kv_caches: List[torch.Tensor]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + whisper_data = torch.zeros( + (30 * self.whisper_config.sample_rate), + dtype=torch.float16) + whisper_data = self.whisper_processor( + whisper_data, + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0].cuda() + + # Prepare buffer for outputs. These will be reused for all batch sizes. + # It will be filled after the first graph capture. + hidden_states: Optional[torch.Tensor] = None + + graph_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + + with graph_capture() as graph_capture_context: + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + # Create dummy attn_metadata. + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner(self.model) + hidden_states = graph_runner.capture( + whisper_data, + input_tokens[:batch_size], + input_positions[:batch_size], + hidden_states[:batch_size] + if hidden_states is not None else None, + kv_caches, + attn_metadata, + memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + +class CUDAGraphRunner: + + def __init__(self, model: nn.Module): + self.model = model + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + + def capture( + self, + whisper_data: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, + **kwargs, + ) -> torch.Tensor: + assert self._graph is None + # Run the model a few times without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + # Note one iteration is not enough for torch.jit.script + for _ in range(_NUM_WARMUP_ITERS): + self.model( + whisper_data, + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + torch.cuda.synchronize() + + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + output_hidden_states = self.model( + whisper_data, + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + if hidden_states is not None: + hidden_states.copy_(output_hidden_states) + else: + hidden_states = output_hidden_states + del output_hidden_states + # make sure `output_hidden_states` is deleted + # in the graph's memory pool + gc.collect() + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "whisper_data": whisper_data, + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + self.output_buffers = {"hidden_states": hidden_states} + return hidden_states + + def forward( + self, + whisper_data: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["whisper_data"].copy_(whisper_data, non_blocking=True) + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, + non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["hidden_states"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False \ No newline at end of file diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c60764ef1bed8..6cde97b53227e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, WhisperConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, @@ -20,6 +20,7 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner +from vllm.worker.whisper_model_runner import WhisperModelRunner from vllm.worker.worker_base import WorkerBase @@ -44,6 +45,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + whisper_config: Optional[WhisperConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, ) -> None: @@ -69,6 +71,10 @@ def __init__( if self.vision_language_config: assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") + self.whisper_config = whisper_config + if self.whisper_config: + assert not self.lora_config, ( + "Whisper does not support with LoRA settings.") # Return hidden states from target model if the draft model is an # mlp_speculator @@ -78,8 +84,12 @@ def __init__( or (speculative_config.draft_model_config.hf_config.model_type != "mlp_speculator") else {"return_hidden_states": True} - ModelRunnerClass = (EmbeddingModelRunner if - self.model_config.embedding_mode else ModelRunner) + if self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + elif self.whisper_config is not None: + ModelRunnerClass = WhisperModelRunner + else: + ModelRunnerClass = ModelRunner self.model_runner = ModelRunnerClass( model_config, parallel_config, @@ -91,6 +101,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + whisper_config=whisper_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by