From 2d3347fcf79866ac18a8fc812262a76c453a46fa Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Thu, 27 Jun 2024 13:46:44 +0800 Subject: [PATCH 01/16] initial --- vllm/model_executor/models/whisper.py | 200 ++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 vllm/model_executor/models/whisper.py diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py new file mode 100644 index 0000000000000..d80acbe1d9d2e --- /dev/null +++ b/vllm/model_executor/models/whisper.py @@ -0,0 +1,200 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import WhisperConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import FastGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.utils import is_hip, print_warning_once + +def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: + """Returns sinusoids for positional embedding""" + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + +class WhisperPositionalEmbedding(nn.Embedding): + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__(num_positions, embedding_dim) + + def forward(self, input_ids, past_key_values_length=0, position_ids=None): + if position_ids is None: + return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + else: + return self.weight[position_ids] + +class WhisperAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = max(1, self.num_heads // tp_size) + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = False + ) + self.k_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias + ) + self.q_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias + ) + self.out_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config + ) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ): + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + q, _ = self.q_proj(hidden_states) * self.scaling + k, _ = self.k_proj(key_value_states) + v, _ = self.v_proj(hidden_states) + if is_cross_attention: + # reuse k,v, cross_attentions + q = self._shape(q, tgt_len, bsz) + k = self._shape(k, -1, bsz) + v = self._shape(v, -1, bsz) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=self.is_causal and tgt_len > 1, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + output = self.out_proj(attn_output) + elif past_key_value is not None: + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + output, _ = self.o_proj(attn_output) + return output + +# class WhisperAttention(nn.Module): +# def __init__( +# self, +# embed_dim: int, +# num_heads: int, +# dropout: float = 0.0, +# is_decoder: bool = False, +# bias: bool = True, +# is_causal: bool = False, +# config: Optional[WhisperConfig] = None, +# cache_config: Optional[CacheConfig] = None, +# ): + +class WhisperEncoderLayer(nn.Module): + def __init__( + self, + config: WhisperConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + quant_config=quant_config, + cache_config=cache_config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = FastGELU() + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states From 860b70aae37c3203cc4035b2ab78aaf0c4838a38 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Thu, 27 Jun 2024 18:11:27 +0800 Subject: [PATCH 02/16] added whisper encoder decoder --- vllm/model_executor/models/whisper.py | 250 +++++++++++++++++++++----- 1 file changed, 209 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d80acbe1d9d2e..88f5221082440 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -40,11 +40,8 @@ class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) - def forward(self, input_ids, past_key_values_length=0, position_ids=None): - if position_ids is None: - return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] - else: - return self.weight[position_ids] + def forward(self, input_ids, position_ids=None): + return self.weight[position_ids] class WhisperAttention(nn.Module): def __init__( @@ -52,8 +49,8 @@ def __init__( embed_dim: int, num_heads: int, is_decoder: bool = False, - bias: bool = True, is_causal: bool = False, + bias: bool = True, config: Optional[WhisperConfig] = None, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, @@ -78,22 +75,26 @@ def __init__( self.k_proj = RowParallelLinear( input_size = embed_dim, output_size = embed_dim, - bias = False + bias = False, + quant_config=quant_config, ) self.k_proj = RowParallelLinear( input_size = embed_dim, output_size = embed_dim, - bias = bias + bias = bias, + quant_config=quant_config ) self.q_proj = RowParallelLinear( input_size = embed_dim, output_size = embed_dim, - bias = bias + bias = bias, + quant_config=quant_config ) self.out_proj = RowParallelLinear( input_size = embed_dim, output_size = embed_dim, - bias = bias + bias = bias, + quant_config=quant_config ) self.attn = Attention( self.num_heads, @@ -110,51 +111,44 @@ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value = None, + kv_cache: torch.Tensor = None, + attn_metadata: AttentionMetadata = None, ): - is_cross_attention = key_value_states is not None + is_cross_attention = past_key_value is not None bsz, tgt_len, _ = hidden_states.size() q, _ = self.q_proj(hidden_states) * self.scaling - k, _ = self.k_proj(key_value_states) - v, _ = self.v_proj(hidden_states) - if is_cross_attention: - # reuse k,v, cross_attentions + + + if kv_cache is None: q = self._shape(q, tgt_len, bsz) - k = self._shape(k, -1, bsz) - v = self._shape(v, -1, bsz) + + if is_cross_attention: + k = past_key_value[0] + v = past_key_value[1] + else: + k, _ = self.k_proj(key_value_states) + v, _ = self.v_proj(hidden_states) + k = self._shape(k, -1, bsz) + v = self._shape(v, -1, bsz) attn_output = torch.nn.functional.scaled_dot_product_attention( q, k, v, - attn_mask=attention_mask, dropout_p=0.0, is_causal=self.is_causal and tgt_len > 1, ) attn_output = attn_output.reshape(bsz, q_len, -1) output = self.out_proj(attn_output) - elif past_key_value is not None: - + else: + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output -# class WhisperAttention(nn.Module): -# def __init__( -# self, -# embed_dim: int, -# num_heads: int, -# dropout: float = 0.0, -# is_decoder: bool = False, -# bias: bool = True, -# is_causal: bool = False, -# config: Optional[WhisperConfig] = None, -# cache_config: Optional[CacheConfig] = None, -# ): - class WhisperEncoderLayer(nn.Module): def __init__( self, @@ -173,8 +167,18 @@ def __init__( ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = FastGELU() - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.fc1 = RowParallelLinear( + input_size = self.embed_dim, + output_size = config.encoder_ffn_dim, + bias = True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + input_size = config.encoder_ffn_dim, + output_size = self.embed_dim, + bias = True, + quant_config=quant_config, + ) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( @@ -188,9 +192,6 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states @@ -198,3 +199,170 @@ def forward( hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states + + return hidden_states + +class WhisperDecoderLayer(nn.Module): + def __init__( + self, + config: WhisperConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + is_causal=True, + config=config, + quant_config=quant_config, + cache_config=cache_config, + ) + self.activation_fn = FastGELU() + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config, + quant_config=quant_config, + cache_config=cache_config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = RowParallelLinear( + input_size = self.embed_dim, + output_size = config.decoder_ffn_dim, + bias = True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + input_size = config.decoder_ffn_dim, + output_size = self.embed_dim, + bias = True, + quant_config=quant_config, + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata + ) + hidden_states = residual + hidden_states + + hidden_states = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attn_metadata=attn_metadata + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return outputs + +class WhisperEncoder(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + + self.layers = nn.ModuleList([WhisperEncoderLayer(config, quant_config=quant_config, cache_config=cache_config) + for layer_idx in range(config.decoder_layers)]) + + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_features, + ): + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer(hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + +class WhisperDecoder(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.d_model) + self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + self.layers = nn.ModuleList([WhisperDecoderLayer(config, quant_config=quant_config, cache_config=cache_config) + for layer_idx in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_ids, + positions: torch.Tensor, + past_key_values, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + positions = self.embed_positions(input_ids, positions) + hidden_states = inputs_embeds + positions + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + kv_cache=kv_caches[idx], + output_attentions=output_attentions, + attn_metadata=attn_metadata + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + + + + From c15bafc1756fac227f9ad36f80f8886935a804e2 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Fri, 28 Jun 2024 07:31:38 +0800 Subject: [PATCH 03/16] added configs --- vllm/config.py | 24 ++++ vllm/engine/arg_utils.py | 33 +++++- vllm/model_executor/models/whisper.py | 113 ++++++++++++++++--- vllm/transformers_utils/whisper_processor.py | 44 ++++++++ 4 files changed, 198 insertions(+), 16 deletions(-) create mode 100644 vllm/transformers_utils/whisper_processor.py diff --git a/vllm/config.py b/vllm/config.py index 0217a2b569928..7abca2030478c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1218,6 +1218,30 @@ def as_cli_args_dict(self) -> Dict[str, Any]: return result +@dataclass +class WhisperConfig: + whisper_processor: Optional[str] + whisper_processor_revision: Optional[str] + + def as_cli_args_dict(self) -> Dict[str, Any]: + """Flatten vision language config to pure args. + + Compatible with what llm entrypoint expects. + """ + result: Dict[str, Any] = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, enum.Enum): + result[f.name] = value.name.lower() + elif isinstance(value, tuple): + result[f.name] = ",".join([str(item) for item in value]) + else: + result[f.name] = value + + result["disable_image_processor"] = self.image_processor is None + + return result + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 16374098b23d4..e6cb811f7e051 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,7 +9,7 @@ EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) + VisionLanguageConfig, WhisperConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser, str_to_int_tuple @@ -88,6 +88,11 @@ class EngineArgs: image_processor_revision: Optional[str] = None disable_image_processor: bool = False + # Related to Whisper + whisper_input_type = 'input_features' + whisper_processor: Optional[str] = None + whisper_processor_revision: Optional[str] = None + scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -156,6 +161,30 @@ def add_cli_args_for_vlm( return parser + @staticmethod + def add_cli_args_for_whisper( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + '--image-processor', + type=str, + default=EngineArgs.image_processor, + help='Name or path of the huggingface image processor to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--image-processor-revision', + type=str, + default=None, + help='Revision of the huggingface image processor version to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--disable-image-processor', + action='store_true', + help='Disables the use of image processor, even if one is defined ' + 'for the model on huggingface.') + + return parser + @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" @@ -772,6 +801,8 @@ def create_engine_config(self, ) -> EngineConfig: ) else: vision_language_config = None + + if self.whisper_input_type decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 88f5221082440..478ce5b32612c 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -111,21 +111,31 @@ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states = None, past_key_value = None, kv_cache: torch.Tensor = None, attn_metadata: AttentionMetadata = None, ): - is_cross_attention = past_key_value is not None + is_cross_attention = encoder_hidden_states is not None bsz, tgt_len, _ = hidden_states.size() q, _ = self.q_proj(hidden_states) * self.scaling + past_key_value = None if kv_cache is None: q = self._shape(q, tgt_len, bsz) if is_cross_attention: - k = past_key_value[0] - v = past_key_value[1] + if past_key_value is not None: + k = past_key_value[0] + v = past_key_value[1] + else: + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + k = self._shape(k, -1, bsz) + v = self._shape(v, -1, bsz) + + past_key_value = (k, v) else: k, _ = self.k_proj(key_value_states) v, _ = self.v_proj(hidden_states) @@ -147,7 +157,7 @@ def forward( attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) - return output + return output, past_key_value class WhisperEncoderLayer(nn.Module): def __init__( @@ -190,7 +200,7 @@ def forward( residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, ) hidden_states = residual + hidden_states @@ -211,7 +221,7 @@ def __init__( ): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( + self.self_attn, _ = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, is_decoder=True, @@ -249,10 +259,11 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - past_key_value: torch.Tensor, + encoder_hidden_states: torch.Tensor, + past_key_value: Tuple[torch.Tensor, torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - ) -> torch.Tensor: + ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -264,8 +275,9 @@ def forward( ) hidden_states = residual + hidden_states - hidden_states = self.self_attn( + hidden_states, cross_attention_past_key_value = self.self_attn( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, attn_metadata=attn_metadata ) @@ -277,7 +289,7 @@ def forward( hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states - return outputs + return outputs, cross_attention_past_key_value class WhisperEncoder(nn.Module): def __init__( @@ -333,7 +345,7 @@ def __init__( self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.d_model) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.layers = nn.ModuleList([WhisperDecoderLayer(config, quant_config=quant_config, cache_config=cache_config) for layer_idx in range(config.decoder_layers)]) @@ -343,26 +355,97 @@ def forward( self, input_ids, positions: torch.Tensor, - past_key_values, + encoder_hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - inputs_embeds: Optional[torch.Tensor] = None, + past_key_values = None, ): inputs_embeds = self.embed_tokens(input_ids) positions = self.embed_positions(input_ids, positions) hidden_states = inputs_embeds + positions + + cross_attention_past_key_values = [] + for idx, decoder_layer in enumerate(self.layers): - hidden_states = decoder_layer( + hidden_states, cross_attention_past_key_value = decoder_layer( hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=None if past_key_values is None else past_key_values[idx], kv_cache=kv_caches[idx], output_attentions=output_attentions, attn_metadata=attn_metadata ) + cross_attention_past_key_values.append(cross_attention_past_key_value) hidden_states = self.layer_norm(hidden_states) - return hidden_states + return hidden_states, cross_attention_past_key_values +class WhisperModel(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config) + + self.encoder = WhisperEncoder(config, cache_config=cache_config, quant_config=quant_config) + self.decoder = WhisperDecoder(config, cache_config=cache_config, quant_config=quant_config) + + def forward( + self, + input_features: torch.FloatTensor, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + past_key_values = None, + ): + encoder_outputs = self.encoder( + input_features, + ) + decoder_outputs, cross_attention_past_key_values = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=encoder_outputs, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + past_key_values=past_key_values + ) + return decoder_outputs, cross_attention_past_key_values +class WhisperForConditionalGeneration(nn.Module): + def __init__( + self, + config: WhisperConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config) + self.model = WhisperModel(config, cache_config=cache_config, quant_config=quant_config) + self.proj_out = RowParallelLinear( + input_size = config.d_model, + output_size = config.vocab_size, + bias = False, + quant_config=quant_config, + ) + def forward( + self, + input_features: torch.FloatTensor, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + past_key_values = None, + ): + outputs = self.model( + input_features=input_features, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + past_key_values=past_key_values, + ) \ No newline at end of file diff --git a/vllm/transformers_utils/whisper_processor.py b/vllm/transformers_utils/whisper_processor.py new file mode 100644 index 0000000000000..9bec4213d0240 --- /dev/null +++ b/vllm/transformers_utils/whisper_processor.py @@ -0,0 +1,44 @@ +from functools import lru_cache +from typing import Optional + +from transformers import WhisperProcessor + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_whisper_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +) -> BaseImageProcessor: + """Gets an image processor for the given model name via HuggingFace.""" + try: + processor: WhisperProcessor = WhisperProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the whisper processor. If the whisper processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor + + +get_whisper_processor = lru_cache(get_whisper_processor) From 2c82ba7e3b9084fb00802c73b2c59a02ff049e9e Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Fri, 28 Jun 2024 15:32:21 +0800 Subject: [PATCH 04/16] able to load --- vllm/config.py | 11 ++- vllm/engine/arg_utils.py | 16 +++- vllm/engine/llm_engine.py | 7 +- vllm/entrypoints/llm.py | 3 +- vllm/executor/executor_base.py | 4 +- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/whisper.py | 61 ++++++++++--- vllm/multimodal/audio.py | 96 ++++++++++++++++++++ vllm/multimodal/registry.py | 21 ++++- vllm/transformers_utils/whisper_processor.py | 6 +- 10 files changed, 200 insertions(+), 27 deletions(-) create mode 100644 vllm/multimodal/audio.py diff --git a/vllm/config.py b/vllm/config.py index 7abca2030478c..7347a1f2f1522 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1220,8 +1220,10 @@ def as_cli_args_dict(self) -> Dict[str, Any]: @dataclass class WhisperConfig: - whisper_processor: Optional[str] - whisper_processor_revision: Optional[str] + whisper_input_type: Optional[str] = 'input_features' + whisper_processor: Optional[str] = 'openai/whisper-large-v3' + whisper_processor_revision: Optional[str] = 'openai/whisper-large-v3' + sample_rate: Optional[int] = 16000 def as_cli_args_dict(self) -> Dict[str, Any]: """Flatten vision language config to pure args. @@ -1238,8 +1240,6 @@ def as_cli_args_dict(self) -> Dict[str, Any]: else: result[f.name] = value - result["disable_image_processor"] = self.image_processor is None - return result @@ -1323,6 +1323,8 @@ def _get_and_verify_max_len( "max_sequence_length", "max_seq_length", "seq_len", + # Whisper + "max_length", ] # Choose the smallest "max_length" from the possible keys. max_len_key = None @@ -1459,6 +1461,7 @@ class EngineConfig: load_config: LoadConfig lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] + whisper_config: Optional[WhisperConfig] speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e6cb811f7e051..a0328a88016ac 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -89,9 +89,10 @@ class EngineArgs: disable_image_processor: bool = False # Related to Whisper - whisper_input_type = 'input_features' + whisper_input_type: Optional[str] = 'input_features' whisper_processor: Optional[str] = None whisper_processor_revision: Optional[str] = None + sample_rate: Optional[int] = 16000 scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -802,7 +803,17 @@ def create_engine_config(self, ) -> EngineConfig: else: vision_language_config = None - if self.whisper_input_type + if self.whisper_input_type: + if self.whisper_processor is None: + self.whisper_processor = self.model + whisper_config = WhisperConfig( + self.whisper_input_type, + self.whisper_processor, + self.whisper_processor_revision, + self.sample_rate, + ) + else: + whisper_config = None decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) @@ -825,6 +836,7 @@ def create_engine_config(self, ) -> EngineConfig: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, + whisper_config=whisper_config, speculative_config=speculative_config, load_config=load_config, decoding_config=decoding_config, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7eae257fdd16..8c96afd62d695 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + VisionLanguageConfig, WhisperConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -154,6 +154,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + whisper_config: Optional[WhisperConfig], speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], observability_config: Optional[ObservabilityConfig], @@ -206,6 +207,7 @@ def __init__( self.cache_config = cache_config self.lora_config = lora_config self.vision_language_config = vision_language_config + self.whisper_config = whisper_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config @@ -227,6 +229,8 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) + print(executor_class) + self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, @@ -235,6 +239,7 @@ def __init__( device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, + whisper_config=whisper_config, speculative_config=speculative_config, load_config=load_config, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e923493160ed..aca0ecac4a81b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -141,6 +141,7 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) + print(engine_args) self.llm_engine = LLMEngine.from_engine_args( engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() @@ -575,4 +576,4 @@ def _run_engine( # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - return sorted(outputs, key=lambda x: int(x.request_id)) + return sorted(outputs, key=lambda x: int(x.request_id)) \ No newline at end of file diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 7c2520b5a64f5..b0ce00448171b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -3,7 +3,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, WhisperConfig) from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -26,6 +26,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + whisper_config: Optional[WhisperConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config @@ -36,6 +37,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.whisper_config = whisper_config self.speculative_config = speculative_config self._init_executor() diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 5afb2e1d44d39..55535ddd94291 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -61,6 +61,8 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "WhisperForConditionalGeneration": + ("whisper", "WhisperForConditionalGeneration"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 478ce5b32612c..60c0571c94221 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch from torch import nn @@ -22,6 +22,8 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.audio import get_dummy_audio_data from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once @@ -221,7 +223,7 @@ def __init__( ): super().__init__() self.embed_dim = config.d_model - self.self_attn, _ = WhisperAttention( + self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, is_decoder=True, @@ -289,7 +291,7 @@ def forward( hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states - return outputs, cross_attention_past_key_value + return outputs, cross_attention_past_key_value class WhisperEncoder(nn.Module): def __init__( @@ -311,7 +313,7 @@ def __init__( self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) self.layers = nn.ModuleList([WhisperEncoderLayer(config, quant_config=quant_config, cache_config=cache_config) - for layer_idx in range(config.decoder_layers)]) + for layer_idx in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) @@ -347,6 +349,7 @@ def __init__( self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + self.layers = nn.ModuleList([WhisperDecoderLayer(config, quant_config=quant_config, cache_config=cache_config) for layer_idx in range(config.decoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) @@ -387,7 +390,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): - super().__init__(config) + super().__init__() self.encoder = WhisperEncoder(config, cache_config=cache_config, quant_config=quant_config) self.decoder = WhisperDecoder(config, cache_config=cache_config, quant_config=quant_config) @@ -415,7 +418,13 @@ def forward( ) return decoder_outputs, cross_attention_past_key_values +class AudioInputs(TypedDict): + type: Literal["input_features"] + data: torch.Tensor + """Shape: (batch_size, num_channels, 3000)""" +@MULTIMODAL_REGISTRY.register_audio_input() +@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_audio_data) class WhisperForConditionalGeneration(nn.Module): def __init__( self, @@ -423,7 +432,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): - super().__init__(config) + super().__init__() self.model = WhisperModel(config, cache_config=cache_config, quant_config=quant_config) self.proj_out = RowParallelLinear( input_size = config.d_model, @@ -432,20 +441,50 @@ def __init__( quant_config=quant_config, ) + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[AudioInputs]: + input_features = kwargs.pop("input_features", None) + + return AudioInputs( + type="pixel_values", + data=input_features + ) + def forward( self, - input_features: torch.FloatTensor, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, past_key_values = None, - ): - outputs = self.model( + **kwargs: object, + ) -> SamplerOutput: + + input_features = self._parse_and_validate_audio_input(**kwargs) + + decoder_outputs, cross_attention_past_key_values = self.model( input_features=input_features, input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, past_key_values=past_key_values, - ) \ No newline at end of file + ) + return decoder_outputs + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.proj_out.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + pass \ No newline at end of file diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py new file mode 100644 index 0000000000000..3bed5ae6b3a50 --- /dev/null +++ b/vllm/multimodal/audio.py @@ -0,0 +1,96 @@ +from typing import Dict, Tuple, Type, Union + +import torch +import numpy as np +from vllm.config import ModelConfig, WhisperConfig +from vllm.logger import init_logger +from vllm.sequence import SequenceData +from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor + +from .base import MultiModalData, MultiModalPlugin + +logger = init_logger(__name__) + + +def _get_dummy_seq_data(seq_len: int, + whisper_config: WhisperConfig) -> SequenceData: + token_ids = [0, 0, 0] + + return SequenceData(token_ids) + + +def _get_dummy_values(whisper_config: WhisperConfig) -> torch.Tensor: + values_dtype = torch.float16 + + return torch.zeros((30 * whisper_config), dtype=values_dtype) + + +def get_dummy_audio_data( + seq_len: int, + model_config: ModelConfig, + whisper_config: WhisperConfig, +) -> Tuple[SequenceData, MultiModalData]: + """Standard dummy data factory for image data (to be used in + :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`).""" + seq_data = _get_dummy_seq_data(seq_len, whisper_config) + values = _get_dummy_values(whisper_config) + + + fake_mm_data: MultiModalData + if config_input_type == ImageInputType.PIXEL_VALUES: + fake_mm_data = ImagePixelData(values) + elif config_input_type == ImageInputType.IMAGE_FEATURES: + fake_mm_data = ImageFeatureData(values) + else: + raise NotImplementedError + + return seq_data, fake_mm_data + + +class AudioData(MultiModalData): + """ + The pixel data of an image. Can be one of: + + - :class:``PIL.Image``: An image object. Requires that a HuggingFace + processor is available to the model. + - :class:``torch.Tensor``: The raw pixel data which is passed to the model + without additional pre-processing. + """ + + def __init__(self, audio: Union[np.array, torch.Tensor]) -> None: + + self.audio = audio + + def __repr__(self) -> str: + return str(self.audio) + + +class AudioPlugin(MultiModalPlugin[AudioData]): + + def get_data_type(self) -> Type[AudioData]: + return AudioData + + def _get_hf_whisper_processor(self, model_config: ModelConfig, + whisper_config: WhisperConfig): + if whisper_config is None or whisper_config.whisper_processor is None: + return None + + return cached_get_whisper_processor( + whisper_config.whisper_processor, + revision=whisper_config.whisper_processor_revision, + ) + + def _default_input_processor( + self, data: AudioData, model_config: ModelConfig, + whisper_config: WhisperConfig) -> Dict[str, torch.Tensor]: + audio = data.audio + + processor = self._get_hf_whisper_processor(model_config, whisper_config) + if processor is None: + raise RuntimeError("No HuggingFace processor is available" + "to process the audio object") + try: + return processor(image, return_tensors="pt").to(model_config.dtype) + except Exception: + logger.error("Failed to process audio (%s)", image) + raise diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 4789ce5ce4cfe..92cd484cea366 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,12 +2,13 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar) -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig, VisionLanguageConfig, WhisperConfig from vllm.logger import init_logger from .base import MultiModalData, MultiModalPlugin from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, ImagePixelPlugin) +from .audio import AudioData, AudioPlugin if TYPE_CHECKING: import torch @@ -20,9 +21,9 @@ D = TypeVar("D", bound=MultiModalData) N = TypeVar("N", bound=Type["nn.Module"]) -MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], +MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig, WhisperConfig], Dict[str, "torch.Tensor"]] -MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig], +MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig, WhisperConfig], Tuple["SequenceData", MultiModalData]] @@ -32,7 +33,7 @@ class MultiModalRegistry: according to its modality and the target model. """ - DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) + DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin(), AudioPlugin()) def __init__(self, *, @@ -119,6 +120,17 @@ def register_image_pixel_input( """ return self.register_input(ImagePixelData, processor) + def register_audio_input( + self, + processor: Optional[ + MultiModalInputProcessor[AudioData]] = None): + """ + Register an input processor for image pixel data to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self.register_input(AudioData, processor) + def register_image_feature_input( self, processor: Optional[ @@ -129,6 +141,7 @@ def register_image_feature_input( See :meth:`MultiModalPlugin.register_input_processor` for more details. """ return self.register_input(ImageFeatureData, processor) + def process_input(self, data: MultiModalData, model_config: ModelConfig, vlm_config: VisionLanguageConfig): diff --git a/vllm/transformers_utils/whisper_processor.py b/vllm/transformers_utils/whisper_processor.py index 9bec4213d0240..3b0f20ca101dd 100644 --- a/vllm/transformers_utils/whisper_processor.py +++ b/vllm/transformers_utils/whisper_processor.py @@ -14,8 +14,8 @@ def get_whisper_processor( trust_remote_code: bool = False, revision: Optional[str] = None, **kwargs, -) -> BaseImageProcessor: - """Gets an image processor for the given model name via HuggingFace.""" +) -> WhisperProcessor: + """Gets an whisper processor for the given model name via HuggingFace.""" try: processor: WhisperProcessor = WhisperProcessor.from_pretrained( processor_name, @@ -41,4 +41,4 @@ def get_whisper_processor( return processor -get_whisper_processor = lru_cache(get_whisper_processor) +cached_get_whisper_processor = lru_cache(get_whisper_processor) From d46c5c823dcd22a3bfc4339e9bc8131fbf6c0d01 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Fri, 28 Jun 2024 16:50:58 +0800 Subject: [PATCH 05/16] able to forward --- vllm/executor/gpu_executor.py | 1 + vllm/model_executor/models/whisper.py | 51 +++++++++----------- vllm/multimodal/audio.py | 24 ++------- vllm/transformers_utils/whisper_processor.py | 1 + vllm/worker/model_runner.py | 26 +++++++--- vllm/worker/worker.py | 4 +- 6 files changed, 53 insertions(+), 54 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 0a654200ed796..3f13f06e36cc7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -44,6 +44,7 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + whisper_config=self.whisper_config, speculative_config=self.speculative_config, is_driver_worker=rank == 0, ) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 60c0571c94221..16966ddb607c7 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -80,7 +80,7 @@ def __init__( bias = False, quant_config=quant_config, ) - self.k_proj = RowParallelLinear( + self.v_proj = RowParallelLinear( input_size = embed_dim, output_size = embed_dim, bias = bias, @@ -107,8 +107,8 @@ def __init__( quant_config=quant_config ) - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, @@ -119,8 +119,14 @@ def forward( attn_metadata: AttentionMetadata = None, ): is_cross_attention = encoder_hidden_states is not None - bsz, tgt_len, _ = hidden_states.size() - q, _ = self.q_proj(hidden_states) * self.scaling + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + print(sizes) + q, _ = self.q_proj(hidden_states) + q = q * self.scaling past_key_value = None @@ -139,7 +145,7 @@ def forward( past_key_value = (k, v) else: - k, _ = self.k_proj(key_value_states) + k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) k = self._shape(k, -1, bsz) v = self._shape(v, -1, bsz) @@ -151,14 +157,14 @@ def forward( dropout_p=0.0, is_causal=self.is_causal and tgt_len > 1, ) - attn_output = attn_output.reshape(bsz, q_len, -1) - output = self.out_proj(attn_output) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + output, _ = self.out_proj(attn_output) else: k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) + output, _ = self.out_proj(attn_output) return output, past_key_value class WhisperEncoderLayer(nn.Module): @@ -196,10 +202,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): - residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, _ = self.self_attn( @@ -208,8 +211,9 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.fc2(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states return hidden_states @@ -287,8 +291,9 @@ def forward( residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.fc2(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states return outputs, cross_attention_past_key_value @@ -375,7 +380,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, past_key_value=None if past_key_values is None else past_key_values[idx], kv_cache=kv_caches[idx], - output_attentions=output_attentions, attn_metadata=attn_metadata ) cross_attention_past_key_values.append(cross_attention_past_key_value) @@ -417,11 +421,6 @@ def forward( past_key_values=past_key_values ) return decoder_outputs, cross_attention_past_key_values - -class AudioInputs(TypedDict): - type: Literal["input_features"] - data: torch.Tensor - """Shape: (batch_size, num_channels, 3000)""" @MULTIMODAL_REGISTRY.register_audio_input() @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_audio_data) @@ -442,13 +441,10 @@ def __init__( ) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[AudioInputs]: + self, **kwargs: object) -> torch.Tensor: input_features = kwargs.pop("input_features", None) - return AudioInputs( - type="pixel_values", - data=input_features - ) + return input_features def forward( self, @@ -461,6 +457,7 @@ def forward( ) -> SamplerOutput: input_features = self._parse_and_validate_audio_input(**kwargs) + print(input_features, input_ids.shape, kv_caches) decoder_outputs, cross_attention_past_key_values = self.model( input_features=input_features, diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 3bed5ae6b3a50..36ba056ae8cd1 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -22,7 +22,7 @@ def _get_dummy_seq_data(seq_len: int, def _get_dummy_values(whisper_config: WhisperConfig) -> torch.Tensor: values_dtype = torch.float16 - return torch.zeros((30 * whisper_config), dtype=values_dtype) + return torch.zeros((30 * whisper_config.sample_rate), dtype=values_dtype) def get_dummy_audio_data( @@ -35,27 +35,11 @@ def get_dummy_audio_data( seq_data = _get_dummy_seq_data(seq_len, whisper_config) values = _get_dummy_values(whisper_config) - - fake_mm_data: MultiModalData - if config_input_type == ImageInputType.PIXEL_VALUES: - fake_mm_data = ImagePixelData(values) - elif config_input_type == ImageInputType.IMAGE_FEATURES: - fake_mm_data = ImageFeatureData(values) - else: - raise NotImplementedError - + fake_mm_data = AudioData(values) return seq_data, fake_mm_data class AudioData(MultiModalData): - """ - The pixel data of an image. Can be one of: - - - :class:``PIL.Image``: An image object. Requires that a HuggingFace - processor is available to the model. - - :class:``torch.Tensor``: The raw pixel data which is passed to the model - without additional pre-processing. - """ def __init__(self, audio: Union[np.array, torch.Tensor]) -> None: @@ -90,7 +74,7 @@ def _default_input_processor( raise RuntimeError("No HuggingFace processor is available" "to process the audio object") try: - return processor(image, return_tensors="pt").to(model_config.dtype) + return processor(audio, return_tensors="pt").to(model_config.dtype) except Exception: - logger.error("Failed to process audio (%s)", image) + logger.error("Failed to process audio (%s)", audio) raise diff --git a/vllm/transformers_utils/whisper_processor.py b/vllm/transformers_utils/whisper_processor.py index 3b0f20ca101dd..ea15035da48e9 100644 --- a/vllm/transformers_utils/whisper_processor.py +++ b/vllm/transformers_utils/whisper_processor.py @@ -17,6 +17,7 @@ def get_whisper_processor( ) -> WhisperProcessor: """Gets an whisper processor for the given model name via HuggingFace.""" try: + print('processor_name', processor_name) processor: WhisperProcessor = WhisperProcessor.from_pretrained( processor_name, *args, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a321eafce1a2f..6085df494d1db 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,7 +11,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + VisionLanguageConfig, WhisperConfig) from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger @@ -86,6 +86,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + whisper_config: Optional[WhisperConfig] = None, return_hidden_states: bool = False, ): self.model_config = model_config @@ -97,6 +98,7 @@ def __init__( self.load_config = load_config self.is_driver_worker = is_driver_worker self.vision_language_config = vision_language_config + self.whisper_config = whisper_config self.return_hidden_states = return_hidden_states self.device = self.device_config.device @@ -137,6 +139,12 @@ def __init__( self.model_config, self.vision_language_config, ) + elif self.whisper_config is not None: + self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ + .create_input_processor( + self.model_config, + self.whisper_config, + ) else: self.multi_modal_input_processor = None @@ -822,6 +830,7 @@ def profile_run(self) -> None: # of images processed. model_config = self.model_config vlm_config = self.vision_language_config + whisper_config = self.whisper_config if vlm_config: max_num_seqs = min( @@ -830,13 +839,18 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - - if vlm_config is None: - seq_data = SequenceData([0] * seq_len) - dummy_multi_modal_data = None - else: + + if vlm_config is not None: seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ .dummy_data_for_profiling(seq_len, model_config, vlm_config) + elif whisper_config is not None: + seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ + .dummy_data_for_profiling(seq_len, model_config, whisper_config) + else: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + + print(group_id, seq_len, dummy_multi_modal_data) seq = SequenceGroupMetadata( request_id=str(group_id), diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c60764ef1bed8..840201cf4942b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig, WhisperConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, @@ -44,6 +44,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + whisper_config: Optional[WhisperConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, ) -> None: @@ -91,6 +92,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + whisper_config=whisper_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by From 60517822fed8cd250c1590dd490eb9f80fbcc1bc Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Fri, 28 Jun 2024 22:42:30 +0800 Subject: [PATCH 06/16] added example --- examples/whisper_example.py | 31 +++++++ vllm/engine/llm_engine.py | 2 - vllm/entrypoints/llm.py | 1 - vllm/model_executor/models/whisper.py | 95 ++++++++++++-------- vllm/transformers_utils/whisper_processor.py | 1 - vllm/worker/model_runner.py | 7 +- 6 files changed, 96 insertions(+), 41 deletions(-) create mode 100644 examples/whisper_example.py diff --git a/examples/whisper_example.py b/examples/whisper_example.py new file mode 100644 index 0000000000000..479bf2c642593 --- /dev/null +++ b/examples/whisper_example.py @@ -0,0 +1,31 @@ +import vllm +import torch +import requests +from vllm import LLM +from vllm.multimodal.audio import AudioData +from datasets import Audio + + +def main(): + sr = 16000 + audio = Audio(sampling_rate=sr) + llm = LLM( + model="openai/whisper-large-v3", + max_num_seqs = 1, + max_seq_len_to_capture = 448, + max_model_len = 448, + gpu_memory_utilization = 0.4 + ) + + r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav') + y = audio.decode_example(audio.encode_example(r.content))['array'] + prompt = '<|startoftranscript|><|en|><|transcribe|>' + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": AudioData(y), + }) + print(outputs[0].outputs[0].text) + + +if __name__ == "__main__": + main() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c96afd62d695..ec67c91c5c04c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -229,8 +229,6 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) - print(executor_class) - self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index aca0ecac4a81b..70651cad73e59 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -141,7 +141,6 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) - print(engine_args) self.llm_engine = LLMEngine.from_engine_args( engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 16966ddb607c7..699b261aeedfb 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,5 +1,6 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +import math import torch from torch import nn from transformers import WhisperConfig @@ -26,6 +27,8 @@ from vllm.multimodal.audio import get_dummy_audio_data from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once +from xformers import ops as xops +from vllm.utils import is_hip def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: """Returns sinusoids for positional embedding""" @@ -98,17 +101,20 @@ def __init__( bias = bias, quant_config=quant_config ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config - ) + if self.is_causal: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config + ) + else: + self.attn = None def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return tensor.view(1, seq_len, self.num_heads, self.head_dim).contiguous() def forward( self, @@ -117,54 +123,56 @@ def forward( past_key_value = None, kv_cache: torch.Tensor = None, attn_metadata: AttentionMetadata = None, + is_cross_attention = False, ): - is_cross_attention = encoder_hidden_states is not None sizes = hidden_states.size() if len(sizes) == 3: bsz, tgt_len, _ = sizes else: tgt_len, _ = sizes - print(sizes) q, _ = self.q_proj(hidden_states) q = q * self.scaling past_key_value = None - - if kv_cache is None: - q = self._shape(q, tgt_len, bsz) - if is_cross_attention: + if is_cross_attention or not self.is_decoder: + if is_cross_attention and encoder_hidden_states is not None: if past_key_value is not None: k = past_key_value[0] v = past_key_value[1] else: k, _ = self.k_proj(encoder_hidden_states) v, _ = self.v_proj(encoder_hidden_states) - k = self._shape(k, -1, bsz) - v = self._shape(v, -1, bsz) past_key_value = (k, v) else: k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - k = self._shape(k, -1, bsz) - v = self._shape(v, -1, bsz) - attn_output = torch.nn.functional.scaled_dot_product_attention( + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(k, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( q, k, v, - dropout_p=0.0, - is_causal=self.is_causal and tgt_len > 1, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, ) - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - output, _ = self.out_proj(attn_output) + + attn_output = attn_output.reshape(-1, self.embed_dim) + else: k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.out_proj(attn_output) + output, _ = self.out_proj(attn_output) + return output, past_key_value class WhisperEncoderLayer(nn.Module): @@ -274,7 +282,7 @@ def forward( residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata @@ -285,7 +293,8 @@ def forward( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, - attn_metadata=attn_metadata + attn_metadata=attn_metadata, + is_cross_attention=True, ) hidden_states = residual + hidden_states @@ -296,7 +305,7 @@ def forward( hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states - return outputs, cross_attention_past_key_value + return hidden_states, cross_attention_past_key_value class WhisperEncoder(nn.Module): def __init__( @@ -321,6 +330,10 @@ def __init__( for layer_idx in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) + + with torch.no_grad(): + embed_positions = self.embed_positions.weight + embed_positions.copy_(sinusoids(*embed_positions.shape)) def forward( self, @@ -328,7 +341,8 @@ def forward( ): inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) - inputs_embeds = inputs_embeds.permute(0, 2, 1) + inputs_embeds = inputs_embeds.permute(1, 0) + embed_pos = self.embed_positions.weight hidden_states = inputs_embeds + embed_pos @@ -408,10 +422,12 @@ def forward( attn_metadata: AttentionMetadata, past_key_values = None, ): - - encoder_outputs = self.encoder( - input_features, - ) + if input_features is not None: + encoder_outputs = self.encoder( + input_features[0], + ) + else: + encoder_outputs = None decoder_outputs, cross_attention_past_key_values = self.decoder( input_ids=input_ids, positions=positions, @@ -433,12 +449,17 @@ def __init__( ): super().__init__() self.model = WhisperModel(config, cache_config=cache_config, quant_config=quant_config) + self.unpadded_vocab_size = config.vocab_size self.proj_out = RowParallelLinear( input_size = config.d_model, output_size = config.vocab_size, bias = False, quant_config=quant_config, ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() def _parse_and_validate_audio_input( self, **kwargs: object) -> torch.Tensor: @@ -457,7 +478,6 @@ def forward( ) -> SamplerOutput: input_features = self._parse_and_validate_audio_input(**kwargs) - print(input_features, input_ids.shape, kv_caches) decoder_outputs, cross_attention_past_key_values = self.model( input_features=input_features, @@ -484,4 +504,9 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - pass \ No newline at end of file + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/vllm/transformers_utils/whisper_processor.py b/vllm/transformers_utils/whisper_processor.py index ea15035da48e9..3b0f20ca101dd 100644 --- a/vllm/transformers_utils/whisper_processor.py +++ b/vllm/transformers_utils/whisper_processor.py @@ -17,7 +17,6 @@ def get_whisper_processor( ) -> WhisperProcessor: """Gets an whisper processor for the given model name via HuggingFace.""" try: - print('processor_name', processor_name) processor: WhisperProcessor = WhisperProcessor.from_pretrained( processor_name, *args, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6085df494d1db..8bcd45814e290 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -850,8 +850,6 @@ def profile_run(self) -> None: seq_data = SequenceData([0] * seq_len) dummy_multi_modal_data = None - print(group_id, seq_len, dummy_multi_modal_data) - seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -936,6 +934,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + _, dummy_multi_modal_data = MULTIMODAL_REGISTRY.dummy_data_for_profiling( + max_batch_size, + self.model_config, + self.whisper_config + ) # Prepare buffer for outputs. These will be reused for all batch sizes. # It will be filled after the first graph capture. From 3a8258ff70611523360ef45a515a0cc5befee3c9 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Sat, 29 Jun 2024 00:22:40 +0800 Subject: [PATCH 07/16] fix load weights --- vllm/model_executor/models/whisper.py | 24 ++++++++++++++++++------ vllm/multimodal/audio.py | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 699b261aeedfb..fb119c04e5092 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -114,7 +114,7 @@ def __init__( self.attn = None def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(1, seq_len, self.num_heads, self.head_dim).contiguous() + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() def forward( self, @@ -131,7 +131,6 @@ def forward( else: tgt_len, _ = sizes q, _ = self.q_proj(hidden_states) - q = q * self.scaling past_key_value = None @@ -151,7 +150,7 @@ def forward( q = self._shape(q, -1, 1) k = self._shape(k, -1, 1) - v = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) attn_output = xops.memory_efficient_attention_forward( q, @@ -224,6 +223,12 @@ def forward( hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + return hidden_states class WhisperDecoderLayer(nn.Module): @@ -332,8 +337,7 @@ def __init__( self.layer_norm = nn.LayerNorm(config.d_model) with torch.no_grad(): - embed_positions = self.embed_positions.weight - embed_positions.copy_(sinusoids(*embed_positions.shape)) + self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape)) def forward( self, @@ -509,4 +513,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) + + if name == 'model.decoder.embed_tokens.weight': + param = params_dict['proj_out.weight'] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + param = params_dict[name] \ No newline at end of file diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 36ba056ae8cd1..65dabac7ec3ef 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -74,7 +74,7 @@ def _default_input_processor( raise RuntimeError("No HuggingFace processor is available" "to process the audio object") try: - return processor(audio, return_tensors="pt").to(model_config.dtype) + return processor(audio, return_tensors="pt", sampling_rate = 16000).to(model_config.dtype) except Exception: logger.error("Failed to process audio (%s)", audio) raise From 8a837c075eac4f68d4feb4c1e449b6b9a5b8b669 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Sat, 29 Jun 2024 12:22:46 +0800 Subject: [PATCH 08/16] fix whisper --- vllm/model_executor/models/whisper.py | 14 ++++++++------ vllm/multimodal/audio.py | 5 +++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index fb119c04e5092..274ec8c05f0ad 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -45,7 +45,7 @@ class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) - def forward(self, input_ids, position_ids=None): + def forward(self, position_ids): return self.weight[position_ids] class WhisperAttention(nn.Module): @@ -286,7 +286,6 @@ def forward( residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, _ = self.self_attn( hidden_states=hidden_states, kv_cache=kv_cache, @@ -294,11 +293,14 @@ def forward( ) hidden_states = residual + hidden_states - hidden_states, cross_attention_past_key_value = self.self_attn( + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attention_past_key_value = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, - attn_metadata=attn_metadata, + kv_cache=None, + attn_metadata=None, is_cross_attention=True, ) hidden_states = residual + hidden_states @@ -387,7 +389,7 @@ def forward( past_key_values = None, ): inputs_embeds = self.embed_tokens(input_ids) - positions = self.embed_positions(input_ids, positions) + positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions cross_attention_past_key_values = [] @@ -482,7 +484,7 @@ def forward( ) -> SamplerOutput: input_features = self._parse_and_validate_audio_input(**kwargs) - + decoder_outputs, cross_attention_past_key_values = self.model( input_features=input_features, input_ids=input_ids, diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 65dabac7ec3ef..9700a4d88622b 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -14,8 +14,9 @@ def _get_dummy_seq_data(seq_len: int, whisper_config: WhisperConfig) -> SequenceData: - token_ids = [0, 0, 0] - + + # '<|startoftranscript|><|en|><|transcribe|>' + token_ids = [50258, 50259, 50360] return SequenceData(token_ids) From 2d62f4559fa9ec0f3311416fffeb5c8bfa6f95f6 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Sat, 29 Jun 2024 12:50:34 +0800 Subject: [PATCH 09/16] added predict lang in whisper example --- examples/whisper_example.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/whisper_example.py b/examples/whisper_example.py index 479bf2c642593..1079e9ab6d410 100644 --- a/examples/whisper_example.py +++ b/examples/whisper_example.py @@ -19,11 +19,17 @@ def main(): r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav') y = audio.decode_example(audio.encode_example(r.content))['array'] - prompt = '<|startoftranscript|><|en|><|transcribe|>' + + output_lang = llm.generate({ + "prompt_token_ids": [50258], + "multi_modal_data": AudioData(y), + }, sampling_params = SamplingParams(max_tokens = 1, temperature = 0)) + outputs = llm.generate({ - "prompt": prompt, + "prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360], "multi_modal_data": AudioData(y), - }) + }, sampling_params = SamplingParams(max_tokens = 10, temperature = 0)) + print(outputs[0].outputs[0].text) From 1eb13dd3f60cb3f02b1305ac9c53ce6e73373552 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Sun, 30 Jun 2024 15:34:11 +0800 Subject: [PATCH 10/16] added cudagraph, whisper no longer as multimodal --- examples/whisper_example.py | 2 +- vllm/config.py | 13 + vllm/engine/arg_utils.py | 3 +- vllm/engine/llm_engine.py | 9 +- vllm/inputs.py | 9 +- vllm/model_executor/models/__init__.py | 11 +- vllm/model_executor/models/whisper.py | 260 +++++++----- vllm/multimodal/__init__.py | 3 +- vllm/multimodal/audio.py | 5 +- vllm/sequence.py | 10 +- vllm/worker/model_runner.py | 23 +- vllm/worker/whisper_model_runner.py | 528 +++++++++++++++++++++++++ vllm/worker/worker.py | 13 +- 13 files changed, 758 insertions(+), 131 deletions(-) create mode 100644 vllm/worker/whisper_model_runner.py diff --git a/examples/whisper_example.py b/examples/whisper_example.py index 1079e9ab6d410..7bf3e9cc6ea37 100644 --- a/examples/whisper_example.py +++ b/examples/whisper_example.py @@ -17,7 +17,7 @@ def main(): gpu_memory_utilization = 0.4 ) - r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav') + r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/7021-79759-0004.wav') y = audio.decode_example(audio.encode_example(r.content))['array'] output_lang = llm.generate({ diff --git a/vllm/config.py b/vllm/config.py index 7347a1f2f1522..3623adfe27b0e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -24,6 +24,7 @@ _GB = 1 << 30 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_WHISPER_MAX_NUM_BATCHED_TOKENS = 448 class ModelConfig: @@ -149,6 +150,7 @@ def __init__( if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_embedding_mode() + self._verify_whisper_mode() self._verify_quantization() self._verify_cuda_graph() @@ -165,6 +167,11 @@ def _verify_embedding_mode(self) -> None: self.embedding_mode = any( ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _verify_whisper_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.whisper_mode = any( + ModelRegistry.is_whisper_model(arch) for arch in architectures) + def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: @@ -682,6 +689,7 @@ class SchedulerConfig: enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. embedding_mode: Whether the running model is for embedding. + whisper_mode: Whether the running model is for whisper. preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than @@ -699,6 +707,7 @@ def __init__(self, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, + whisper_mode: Optional[bool] = False, preemption_mode: Optional[str] = None) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -711,6 +720,9 @@ def __init__(self, # For embedding, choose specific value for higher throughput self.max_num_batched_tokens = max( max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + elif whisper_mode: + self.max_num_batched_tokens = max( + max_model_len, _WHISPER_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. @@ -725,6 +737,7 @@ def __init__(self, self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode + self.whisper_mode = whisper_mode self.preemption_mode = preemption_mode self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a0328a88016ac..0d71b63b45753 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -89,7 +89,7 @@ class EngineArgs: disable_image_processor: bool = False # Related to Whisper - whisper_input_type: Optional[str] = 'input_features' + whisper_input_type: Optional[str] = None whisper_processor: Optional[str] = None whisper_processor_revision: Optional[str] = None sample_rate: Optional[int] = 16000 @@ -747,6 +747,7 @@ def create_engine_config(self, ) -> EngineConfig: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + whisper_mode=model_config.whisper_mode, preemption_mode=self.preemption_mode, ) lora_config = LoRAConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ec67c91c5c04c..9dfe5b088da93 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -228,7 +228,7 @@ def __init__( self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) - + self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, @@ -504,6 +504,10 @@ def process_model_inputs( if isinstance(inputs, str): inputs = {"prompt": inputs} + if 'whisper_input' in inputs: + if self.whisper_config is None: + raise ValueError(f"Whisper config is None, must initialize a Whisper model.") + if "prompt_token_ids" not in inputs: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") @@ -516,7 +520,8 @@ def process_model_inputs( return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + multi_modal_data=inputs.get("multi_modal_data"), + whisper_data=inputs.get('whisper_data')) def add_request( self, diff --git a/vllm/inputs.py b/vllm/inputs.py index 026903e19a26e..12a60b6aa9af4 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -4,7 +4,7 @@ from typing_extensions import NotRequired if TYPE_CHECKING: - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalData, WhisperData class ParsedText(TypedDict): @@ -73,6 +73,8 @@ class TextPrompt(TypedDict): """The input text to be tokenized before passing to the model.""" multi_modal_data: NotRequired["MultiModalData"] + + whisper_data: NotRequired["WhisperData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -86,6 +88,8 @@ class TokensPrompt(TypedDict): """A list of token IDs to pass to the model.""" multi_modal_data: NotRequired["MultiModalData"] + + whisper_data: NotRequired["WhisperData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -105,6 +109,8 @@ class TextTokensPrompt(TypedDict): tokenizer to convert the prompts to token IDs.""" multi_modal_data: NotRequired["MultiModalData"] + + whisper_data: NotRequired["WhisperData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -128,3 +134,4 @@ class LLMInputs(TypedDict): prompt_token_ids: List[int] prompt: NotRequired[Optional[str]] multi_modal_data: NotRequired[Optional["MultiModalData"]] + whisper_data: NotRequired[Optional["WhisperData"]] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 55535ddd94291..b9bb666a495ac 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -61,15 +61,17 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "WhisperForConditionalGeneration": - ("whisper", "WhisperForConditionalGeneration"), } _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } -_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} +_WHISPER_MODELS = { + "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), +} + +_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS, **_WHISPER_MODELS} # Architecture -> type. # out of tree models @@ -131,6 +133,9 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): def is_embedding_model(model_arch: str) -> bool: return model_arch in _EMBEDDING_MODELS + @staticmethod + def is_whisper_model(model_arch: str) -> bool: + return model_arch in _WHISPER_MODELS __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 274ec8c05f0ad..4825071d924e0 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -23,8 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.audio import get_dummy_audio_data from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once from xformers import ops as xops @@ -53,8 +51,6 @@ def __init__( self, embed_dim: int, num_heads: int, - is_decoder: bool = False, - is_causal: bool = False, bias: bool = True, config: Optional[WhisperConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -74,8 +70,6 @@ def __init__( f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - self.is_causal = is_causal self.k_proj = RowParallelLinear( input_size = embed_dim, @@ -101,78 +95,163 @@ def __init__( bias = bias, quant_config=quant_config ) - if self.is_causal: - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config - ) - else: - self.attn = None - + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + +class WhisperEncoderAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + config=config, + quant_config=quant_config, + cache_config=cache_config + ) def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states = None, - past_key_value = None, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + + attn_output = attn_output.reshape(-1, self.embed_dim) + output, _ = self.out_proj(attn_output) + return output + +class WhisperDecoderAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + config=config, + quant_config=quant_config, + cache_config=cache_config + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config + ) + + def forward( + self, + hidden_states: torch.Tensor, kv_cache: torch.Tensor = None, attn_metadata: AttentionMetadata = None, - is_cross_attention = False, ): sizes = hidden_states.size() if len(sizes) == 3: bsz, tgt_len, _ = sizes else: tgt_len, _ = sizes + q, _ = self.q_proj(hidden_states) - - past_key_value = None - - if is_cross_attention or not self.is_decoder: - if is_cross_attention and encoder_hidden_states is not None: - if past_key_value is not None: - k = past_key_value[0] - v = past_key_value[1] - else: - k, _ = self.k_proj(encoder_hidden_states) - v, _ = self.v_proj(encoder_hidden_states) - - past_key_value = (k, v) - else: - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) - - q = self._shape(q, -1, 1) - k = self._shape(k, -1, 1) - v = self._shape(v, -1, 1) - - attn_output = xops.memory_efficient_attention_forward( - q, - k, - v, - attn_bias=None, - p=0.0, - scale=None, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - - attn_output = attn_output.reshape(-1, self.embed_dim) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + output, _ = self.out_proj(attn_output) + return output + +class WhisperDecoderCrossAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[WhisperConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + config=config, + quant_config=quant_config, + cache_config=cache_config + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes else: - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + tgt_len, _ = sizes + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + + attn_output = attn_output.reshape(-1, self.embed_dim) output, _ = self.out_proj(attn_output) - - return output, past_key_value + return output class WhisperEncoderLayer(nn.Module): def __init__( @@ -183,7 +262,7 @@ def __init__( ): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( + self.self_attn = WhisperEncoderAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, config=config, @@ -212,7 +291,7 @@ def forward( ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, ) hidden_states = residual + hidden_states @@ -240,11 +319,9 @@ def __init__( ): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( + self.self_attn = WhisperDecoderAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - is_decoder=True, - is_causal=True, config=config, quant_config=quant_config, cache_config=cache_config, @@ -252,10 +329,9 @@ def __init__( self.activation_fn = FastGELU() self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WhisperAttention( + self.encoder_attn = WhisperDecoderCrossAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - is_decoder=True, config=config, quant_config=quant_config, cache_config=cache_config, @@ -279,14 +355,13 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - past_key_value: Tuple[torch.Tensor, torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata @@ -295,13 +370,10 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attention_past_key_value = self.encoder_attn( + hidden_states = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, - kv_cache=None, - attn_metadata=None, - is_cross_attention=True, + attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -312,7 +384,7 @@ def forward( hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states - return hidden_states, cross_attention_past_key_value + return hidden_states class WhisperEncoder(nn.Module): def __init__( @@ -392,20 +464,16 @@ def forward( positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions - cross_attention_past_key_values = [] - for idx, decoder_layer in enumerate(self.layers): - hidden_states, cross_attention_past_key_value = decoder_layer( + hidden_states = decoder_layer( hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=None if past_key_values is None else past_key_values[idx], kv_cache=kv_caches[idx], attn_metadata=attn_metadata ) - cross_attention_past_key_values.append(cross_attention_past_key_value) hidden_states = self.layer_norm(hidden_states) - return hidden_states, cross_attention_past_key_values + return hidden_states class WhisperModel(nn.Module): def __init__( @@ -426,26 +494,22 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - past_key_values = None, ): - if input_features is not None: - encoder_outputs = self.encoder( - input_features[0], - ) + if hasattr(attn_metadata, 'encoder_outputs'): + encoder_outputs = attn_metadata.encoder_outputs else: - encoder_outputs = None - decoder_outputs, cross_attention_past_key_values = self.decoder( + encoder_outputs = self.encoder(input_features) + attn_metadata.encoder_outputs = encoder_outputs + + decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, encoder_hidden_states=encoder_outputs, kv_caches=kv_caches, attn_metadata=attn_metadata, - past_key_values=past_key_values ) - return decoder_outputs, cross_attention_past_key_values + return decoder_outputs -@MULTIMODAL_REGISTRY.register_audio_input() -@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_audio_data) class WhisperForConditionalGeneration(nn.Module): def __init__( self, @@ -467,31 +531,21 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() - def _parse_and_validate_audio_input( - self, **kwargs: object) -> torch.Tensor: - input_features = kwargs.pop("input_features", None) - - return input_features - def forward( self, + input_features: torch.Tensor, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - past_key_values = None, - **kwargs: object, ) -> SamplerOutput: - - input_features = self._parse_and_validate_audio_input(**kwargs) - decoder_outputs, cross_attention_past_key_values = self.model( + decoder_outputs = self.model( input_features=input_features, input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, - past_key_values=past_key_values, ) return decoder_outputs diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 270012e7d1c3b..df3b5d51e8ff9 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,7 +1,8 @@ from .base import MultiModalData, MultiModalPlugin +from .audio import WhisperData from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry __all__ = [ "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", - "MultiModalRegistry" + "MultiModalRegistry", "WhisperData" ] diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 9700a4d88622b..bc3670caaf7f6 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -11,10 +11,13 @@ logger = init_logger(__name__) +class WhisperData: + pass + def _get_dummy_seq_data(seq_len: int, whisper_config: WhisperConfig) -> SequenceData: - + # '<|startoftranscript|><|en|><|transcribe|>' token_ids = [50258, 50259, 50360] return SequenceData(token_ids) diff --git a/vllm/sequence.py b/vllm/sequence.py index 287e1b9df6165..70a9286f8091c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,7 +14,7 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalData, WhisperData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -463,6 +463,12 @@ def multi_modal_data(self) -> Optional["MultiModalData"]: # We use the multi-modal data of an arbitrary sequence. return next(iter(self.seqs_dict.values())).multi_modal_data + @property + def whisper_data(self) -> Optional["WhisperData"]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).whisper_data + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -641,6 +647,7 @@ def __init__( computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional["MultiModalData"] = None, + whisper_data: Optional["WhisperData"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: @@ -653,6 +660,7 @@ def __init__( self.lora_request = lora_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data + self.whisper_data = whisper_data self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data self.cross_block_table = cross_block_table diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8bcd45814e290..3fa761ec8d37c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -26,6 +26,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) +from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor logger = init_logger(__name__) @@ -139,14 +140,15 @@ def __init__( self.model_config, self.vision_language_config, ) - elif self.whisper_config is not None: - self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ - .create_input_processor( - self.model_config, - self.whisper_config, - ) else: self.multi_modal_input_processor = None + + if self.whisper_config is not None: + self.whisper_processor = cached_get_whisper_processor( + self.whisper_config.whisper_processor + ) + else: + self.whisper_processor = None # Lazy initialization self.model: nn.Module # Set after load_model @@ -830,7 +832,6 @@ def profile_run(self) -> None: # of images processed. model_config = self.model_config vlm_config = self.vision_language_config - whisper_config = self.whisper_config if vlm_config: max_num_seqs = min( @@ -843,9 +844,6 @@ def profile_run(self) -> None: if vlm_config is not None: seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ .dummy_data_for_profiling(seq_len, model_config, vlm_config) - elif whisper_config is not None: - seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ - .dummy_data_for_profiling(seq_len, model_config, whisper_config) else: seq_data = SequenceData([0] * seq_len) dummy_multi_modal_data = None @@ -934,11 +932,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() - _, dummy_multi_modal_data = MULTIMODAL_REGISTRY.dummy_data_for_profiling( - max_batch_size, - self.model_config, - self.whisper_config - ) # Prepare buffer for outputs. These will be reused for all batch sizes. # It will be filled after the first graph capture. diff --git a/vllm/worker/whisper_model_runner.py b/vllm/worker/whisper_model_runner.py new file mode 100644 index 0000000000000..7e4091679be67 --- /dev/null +++ b/vllm/worker/whisper_model_runner.py @@ -0,0 +1,528 @@ +import gc +import time +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig, WhisperConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.parallel_state import graph_capture +from vllm.model_executor import SamplingMetadata +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import ModelRunner + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 +LORA_WARMUP_RANK = 8 +_BATCH_SIZE_ALIGNMENT = 8 +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] +_NUM_WARMUP_ITERS = 2 + +class WhisperModelRunner(ModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + whisper_config: Optional[WhisperConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config, + whisper_config=whisper_config) + + self._check_whisper_unsupported_scenarios() + + def _check_whisper_unsupported_scenarios(self): + if self.scheduler_config.chunked_prefill_enabled: + # Fail if chunked prefill is enabled + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) + + if self.cache_config.enable_prefix_caching: + # Fail if prefix caching is enabled + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + + if self.sliding_window is not None: + # Fail if sliding window is enabled + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (whisper_input, input_tokens, input_positions, attn_metadata, sampling_metadata, + lora_requests, lora_mapping, multi_modal_kwargs + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + hidden_states = model_executable( + input_features=whisper_input, + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **multi_modal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + return output + + def _prepare_encoder_model_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + attn_metadata: AttentionMetadata): + + whisper_data_list = [] + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + for seq_id in seq_ids: + + whisper_data = seq_group_metadata.whisper_data + if whisper_data is not None: + # Process multi-modal data + if self.whisper_processor is None: + raise ValueError("Whisper Processor not initialized") + + whisper_data = self.whisper_processor(whisper_data, return_tensors = 'pt') + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + whisper_data_list.append(whisper_data) + + whisper_input = torch.cat(whisper_data_list, dim = 1).cuda() + + return whisper_input + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + lora_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + whisper_input = self._prepare_encoder_model_input(seq_group_metadata_list, attn_metadata) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) + + metadata_dict = { + 'whisper_input': whisper_input, + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + } + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + whisper_input = metadata_dict.pop("whisper_input") + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + attn_metadata = None + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + + return (whisper_input, input_tokens, input_positions, attn_metadata, + sampling_metadata, lora_requests, lora_mapping, + multi_modal_kwargs) + + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + model_config = self.model_config + vlm_config = self.vision_language_config + whisper_config = self.whisper_config + + if vlm_config: + max_num_seqs = min( + max_num_seqs, + int(max_num_batched_tokens / vlm_config.image_feature_size)) + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + + whisper_data = None + if vlm_config is not None: + seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ + .dummy_data_for_profiling(seq_len, model_config, vlm_config) + else: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + + if whisper_config is not None: + whisper_data = torch.zeros( + (30 * whisper_config.sample_rate), + dtype=torch.float16) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_multi_modal_data, + whisper_data=whisper_data + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() + return + + @torch.inference_mode() + def capture_model(self, kv_caches: List[torch.Tensor]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + whisper_data = torch.zeros( + (30 * self.whisper_config.sample_rate), + dtype=torch.float16) + whisper_input = self.whisper_processor(whisper_data, return_tensors = 'pt') + whisper_input = whisper_input.to(self.model_config.dtype).input_features[0].cuda() + + # Prepare buffer for outputs. These will be reused for all batch sizes. + # It will be filled after the first graph capture. + hidden_states: Optional[torch.Tensor] = None + + graph_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + + with graph_capture() as graph_capture_context: + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + # Create dummy attn_metadata. + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner(self.model) + hidden_states = graph_runner.capture( + whisper_input, + input_tokens[:batch_size], + input_positions[:batch_size], + hidden_states[:batch_size] + if hidden_states is not None else None, + kv_caches, + attn_metadata, + memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + +class CUDAGraphRunner: + + def __init__(self, model: nn.Module): + self.model = model + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + + def capture( + self, + whisper_input: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, + **kwargs, + ) -> torch.Tensor: + assert self._graph is None + # Run the model a few times without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + # Note one iteration is not enough for torch.jit.script + for _ in range(_NUM_WARMUP_ITERS): + self.model( + whisper_input, + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + torch.cuda.synchronize() + + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + output_hidden_states = self.model( + whisper_input, + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + if hidden_states is not None: + hidden_states.copy_(output_hidden_states) + else: + hidden_states = output_hidden_states + del output_hidden_states + # make sure `output_hidden_states` is deleted + # in the graph's memory pool + gc.collect() + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "whisper_input": whisper_input, + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + self.output_buffers = {"hidden_states": hidden_states} + return hidden_states + + def forward( + self, + whisper_input: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["whisper_input"].copy_(whisper_input, non_blocking=True) + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, + non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["hidden_states"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False \ No newline at end of file diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 840201cf4942b..6cde97b53227e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -20,6 +20,7 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner +from vllm.worker.whisper_model_runner import WhisperModelRunner from vllm.worker.worker_base import WorkerBase @@ -70,6 +71,10 @@ def __init__( if self.vision_language_config: assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") + self.whisper_config = whisper_config + if self.whisper_config: + assert not self.lora_config, ( + "Whisper does not support with LoRA settings.") # Return hidden states from target model if the draft model is an # mlp_speculator @@ -79,8 +84,12 @@ def __init__( or (speculative_config.draft_model_config.hf_config.model_type != "mlp_speculator") else {"return_hidden_states": True} - ModelRunnerClass = (EmbeddingModelRunner if - self.model_config.embedding_mode else ModelRunner) + if self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + elif self.whisper_config is not None: + ModelRunnerClass = WhisperModelRunner + else: + ModelRunnerClass = ModelRunner self.model_runner = ModelRunnerClass( model_config, parallel_config, From 3db307f36c01d2e9c685e3fd9dc26df13dc47285 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Sun, 30 Jun 2024 20:33:15 +0800 Subject: [PATCH 11/16] able to decode properly --- examples/whisper_example.py | 8 ++-- vllm/core/scheduler.py | 1 + vllm/engine/llm_engine.py | 23 ++++++++- vllm/model_executor/models/whisper.py | 4 +- vllm/sequence.py | 4 ++ vllm/worker/whisper_model_runner.py | 67 +++++++++++++-------------- 6 files changed, 65 insertions(+), 42 deletions(-) diff --git a/examples/whisper_example.py b/examples/whisper_example.py index 7bf3e9cc6ea37..1663df72f53f6 100644 --- a/examples/whisper_example.py +++ b/examples/whisper_example.py @@ -12,9 +12,9 @@ def main(): llm = LLM( model="openai/whisper-large-v3", max_num_seqs = 1, - max_seq_len_to_capture = 448, max_model_len = 448, - gpu_memory_utilization = 0.4 + gpu_memory_utilization = 0.4, + dtype = 'bfloat16', ) r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/7021-79759-0004.wav') @@ -28,10 +28,12 @@ def main(): outputs = llm.generate({ "prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360], "multi_modal_data": AudioData(y), - }, sampling_params = SamplingParams(max_tokens = 10, temperature = 0)) + }, sampling_params = SamplingParams(min_tokens = 20, max_tokens = 20, temperature = 0)) + # ' without going to any such extreme as this we can easily see on reflection how vast an influence on the' print(outputs[0].outputs[0].text) + if __name__ == "__main__": main() diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08ae..c03fe7fffb0f6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1006,6 +1006,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + whisper_data=seq_group.whisper_data ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9dfe5b088da93..2cf6cea79ad4d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) +from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -225,6 +226,13 @@ def __init__( self.tokenizer = None self.detokenizer = None + if self.whisper_config is not None: + self.whisper_processor = cached_get_whisper_processor( + self.whisper_config.whisper_processor + ) + else: + self.whisper_processor = None + self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) @@ -504,9 +512,19 @@ def process_model_inputs( if isinstance(inputs, str): inputs = {"prompt": inputs} - if 'whisper_input' in inputs: + if 'whisper_data' in inputs: if self.whisper_config is None: raise ValueError(f"Whisper config is None, must initialize a Whisper model.") + if self.whisper_processor is None: + raise ValueError(f"Whisper Processor is not initialized.") + whisper_data = self.whisper_processor( + inputs['whisper_data'], + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + else: + whisper_data = None if "prompt_token_ids" not in inputs: tokenizer = self.get_tokenizer_group("prompts must be None if " @@ -517,11 +535,12 @@ def process_model_inputs( lora_request=lora_request) else: prompt_token_ids = inputs["prompt_token_ids"] + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data"), - whisper_data=inputs.get('whisper_data')) + whisper_data=whisper_data) def add_request( self, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 4825071d924e0..8b0d9e54d546a 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -533,7 +533,7 @@ def __init__( def forward( self, - input_features: torch.Tensor, + whisper_data: torch.Tensor, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], @@ -541,7 +541,7 @@ def forward( ) -> SamplerOutput: decoder_outputs = self.model( - input_features=input_features, + input_features=whisper_data, input_ids=input_ids, positions=positions, kv_caches=kv_caches, diff --git a/vllm/sequence.py b/vllm/sequence.py index 70a9286f8091c..2e94a98449b2e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -260,6 +260,10 @@ def prompt_token_ids(self) -> List[int]: def multi_modal_data(self) -> Optional["MultiModalData"]: return self.inputs.get("multi_modal_data") + @property + def whisper_data(self) -> Optional["WhisperData"]: + return self.inputs.get("whisper_data") + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 diff --git a/vllm/worker/whisper_model_runner.py b/vllm/worker/whisper_model_runner.py index 7e4091679be67..76533c687f470 100644 --- a/vllm/worker/whisper_model_runner.py +++ b/vllm/worker/whisper_model_runner.py @@ -57,28 +57,13 @@ def __init__( vision_language_config=vision_language_config, whisper_config=whisper_config) - self._check_whisper_unsupported_scenarios() - - def _check_whisper_unsupported_scenarios(self): - if self.scheduler_config.chunked_prefill_enabled: - # Fail if chunked prefill is enabled - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) - - if self.cache_config.enable_prefix_caching: - # Fail if prefix caching is enabled - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) - - if self.sliding_window is not None: - # Fail if sliding window is enabled - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - @torch.inference_mode() def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (whisper_input, input_tokens, input_positions, attn_metadata, sampling_metadata, + (whisper_data, input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_kwargs ) = self.prepare_input_tensors(seq_group_metadata_list) @@ -94,8 +79,10 @@ def execute_model( else: model_executable = self.model + model_executable = self.model + hidden_states = model_executable( - input_features=whisper_input, + whisper_data=whisper_data, input_ids=input_tokens, positions=input_positions, kv_caches=kv_caches, @@ -134,14 +121,20 @@ def _prepare_encoder_model_input( # Process multi-modal data if self.whisper_processor is None: raise ValueError("Whisper Processor not initialized") - - whisper_data = self.whisper_processor(whisper_data, return_tensors = 'pt') - whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + + if len(whisper_data.shape) == 1: + whisper_data = self.whisper_processor( + whisper_data, + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + whisper_data_list.append(whisper_data) - whisper_input = torch.cat(whisper_data_list, dim = 1).cuda() + whisper_data = torch.cat(whisper_data_list, dim = 1).cuda() - return whisper_input + return whisper_data def prepare_input_tensors( self, @@ -165,13 +158,13 @@ def prepare_input_tensors( num_decode_tokens, num_prefills, ) = self._prepare_model_input(seq_group_metadata_list) - whisper_input = self._prepare_encoder_model_input(seq_group_metadata_list, attn_metadata) + whisper_data = self._prepare_encoder_model_input(seq_group_metadata_list, attn_metadata) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) metadata_dict = { - 'whisper_input': whisper_input, + 'whisper_data': whisper_data, "input_tokens": input_tokens, "input_positions": input_positions, "lora_requests": lora_requests, @@ -187,7 +180,7 @@ def prepare_input_tensors( broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) - whisper_input = metadata_dict.pop("whisper_input") + whisper_data = metadata_dict.pop("whisper_data") input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") lora_mapping = metadata_dict.pop("lora_mapping") @@ -205,7 +198,7 @@ def prepare_input_tensors( num_prompts=0, ) - return (whisper_input, input_tokens, input_positions, attn_metadata, + return (whisper_data, input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_kwargs) @@ -331,8 +324,12 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: whisper_data = torch.zeros( (30 * self.whisper_config.sample_rate), dtype=torch.float16) - whisper_input = self.whisper_processor(whisper_data, return_tensors = 'pt') - whisper_input = whisper_input.to(self.model_config.dtype).input_features[0].cuda() + whisper_data = self.whisper_processor( + whisper_data, + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0].cuda() # Prepare buffer for outputs. These will be reused for all batch sizes. # It will be filled after the first graph capture. @@ -375,7 +372,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: graph_runner = CUDAGraphRunner(self.model) hidden_states = graph_runner.capture( - whisper_input, + whisper_data, input_tokens[:batch_size], input_positions[:batch_size], hidden_states[:batch_size] @@ -409,7 +406,7 @@ def graph(self): def capture( self, - whisper_input: torch.Tensor, + whisper_data: torch.Tensor, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: Optional[torch.Tensor], @@ -426,7 +423,7 @@ def capture( # Note one iteration is not enough for torch.jit.script for _ in range(_NUM_WARMUP_ITERS): self.model( - whisper_input, + whisper_data, input_ids, positions, kv_caches, @@ -439,7 +436,7 @@ def capture( self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_states = self.model( - whisper_input, + whisper_data, input_ids, positions, kv_caches, @@ -458,7 +455,7 @@ def capture( # Save the input and output buffers. self.input_buffers = { - "whisper_input": whisper_input, + "whisper_data": whisper_data, "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, @@ -471,7 +468,7 @@ def capture( def forward( self, - whisper_input: torch.Tensor, + whisper_data: torch.Tensor, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], @@ -482,7 +479,7 @@ def forward( del kv_caches # Copy the input tensors to the input buffers. - self.input_buffers["whisper_input"].copy_(whisper_input, non_blocking=True) + self.input_buffers["whisper_data"].copy_(whisper_data, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, From 00ed7ea454a991264f3cea2310b034ca7d797653 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Mon, 1 Jul 2024 19:55:20 +0800 Subject: [PATCH 12/16] initial whisper serving --- examples/whisper_example.py | 2 +- vllm/engine/llm_engine.py | 3 +- vllm/entrypoints/openai/api_server.py | 34 +- vllm/entrypoints/openai/cli_args.py | 4 + vllm/entrypoints/openai/serving_whisper.py | 461 ++++++++++++++++++ vllm/model_executor/models/whisper.py | 8 +- .../tokenizer_group/tokenizer_group.py | 6 +- vllm/worker/model_runner.py | 2 + vllm/worker/whisper_model_runner.py | 2 - 9 files changed, 509 insertions(+), 13 deletions(-) create mode 100644 vllm/entrypoints/openai/serving_whisper.py diff --git a/examples/whisper_example.py b/examples/whisper_example.py index 1663df72f53f6..21db65fda1d51 100644 --- a/examples/whisper_example.py +++ b/examples/whisper_example.py @@ -28,7 +28,7 @@ def main(): outputs = llm.generate({ "prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360], "multi_modal_data": AudioData(y), - }, sampling_params = SamplingParams(min_tokens = 20, max_tokens = 20, temperature = 0)) + }, sampling_params = SamplingParams(max_tokens = 100, temperature = 0)) # ' without going to any such extreme as this we can easily see on reflection how vast an influence on the' print(outputs[0].outputs[0].text) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2cf6cea79ad4d..24a0ef2e7ee89 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -532,7 +532,8 @@ def process_model_inputs( prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], - lora_request=lora_request) + lora_request=lora_request, + add_special_tokens=self.whisper_processor is None) else: prompt_token_ids = inputs["prompt_token_ids"] diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ea6275920c79d..f912e55b67398 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -26,6 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_whisper import OpenAIServingWhisper from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.version import __version__ as VLLM_VERSION @@ -35,6 +36,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +openai_serving_whisper: OpenAIServingWhisper logger = init_logger('vllm.entrypoints.openai.api_server') @@ -137,7 +139,35 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): else: return JSONResponse(content=generator.model_dump()) - +@app.post("/audio/transcriptions") +async def audio_transcriptions( + file: bytes = File(), + language: str = Form(None), + response_format: str = Form('text'), + timestamp_granularities: str = Form('segment'), + repetition_penalty: float = Form(1.0), + stream: bool = Form(False), + raw_request: Request = None, +): + generator = await openai_serving_chat.create_audio_transcriptions( + file=file, + language=language, + response_format=response_format, + timestamp_granularities=timestamp_granularities, + repetition_penalty=repetition_penalty, + stream=stream, + raw_request=raw_request, + ) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + assert isinstance(generator, ChatCompletionResponse) + return JSONResponse(content=generator.model_dump()) + if __name__ == "__main__": args = parse_args() @@ -219,6 +249,8 @@ async def authentication(request: Request, call_next): engine, model_config, served_model_names, args.lora_modules) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) + openai_serving_whisper = OpenAIServingWhisper( + engine, model_config, served_model_names, args.max_size_whisper) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 59ad73bf097c8..30d16a3d7f755 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -71,6 +71,10 @@ def make_arg_parser(): help="The file path to the chat template, " "or the template in single-line form " "for the specified model") + parser.add_argument("--max-size-whisper", + type=int, + default=100, + help="max size of audio to transcribe using Whisper in term of MB.") parser.add_argument("--response-role", type=nullable_str, default="assistant", diff --git a/vllm/entrypoints/openai/serving_whisper.py b/vllm/entrypoints/openai/serving_whisper.py new file mode 100644 index 0000000000000..c8d3e3514d89d --- /dev/null +++ b/vllm/entrypoints/openai/serving_whisper.py @@ -0,0 +1,461 @@ +import codecs +import time +import torchaudio +import numpy as np +from dataclasses import dataclass, field +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, + List, Optional) +from typing import Sequence as GenericSequence +from typing import TypedDict, Union, cast, final + +from fastapi import Request +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + ChatCompletionContentPartParam, ChatCompletionLogProb, + ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + FunctionCall, ToolCall, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.multimodal.image import ImagePixelData +from vllm.multimodal.utils import (async_get_and_parse_image, + get_full_image_text_prompt) +from vllm.outputs import RequestOutput +from vllm.sequence import Logprob +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +logger = init_logger(__name__) + +buffer_size = 4096 +sample_rate = 16000 +segment_length = sample_rate * 1 +maxlen = 30 +replaces = ['<|startoftranscript|>', '<|endoftext|>', '<|transcribe|>'] +pattern = r'<\|\-?\d+\.?\d*\|>' +pattern_pair = r'<\|(\d+\.\d+)\|>(.*?)<\|(\d+\.\d+)\|>' + + +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + image_futures: List[Awaitable[ImagePixelData]] = field( + default_factory=list) + + +class OpenAIServingWhisper(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + max_size_whisper: 100, + ): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + + self._check_whisper_mode(model_config.whisper_mode) + self.max_size_whisper = max_size_whisper + + def _check_whisper_mode(self, whisper_mode: bool): + if not whisper_mode: + logger.warning( + "whisper_mode is False. Whisper API will not work.") + else: + logger.info("Activating the server engine with whisper enabled.") + + async def create_audio_transcriptions( + self, + file, + language, + response_format, + timestamp_granularities, + repetition_penalty, + stream, + raw_request: Optional[Request] = None + ) -> Union[ErrorResponse, AsyncGenerator[str, None], + ChatCompletionResponse]: + + if len(file) > self.max_size_whisper: + return self.create_error_response(f"maximum size for file is {self.max_size_whisper}MB only") + + request_id = f"cmpl-{random_uuid()}" + + sampling_params = SamplingParams( + max_tokens = self.engine.model_config.max_model_len - 4, + temperature = 0.0, + skip_special_tokens = False, + stop_token_ids = [50257], + repetition_penalty = repetition_penalty + ) + + if isinstance(language, str) and language.lower() == 'null': + language = None + + streamer = StreamReader( + src=file, + format=None, + option=None, + buffer_size=buffer_size + ) + streamer.add_basic_audio_stream( + frames_per_chunk=segment_length, + sample_rate=sample_rate + ) + stream_iterator = streamer.stream() + + is_tracing_enabled = await self.engine.is_tracing_enabled() + trace_headers = None + if is_tracing_enabled and raw_request: + trace_headers = extract_trace_headers(raw_request.headers) + if not is_tracing_enabled and raw_request and contains_trace_headers( + raw_request.headers): + log_tracing_disabled_warning() + + # Streaming response + if request.stream: + return self.audio_transcription_stream_generator( + request, sampling_params, stream_iterator, language, request_id, trace_headers) + else: + try: + return await self.audio_transcription_full_generator( + request, sampling_params, stream_iterator, language, request_id, trace_headers) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def generate( + self, + sampling_params, + language, + wav_data, + last_timestamp, + request_id, + trace_headers, + ): + prompt_text = '<|startoftranscript|>' + prompt_ids = self.tokenizer.encode(prompt_text, add_special_tokens = False) + + inputs: PromptInputs = { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids, + "whisper_data": wav_data + } + sampling_params = SamplingParams(max_tokens = 1, temperature = 0, skip_special_tokens = False) + + result_generator = self.engine.generate( + inputs, + sampling_params, + request_id=request_id, + lora_request = None, + trace_headers=trace_headers, + ) + + + + async def audio_transcription_stream_generator( + self, + request: ChatCompletionRequest, + sampling_params, + stream_iterator, + language, + request_id: str, + ) -> AsyncGenerator[str, None]: + + wav_data = np.array([], dtype=np.float32) + last_timestamp = 0 + for chunk in stream_iterator: + wav_data = np.concatenate([wav_data, frame]) + audio_len = len(wav_data) / sample_rate + if audio_len >= maxlen: + async for t in generate( + sampling_params=sampling_params, + language=language, + wav_data=wav_data, + last_timestamp=last_timestamp, + ): + yield t + + last_timestamp += audio_len + wav_data = np.array([], dtype=np.float32) + + if len(wav_data): + + + inputs: PromptInputs = { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids, + } + + # Send response for each token for each request.n (index) + result_generator = self.engine.generate( + inputs, + sampling_params, + request_id, + lora_request, + trace_headers=trace_headers, + ) + assert request.n is not None + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n + try: + async for res in result_generator: + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role=role), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content = "" + if conversation and conversation[-1].get( + "content") and conversation[-1].get( + "role") == role: + last_msg_content = conversation[-1]["content"] + + if last_msg_content: + for i in range(request.n): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + logprobs=None, + model=model_name) + if (request.stream_options and + request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + out_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=delta_token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + ) + else: + logprobs = None + + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + if request.tool_choice and type( + request.tool_choice + ) is ChatCompletionNamedToolChoiceParam: + delta_message = DeltaMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text)) + ]) + else: + delta_message = DeltaMessage(content=delta_text) + + if output.finish_reason is None: + # Send token-by-token response for each request.n + + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + + if (request.stream_options + and request.stream_options.include_usage): + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + previous_num_tokens[i], + ) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, request: ChatCompletionRequest, raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + model_name = self.served_model_names[0] + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + async for res in result_generator: + if raw_request is not None and await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + + choices: List[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + ) + else: + logprobs = None + + if request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + message = ChatMessage( + role=role, + content="", + tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=output.text)) + ]) + elif not request.tool_choice or request.tool_choice == "none": + message = ChatMessage(role=role, content=output.text) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + choices.append(choice_data) + + if request.echo: + last_msg_content = "" + if conversation and conversation[-1].get( + "content") and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] + + for choice in choices: + full_message = last_msg_content + choice.message.content + choice.message.content = full_message + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 8b0d9e54d546a..e543393cc03e9 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -494,12 +494,8 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ): - if hasattr(attn_metadata, 'encoder_outputs'): - encoder_outputs = attn_metadata.encoder_outputs - else: - encoder_outputs = self.encoder(input_features) - attn_metadata.encoder_outputs = encoder_outputs + ): + encoder_outputs = self.encoder(input_features) decoder_outputs = self.decoder( input_ids=input_ids, diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 9614f01d2b955..4ab768f0e5ea8 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -49,9 +49,11 @@ def _raise_if_input_too_long(self, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: bool = True, + ) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = tokenizer.encode(prompt) + ret = tokenizer.encode(prompt, add_special_tokens = add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3fa761ec8d37c..adda874e1c970 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -760,6 +760,8 @@ def execute_model( else: model_executable = self.model + print(input_tokens, input_positions, attn_metadata) + hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, diff --git a/vllm/worker/whisper_model_runner.py b/vllm/worker/whisper_model_runner.py index 76533c687f470..ac4559e8d5a03 100644 --- a/vllm/worker/whisper_model_runner.py +++ b/vllm/worker/whisper_model_runner.py @@ -79,8 +79,6 @@ def execute_model( else: model_executable = self.model - model_executable = self.model - hidden_states = model_executable( whisper_data=whisper_data, input_ids=input_tokens, From 5b169829df845455eae5f2ced0b164b35b1967ac Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Tue, 2 Jul 2024 10:54:53 +0800 Subject: [PATCH 13/16] improve whisper serving --- vllm/engine/arg_utils.py | 27 +- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/serving_whisper.py | 355 ++++----------------- 3 files changed, 80 insertions(+), 303 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d71b63b45753..3246e11203825 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -166,23 +166,31 @@ def add_cli_args_for_vlm( def add_cli_args_for_whisper( parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( - '--image-processor', + '--whisper-input-type', + type=nullable_str, + default=EngineArgs.whisper_input_type, + choices=[ + 'input_features' + ], + help=('The audio input type for whisper passed into vLLM.')) + parser.add_argument( + '--whisper-processor', type=str, - default=EngineArgs.image_processor, - help='Name or path of the huggingface image processor to use. ' + default=EngineArgs.whisper_processor, + help='Name or path of the huggingface whisper processor to use. ' 'If unspecified, model name or path will be used.') parser.add_argument( - '--image-processor-revision', + '--whisper-processor-revision', type=str, default=None, - help='Revision of the huggingface image processor version to use. ' + help='Revision of the huggingface whisper processor version to use. ' 'It can be a branch name, a tag name, or a commit id. ' 'If unspecified, will use the default version.') parser.add_argument( - '--disable-image-processor', - action='store_true', - help='Disables the use of image processor, even if one is defined ' - 'for the model on huggingface.') + '--sample-rate', + type=int, + default=EngineArgs.sample_rate, + help='sample rate for whisper processor') return parser @@ -543,6 +551,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) + parser = EngineArgs.add_cli_args_for_whisper(parser) parser.add_argument( '--scheduler-delay-factor', diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f912e55b67398..60dc5d723e129 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -12,6 +12,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi import File, Form, UploadFile from prometheus_client import make_asgi_app from starlette.routing import Mount diff --git a/vllm/entrypoints/openai/serving_whisper.py b/vllm/entrypoints/openai/serving_whisper.py index c8d3e3514d89d..69dd84f9dc035 100644 --- a/vllm/entrypoints/openai/serving_whisper.py +++ b/vllm/entrypoints/openai/serving_whisper.py @@ -49,13 +49,6 @@ pattern_pair = r'<\|(\d+\.\d+)\|>(.*?)<\|(\d+\.\d+)\|>' -@dataclass(frozen=True) -class ChatMessageParseResult: - messages: List[ConversationMessage] - image_futures: List[Awaitable[ImagePixelData]] = field( - default_factory=list) - - class OpenAIServingWhisper(OpenAIServing): def __init__(self, @@ -88,8 +81,7 @@ async def create_audio_transcriptions( repetition_penalty, stream, raw_request: Optional[Request] = None - ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ): if len(file) > self.max_size_whisper: return self.create_error_response(f"maximum size for file is {self.max_size_whisper}MB only") @@ -148,24 +140,48 @@ async def generate( request_id, trace_headers, ): - prompt_text = '<|startoftranscript|>' - prompt_ids = self.tokenizer.encode(prompt_text, add_special_tokens = False) + if language is None: + prompt_ids = [50258] + inputs: PromptInputs = { + "prompt": None, + "prompt_token_ids": prompt_ids, + "whisper_data": wav_data + } + lang_sampling_params = SamplingParams( + max_tokens = 1, temperature = 0, skip_special_tokens = False) + + result_generator = self.engine.generate( + inputs, + lang_sampling_params, + request_id=request_id, + lora_request = None, + trace_headers=trace_headers, + ) + async for res in result_generator: + if raw_request is not None and await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + yield self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + + lang_token = final_res.outputs[0].token_ids[0] + else: + lang_token = self.tokenizer.encode( + lang_token = f'<|{language}|>', add_special_tokens = False)[0] + prompt_ids = [50258, lang_token, 50360, 50365] inputs: PromptInputs = { - "prompt": prompt_text, + "prompt": None, "prompt_token_ids": prompt_ids, "whisper_data": wav_data } - sampling_params = SamplingParams(max_tokens = 1, temperature = 0, skip_special_tokens = False) - result_generator = self.engine.generate( - inputs, - sampling_params, - request_id=request_id, - lora_request = None, - trace_headers=trace_headers, - ) - + text = processor.tokenizer.decode(prompt_ids, decode_with_timestamps = True) + + async for res in result_generator: + print(res.outputs) + yield res.outputs async def audio_transcription_stream_generator( @@ -179,283 +195,34 @@ async def audio_transcription_stream_generator( wav_data = np.array([], dtype=np.float32) last_timestamp = 0 - for chunk in stream_iterator: - wav_data = np.concatenate([wav_data, frame]) - audio_len = len(wav_data) / sample_rate - if audio_len >= maxlen: - async for t in generate( - sampling_params=sampling_params, - language=language, - wav_data=wav_data, - last_timestamp=last_timestamp, - ): - yield t - - last_timestamp += audio_len - wav_data = np.array([], dtype=np.float32) - - if len(wav_data): - - - inputs: PromptInputs = { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids, - } - - # Send response for each token for each request.n (index) - result_generator = self.engine.generate( - inputs, - sampling_params, - request_id, - lora_request, - trace_headers=trace_headers, - ) - assert request.n is not None - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - finish_reason_sent = [False] * request.n try: - async for res in result_generator: - # We need to do it here, because if there are exceptions in - # the result_generator, it needs to be sent as the FIRST - # response (by the try...catch). - if first_iteration: - # Send first response for each request.n (index) with - # the role - role = self.get_chat_request_role(request) - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role=role), - logprobs=None, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if (request.stream_options - and request.stream_options.include_usage): - chunk.usage = None - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - - # Send response to echo the input portion of the - # last message - if request.echo: - last_msg_content = "" - if conversation and conversation[-1].get( - "content") and conversation[-1].get( - "role") == role: - last_msg_content = conversation[-1]["content"] - - if last_msg_content: - for i in range(request.n): - choice_data = ( - ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage( - content=last_msg_content), - finish_reason=None)) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - logprobs=None, - model=model_name) - if (request.stream_options and - request.stream_options.include_usage): - chunk.usage = None - data = chunk.model_dump_json( - exclude_unset=True) - yield f"data: {data}\n\n" - first_iteration = False - - for output in res.outputs: - i = output.index - - if finish_reason_sent[i]: - continue - - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, ( - "Did not output logprobs") - logprobs = self._create_chat_logprobs( - token_ids=delta_token_ids, - top_logprobs=out_logprobs, - num_output_top_logprobs=request.top_logprobs, - ) - else: - logprobs = None - - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - - if request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam: - delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( - name=request.tool_choice.function.name, - arguments=delta_text)) - ]) - else: - delta_message = DeltaMessage(content=delta_text) - - if output.finish_reason is None: - # Send token-by-token response for each request.n - - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=delta_message, - logprobs=logprobs, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if (request.stream_options - and request.stream_options.include_usage): - chunk.usage = None - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - else: - # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=delta_message, - logprobs=logprobs, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if (request.stream_options - and request.stream_options.include_usage): - chunk.usage = None - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - finish_reason_sent[i] = True - - if (request.stream_options - and request.stream_options.include_usage): - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], - ) - - final_usage_chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[], - model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) - yield f"data: {final_usage_data}\n\n" + for chunk in stream_iterator: + wav_data = np.concatenate([wav_data, frame]) + audio_len = len(wav_data) / sample_rate + if audio_len >= maxlen: + async for t in generate( + sampling_params=sampling_params, + language=language, + wav_data=wav_data, + last_timestamp=last_timestamp, + ): + yield t + + last_timestamp += audio_len + wav_data = np.array([], dtype=np.float32) + + if len(wav_data): + async for t in generate( + sampling_params=sampling_params, + language=language, + wav_data=wav_data, + last_timestamp=last_timestamp, + ): + yield t except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished - yield "data: [DONE]\n\n" - - async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], request_id: str, - conversation: List[ConversationMessage] - ) -> Union[ErrorResponse, ChatCompletionResponse]: - - model_name = self.served_model_names[0] - created_time = int(time.time()) - final_res: Optional[RequestOutput] = None - - async for res in result_generator: - if raw_request is not None and await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(request_id) - return self.create_error_response("Client disconnected") - final_res = res - assert final_res is not None - - choices: List[ChatCompletionResponseChoice] = [] - - role = self.get_chat_request_role(request) - for output in final_res.outputs: - token_ids = output.token_ids - out_logprobs = output.logprobs - - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, "Did not output logprobs" - logprobs = self._create_chat_logprobs( - token_ids=token_ids, - top_logprobs=out_logprobs, - num_output_top_logprobs=request.top_logprobs, - ) - else: - logprobs = None - - if request.tool_choice and type( - request.tool_choice) is ChatCompletionNamedToolChoiceParam: - message = ChatMessage( - role=role, - content="", - tool_calls=[ - ToolCall(function=FunctionCall( - name=request.tool_choice.function.name, - arguments=output.text)) - ]) - elif not request.tool_choice or request.tool_choice == "none": - message = ChatMessage(role=role, content=output.text) - - choice_data = ChatCompletionResponseChoice( - index=output.index, - message=message, - logprobs=logprobs, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - choices.append(choice_data) - - if request.echo: - last_msg_content = "" - if conversation and conversation[-1].get( - "content") and conversation[-1].get("role") == role: - last_msg_content = conversation[-1]["content"] - - for choice in choices: - full_message = last_msg_content + choice.message.content - choice.message.content = full_message - - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = ChatCompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - - return response + yield "data: [DONE]\n\n" \ No newline at end of file From 7828dc90933834d4514b15556349da9388106566 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Tue, 2 Jul 2024 16:05:55 +0800 Subject: [PATCH 14/16] added streaming token --- vllm/engine/async_llm_engine.py | 17 ++- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/serving_whisper.py | 165 ++++++++++++++++++--- vllm/worker/whisper_model_runner.py | 2 +- 4 files changed, 160 insertions(+), 26 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index df25eb111e87f..4e418e5964e58 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -278,9 +278,24 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if 'whisper_data' in inputs: + if self.whisper_config is None: + raise ValueError(f"Whisper config is None, must initialize a Whisper model.") + if self.whisper_processor is None: + raise ValueError(f"Whisper Processor is not initialized.") + whisper_data = self.whisper_processor( + inputs['whisper_data'], + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] + else: + whisper_data = None + return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + multi_modal_data=inputs.get("multi_modal_data"), + whisper_data=whisper_data) async def add_request_async( self, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 60dc5d723e129..74d0226af76f7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -150,7 +150,7 @@ async def audio_transcriptions( stream: bool = Form(False), raw_request: Request = None, ): - generator = await openai_serving_chat.create_audio_transcriptions( + generator = await openai_serving_whisper.create_audio_transcriptions( file=file, language=language, response_format=response_format, diff --git a/vllm/entrypoints/openai/serving_whisper.py b/vllm/entrypoints/openai/serving_whisper.py index 69dd84f9dc035..fc0c7c287ae8c 100644 --- a/vllm/entrypoints/openai/serving_whisper.py +++ b/vllm/entrypoints/openai/serving_whisper.py @@ -1,7 +1,10 @@ import codecs import time +import re +import json import torchaudio import numpy as np +from torchaudio.io import StreamReader from dataclasses import dataclass, field from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional) @@ -48,6 +51,26 @@ pattern = r'<\|\-?\d+\.?\d*\|>' pattern_pair = r'<\|(\d+\.\d+)\|>(.*?)<\|(\d+\.\d+)\|>' +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + class OpenAIServingWhisper(OpenAIServing): @@ -63,7 +86,8 @@ def __init__(self, lora_modules=None) self._check_whisper_mode(model_config.whisper_mode) - self.max_size_whisper = max_size_whisper + self.model_config = model_config + self.max_size_whisper = max_size_whisper * 1024 * 1024 def _check_whisper_mode(self, whisper_mode: bool): if not whisper_mode: @@ -84,12 +108,19 @@ async def create_audio_transcriptions( ): if len(file) > self.max_size_whisper: - return self.create_error_response(f"maximum size for file is {self.max_size_whisper}MB only") + return self.create_error_response(f"maximum size for `file` is {self.max_size_whisper}MB only") + + if timestamp_granularities.lower().strip() != 'segment': + return self.create_error_response("currently `timestamp_granularities` only support `segment`") + + if response_format.lower() not in {'text', 'json', 'verbose_json', 'srt'}: + return self.create_error_response( + 'currently `response_format` only support `text`, `json`, `verbose_json` and `srt`') request_id = f"cmpl-{random_uuid()}" sampling_params = SamplingParams( - max_tokens = self.engine.model_config.max_model_len - 4, + max_tokens = self.model_config.max_model_len - 4, temperature = 0.0, skip_special_tokens = False, stop_token_ids = [50257], @@ -120,13 +151,27 @@ async def create_audio_transcriptions( log_tracing_disabled_warning() # Streaming response - if request.stream: + if stream: return self.audio_transcription_stream_generator( - request, sampling_params, stream_iterator, language, request_id, trace_headers) + sampling_params, + stream_iterator, + language, + response_format, + request_id, + trace_headers, + raw_request, + ) else: try: return await self.audio_transcription_full_generator( - request, sampling_params, stream_iterator, language, request_id, trace_headers) + sampling_params, + stream_iterator, + language, + response_format, + request_id, + trace_headers, + raw_request, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -134,11 +179,14 @@ async def create_audio_transcriptions( async def generate( self, sampling_params, - language, + language, wav_data, last_timestamp, + last_i, + response_format, request_id, trace_headers, + raw_request, ): if language is None: prompt_ids = [50258] @@ -153,7 +201,7 @@ async def generate( result_generator = self.engine.generate( inputs, lang_sampling_params, - request_id=request_id, + request_id=request_id + '-predict-lang', lora_request = None, trace_headers=trace_headers, ) @@ -166,6 +214,7 @@ async def generate( assert final_res is not None lang_token = final_res.outputs[0].token_ids[0] + language = self.tokenizer.decode([lang_token])[2:-2] else: lang_token = self.tokenizer.encode( lang_token = f'<|{language}|>', add_special_tokens = False)[0] @@ -177,48 +226,118 @@ async def generate( "whisper_data": wav_data } - text = processor.tokenizer.decode(prompt_ids, decode_with_timestamps = True) + result_generator = self.engine.generate( + inputs, + sampling_params, + request_id=request_id + f'-{last_i}', + lora_request = None, + trace_headers=trace_headers, + ) + + texts = f'<|{language}|><|{last_timestamp}|>' + if response_format != 'srt': + text = texts + if response_format == 'json': + text = json.dumps({'token': texts}) + + yield f"data: {text}\n\n" + + """ + [CompletionOutput(index=0, text=' and', token_ids=[293], cumulative_logprob=-1.7037980556488037, logprobs=None, finish_reason=None, stop_reason=None)] + """ async for res in result_generator: - print(res.outputs) - yield res.outputs + + token = self.tokenizer.convert_ids_to_tokens([res.outputs[0].token_ids[-1]]) + text = self.tokenizer.convert_tokens_to_string(token) + + for r in replaces: + text = text.replace(r, '') + matches = re.findall(pattern, text) + for match in matches: + timestamp = float(match.split('|')[1]) + timestamp += last_timestamp + timestamp = f'<|{timestamp}|>' + text = text.replace(match, timestamp) + if len(text): + texts += text + matches = re.findall(pattern_pair, texts) + if response_format == 'srt': + if len(matches): + match = matches[0] + if len(match[1]) > 2: + start = float(match[0]) + last_timestamp + end = float(match[-1]) + last_timestamp + text_between = match[1].strip() + ids = f"{last_i + 1}\n" + r = [ + ids, + f"{format_timestamp(start, always_include_hours=True, decimal_marker=',')} --> ", + f"{format_timestamp(end, always_include_hours=True, decimal_marker=',')}\n", + f"{text_between.replace('-->', '->')}\n"] + + combined = ''.join(r) + '\n' + last_i += 1 + yield f"data: {combined}\n\n" + + texts = text.split('|>')[-2] + '|>' + else: + if response_format == 'json': + text = json.dumps({'token': text}) + + yield f"data: {text}\n\n" async def audio_transcription_stream_generator( self, - request: ChatCompletionRequest, sampling_params, stream_iterator, language, + response_format, request_id: str, + trace_headers, + raw_request, ) -> AsyncGenerator[str, None]: - wav_data = np.array([], dtype=np.float32) - last_timestamp = 0 + last_i = 0 + last_timestamp = 0.0 try: for chunk in stream_iterator: + frame = chunk[0][:, 0].numpy() wav_data = np.concatenate([wav_data, frame]) audio_len = len(wav_data) / sample_rate if audio_len >= maxlen: - async for t in generate( + async for t in self.generate( sampling_params=sampling_params, language=language, wav_data=wav_data, last_timestamp=last_timestamp, + last_i=last_i, + response_format=response_format, + request_id=request_id, + trace_headers=trace_headers, + raw_request=raw_request, ): yield t + last_i += 1 last_timestamp += audio_len wav_data = np.array([], dtype=np.float32) - if len(wav_data): - async for t in generate( - sampling_params=sampling_params, - language=language, - wav_data=wav_data, - last_timestamp=last_timestamp, - ): - yield t + if len(wav_data): + async for t in self.generate( + sampling_params=sampling_params, + language=language, + wav_data=wav_data, + last_timestamp=last_timestamp, + last_i=last_i, + response_format=response_format, + request_id=request_id, + trace_headers=trace_headers, + raw_request=raw_request, + ): + yield t + last_i += 1 except ValueError as e: # TODO: Use a vllm-specific Validation Error diff --git a/vllm/worker/whisper_model_runner.py b/vllm/worker/whisper_model_runner.py index ac4559e8d5a03..e148fabb181ac 100644 --- a/vllm/worker/whisper_model_runner.py +++ b/vllm/worker/whisper_model_runner.py @@ -130,7 +130,7 @@ def _prepare_encoder_model_input( whisper_data_list.append(whisper_data) - whisper_data = torch.cat(whisper_data_list, dim = 1).cuda() + whisper_data = whisper_data_list[0].cuda() return whisper_data From fa81def0aab015cf183b662ea8cb2d89ab1be428 Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Tue, 2 Jul 2024 22:56:25 +0800 Subject: [PATCH 15/16] added non streaming --- vllm/entrypoints/openai/api_server.py | 11 ++- vllm/entrypoints/openai/cli_args.py | 4 +- vllm/entrypoints/openai/protocol.py | 22 +++++ vllm/entrypoints/openai/serving_whisper.py | 97 +++++++++++++++++----- 4 files changed, 110 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 74d0226af76f7..b47d95b855d99 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -23,6 +23,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, + TranscriptionVerboseJsonResponse, + TranscriptionJsonResponse, EmbeddingRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -143,6 +145,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @app.post("/audio/transcriptions") async def audio_transcriptions( file: bytes = File(), + model: str = 'whisper', language: str = Form(None), response_format: str = Form('text'), timestamp_granularities: str = Form('segment'), @@ -166,8 +169,10 @@ async def audio_transcriptions( return StreamingResponse(content=generator, media_type="text/event-stream") else: - assert isinstance(generator, ChatCompletionResponse) - return JSONResponse(content=generator.model_dump()) + if isinstance(generator, str): + return generator + else: + return JSONResponse(content=generator.model_dump()) if __name__ == "__main__": args = parse_args() @@ -251,7 +256,7 @@ async def authentication(request: Request, call_next): openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) openai_serving_whisper = OpenAIServingWhisper( - engine, model_config, served_model_names, args.max_size_whisper) + engine, model_config, served_model_names, args.max_size_mb_whisper) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 30d16a3d7f755..a134a0cbeae09 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -71,9 +71,9 @@ def make_arg_parser(): help="The file path to the chat template, " "or the template in single-line form " "for the specified model") - parser.add_argument("--max-size-whisper", + parser.add_argument("--max-size-mb-whisper", type=int, - default=100, + default=200, help="max size of audio to transcribe using Whisper in term of MB.") parser.add_argument("--response-role", type=nullable_str, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b57d79859aec5..d1719998b8854 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -699,3 +699,25 @@ class BatchRequestOutput(OpenAIBaseModel): # For requests that failed with a non-HTTP error, this will contain more # information on the cause of the failure. error: Optional[Any] + +class Segment(BaseModel): + id: int + seek: int + start: float + end: float + text: str + tokens: list[int] + temperature: float + avg_logprob: float + compression_ratio: float + no_speech_prob: float + +class TranscriptionVerboseJsonResponse(BaseModel): + task: str = "transcribe" + language: str + duration: float + text: str + segments: list[Segment] + +class TranscriptionJsonResponse(BaseModel): + text: str \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_whisper.py b/vllm/entrypoints/openai/serving_whisper.py index fc0c7c287ae8c..3f46816233c00 100644 --- a/vllm/entrypoints/openai/serving_whisper.py +++ b/vllm/entrypoints/openai/serving_whisper.py @@ -18,13 +18,7 @@ from vllm.config import ModelConfig, VisionLanguageConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( - ChatCompletionContentPartParam, ChatCompletionLogProb, - ChatCompletionLogProbs, ChatCompletionLogProbsContent, - ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, - ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo) + Segment, TranscriptionVerboseJsonResponse, TranscriptionJsonResponse) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.inputs import PromptInputs @@ -78,7 +72,7 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - max_size_whisper: 100, + max_size_mb_whisper: 200, ): super().__init__(engine=engine, model_config=model_config, @@ -87,7 +81,7 @@ def __init__(self, self._check_whisper_mode(model_config.whisper_mode) self.model_config = model_config - self.max_size_whisper = max_size_whisper * 1024 * 1024 + self.max_size_mb_whisper = max_size_mb_whisper * 1024 * 1024 def _check_whisper_mode(self, whisper_mode: bool): if not whisper_mode: @@ -107,8 +101,8 @@ async def create_audio_transcriptions( raw_request: Optional[Request] = None ): - if len(file) > self.max_size_whisper: - return self.create_error_response(f"maximum size for `file` is {self.max_size_whisper}MB only") + if len(file) > self.max_size_mb_whisper: + return self.create_error_response(f"maximum size for `file` is {self.max_size_mb_whisper}MB only") if timestamp_granularities.lower().strip() != 'segment': return self.create_error_response("currently `timestamp_granularities` only support `segment`") @@ -216,8 +210,7 @@ async def generate( lang_token = final_res.outputs[0].token_ids[0] language = self.tokenizer.decode([lang_token])[2:-2] else: - lang_token = self.tokenizer.encode( - lang_token = f'<|{language}|>', add_special_tokens = False)[0] + lang_token = self.tokenizer.encode(f'<|{language}|>', add_special_tokens = False)[0] prompt_ids = [50258, lang_token, 50360, 50365] inputs: PromptInputs = { @@ -241,7 +234,7 @@ async def generate( if response_format == 'json': text = json.dumps({'token': texts}) - yield f"data: {text}\n\n" + yield text """ [CompletionOutput(index=0, text=' and', token_ids=[293], cumulative_logprob=-1.7037980556488037, logprobs=None, finish_reason=None, stop_reason=None)] @@ -278,14 +271,14 @@ async def generate( combined = ''.join(r) + '\n' last_i += 1 - yield f"data: {combined}\n\n" + yield combined texts = text.split('|>')[-2] + '|>' else: if response_format == 'json': text = json.dumps({'token': text}) - yield f"data: {text}\n\n" + yield text async def audio_transcription_stream_generator( @@ -318,7 +311,7 @@ async def audio_transcription_stream_generator( trace_headers=trace_headers, raw_request=raw_request, ): - yield t + yield f"data: {t}\n\n" last_i += 1 last_timestamp += audio_len @@ -336,7 +329,7 @@ async def audio_transcription_stream_generator( trace_headers=trace_headers, raw_request=raw_request, ): - yield t + yield f"data: {t}\n\n" last_i += 1 except ValueError as e: @@ -344,4 +337,70 @@ async def audio_transcription_stream_generator( data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished - yield "data: [DONE]\n\n" \ No newline at end of file + yield "data: [DONE]\n\n" + + async def audio_transcription_full_generator( + self, + sampling_params, + stream_iterator, + language, + response_format, + request_id: str, + trace_headers, + raw_request, + ): + tokens = [] + async for data in self.audio_transcription_stream_generator( + sampling_params=sampling_params, + stream_iterator=stream_iterator, + language=language, + response_format='json', + request_id=request_id, + trace_headers=trace_headers, + raw_request=raw_request, + ): + if isinstance(data, str): + if '[DONE]' in data: + break + data = json.loads(data.split('data:')[1].strip()) + tokens.append(data['token']) + + tokens = ''.join(tokens) + lang = tokens.split('|')[1] + matches = re.findall(pattern_pair, tokens) + print(matches) + segments = [] + all_texts = [] + for no, (start, substring, end) in enumerate(matches): + start_timestamp = float(start) + end_timestamp = float(end) + segment = Segment( + id=no, + seek=0, + start=start_timestamp, + end=end_timestamp, + text=substring.strip(), + tokens=self.tokenizer.encode(substring.strip(), add_special_tokens=False), + temperature=0.0, + avg_logprob=0.0, + compression_ratio=1.0, + no_speech_prob=0.0, + ) + segments.append(segment) + all_texts.append(substring) + + all_texts = ''.join(all_texts).strip() + if response_format == 'verbose_json': + return TranscriptionVerboseJsonResponse( + task='transcribe', + language=lang, + duration=segments[-1].end, + text=all_texts, + segments=segments + ) + elif response_format == 'json': + return TranscriptionJsonResponse( + text=all_texts + ) + else: + return all_texts \ No newline at end of file From ebf1cbfd77a42ef5772b1fbfa78c998620cc7e9e Mon Sep 17 00:00:00 2001 From: huseinzol05 Date: Sun, 28 Jul 2024 22:23:11 +0800 Subject: [PATCH 16/16] fix whisper example --- examples/whisper_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/whisper_example.py b/examples/whisper_example.py index 21db65fda1d51..d325870b20d51 100644 --- a/examples/whisper_example.py +++ b/examples/whisper_example.py @@ -2,7 +2,6 @@ import torch import requests from vllm import LLM -from vllm.multimodal.audio import AudioData from datasets import Audio @@ -22,12 +21,12 @@ def main(): output_lang = llm.generate({ "prompt_token_ids": [50258], - "multi_modal_data": AudioData(y), + "whisper_data": y, }, sampling_params = SamplingParams(max_tokens = 1, temperature = 0)) outputs = llm.generate({ "prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360], - "multi_modal_data": AudioData(y), + "whisper_data": y, }, sampling_params = SamplingParams(max_tokens = 100, temperature = 0)) # ' without going to any such extreme as this we can easily see on reflection how vast an influence on the'