Skip to content

Commit

Permalink
[TTS] Fix TTS export test
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <rlangman@nvidia.com>
  • Loading branch information
rlangman committed May 15, 2023
1 parent c3009e8 commit 6ebfe8f
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if not self.ds_class in [
"nemo.collections.tts.data.dataset.TTSDataset",
"nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset",
"nemo.collections.tts.torch.data.TTSDataset",
]:
raise ValueError(f"Unknown dataset class: {self.ds_class}.")

Expand Down Expand Up @@ -383,10 +384,10 @@ def training_step(self, batch, batch_idx):
None,
)
if self.learn_alignment:
if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset":
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
else:
if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
batch_dict = batch
else:
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
audio = batch_dict.get("audio")
audio_lens = batch_dict.get("audio_lens")
text = batch_dict.get("text")
Expand Down Expand Up @@ -496,10 +497,10 @@ def validation_step(self, batch, batch_idx):
None,
)
if self.learn_alignment:
if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset":
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
else:
if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
batch_dict = batch
else:
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
audio = batch_dict.get("audio")
audio_lens = batch_dict.get("audio_lens")
text = batch_dict.get("text")
Expand Down Expand Up @@ -624,20 +625,17 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na
elif cfg.dataloader_params.shuffle:
logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset":
phon_mode = contextlib.nullcontext()
if hasattr(self.vocab, "set_phone_prob"):
phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability)

with phon_mode:
dataset = instantiate(
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.vocab,
)
else:
dataset = instantiate(cfg.dataset)
phon_mode = contextlib.nullcontext()
if hasattr(self.vocab, "set_phone_prob"):
phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability)

with phon_mode:
dataset = instantiate(
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.vocab,
)

return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)

Expand Down

0 comments on commit 6ebfe8f

Please sign in to comment.