Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup overflow #2218

Merged
merged 6 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions TTS/tts/configs/overflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
lr_scheduler: str = None

# overrides
min_seq_len: int = 3
max_seq_len: int = 500
min_text_len: int = 10
max_text_len: int = 500
min_audio_len: int = 512

# testing
test_sentences: List[str] = field(
Expand Down
9 changes: 9 additions & 0 deletions TTS/tts/configs/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ class BaseTTSConfig(BaseTrainingConfig):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.

shuffle (bool):
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.

drop_last (bool):
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
issues that emerge from the partial batch statistics. Defaults to True.

add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence.
Expand Down Expand Up @@ -309,6 +316,8 @@ class BaseTTSConfig(BaseTrainingConfig):
precompute_num_workers: int = 0
use_noise_augment: bool = False
start_by_longest: bool = False
shuffle: bool = False
drop_last: bool = False
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ def get_data_loader(
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=True, # if there is no other sampler
shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler
collate_fn=dataset.collate_fn,
drop_last=False, # setting this False might cause issues in AMP training.
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
Expand Down
14 changes: 12 additions & 2 deletions TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ def forward(self, text, text_len, mels, mel_len):

return outputs

@staticmethod
def _training_stats(batch):
stats = {}
stats["avg_text_length"] = batch["text_lengths"].float().mean()
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
return stats

def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
Expand All @@ -171,9 +180,10 @@ def train_step(self, batch: dict, criterion: nn.Module):
mels=mel_input,
mel_len=mel_lengths,
)
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))

loss_dict = criterion(outputs["log_probs"])

# for printing useful statistics on terminal
loss_dict.update(self._training_stats(batch))
return outputs, loss_dict

def eval_step(self, batch: Dict, criterion: nn.Module):
Expand Down