Skip to content

Commit

Permalink
Fix fastpitch test nightly (#6742)
Browse files Browse the repository at this point in the history
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
  • Loading branch information
hsiehjackson authored May 26, 2023
1 parent 2e2df4a commit 4df8f33
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/collections/tts/models/test_fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
This file implemented unit tests for loading all pretrained FastPitch NGC checkpoints and generating Mel-spectrograms.
The test duration breakdowns are shown below. In general, each test for a single model is ~25 seconds on an NVIDIA RTX A6000.
"""
import random

import pytest
import torch

from nemo.collections.tts.models import FastPitchModel

Expand All @@ -38,4 +41,23 @@ def test_inference(pretrained_model, language_specific_text_example):
model, language_id = pretrained_model
text = language_specific_text_example[language_id]
parsed_text = model.parse(text)
_ = model.generate_spectrogram(tokens=parsed_text)

# Multi-Speaker
speaker_id = None
reference_spec = None
reference_spec_lens = None

if hasattr(model.fastpitch, 'speaker_emb'):
speaker_id = 0

if hasattr(model.fastpitch, 'speaker_encoder'):
if hasattr(model.fastpitch.speaker_encoder, 'lookup_module'):
speaker_id = 0
if hasattr(model.fastpitch.speaker_encoder, 'gst_module'):
bs, lens, t_spec = parsed_text.shape[0], random.randint(50, 100), model.cfg.n_mel_channels
reference_spec = torch.rand(bs, lens, t_spec)
reference_spec_lens = torch.tensor([lens]).long().expand(bs)

_ = model.generate_spectrogram(
tokens=parsed_text, speaker=speaker_id, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens
)

0 comments on commit 4df8f33

Please sign in to comment.