Skip to content

Added customio for seq2seq models and updated input names #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,23 @@ def compile(
if num_speculative_tokens:
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")

output_names = self.model.get_output_names()

kv_cache_dtype = "float16"
custom_io = {}

custom_io["input_features"] = kv_cache_dtype

# Slice output_names to get input names
for output_name in output_names:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says inputs, but the line below is iterating output_names. We need to add input_features to custom_io.

Otherwise, input_features are still float32, as in the generation code input_features are explicitly converted to float32:
line 1905 inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32) and line 1939 inputs["input_features"] = np.zeros((self.batch_size, self.model.config.num_mel_bins, 1)).astype(np.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kdulla Could you add input_features to custom_io and see if dtype conversion can be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are adding inputs to custom_io in this loop, we get the names of inputs by slicing "_RetainedState" off the end of output_names, this is just the most straightforward to get the input_names that will be compatible with future Seq2Seq models added.

input_features are in float32 because that is the output type of WhisperProcessor, the generate is written in a way that the processor inputs are taken directly as input, so we expect float32.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_features are in float32 because that is the output type of WhisperProcessor, the generate is written in a way that the processor inputs are taken directly as input, so we expect float32.

However, for vision models, even though they use AutoProcessor which outputs pixel_values in float32, pixel_values is still part of the custom_io and is set to float16. : https://github.com/quic/efficient-transformers/pull/336/files Additionally, inside the generate function vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16")

So I don't see any reason that seq2seq models can't do the same. It would be great if the design choices were more consistent, so that in vLLM we don't have to convert multimodal inputs to different data types for different models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated seq2seq models to match vision models and use float16

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have verified the PR in vLLM and it looks good to me.

if output_name.endswith("_RetainedState"):
custom_io[output_name[: -len("_RetainedState")]] = kv_cache_dtype

# Get output names
for output_name in output_names:
if output_name.endswith("_RetainedState"):
custom_io[output_name] = kv_cache_dtype

return self._compile(
onnx_path,
compile_dir,
Expand All @@ -1821,6 +1838,7 @@ def compile(
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
custom_io=custom_io,
**compiler_options,
)

Expand Down Expand Up @@ -1852,14 +1870,14 @@ def generate(
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
self.batch_size = self.qpc_session.bindings[0].dims[0]

inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32)
inputs["input_features"] = inputs["input_features"].numpy().astype(np.float16)

# add start token id and initial position ids to inputs
seq_len = 1
inputs["decoder_input_ids"] = (
inputs["input_ids"] = (
torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id
).numpy()
inputs["decoder_position_ids"] = (
inputs["position_ids"] = (
torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy()
)

Expand All @@ -1886,7 +1904,7 @@ def generate(
if streamer:
streamer.put(next_token)

inputs["input_features"] = np.zeros((self.batch_size, self.model.config.num_mel_bins, 1)).astype(np.float32)
inputs["input_features"] = np.zeros((self.batch_size, self.model.config.num_mel_bins, 1)).astype(np.float16)

loop_start = perf_counter()
for num_tokens in range(generation_len):
Expand All @@ -1898,8 +1916,8 @@ def generate(
if next_token[0][0] == self.model.config.eos_token_id:
break

inputs["decoder_input_ids"] = next_token
inputs["decoder_position_ids"] += 1
inputs["input_ids"] = next_token
inputs["position_ids"] += 1

if streamer:
streamer.put(next_token)
Expand Down
79 changes: 73 additions & 6 deletions QEfficient/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
#
# ----------------------------------------------------------------------------

from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers.cache_utils import Cache, StaticCache
from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache
from transformers.modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from transformers.models.whisper.modeling_whisper import (
Expand Down Expand Up @@ -700,8 +701,74 @@ class QEffWhisperForConditionalGeneration(WhisperForConditionalGeneration):

The only differences are:
- Added get_dummy_inputs, get_onnx_dynamic_axes, get_output_names for AutoModel export
- changed forward inputs decoder_input_ids and decoder_position_ids to input_ids and position_ids
"""

def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
lm_logits = self.proj_out(outputs[0])

loss = None
if labels is not None:
loss_fct = torch.nn.CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))

if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)

def get_dummy_inputs(
self,
):
Expand All @@ -715,8 +782,8 @@ def get_dummy_inputs(

inputs = {
"input_features": torch.zeros((bs, encoder_feature_count, 1), dtype=torch.float32),
"decoder_input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"decoder_position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
"past_key_values": [[] for _ in range(num_layers)],
}

Expand Down Expand Up @@ -769,8 +836,8 @@ def get_onnx_dynamic_axes(

dynamic_axes = {
"input_features": {0: "batch_size", 2: "feature_len"},
"decoder_input_ids": {0: "batch_size", 1: "seq_len"},
"decoder_position_ids": {0: "batch_size", 1: "seq_len"},
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
}
pkv_self_dynamic_axes = {
0: "batch_size",
Expand Down
24 changes: 10 additions & 14 deletions tests/transformers/models/test_speech_seq2seq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def run_seq2seq_pytorch_with_kv(

model_inputs = dict(
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_position_ids=decoder_position_ids,
input_ids=decoder_input_ids,
position_ids=decoder_position_ids,
past_key_values=[[] for _ in range(config.num_hidden_layers)],
)

Expand All @@ -169,9 +169,7 @@ def run_seq2seq_pytorch_with_kv(
next_token = logits.argmax(-1)
generated_ids[:, 1] = next_token.squeeze(1)

model_inputs["input_features"] = torch.tensor(
np.random.randn(batch_size, config.num_mel_bins, 1).astype(np.float32)
)
model_inputs["input_features"] = torch.tensor(np.zeros((batch_size, config.num_mel_bins, 1)).astype(np.float32))
model_inputs["past_key_values"] = outputs["past_key_values"]

for num_tokens in range(generation_len):
Expand All @@ -183,8 +181,8 @@ def run_seq2seq_pytorch_with_kv(
if next_token[0][0] == processor.tokenizer.eos_token_id:
break

model_inputs["decoder_input_ids"] = next_token
model_inputs["decoder_position_ids"] += 1
model_inputs["input_ids"] = next_token
model_inputs["position_ids"] += 1
model_inputs["past_key_values"] = outputs["past_key_values"]

return generated_ids[0]
Expand Down Expand Up @@ -234,8 +232,8 @@ def run_seq2seq_ort(

model_inputs = dict(
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_position_ids=decoder_position_ids,
input_ids=decoder_input_ids,
position_ids=decoder_position_ids,
)

# prepare dummy past kvs and cross kvs
Expand Down Expand Up @@ -263,9 +261,7 @@ def run_seq2seq_ort(
next_token = logits.argmax(-1)
generated_ids[:, 1] = next_token.squeeze(1)

model_inputs["input_features"] = torch.tensor(
np.random.randn(batch_size, config.num_mel_bins, 1).astype(np.float32)
)
model_inputs["input_features"] = torch.tensor(np.zeros((batch_size, config.num_mel_bins, 1)).astype(np.float32))
for i, name in enumerate(pkv_names):
model_inputs[name.split("_RetainedState")[0]] = outputs[1 + i]

Expand All @@ -280,8 +276,8 @@ def run_seq2seq_ort(
if next_token[0][0] == processor.tokenizer.eos_token_id:
break

model_inputs["decoder_input_ids"] = next_token
model_inputs["decoder_position_ids"] += 1
model_inputs["input_ids"] = next_token
model_inputs["position_ids"] += 1
for i, name in enumerate(pkv_names):
model_inputs[name.split("_RetainedState")[0]] = outputs[1 + i]

Expand Down
Loading