Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize seamless-m4t/vits model for text-to-speech generation #825

Merged
merged 12 commits into from
Jun 6, 2024
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ fast_tests_diffusers:
# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
python -m pytest tests/test_pipeline.py

# Run multi-card non-regression tests
slow_tests_8x: test_installs
Expand Down
3 changes: 3 additions & 0 deletions examples/text-to-speech/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ pip install -r requirements.txt
python3 run_pipeline.py \
--model_name_or_path microsoft/speecht5_tts \
--text "Hello, my dog is cooler than you!" \
--use_hpu_graphs \
--bf16
```
Models that have been validated:
- [microsoft/speecht5_tts](https://huggingface.co/microsoft/speecht5_tts)
- [facebook/hf-seamless-m4t-medium](https://huggingface.co/facebook/hf-seamless-m4t-medium)
- [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng)
37 changes: 29 additions & 8 deletions examples/text-to-speech/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,18 @@ def main():
parser.add_argument("--batch_size", type=int, default=1, help="Input batch size.")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")
parser.add_argument("--seed", type=int, default=555, help="make speech generation deterministic")
parser.add_argument(
"--use_hpu_graphs",
action="store_true",
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
)
args = parser.parse_args()

adapt_transformers_to_gaudi()
text = args.text
text_bs = len(text)
set_seed(args.seed)

if args.batch_size > text_bs:
# Dynamically extends to support larger batch sizes
Expand All @@ -84,32 +91,46 @@ def main():
device="hpu",
)

embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to("hpu")
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

generator.model = wrap_in_hpu_graph(generator.model)

forward_params = None
if generator.model.config.model_type == "speecht5":
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to("hpu")
forward_params = {"speaker_embeddings": speaker_embedding}
if generator.model.config.model_type == "seamless_m4t":
forward_params = {"tgt_lang": "eng"}

generate_kwargs = None
if generator.model.can_generate():
generate_kwargs = {"lazy_mode": True, "ignore_eos": False, "hpu_graphs": args.use_hpu_graphs}

with torch.autocast("hpu", torch.bfloat16, enabled=args.bf16), torch.no_grad(), torch.inference_mode():
with torch.autocast("hpu", torch.bfloat16, enabled=args.bf16), torch.inference_mode():
# warm up
for i in range(args.warmup):
if generator.model.config.model_type == "speecht5":
# SpeechT5 forces a dropout with training=True, which may zero out some elements randomly.
# A random dropout may need different lengths of spectrograms to fit probability thresholds,
# which violates the HPU static shape, so we have to fix the seed here.
set_seed(555)
generator(text, batch_size=args.batch_size, forward_params={"speaker_embeddings": speaker_embedding})
set_seed(args.seed)
generator(text, batch_size=args.batch_size, forward_params=forward_params, generate_kwargs=generate_kwargs)

start = time.time()
for i in range(args.n_iterations):
if generator.model.config.model_type == "speecht5":
# SpeechT5 forces a dropout with training=True, which may zero out some elements randomly.
# A random dropout may need different lengths of spectrograms to fit probability thresholds,
# which violates the HPU static shape, so we have to fix the seed here.
set_seed(555)
set_seed(args.seed)
speech = generator(
text, batch_size=args.batch_size, forward_params={"speaker_embeddings": speaker_embedding}
text, batch_size=args.batch_size, forward_params=forward_params, generate_kwargs=generate_kwargs
)
end = time.time()
logger.info(f"speech = {speech} time = {(end-start) * 1000 / args.n_iterations }ms")
sf.write("speech.wav", speech[0]["audio"], samplerate=speech[0]["sampling_rate"])
sf.write("speech.wav", speech[0]["audio"].squeeze(), samplerate=speech[0]["sampling_rate"])


if __name__ == "__main__":
Expand Down
42 changes: 40 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"mixtral",
"gemma",
"blip_text_model",
"seamless_m4t",
"starcoder2",
"persimmon",
"qwen2",
Expand Down Expand Up @@ -172,6 +173,7 @@ def _prepare_decoder_input_ids_for_generation(
bos_token_id: int = None,
device: torch.device = None,
max_new_tokens: int = None,
pad_token_id: int = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
Expand Down Expand Up @@ -206,7 +208,10 @@ def _prepare_decoder_input_ids_for_generation(
# creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx
max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length
decoder_input_ids_start = (
torch.ones((batch_size, max_length), dtype=torch.long, device=device) * decoder_start_token_id
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
decoder_input_ids_start = torch.nn.functional.pad(
decoder_input_ids_start, (0, max_length - 1), value=pad_token_id
)

# no user input -> use decoder_start_token_id as decoder_input_ids
Expand All @@ -226,14 +231,46 @@ def _prepare_decoder_input_ids_for_generation(
isinstance(decoder_start_token_id, torch.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if token_idx is None:
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
else:
max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length
sywangyi marked this conversation as resolved.
Show resolved Hide resolved
if max_length != decoder_input_ids_start.shape[-1]:
decoder_input_ids_start = torch.nn.functional.pad(
decoder_input_ids_start,
(0, max_length - decoder_input_ids_start.shape[-1]),
value=pad_token_id,
)
decoder_input_ids = decoder_input_ids_start.index_copy(1, token_idx, decoder_input_ids)
token_idx.add_(1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
else:
if token_idx is not None:
decoder_input_ids_len = decoder_input_ids.shape[-1]
max_length = (
max_new_tokens + decoder_input_ids_len
if max_new_tokens is not None
else self.generation_config.max_length
)
decoder_input_ids = torch.nn.functional.pad(
decoder_input_ids, (0, max_length - decoder_input_ids_len), value=pad_token_id
)
token_idx.copy_(decoder_input_ids_len)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
pad_len = max_length - decoder_attention_mask.shape[-1]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :pad_len], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask

return decoder_input_ids, model_kwargs

@staticmethod
Expand Down Expand Up @@ -866,6 +903,7 @@ def generate(
bos_token_id=generation_config.bos_token_id,
device=inputs_tensor.device,
max_new_tokens=generation_config.max_new_tokens,
pad_token_id=generation_config.pad_token_id,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
Expand Down
52 changes: 52 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@
gaudi_qwen2_model_forward,
gaudi_rot_matmul,
gaudi_rot_vec_mul,
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
gaudi_SeamlessM4TDecoder_forward,
gaudi_SeamlessM4TDecoderLayer_forward,
gaudi_SeamlessM4TForTextToSpeech_forward,
gaudi_SeamlessM4TForTextToSpeech_generate,
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitModel_forward,
gaudi_SpeechT5Attention_forward,
gaudi_SpeechT5Decoder_forward,
gaudi_SpeechT5DecoderLayer_forward,
Expand All @@ -161,8 +171,10 @@
gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation,
gaudi_T5LayerSelfAttention_forward,
gaudi_T5Stack_forward,
gaudi_unconstrained_rational_quadratic_spline,
gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation,
gaudi_vit_self_attention_forward,
gaudi_VitsResidualCouplingLayer_forward,
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
gaudi_wav2vec2_tdnnlayer_forward,
Expand Down Expand Up @@ -431,6 +443,46 @@ def adapt_transformers_to_gaudi():
gaudi_persimmon_decoder_layer_forward
)

# Optimization for seamless m4t on Gaudi
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TAttention.forward = (
gaudi_SeamlessM4TAttention_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer.forward = (
gaudi_SeamlessM4TDecoderLayer_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoder.forward = (
gaudi_SeamlessM4TDecoder_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitModel.forward = (
gaudi_SeamlessM4TTextToUnitModel_forward
)
transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.forward = (
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.prepare_inputs_for_generation = gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._get_output_hifigan_lengths = (
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.forward = (
gaudi_SeamlessM4TForTextToSpeech_forward
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.generate = (
gaudi_SeamlessM4TForTextToSpeech_generate
)

transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.prepare_inputs_for_generation = (
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation
)

transformers.models.vits.modeling_vits._unconstrained_rational_quadratic_spline = (
gaudi_unconstrained_rational_quadratic_spline
)
transformers.models.vits.modeling_vits.VitsResidualCouplingLayer.forward = gaudi_VitsResidualCouplingLayer_forward

# Optimization for starcoder2 on Gaudi
transformers.models.starcoder2.modeling_starcoder2.Starcoder2ForCausalLM = GaudiStarcoder2ForCausalLM
transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model.forward = gaudi_starcoder2_model_forward
Expand Down
16 changes: 16 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@
gaudi_qwen2_attention_forward,
gaudi_qwen2_model_forward,
)
from .seamless_m4t import (
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
gaudi_SeamlessM4TDecoder_forward,
gaudi_SeamlessM4TDecoderLayer_forward,
gaudi_SeamlessM4TForTextToSpeech_forward,
gaudi_SeamlessM4TForTextToSpeech_generate,
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitModel_forward,
)
from .speecht5 import (
gaudi_generate_speech,
gaudi_SpeechT5Attention_forward,
Expand Down Expand Up @@ -178,6 +190,10 @@
gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation,
)
from .vit import gaudi_vit_self_attention_forward
from .vits import (
gaudi_unconstrained_rational_quadratic_spline,
gaudi_VitsResidualCouplingLayer_forward,
)
from .wav2vec2 import (
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
Expand Down
12 changes: 12 additions & 0 deletions optimum/habana/transformers/models/seamless_m4t/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .modeling_seamless_m4t import (
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
gaudi_SeamlessM4TDecoder_forward,
gaudi_SeamlessM4TDecoderLayer_forward,
gaudi_SeamlessM4TForTextToSpeech_forward,
gaudi_SeamlessM4TForTextToSpeech_generate,
gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward,
gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation,
gaudi_SeamlessM4TTextToUnitModel_forward,
)
Loading
Loading