Skip to content

Commit

Permalink
formatting and a small bug fix in Tacotron model
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Apr 15, 2021
1 parent 1ad838b commit 9cc17be
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 6 deletions.
4 changes: 3 additions & 1 deletion TTS/bin/train_vocoder_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
# Sample audio
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
real_waveform = y_G[0].squeeze(0).cpu().numpy()
tb_logger.tb_eval_audios(global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"])
tb_logger.tb_eval_audios(
global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]
)

tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)

Expand Down
10 changes: 9 additions & 1 deletion TTS/tts/layers/tacotron/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,15 @@ class Prenet(nn.Module):
"""

# pylint: disable=dangerous-default-value
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
def __init__(
self,
in_features,
prenet_type="original",
prenet_dropout=True,
dropout_at_inference=False,
out_features=[256, 256],
bias=True,
):
super().__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
Expand Down
1 change: 0 additions & 1 deletion TTS/tts/layers/tacotron/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def __init__(
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
# attention_rnn generates queries for the attention mechanism
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)

self.attention = init_attn(
attn_type=attn_type,
query_dim=self.query_dim,
Expand Down
1 change: 1 addition & 0 deletions TTS/tts/models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
attn_norm,
prenet_type,
prenet_dropout,
prenet_dropout_at_inference,
forward_attn,
trans_agent,
forward_attn_mask,
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
Expand Down Expand Up @@ -97,7 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/utils/text/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def make_symbols(
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
): # pylint: disable=redefined-outer-name
""" Function to create symbols and phonemes """
_symbols = list(characters)
_symbols = list(characters)
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
Expand Down
1 change: 1 addition & 0 deletions TTS/vocoder/tf/models/melgan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack


# pylint: disable=too-many-ancestors
# pylint: disable=abstract-method
class MelganGenerator(tf.keras.models.Model):
Expand Down

0 comments on commit 9cc17be

Please sign in to comment.