Skip to content

Commit

Permalink
optimize seamless-m4t model for text-to-speech generation
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Mar 21, 2024
1 parent bb43f6c commit 0a511c4
Show file tree
Hide file tree
Showing 5 changed files with 973 additions and 2 deletions.
42 changes: 40 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"phi",
"mixtral",
"blip_text_model",
"seamless_m4t",
]


Expand Down Expand Up @@ -182,6 +183,7 @@ def _prepare_decoder_input_ids_for_generation(
bos_token_id: int = None,
device: torch.device = None,
max_new_tokens: int = None,
pad_token_id: int = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
Expand Down Expand Up @@ -216,7 +218,10 @@ def _prepare_decoder_input_ids_for_generation(
# creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx
max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length
decoder_input_ids_start = (
torch.ones((batch_size, max_length), dtype=torch.long, device=device) * decoder_start_token_id
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
decoder_input_ids_start = torch.nn.functional.pad(
decoder_input_ids_start, (0, max_length - 1), value=pad_token_id
)

# no user input -> use decoder_start_token_id as decoder_input_ids
Expand All @@ -236,14 +241,46 @@ def _prepare_decoder_input_ids_for_generation(
isinstance(decoder_start_token_id, torch.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if token_idx is None:
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
else:
max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length
if max_length != decoder_input_ids_start.shape[-1]:
decoder_input_ids_start = torch.nn.functional.pad(
decoder_input_ids_start,
(0, max_length - decoder_input_ids_start.shape[-1]),
value=pad_token_id,
)
decoder_input_ids = decoder_input_ids_start.index_copy(1, token_idx, decoder_input_ids)
token_idx.add_(1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
else:
if token_idx is not None:
decoder_input_ids_len = decoder_input_ids.shape[-1]
max_length = (
max_new_tokens + decoder_input_ids_len
if max_new_tokens is not None
else self.generation_config.max_length
)
decoder_input_ids = torch.nn.functional.pad(
decoder_input_ids, (0, max_length - decoder_input_ids_len), value=pad_token_id
)
token_idx.copy_(decoder_input_ids_len)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
pad_len = max_length - decoder_attention_mask.shape[-1]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :pad_len], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask

return decoder_input_ids, model_kwargs

def _update_model_kwargs_for_generation(
Expand Down Expand Up @@ -656,6 +693,7 @@ def generate(
bos_token_id=generation_config.bos_token_id,
device=inputs_tensor.device,
max_new_tokens=generation_config.max_new_tokens,
pad_token_id=generation_config.pad_token_id,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
Expand Down
45 changes: 45 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@
gaudi_phi_model_forward,
gaudi_rot_matmul,
gaudi_rot_vec_mul,
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
gaudi_SeamlessM4TDecoder_forward,
gaudi_SeamlessM4TDecoderLayer_forward,
gaudi_SeamlessM4TForTextToSpeech_forward,
gaudi_SeamlessM4TForTextToSpeech_generate,
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitModel_forward,
gaudi_SpeechT5Attention_forward,
gaudi_SpeechT5Decoder_forward,
gaudi_SpeechT5DecoderLayer_forward,
Expand Down Expand Up @@ -359,3 +369,38 @@ def adapt_transformers_to_gaudi():
transformers.models.speecht5.modeling_speecht5.SpeechT5SpeechDecoderPrenet.forward = (
gaudi_SpeechT5SpeechDecoderPrenet_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TAttention.forward = (
gaudi_SeamlessM4TAttention_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer.forward = (
gaudi_SeamlessM4TDecoderLayer_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoder.forward = (
gaudi_SeamlessM4TDecoder_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitModel.forward = (
gaudi_SeamlessM4TTextToUnitModel_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.forward = (
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.prepare_inputs_for_generation = gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._get_output_hifigan_lengths = (
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.forward = (
gaudi_SeamlessM4TForTextToSpeech_forward
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.generate = (
gaudi_SeamlessM4TForTextToSpeech_generate
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.prepare_inputs_for_generation = (
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation
)

transformers.models.seamless
12 changes: 12 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@
gaudi_phi_decoder_layer_forward,
gaudi_phi_model_forward,
)
from .seamless_m4t import (
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
gaudi_SeamlessM4TDecoder_forward,
gaudi_SeamlessM4TDecoderLayer_forward,
gaudi_SeamlessM4TForTextToSpeech_forward,
gaudi_SeamlessM4TForTextToSpeech_generate,
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitModel_forward,
)
from .speecht5 import (
gaudi_generate_speech,
gaudi_SpeechT5Attention_forward,
Expand Down
12 changes: 12 additions & 0 deletions optimum/habana/transformers/models/seamless_m4t/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .modeling_seamless_m4t import (
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
gaudi_SeamlessM4TDecoder_forward,
gaudi_SeamlessM4TDecoderLayer_forward,
gaudi_SeamlessM4TForTextToSpeech_forward,
gaudi_SeamlessM4TForTextToSpeech_generate,
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitModel_forward,
)
Loading

0 comments on commit 0a511c4

Please sign in to comment.