Skip to content

Commit

Permalink
Fixup overflow (coqui-ai#2218)
Browse files Browse the repository at this point in the history
* Update overflow config

* Pulling shuffle and drop_last  from config

* Print training stats for overflow
  • Loading branch information
erogol authored and shivammehta25 committed Dec 23, 2022
1 parent aedd795 commit 253b03f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
5 changes: 3 additions & 2 deletions TTS/tts/configs/overflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
lr_scheduler: str = None

# overrides
min_seq_len: int = 3
max_seq_len: int = 500
min_text_len: int = 10
max_text_len: int = 500
min_audio_len: int = 512

# testing
test_sentences: List[str] = field(
Expand Down
9 changes: 9 additions & 0 deletions TTS/tts/configs/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ class BaseTTSConfig(BaseTrainingConfig):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.
shuffle (bool):
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
drop_last (bool):
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
issues that emerge from the partial batch statistics. Defaults to True.
add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence.
Expand Down Expand Up @@ -309,6 +316,8 @@ class BaseTTSConfig(BaseTrainingConfig):
precompute_num_workers: int = 0
use_noise_augment: bool = False
start_by_longest: bool = False
shuffle: bool = False
drop_last: bool = False
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ def get_data_loader(
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=True, # if there is no other sampler
shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler
collate_fn=dataset.collate_fn,
drop_last=False, # setting this False might cause issues in AMP training.
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
Expand Down
14 changes: 12 additions & 2 deletions TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ def forward(self, text, text_len, mels, mel_len):

return outputs

@staticmethod
def _training_stats(batch):
stats = {}
stats["avg_text_length"] = batch["text_lengths"].float().mean()
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
return stats

def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
Expand All @@ -171,9 +180,10 @@ def train_step(self, batch: dict, criterion: nn.Module):
mels=mel_input,
mel_len=mel_lengths,
)
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))

loss_dict = criterion(outputs["log_probs"])

# for printing useful statistics on terminal
loss_dict.update(self._training_stats(batch))
return outputs, loss_dict

def eval_step(self, batch: Dict, criterion: nn.Module):
Expand Down

0 comments on commit 253b03f

Please sign in to comment.