diff --git a/tacotron/hparams.py b/tacotron/hparams.py index f0370297..0ff5a8ff 100644 --- a/tacotron/hparams.py +++ b/tacotron/hparams.py @@ -19,29 +19,35 @@ cmu_dict=False, #Model - outputs_per_step = 1, - attention_dim = 128, - parameter_init = 0.5, - sharpening_factor = 1.0, - max_decode_length = None, - num_classes = None, - time_major = False, - hidden_dim = 128, - embedding_dim = 512, - num_decoder_layers=2, + outputs_per_step = 1, #number of frames to generate at each decoding step + embedding_dim = 512, #dimension of embedding space + enc_conv_num_layers=3, #number of encoder convolutional layers + enc_conv_kernel_size=(5, ), #size of encoder convolution filters for each layer + enc_conv_channels=512, #number of encoder convolutions filters for each layer + encoder_lstm_units=256, #number of lstm units for each direction (forward and backward) + attention_dim = 128, #dimension of attention space + attention_stddev_init = 0.1, #Initial standard deviation for attention projection (normal initializer) + prenet_layers=[128, 128], #number of layers and number of units of prenet + decoder_layers=2, #number of decoder lstm layers + decoder_lstm_units=512, #number of decoder lstm units on each layer + postnet_num_layers=5, #number of postnet convolutional layers + postnet_kernel_size=(5, ), #size of postnet convolution filters for each layer + postnet_channels=512, #number of postnet convolution filters for each layer max_iters=808, #Max decoder steps during inference (feel free to change it) #Training - batch_size = 32, - reg_weight = 10e-6, - decay_learning_rate = True, - decay_steps = 50000, - decay_rate = 0.97, - initial_learning_rate = 10e-3, - final_learning_rate = 10e-5, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 10e-6, + batch_size = 16, #number of training samples on each training steps + reg_weight = 10e-6, #regularization weight (for l2 regularization) + decay_learning_rate = True, #boolean, determines if the learning rate will follow an exponential decay + decay_steps = 50000, #starting point for learning rate decay (and determines the decay slope) + decay_rate = 0.97, #learning rate decay rate + initial_learning_rate = 10e-3, #starting learning rate + final_learning_rate = 10e-5, #minimal learning rate + adam_beta1 = 0.9, #AdamOptimizer beta1 parameter + adam_beta2 = 0.999, #AdamOptimizer beta2 parameter + adam_epsilon = 10e-6, #AdamOptimizer beta3 parameter + zoneout_rate=0.1, #zoneout rate for all LSTM cells in the network + dropout_rate=0.5, #dropout rate for all convolutional layers + prenet #Eval sentences sentences = [ @@ -52,8 +58,22 @@ 'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.', # From Google's Tacotron example page: 'Generative adversarial network or variational auto-encoder.', - 'The buses aren\'t the problem, they actually provide a solution.', - 'Does the quick brown fox jump over the lazy dog?', + 'Basilar membrane and otolaryngology are not auto-correlations.', + 'He has read the whole thing.', + 'He reads books.', + "Don't desert me here in the desert!", + 'He thought it was time to present the present.', + 'Thisss isrealy awhsome.', + 'Punctuation sensitivity, is working.', + 'Punctuation sensitivity is working.', + "The buses aren't the problem, they actually provide a solution.", + "The buses aren't the PROBLEM, they actually provide a SOLUTION.", + "The quick brown fox jumps over the lazy dog.", + "Does the quick brown fox jump over the lazy dog?", + "Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick?", + "She sells sea-shells on the sea-shore. The shells she sells are sea-shells I'm sure.", + "The blue lagoon is a nineteen eighty American romance adventure film.", + "Tajima Airport serves Toyooka.", 'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.', ] diff --git a/tacotron/models/attention.py b/tacotron/models/attention.py new file mode 100644 index 00000000..99d2502c --- /dev/null +++ b/tacotron/models/attention.py @@ -0,0 +1,131 @@ +"""Attention file for location based attention (compatible with tensorflow attention wrapper)""" + +import tensorflow as tf +from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _BaseAttentionMechanism +from tensorflow.python.ops import nn_ops +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import variable_scope +from hparams import hparams + + +def _location_based_score(W_query, attention_weights, W_keys): + """Impelements Bahdanau-style (cumulative) scoring function. + This attention is described in: + J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben- + gio, “Attention-based models for speech recognition,” in Ad- + vances in Neural Information Processing Systems, 2015, pp. + 577–585. + + ####################################################################### + hybrid attention (content-based + location-based) + f = F * α_{i-1} + energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f))) + ####################################################################### + + Args: + W_query: Tensor, shape '[batch_size, num_units]' to compare to location features. + attention_weights (alignments): previous attention weights, shape '[batch_size, max_time]' + Returns: + A '[batch_size, max_time]' + """ + dtype = W_query.dtype + # Get the number of hidden units from the trailing dimension of query + num_units = W_query.shape[-1].value or array_ops.shape(W_query)[-1] + + # [batch_size, max_time] -> [batch_size, max_time, 1] + attention_weights = tf.expand_dims(attention_weights, axis=2) + # location features [batch_size, max_time, filters] + f = tf.layers.conv1d(attention_weights, filters=32, + kernel_size=31, padding='same', + name='location_features') + + # Projected location features [batch_size, max_time, attention_dim] + W_fil = tf.contrib.layers.fully_connected( + f, + num_outputs=num_units, + activation_fn=None, + weights_initializer=tf.truncated_normal_initializer( + stddev=hparams.attention_stddev_init), + biases_initializer=tf.zeros_initializer(), + scope='W_filter') + + v_a = tf.get_variable( + 'v_a', shape=[num_units], dtype=tf.float32) + + return tf.reduce_sum(v_a * tf.tanh(W_keys + tf.expand_dims(W_query, axis=1) + W_fil), axis=2) + + +class LocationBasedAttention(_BaseAttentionMechanism): + """Impelements Bahdanau-style (cumulative) scoring function. + Usually referred to as "hybrid" attention (content-based + location-based) + This attention is described in: + J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben- + gio, “Attention-based models for speech recognition,” in Ad- + vances in Neural Information Processing Systems, 2015, pp. + 577–585. + """ + + def __init__(self, + num_units, + memory, + memory_sequence_length=None, + probability_fn=None, + score_mask_value=tf.float32.min, + name='LocationBasedAttention'): + """Construct the Attention mechanism. + Args: + num_units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. + probability_fn: (optional) A `callable`. Converts the score to + probabilities. The default is @{tf.nn.softmax}. Other options include + @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. + Its signature should be: `probabilities = probability_fn(score)`. + score_mask_value: (optional): The mask value for score before passing into + `probability_fn`. The default is -inf. Only used if + `memory_sequence_length` is not None. + name: Name to use when creating ops. + """ + if probability_fn is None: + probability_fn = nn_ops.softmax + wrapped_probability_fn = lambda score, _: probability_fn(score) + super(LocationBasedAttention, self).__init__( + query_layer=layers_core.Dense( + num_units, name='query_layer', use_bias=False), + memory_layer=layers_core.Dense( + num_units, name='memory_layer', use_bias=False), + memory=memory, + probability_fn=wrapped_probability_fn, + memory_sequence_length=memory_sequence_length, + score_mask_value=score_mask_value, + name=name) + self._num_units = num_units + self._name = name + + def __call__(self, query, state): + """Score the query based on the keys and values. + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + previous_alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + """ + previous_alignments = state + with variable_scope.variable_scope(None, "location_based_attention", [query]): + # processed_query shape [batch_size, query_depth] -> [batch_size, attention_dim] + processed_query = self.query_layer(query) if self.query_layer else query + # energy shape [batch_size, max_time] + energy = _location_based_score(processed_query, previous_alignments, self._keys) + # alignments shape = energy shape = [batch_size, max_time] + alignments = self._probability_fn(energy, previous_alignments) + #Seems pretty useless but tensorflow attention wrapper requires it to work properly + next_state = alignments + return alignments, next_state \ No newline at end of file diff --git a/tacotron/models/custom_decoder.py b/tacotron/models/custom_decoder.py index f3c3aa64..6901501b 100644 --- a/tacotron/models/custom_decoder.py +++ b/tacotron/models/custom_decoder.py @@ -3,6 +3,7 @@ from __future__ import print_function import collections +import tensorflow as tf from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py @@ -11,125 +12,142 @@ from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest +from .modules import stop_token_projection +from .helpers import TacoTrainingHelper, TacoTestHelper class CustomDecoderOutput( - collections.namedtuple("CustomDecoderOutput", ("rnn_output", "sample_id"))): - pass + collections.namedtuple("CustomDecoderOutput", ("rnn_output", "sample_id"))): + pass class CustomDecoder(decoder.Decoder): - """Custom sampling decoder. - - Allows for stop token prediction at inference time - and returns equivalent loss in training time. - """ - - def __init__(self, cell, helper, initial_state, output_layer=None): - """Initialize CustomDecoder. - Args: - cell: An `RNNCell` instance. - helper: A `Helper` instance. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. - The initial state of the RNNCell. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. - Raises: - TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. - """ - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) - if not isinstance(helper, helper_py.Helper): - raise TypeError("helper must be a Helper, received: %s" % type(helper)) - if (output_layer is not None - and not isinstance(output_layer, layers_base.Layer)): - raise TypeError( - "output_layer must be a Layer, received: %s" % type(output_layer)) - self._cell = cell - self._helper = helper - self._initial_state = initial_state - self._output_layer = output_layer - - @property - def batch_size(self): - return self._helper.batch_size - - def _rnn_output_size(self): - size = self._cell.output_size - if self._output_layer is None: - return size - else: - # To use layer's compute_output_shape, we need to convert the - # RNNCell's output_size entries into shapes with an unknown - # batch size. We then pass this through the layer's - # compute_output_shape and read off all but the first (batch) - # dimensions to get the output size of the rnn with the layer - # applied to the top. - output_shape_with_unknown_batch = nest.map_structure( - lambda s: tensor_shape.TensorShape([None]).concatenate(s), - size) - layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access - output_shape_with_unknown_batch) - return nest.map_structure(lambda s: s[1:], layer_output_shape) - - @property - def output_size(self): - # Return the cell output and the id - return CustomDecoderOutput( - rnn_output=self._rnn_output_size(), - sample_id=self._helper.sample_ids_shape) - - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and the sample_ids_dtype from the helper. - dtype = nest.flatten(self._initial_state)[0].dtype - return CustomDecoderOutput( - nest.map_structure(lambda _: dtype, self._rnn_output_size()), - self._helper.sample_ids_dtype) - - def initialize(self, name=None): - """Initialize the decoder. - Args: - name: Name scope for any created operations. - Returns: - `(finished, first_inputs, initial_state)`. - """ - return self._helper.initialize() + (self._initial_state,) - - def step(self, time, inputs, state, error, name=None): - """Perform a decoding step. - Args: - time: scalar `int32` tensor. - inputs: A (structure of) input tensors. - state: A (structure of) state tensors and TensorArrays. - name: Name scope for any created operations. - Returns: - `(outputs, next_state, next_inputs, finished)`. - """ - with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state, error)): - cell_outputs, cell_state, LSTM_output = self._cell(inputs, state) - - if self._output_layer is not None: - cell_outputs = self._output_layer(cell_outputs) - sample_ids = self._helper.sample( - time=time, outputs=cell_outputs, state=cell_state) - - (finished, next_inputs, next_state, stop_error) = self._helper.next_inputs( - time=time, - cell_outputs=cell_outputs, - state=cell_state, - LSTM_output=LSTM_output, - sample_ids=sample_ids) - - #we don't care about this error at synthesis time - if stop_error is not None: - #Cumulating stop token prediction error - error += stop_error - - - outputs = CustomDecoderOutput(cell_outputs, sample_ids) - return (outputs, next_state, next_inputs, finished, error) \ No newline at end of file + """Custom sampling decoder. + + Allows for stop token prediction at inference time + and returns equivalent loss in training time. + + Note: + Only use this decoder with Tacotron 2 as it only accepts tacotron custom helpers + """ + + def __init__(self, cell, helper, initial_state, output_layer=None): + """Initialize CustomDecoder. + Args: + cell: An `RNNCell` instance. + helper: A `Helper` instance. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + The initial state of the RNNCell. + output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., + `tf.layers.Dense`. Optional layer to apply to the RNN output prior + to storing the result or sampling. + Raises: + TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. + """ + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access + raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + if not isinstance(helper, helper_py.Helper): + raise TypeError("helper must be a Helper, received: %s" % type(helper)) + if (output_layer is not None + and not isinstance(output_layer, layers_base.Layer)): + raise TypeError( + "output_layer must be a Layer, received: %s" % type(output_layer)) + self._cell = cell + self._helper = helper + self._initial_state = initial_state + self._output_layer = output_layer + + @property + def batch_size(self): + return self._helper.batch_size + + def _rnn_output_size(self): + size = self._cell.output_size + if self._output_layer is None: + return size + else: + # To use layer's compute_output_shape, we need to convert the + # RNNCell's output_size entries into shapes with an unknown + # batch size. We then pass this through the layer's + # compute_output_shape and read off all but the first (batch) + # dimensions to get the output size of the rnn with the layer + # applied to the top. + output_shape_with_unknown_batch = nest.map_structure( + lambda s: tensor_shape.TensorShape([None]).concatenate(s), + size) + layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access + output_shape_with_unknown_batch) + return nest.map_structure(lambda s: s[1:], layer_output_shape) + + @property + def output_size(self): + # Return the cell output and the id + return CustomDecoderOutput( + rnn_output=self._rnn_output_size(), + sample_id=self._helper.sample_ids_shape) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and the sample_ids_dtype from the helper. + dtype = nest.flatten(self._initial_state)[0].dtype + return CustomDecoderOutput( + nest.map_structure(lambda _: dtype, self._rnn_output_size()), + self._helper.sample_ids_dtype) + + def initialize(self, name=None): + """Initialize the decoder. + Args: + name: Name scope for any created operations. + Returns: + `(finished, first_inputs, initial_state)`. + """ + return self._helper.initialize() + (self._initial_state,) + + def step(self, time, inputs, state, name=None): + """Perform a custom decoding step. + The difference compared to basic decoder is that it tries to determine + when to stop decoding at each step + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + name: Name scope for any created operations. + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): + #Call outputprojection wrapper cell + cell_outputs, cell_state = self._cell(inputs, state) + + #apply output_layer (if existant) + if self._output_layer is not None: + cell_outputs = self._output_layer(cell_outputs) + sample_ids = self._helper.sample( + time=time, outputs=cell_outputs, state=cell_state) + + #extract LSTM output concatenated to context vector using previous wrappers + lstm_cat_context = self._cell._cell.lstm_concat_context + + #Predict dynamic (Preferred to handle it in decoder step rather that inside the helper) + #Basically this extra "output" is trying to determine when to output a (similarly to NMT) + #Since Tacotron is basically trying to infer real values instead of one of possible classes (case of NMT) + #it requires a "binary-classifier" to determine when to stop since it will never output a perfect by regression. + if isinstance(self._helper, TacoTrainingHelper): + finished_p = tf.squeeze(stop_token_projection(lstm_cat_context), [1]) + elif isinstance(self._helper, TacoTestHelper): + finished_p = tf.squeeze(stop_token_projection(lstm_cat_context, activation=tf.nn.sigmoid), [1]) + else: + raise TypeError('Helper used does not belong to any supported Tacotron helpers (TacoTestHelper, TacoTrainingHelper)') + + (finished, next_inputs, next_state) = self._helper.next_inputs( + time=time, + outputs=cell_outputs, + state=cell_state, + sample_ids=sample_ids, + stop_token_prediction=finished_p) + + outputs = CustomDecoderOutput(cell_outputs, sample_ids) + return (outputs, next_state, next_inputs, finished) \ No newline at end of file diff --git a/tacotron/models/dynamic_decoder.py b/tacotron/models/dynamic_decoder.py index cd38f689..25d9c74d 100644 --- a/tacotron/models/dynamic_decoder.py +++ b/tacotron/models/dynamic_decoder.py @@ -1,4 +1,6 @@ """Seq2seq layer operations for use in neural networks. +Customized to support dynamic decoding of Tacotron 2. +Only use this dynamic decoder with Tacotron 2. For other applications use the original one from tensorflow. """ from __future__ import absolute_import @@ -18,6 +20,7 @@ from tensorflow.python.util import nest import tensorflow as tf +from .helpers import TacoTrainingHelper, TacoTestHelper def _transpose_batch_time(x): @@ -142,7 +145,7 @@ def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, finished, unused_error): return math_ops.logical_not(math_ops.reduce_all(finished)) - def body(time, outputs_ta, state, inputs, finished, error): + def body(time, outputs_ta, state, inputs, finished, loss): """Internal while_loop body. Args: time: scalar int32 tensor. @@ -154,7 +157,7 @@ def body(time, outputs_ta, state, inputs, finished, error): `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. """ (next_outputs, decoder_state, next_inputs, - decoder_finished, stop_error) = decoder.step(time, inputs, state, error) + decoder_finished) = decoder.step(time, inputs, state) next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: @@ -194,7 +197,15 @@ def _maybe_copy_state(new, cur): outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) - return (time + 1, outputs_ta, next_state, next_inputs, next_finished, stop_error) + #Cumulate loss along decoding steps + if isinstance(decoder._helper, TacoTrainingHelper): + stop_token_loss = loss + decoder._helper.stop_token_loss + elif isinstance(decoder._helper, TacoTestHelper): + stop_token_loss = loss + else: + raise TypeError('Helper used does not belong to any supported Tacotron helpers (TacoTestHelper, TacoTrainingHelper)') + + return (time + 1, outputs_ta, next_state, next_inputs, next_finished, stop_token_loss) res = control_flow_ops.while_loop( condition, @@ -210,10 +221,10 @@ def _maybe_copy_state(new, cur): final_state = res[2] steps = tf.cast(res[0], tf.float32) - stop_error = res[5] + stop_token_loss = res[5] - #Average stop_error - avg_stop_error = stop_error / steps + #Average error over decoding steps + avg_stop_loss = stop_token_loss / steps final_outputs = nest.map_structure( lambda ta: ta.stack(), final_outputs_ta) @@ -221,4 +232,4 @@ def _maybe_copy_state(new, cur): final_outputs = nest.map_structure( _transpose_batch_time, final_outputs) - return final_outputs, final_state, avg_stop_error \ No newline at end of file + return final_outputs, final_state, avg_stop_loss \ No newline at end of file diff --git a/tacotron/models/helpers.py b/tacotron/models/helpers.py index f613b7b4..aa2ad49f 100644 --- a/tacotron/models/helpers.py +++ b/tacotron/models/helpers.py @@ -1,15 +1,14 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.seq2seq import Helper -from .modules import stop_token_projection -# Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper class TacoTestHelper(Helper): - def __init__(self, batch_size, output_dim, r=1): + def __init__(self, batch_size, output_dim, r): with tf.name_scope('TacoTestHelper'): self._batch_size = batch_size self._output_dim = output_dim + self._end_token = tf.tile([0.0], [output_dim * r]) @property def batch_size(self): @@ -29,28 +28,19 @@ def initialize(self, name=None): def sample(self, time, outputs, state, name=None): return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them - def next_inputs(self, time, cell_outputs, state, LSTM_output, sample_ids, name=None): + def next_inputs(self, time, outputs, state, sample_ids, stop_token_prediction, name=None): '''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.''' with tf.name_scope('TacoTestHelper'): - #At inference time, stop_error = None - stop_error = None # we don't need it - - context = state.attention # Get context vector - #finished = tf.reduce_all(tf.equal(outputs, self._end_token), axis=1) - - #Predict if the encoder should stop (dynamic end token) - concat = tf.concat([LSTM_output, context], axis=-1) - scalar = tf.squeeze(stop_token_projection(concat, activation=tf.nn.sigmoid), [1]) - finished = tf.cast(tf.round(scalar), tf.bool) - + finished = tf.cast(tf.round(stop_token_prediction), tf.bool) + # Feed last output frame as next input. outputs is [N, output_dim * r] - next_inputs = cell_outputs + next_inputs = outputs[:, -self._output_dim:] next_state = state - return (finished, next_inputs, next_state, stop_error) + return (finished, next_inputs, next_state) class TacoTrainingHelper(Helper): - def __init__(self, inputs, targets, output_dim, r=1): + def __init__(self, inputs, targets, output_dim, r): # inputs is [N, T_in], targets is [N, T_out, D] with tf.name_scope('TacoTrainingHelper'): self._batch_size = tf.shape(inputs)[0] @@ -59,9 +49,13 @@ def __init__(self, inputs, targets, output_dim, r=1): # Feed every r-th target frame as input self._targets = targets[:, r-1::r, :] + # (same value as ) to train dynamic stop + self._end_token = tf.tile([0.0], [output_dim * r]) + # Use full length for every target because we don't want to mask the padding frames num_steps = tf.shape(self._targets)[1] self._lengths = tf.tile([num_steps], [self._batch_size]) + self._num_steps = tf.cast(num_steps, tf.float32) @property def batch_size(self): @@ -81,24 +75,23 @@ def initialize(self, name=None): def sample(self, time, outputs, state, name=None): return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them - def next_inputs(self, time, cell_outputs, state, LSTM_output, sample_ids, name=None): + def next_inputs(self, time, outputs, state, sample_ids, stop_token_prediction, name=None): with tf.name_scope(name or 'TacoTrainingHelper'): - context = state.attention #Get context vector - finished = (time + 1 >= self._lengths) #return true finished + #A sequence is finished if we reach the full true length or we encounter padding + #It is essential to train the model on stopping when encountering padding + #to gain the desired dynamic generation + true_finished = tf.logical_or((time + 1 >= self._lengths), tf.reduce_all(tf.equal(outputs, self._end_token), axis=1)) - #Compute model prediction to stop token - concat = tf.concat([LSTM_output, context], axis=-1) - finished_p = tf.squeeze(stop_token_projection(concat), [1]) + #Compute stop_token_loss of actual decoding step (for dynamic stop training) + self.stop_token_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.cast(true_finished, tf.float32), + logits=stop_token_prediction)) / self._num_steps - #Compute the stop token error for infer time - stop_error = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(finished, tf.float32), - logits=finished_p)) - next_inputs = self._targets[:, time, :] #teacher-forcing: return true frame - next_state = state - return (finished, next_inputs, next_state, stop_error) + next_inputs = self._targets[:, time, :] #Teacher-forcing: return true frame + next_state = state #No change on the cell states + return (true_finished, next_inputs, next_state) #return true "finished" state def _go_frames(batch_size, output_dim): '''Returns all-zero frames for a given batch size and output dimension''' - return tf.tile([[0.0]], [batch_size, output_dim]) - + return tf.tile([[0.0]], [batch_size, output_dim]) \ No newline at end of file diff --git a/tacotron/models/modules.py b/tacotron/models/modules.py index 39ff3c82..7487916d 100644 --- a/tacotron/models/modules.py +++ b/tacotron/models/modules.py @@ -5,7 +5,7 @@ def conv1d(inputs, kernel_size, channels, activation, is_training, scope): - drop_rate = 0.5 + drop_rate = hparams.dropout_rate with tf.variable_scope(scope): conv1d_output = tf.layers.conv1d( @@ -18,15 +18,13 @@ def conv1d(inputs, kernel_size, channels, activation, is_training, scope): return tf.layers.dropout(batched, rate=drop_rate, training=is_training, name='dropout_{}'.format(scope)) - - def enc_conv_layers(inputs, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.relu, scope=None): if scope is None: scope = 'enc_conv_layers' with tf.variable_scope(scope): x = inputs - for i in range(3): + for i in range(hparams.enc_conv_num_layers): x = conv1d(x, kernel_size, channels, activation, is_training, 'conv_layer_{}_'.format(i + 1)+scope) return x @@ -37,23 +35,23 @@ def postnet(inputs, is_training, kernel_size=(5, ), channels=512, activation=tf. with tf.variable_scope(scope): x = inputs - for i in range(4): + for i in range(hparams.postnet_num_layers - 1): x = conv1d(x, kernel_size, channels, activation, is_training, 'conv_layer_{}_'.format(i + 1)+scope) x = conv1d(x, kernel_size, channels, lambda _: _, is_training, 'conv_layer_{}_'.format(5)+scope) return x -def bidirectional_LSTM(inputs, input_lengths, scope, is_training): +def bidirectional_LSTM(inputs, input_lengths, scope, is_training, size=256, zoneout=0.1): with tf.variable_scope(scope): outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn( - ZoneoutLSTMCell(256, + ZoneoutLSTMCell(size, is_training, - zoneout_factor_cell=0.1, - zoneout_factor_output=0.1,), - ZoneoutLSTMCell(256, + zoneout_factor_cell=zoneout, + zoneout_factor_output=zoneout), + ZoneoutLSTMCell(size, is_training, - zoneout_factor_cell=0.1, - zoneout_factor_output=0.1,), + zoneout_factor_cell=zoneout, + zoneout_factor_output=zoneout), inputs, sequence_length=input_lengths, dtype=tf.float32) @@ -65,16 +63,28 @@ def bidirectional_LSTM(inputs, input_lengths, scope, is_training): encoder_final_state_h = tf.concat( (fw_state.h, bw_state.h), 1) - #Get the final state to pass it to attention mechanism as initial state + #Get the final state, we don't really use it in our case + #I'll keep it just in case final_state = LSTMStateTuple( c=encoder_final_state_c, h=encoder_final_state_h) return tf.concat(outputs, axis=2), final_state # Concat forward + backward outputs and return with final states +def unidirectional_LSTM(input_cell, is_training, layers=2, size=512, zoneout=0.1): + #Create a set of LSTM layers + rnn_layers = [ZoneoutLSTMCell(size, is_training, + zoneout_factor_cell=zoneout, + zoneout_factor_output=zoneout) for i in range(layers)] + + #Add the first concatenation layer wrapper + rnn_layers = [input_cell] + rnn_layers + + return tf.nn.rnn_cell.MultiRNNCell(rnn_layers, state_is_tuple=True) + def prenet(inputs, is_training, layer_sizes=[128, 128], scope=None): x = inputs - drop_rate = 0.5 + drop_rate = hparams.dropout_rate if scope is None: scope = 'prenet' @@ -84,23 +94,11 @@ def prenet(inputs, is_training, layer_sizes=[128, 128], scope=None): dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_{}'.format(i + 1)) #The paper discussed introducing diversity in generation at inference time #by using a dropout of 0.5 only in prenet layers. - #In this implementation we're supposing they meant to keep the dropout even at inference time - #So we set training=True at all times (even during synthesis) - x = tf.layers.dropout(dense, rate=drop_rate, training=True, + x = tf.layers.dropout(dense, rate=drop_rate, training=is_training, name='dropout_{}_'.format(i + 1) + scope) return x - -def unidirectional_LSTM(is_training, layers=2, size=512): - - rnn_layers = [ZoneoutLSTMCell(size, is_training, zoneout_factor_cell=0.1, - zoneout_factor_output=0.1, - ext_proj=hparams.num_mels) for i in range(layers)] - - stacked_LSTM_Cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) - return stacked_LSTM_Cell - -def projection(x, shape=512, activation=None, scope=None): +def projection(x, shape=80, activation=None, scope=None): if scope is None: scope = 'linear_projection' diff --git a/tacotron/models/rnn_wrappers.py b/tacotron/models/rnn_wrappers.py index 64b1ac4c..3722f753 100644 --- a/tacotron/models/rnn_wrappers.py +++ b/tacotron/models/rnn_wrappers.py @@ -1,22 +1,21 @@ +"""A set of RNN wrappers usefull for tacotron 2 architecture +All notations and variable names were used in concordance with originial tensorflow implementation +Some tensors were passed through wrappers to make sure we respect the described architecture +""" + import numpy as np import tensorflow as tf from tensorflow.contrib.rnn import RNNCell from .modules import prenet, projection +from tensorflow.python.framework import ops from hparams import hparams -class TacotronDecoderWrapper(RNNCell): - """Computes custom Tacotron decoder and return decoder output and state at each step - - decoder architecture: - Prenet: 2 dense layers, 128 units each - * concat(Prenet output + context vector) - RNNStack (LSTM): 2 uni-directional LSTM layers with 512 units each - * concat(LSTM output + context vector) - Linear projection layer: output_dim = decoder_output - """ + +class DecoderPrenetWrapper(RNNCell): + '''Runs RNN inputs through a prenet before sending them to the cell.''' def __init__(self, cell, is_training): - super(TacotronDecoderWrapper, self).__init__() + super(DecoderPrenetWrapper, self).__init__() self._cell = cell self._is_training = is_training @@ -26,31 +25,107 @@ def state_size(self): @property def output_size(self): - #return (self.batch_size, hparams.num_mels) return self._cell.output_size def call(self, inputs, state): - #Get context vector from cell state - context_vector = state.attention - cell_state = state.cell_state + prenet_out = prenet(inputs, self._is_training, hparams.prenet_layers, scope='decoder_attention_prenet') + self._prenet_out = prenet_out + return self._cell(prenet_out, state) + + def zero_state(self, batch_size, dtype): + return self._cell.zero_state(batch_size, dtype) + + +class ConcatPrenetAndAttentionWrapper(RNNCell): + '''Concatenates prenet output with the attention context vector. + This is expected to wrap a cell wrapped with an AttentionWrapper constructed with + attention_layer_size=None and output_attention=False. Such a cell's state will include an + "attention" field that is the context vector. + ''' + def __init__(self, cell): + super(ConcatPrenetAndAttentionWrapper, self).__init__() + self._cell = cell + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + #attention is stored in attentionwrapper cell state + return self._cell.output_size + self._cell.state_size.attention + + def call(self, inputs, state): + #We assume paper writers mentionned the attention network output when + #they say "The pre-net output and attention context vector are concatenated and + #passed through a stack of 2 uni-directional LSTM layers" + #We rely on the original tacotron architecture for this hypothesis. + output, res_state = self._cell(inputs, state) - #Compute prenet output - prenet_outputs = prenet(inputs, self._is_training, scope='decoder_prenet_layer') + #Store attention in this wrapper to make access easier from future wrappers + self._context_vector = res_state.attention + return tf.concat([output, self._context_vector], axis=-1), res_state - #Concat prenet output and context vector - concat_output_prenet = tf.concat([prenet_outputs, context_vector], axis=-1) + def zero_state(self, batch_size, dtype): + return self._cell.zero_state(batch_size, dtype) - #Compute LSTM output - LSTM_output, next_cell_state = self._cell(concat_output_prenet, cell_state) - #Concat LSTM output and context vector - concat_output_LSTM = tf.concat([LSTM_output, context_vector], axis=-1) +class ConcatLSTMOutputAndAttentionWrapper(RNNCell): + '''Concatenates decoder RNN cell output with the attention context vector. + This is expected to wrap a cell wrapped with an AttentionWrapper constructed with + attention_layer_size=None and output_attention=False. Such a cell's state will include an + "attention" field that is the context vector. + ''' + def __init__(self, cell): + super(ConcatLSTMOutputAndAttentionWrapper, self).__init__() + self._cell = cell + self._prenet_attention_cell = self._cell._cells[0] - #Linear projection - proj_shape = hparams.num_mels - cell_output = (projection(concat_output_LSTM, proj_shape, scope='decoder_projection_layer'), LSTM_output) + @property + def state_size(self): + return self._cell.state_size - return cell_output, next_cell_state + @property + def output_size(self): + return self._cell.output_size + self._prenet_attention_cell.state_size.attention + + def call(self, inputs, state): + output, res_state = self._cell(inputs, state) + context_vector = self._prenet_attention_cell._context_vector + self.lstm_concat_context = tf.concat([output, context_vector], axis=-1) + return self.lstm_concat_context, res_state def zero_state(self, batch_size, dtype): return self._cell.zero_state(batch_size, dtype) + + +# class LinearProjectionWrapper(RNNCell): +# """Operator adding an output projection to the given cell. +# This wrapper will perform a linear transformation with specified activation function.(Default to None) +# """ +# def __init__(self, cell, projection_dim, activation=None): +# super(LinearProjectionWrapper, self).__init__() +# self._cell = cell +# self._projection_dim = projection_dim +# self._activation = activation + +# @property +# def state_size(self): +# return self._cell.state_size + +# @property +# def output_size(self): +# return self._projection_dim + +# def zero_state(self, batch_size, dtype): +# with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): +# return self._cell.zero_state(batch_size, dtype) + +# def call(self, inputs, state): +# """Run the cell and output projection on inputs, starting from state.""" +# output, res_state = self._cell(inputs, state) +# projected = projection(output, self._projection_dim) +# if self._activation: +# projected = self._activation(projected) + +# return projected, res_state \ No newline at end of file diff --git a/tacotron/models/tacotron.py b/tacotron/models/tacotron.py index f4c278e5..2de4649e 100644 --- a/tacotron/models/tacotron.py +++ b/tacotron/models/tacotron.py @@ -3,11 +3,13 @@ from utils.infolog import log from .helpers import TacoTrainingHelper, TacoTestHelper from .modules import * -from .rnn_wrappers import TacotronDecoderWrapper from models.zoneout_LSTM import ZoneoutLSTMCell -from .dynamic_decoder import dynamic_decode +from tensorflow.contrib.seq2seq import AttentionWrapper +from .rnn_wrappers import * +from tensorflow.contrib.rnn import MultiRNNCell, OutputProjectionWrapper +from .attention import LocationBasedAttention from .custom_decoder import CustomDecoder -from .attention_wrapper import AttentionWrapper, LocationBasedAttention +from .dynamic_decoder import dynamic_decode class Tacotron(): @@ -31,8 +33,6 @@ def initialize(self, inputs, input_lengths, mel_targets=None, gta=False): """ with tf.variable_scope('inference') as scope: is_training = mel_targets is not None and not gta - print('training: ', is_training) - print('gta: ', gta) batch_size = tf.shape(inputs)[0] hp = self._hparams @@ -43,57 +43,69 @@ def initialize(self, inputs, input_lengths, mel_targets=None, gta=False): embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) #Encoder - enc_conv_outputs = enc_conv_layers(embedded_inputs, is_training) + enc_conv_outputs = enc_conv_layers(embedded_inputs, is_training, + kernel_size=hp.enc_conv_kernel_size, channels=hp.enc_conv_channels) #Paper doesn't specify what to do with final encoder state - #We send them however to the attention mechanism as source state - #(direct link between source and targets cells) + #So we will simply drop it encoder_outputs, encoder_states = bidirectional_LSTM(enc_conv_outputs, input_lengths, - 'encoder_LSTM', is_training=is_training) - - #DecoderWrapper - decoder_cell = TacotronDecoderWrapper( - unidirectional_LSTM(is_training, layers=hp.num_decoder_layers, size=512), - is_training) - - #AttentionWrapper on top of TacotronDecoderWrapper - attention_decoder = AttentionWrapper( - decoder_cell, + 'encoder_LSTM', is_training=is_training, size=hp.encoder_lstm_units, + zoneout=hp.zoneout_rate) + + #Attention + attention_cell = AttentionWrapper( + DecoderPrenetWrapper(ZoneoutLSTMCell(hp.attention_dim, is_training, + zoneout_factor_cell=hp.zoneout_rate, + zoneout_factor_output=hp.zoneout_rate), is_training), LocationBasedAttention(hp.attention_dim, encoder_outputs), alignment_history=True, output_attention=False, - name='attention_decoder_wrapper') + name='attention_cell') - #We pass (num_decoder_layers times) encoder final states to the decoder of #layers (num_decoder_layers) - decoder_init_state = attention_decoder.zero_state(batch_size=batch_size, dtype=tf.float32).clone( - cell_state=tuple(encoder_states for _ in range(hp.num_decoder_layers))) + #Concat Prenet output with context vector + concat_cell = ConcatPrenetAndAttentionWrapper(attention_cell) + + #Decoder layers (attention pre-net + 2 unidirectional LSTM Cells) + decoder_cell = unidirectional_LSTM(concat_cell, is_training, + layers=hp.decoder_layers, size=hp.decoder_lstm_units, + zoneout=hp.zoneout_rate) + + #Concat LSTM output with context vector + concat_decoder_cell = ConcatLSTMOutputAndAttentionWrapper(decoder_cell) + + #Projection to mel-spectrogram dimension (linear transformation) + output_cell = OutputProjectionWrapper(concat_decoder_cell, hp.num_mels * hp.outputs_per_step) #Define the helper for our decoder - if is_training or gta: - helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.outputs_per_step) + if (is_training or gta) == True: + self.helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.outputs_per_step) else: - helper = TacoTestHelper(batch_size, hp.num_mels, hp.outputs_per_step) + self.helper = TacoTestHelper(batch_size, hp.num_mels, hp.outputs_per_step) #We"ll only limit decoder time steps during inference (consult hparams.py to modify the value) max_iterations = None if is_training else hp.max_iters + #initial decoder state + decoder_init_state = output_cell.zero_state(batch_size=batch_size, dtype=tf.float32) + #Decode - (decoder_output, _), final_decoder_state, self.stop_error = dynamic_decode( - CustomDecoder(attention_decoder, helper, decoder_init_state), - impute_finished=True, maximum_iterations=max_iterations) + (decoder_output, _), final_decoder_state, self.stop_token_loss = dynamic_decode( + CustomDecoder(output_cell, self.helper, decoder_init_state), + impute_finished=True, #Cut out padded parts + maximum_iterations=max_iterations) #Compute residual using post-net - residual = postnet(decoder_output, is_training) + residual = postnet(decoder_output, is_training, + kernel_size=hp.postnet_kernel_size, channels=hp.postnet_channels) - #Project residual to same dimension as mel spectogram - proj_dim = hp.num_mels - projected_residual = projection(residual, shape=proj_dim, + #Project residual to same dimension as mel spectrogram + projected_residual = projection(residual, shape=hp.num_mels, scope='residual_projection') - #Compute the mel spectogram + #Compute the mel spectrogram mel_outputs = decoder_output + projected_residual #Grab alignments from the final decoder state - alignments = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0]) + alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0]) self.inputs = inputs self.input_lengths = input_lengths @@ -131,7 +143,7 @@ def add_loss(self): self.after_loss = after self.regularization_loss = regularization - self.loss = self.before_loss + self.after_loss + self.regularization_loss + self.stop_error + self.loss = self.before_loss + self.after_loss + self.stop_token_loss + self.regularization_loss def add_optimizer(self, global_step): '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called. @@ -148,11 +160,17 @@ def add_optimizer(self, global_step): else: self.learning_rate = tf.convert_to_tensor(hp.initial_learning_rate) - self.optimize = tf.train.AdamOptimizer(self.learning_rate, - hp.adam_beta1, - hp.adam_beta2, - hp.adam_epsilon).minimize(self.loss, - global_step=global_step) + optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2, hp.adam_epsilon) + gradients, variables = zip(*optimizer.compute_gradients(self.loss)) + self.gradients = gradients + #Clip the gradients to avoid rnn gradient explosion + clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) + + # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See: + # https://github.com/tensorflow/tensorflow/issues/1122 + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables), + global_step=global_step) def _learning_rate_decay(self, init_lr, global_step): # Exponential decay starting after 50,000 iterations diff --git a/tacotron/preprocess.py b/tacotron/preprocess.py index 08bcd512..46d8d689 100644 --- a/tacotron/preprocess.py +++ b/tacotron/preprocess.py @@ -27,7 +27,7 @@ def write_metadata(metadata, out_dir): def main(): parser = argparse.ArgumentParser() parser.add_argument('--base_dir', default=os.path.dirname(os.path.realpath(__file__))) - parser.add_argument('--input', default='LJSpeech-1.0') + parser.add_argument('--input', default='LJSpeech-1.1') parser.add_argument('--output', default='training') parser.add_argument('--n_jobs', type=int, default=cpu_count()) args = parser.parse_args() diff --git a/tacotron/synthesize.py b/tacotron/synthesize.py index 795cafd6..86d8acc4 100644 --- a/tacotron/synthesize.py +++ b/tacotron/synthesize.py @@ -17,9 +17,13 @@ def run_eval(args, checkpoint_path): #Create output path if it doesn't exist os.makedirs(eval_dir, exist_ok=True) - for i, text in enumerate(tqdm(hparams.sentences)): - start = time.time() - synth.synthesize(text, i, eval_dir, None) + with open(os.path.join(eval_dir, 'map.txt'), 'w') as file: + file.write('"input"|"generated_mel"\n') + for i, text in enumerate(tqdm(hparams.sentences)): + start = time.time() + mel_filename = synth.synthesize(text, i+1, eval_dir, None) + + file.write('"{}"|"{}"\n'.format(text, mel_filename)) print('synthesized mel spectrograms at {}'.format(eval_dir)) def run_synthesis(args, checkpoint_path): @@ -42,10 +46,14 @@ def run_synthesis(args, checkpoint_path): os.makedirs(synth_dir, exist_ok=True) print('starting synthesis') - for i, meta in enumerate(tqdm(metadata)): - text = meta[2] - mel_filename = os.path.join(args.input_dir, meta[0]) - synth.synthesize(text, i, synth_dir, mel_filename) + with open(os.path.join(synth_dir, 'map.txt'), 'w') as file: + file.write('"input"|"frames"|"target_mel"|"generated_mel"\n') + for i, meta in enumerate(tqdm(metadata)): + text = meta[2] + mel_filename = os.path.join(args.input_dir, meta[0]) + mel_output_filename = synth.synthesize(text, i+1, synth_dir, mel_filename) + + file.write('"{}"|"{}"|"{}"|"{}"\n'.format(text, meta[1], mel_filename, mel_output_filename)) print('synthesized mel spectrograms at {}'.format(synth_dir)) diff --git a/tacotron/synthesizer.py b/tacotron/synthesizer.py index 4f205690..4718c541 100644 --- a/tacotron/synthesizer.py +++ b/tacotron/synthesizer.py @@ -44,6 +44,8 @@ def synthesize(self, text, index, out_dir, mel_filename): mels = self.session.run(self.mel_outputs, feed_dict=feed_dict) # Write the spectrogram to disk - mel_filename = 'ljspeech-mel-{:05d}.npy'.format(index) - np.save(os.path.join(out_dir, mel_filename), mels, allow_pickle=False) + #Note: outputs mel-spectrogram files and input ones have same names, just different folders + mel_filename = os.path.join(out_dir, 'ljspeech-mel-{:05d}.npy'.format(index)) + np.save(mel_filename, mels, allow_pickle=False) + return mel_filename \ No newline at end of file diff --git a/tacotron/train.py b/tacotron/train.py index 156200aa..7abdbe31 100644 --- a/tacotron/train.py +++ b/tacotron/train.py @@ -22,7 +22,7 @@ def add_stats(model): tf.summary.scalar('before_loss', model.before_loss) tf.summary.scalar('after_loss', model.after_loss) tf.summary.scalar('regularization_loss', model.regularization_loss) - tf.summary.scalar('stop_token_loss', model.stop_error) + tf.summary.scalar('stop_token_loss', model.stop_token_loss) tf.summary.scalar('loss', model.loss) return tf.summary.merge_all()