Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tacotron2 + WaveRNN experiments #26

Closed
8 tasks done
erogol opened this issue May 6, 2018 · 80 comments
Closed
8 tasks done

Tacotron2 + WaveRNN experiments #26

erogol opened this issue May 6, 2018 · 80 comments
Labels
experiment experimental things improvement a new feature

Comments

@erogol
Copy link
Contributor

erogol commented May 6, 2018

Tacotron2: https://arxiv.org/pdf/1712.05884.pdf
WaveRNN: https://github.com/erogol/WaveRNN forked from https://github.com/fatchord/WaveRNN

The idea is to add Tacotron2 as another alternative if it is really useful then the current model.

  • Code boilerplate tracotron2 architecture.
  • Train Tacotron2 and compare results (Baseline)
  • Train TTS current model in a comparable size with T2. (Current TTS model has 7M and Tacotron2 has 28M parameters)
  • Add TTS specific architectural changes to T2 and compare with the baseline.
  • Train WaveRNN a vocoder on generated spectrograms
  • Train a better stopnet. Stopnet sometimes misses the prediction that leads to unstable predictions. Maybe it is better to use a RNN as previous TTS version.
  • Release LJspeech Tacotron 2 model. (soon)
  • Release LJSpeech WaveRNN model. (https://github.com/erogol/WaveRNN)

Best result so far: https://soundcloud.com/user-565970875/ljspeech-logistic-wavernn

Some findings:

  • Adding an entropy loss for the attention seems to improve the cases hard to learn the alignment. It forces network to learn more sparse and noise free alignment weights.
entropy = torch.distributions.Categorical(probs=alignments).entropy()
entropy_loss = (entropy / np.log(alignments.shape[1])).mean()
loss += 1e-4 * entropy_loss

Here is the alignment with entropy loss. However, if you keep the loss weight high, then it degrades the model's generalization for new words.
image

  • Replacing Prenet with a BatchNorm version ehnace the performance quite a lot.
  • A network with BN Prenet is harder to learn the attention. It looks like the network needs a level of noise onto autoregressive connection to relate encoder output to network output. Otwerwise, in teacher forcing mode, network does not need encoder output since it finds previous prediction frame enough to generate the next frame.
  • Forward attention seems more robust to longer sequences and faster to align. (https://arxiv.org/abs/1807.06736)
@erogol erogol added the improvement a new feature label May 18, 2018
@erogol erogol changed the title Update model to Tacotron2. Adding Tacotron2 Mar 1, 2019
@erogol
Copy link
Contributor Author

erogol commented Mar 1, 2019

Tacotron2 baseline implementation aligns after 35K iterations.
image

Also it gives better melspectrograms L1 loss compared to Tacotron1. Tacotron2 trained 150K iters and Tacotron trained 320K,

Tacotron1 loss: 0.34
Decoder loss: 0.3115
Postnet loss: 0.2985

@erogol
Copy link
Contributor Author

erogol commented Mar 1, 2019

Tacotron2 + TTS updates aligns much faster after 14K. It also reaches the same loss values of the baseline implementation which it gets after 100K

image

Here is the attention alignment after 14K

image

@erogol erogol changed the title Adding Tacotron2 Tacotron2 experimentw Mar 1, 2019
@erogol erogol changed the title Tacotron2 experimentw Tacotron2 experiments Mar 1, 2019
@erogol erogol added the experiment experimental things label Mar 1, 2019
@erogol
Copy link
Contributor Author

erogol commented Mar 4, 2019

Here is the first WaveRNN based result.

https://soundcloud.com/user-565970875/wavernn-taco2

image

@erogol erogol changed the title Tacotron2 experiments Tacotron2 + WaveRNN experiments Mar 7, 2019
@erogol
Copy link
Contributor Author

erogol commented Mar 7, 2019

Here is a pocket article read by the mode WaveRNN + Tacotron2
Tacotron2 - 90k iterations
WaveRNN - 550K iterations
https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2

@ZohaibAhmed
Copy link

@erogol Awesome results. Did you train tacotron2 with batch size of 32 and r=1? Was it done on multi-gpu or just one?

@G-Wang
Copy link

G-Wang commented Mar 7, 2019

@erogol Fatcord's wavernn works very well with 10 bit audio as well, which helps to eliminate quite a bit of the background static.

@erogol
Copy link
Contributor Author

erogol commented Mar 9, 2019

@ZohaibAhmed I trained it with 4 GPUs with 16 batch size per GPU which is the max I can fit into a 1080ti. Much better results are coming through with couple of architectural changes.

@G-Wang thx for pointing. I am going to try. I just trained WaveRNN once and now I plan to discover some model updates and quantization schemes. I ll let you know here

@m-toman
Copy link
Contributor

m-toman commented Mar 9, 2019

I'm currently pretty happy with https://github.com/h-meru/Tacotron-WaveRNN/blob/master/README.md

I struggled some time when I integrated a similar model (something between that and the amazon universal vocoder architecture) but in the end I found the reason was just the Noam LR scheme performed much worse than just the fixed LR with Adam which in the worst case I could just reduce a little bit manually when the loss starts to behave funny.
But for LJ the loss function is decreasing textbook-like.

@erogol
Copy link
Contributor Author

erogol commented Mar 11, 2019

Training WaveRNN with 10bit quantization. Also the final solution in #50 improved the model performance significantly. Results soon to be shared...

@G-Wang
Copy link

G-Wang commented Mar 11, 2019

@erogol I've tried a few variants such as mu-law, gaussian/beta output, mixture logistics, but I found 9 and 10-bit at the end of the day gave the best result and fastest training, you can hear some samples here: https://github.com/G-Wang/WaveRNN-Pytorch

@geneing
Copy link
Contributor

geneing commented Mar 11, 2019

@erogol There is also a highly optimized pytorch and C++ implementation which I am training on the output of TTS extract_features script: https://github.com/geneing/WaveRNN-Pytorch

@erogol Is your solution to #50 checked into master branch?

@erogol
Copy link
Contributor Author

erogol commented Mar 11, 2019

@geneing that's cool. After a certain point, we can definitely consult to your repo for inference optimization. @reuben might also find it interesting.

@geneing not yet, needs a bit of testing but soon with the trained model.

@OswaldoBornemann
Copy link

@erogol thanks erogol. When could we use the tacotron2 + WaveRNN model to train on our datasets?

@mrgloom
Copy link

mrgloom commented Mar 12, 2019

Is any pretrained wavernn model available for test?

@erogol
Copy link
Contributor Author

erogol commented Mar 14, 2019

This is the latest result.
10 bit WaveRNN trained for 400K initiazlied from 9 bit model
Tacotron2 trained for 170K - postnetloss: 0.019 decoderloss: 0.023

Problems:

  • There is high peak with "s" sound due to the dataset quality
  • Static background noise due to 10 bit quantization
  • Some intonation problems due to false Tacotron2 alignment. Maybe it is better to use an earlier checkpoint.

Below is the comparison of BN vs Dropout prenets as explain in #50

image

Audio Sample:
https://soundcloud.com/user-565970875/commonvoice-1

Audio Sample with the best GL based model:
https://soundcloud.com/user-565970875/sets/ljspeech-model-185k-iters-commit-db7f3d3

@m-toman
Copy link
Contributor

m-toman commented Mar 14, 2019

Interesting results!
Do you use ground truth aligned mel specs generated by tacotron for training the neural vocoder, or do you train it from features extracted from the recordings?

@erogol
Copy link
Contributor Author

erogol commented Mar 14, 2019

@m-toman I don't see what you mean by the second option

I trained it with the specs from tacotron with teacher forcing.

@m-toman
Copy link
Contributor

m-toman commented Mar 15, 2019

Thanks, with the second option I mean just calculating Mel specs from the recordings and then training on those, without using Tacotron at all.
This had a lower MOS in the WaveNet paper so I wondered which approach you used.

In the Amazon universal vocoder paper they did that (I assume) because they train from lots of data from many different speakers and recording conditions, so it's a trickier to use GTA mel specs (although you could do it with a very large multispeaker model I suppose)

@erogol
Copy link
Contributor Author

erogol commented Mar 15, 2019

@m-toman I was not aware that they show better results by using ground-truth mels specs. However, it does not really make sense to me (without any experimentation) since I believe vocoder is able to learn to obviate the mistakes done by the first network as it is trained with synthesized specs. But I guess to be sure, I need to try first.

@m-toman
Copy link
Contributor

m-toman commented Mar 15, 2019

No, they were worse (lower MOS 1-5 with 5 as the best score) as expected.
Just checked, it wasn't in the Wavenet paper, it was in the Taco2 paper:
https://arxiv.org/pdf/1712.05884.pdf in 3.3.1

"As expected, the best performance is obtained when the features
used for training match those used for inference. However, when
trained on ground truth features and made to synthesize from predicted features, the result is worse than the opposite. This is due to
the tendency of the predicted spectrograms to be oversmoothed and
less detailed than the ground truth – a consequence of the squared
error loss optimized by the feature prediction network. When trained
on ground truth spectrograms, the network does not learn to generate
high quality speech waveforms from oversmoothed features."

So that's fine.
Is there a script in the repo to produce the teacher forced data for the training set?

@ljackov
Copy link

ljackov commented May 12, 2019

Can you give me any hints on porting the Tacotron model to OpenCV?

@erogol
Copy link
Contributor Author

erogol commented May 12, 2019

sorry, no idea.

@ljackov
Copy link

ljackov commented May 16, 2019

Hi again,

Is there a reason to pad externally with ConstantPad1d and not use the Conv1d padding in Tacotron/CBHG/BatchNorm1d? If there is no difference, I think using the Conv1d padding would be faster.

@erogol
Copy link
Contributor Author

erogol commented May 16, 2019

Conv1d does not support asymmetric padding. And please don't place random questions to random places.

@mrgloom
Copy link

mrgloom commented May 17, 2019

I have tested tacotron2 + wavernn models on CPU, I have checked np.sum(mel_specs) and np.sum(waveform) and it gives me different output each run, tacotron2 results are consistent, so wavernn results different across runs.
As I understand source of randomness is in sample_from_discretized_mix_logistic
So I added

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

to https://github.com/erogol/WaveRNN/blob/master/utils/distribution.py
And this fixes reproducibility issue.

@mrgloom
Copy link

mrgloom commented May 17, 2019

But still if I call same model on same input twice in a row I don't get the same result, does it have some internal state that should be release before second run?

As I can see here:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L108
Model uses nn.GRU layers
I'm not sure why is nn.GRU converted to nn.GRUCell here:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L227

Also seems here hidden state is setted to zeros, so do I need it to set it to zeros before call generate function?
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L117

Update:
The problem was in random seed, to fix it:

    def generate(self, mels, batched_inference, target, overlap):
        # Fix random seed for reproducible results
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)

@xinqipony
Copy link

al model (dropout prenet) after attention gets aligned, I switch to BN prenet. Here is the difference be

hey Eren,
when you 'switch to BN prenet' and continue training, any changes to the lr or you just use the restored lr?

@erogol
Copy link
Contributor Author

erogol commented May 21, 2019

Hey! ... no change.

@mrgloom
Copy link

mrgloom commented May 31, 2019

I'm trying to reproduce WaveRnn results on single GPU and it's quite slow(i.e. days of training).

Here is my samples at checkpoint_230000 vs checkpoint_433000 (provided by @erogol) for same predicted mel spectrogram from tacotron2: examples.zip

I wonder this meaningless speech is normal and is just due to smaller number of steps or something is really broken?

@erogol
Copy link
Contributor Author

erogol commented Jun 3, 2019

something is really broken. I guess the char symbol order in the code does not match with the model.

@mrgloom
Copy link

mrgloom commented Jun 3, 2019

Yes, it still broken at 600k steps, but I'm training WaveRnn model only (on mel-spectrograms prepared by tacotron2 model), so as I understand char symbol order is not relevant here, because I only have mel spectrograms and wav files.
https://github.com/erogol/WaveRNN/blob/master/dataset.py#L17

Also I have checked prepared mel-spectrogram *.npy file by griffin-lim and by checkpoint_433000 model and in both cases it works fine.

@mrgloom
Copy link

mrgloom commented Jun 3, 2019

In ExtractTTSpectrogram.ipynb for LJspeech data preparation I have used config.json from ljspeech-260k tacotron2 model that can affect AudioProcessor, however the only difference that I can see is do_trim_silence parameter, can it cause problems?

In Tacotron2 config:
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)

In WaveRNN config:
"do_trim_silence": false // KEEP ALWAYS FALSE
https://github.com/erogol/WaveRNN/blob/master/config.json#L25

Also I have tried just to clone master of WaveRnn and use pretrained model:

git clone https://github.com/erogol/WaveRNN.git WaveRNN-master-test

In models/wavernn.py
#from ..utils.distribution import sample_from_gaussian, sample_from_discretized_mix_logistic
from utils.distribution import sample_from_gaussian, sample_from_discretized_mix_logistic

CUDA_VISIBLE_DEVICES=1 python train.py --data_path /data_large/tts-datasets/LJSpeech-1.1-wavernn/ --output_path experiment_temp --restore_path ./mold_ljspeech_best_model/checkpoint_433000.pth.tar --config_path ./mold_ljspeech_best_model/config.json

Next saved checkpoint checkpoint_434000.pth.tar is producing garbage, but ./mold_ljspeech_best_model/checkpoint_433000.pth.tar works fine.

@mrgloom
Copy link

mrgloom commented Jun 3, 2019

Yes, seems problem was in do_trim_silence and running from pretrained model gives sensible results now, loss is about 5.2

But why WaveRnn have problems with do_trim_silence? as I understand predicted mel-spectrogram from tacotron2 not obliged to be 'aligned' to wav file.

@erogol
Copy link
Contributor Author

erogol commented Jun 5, 2019

I am finalizing this thread since we solved Tacotron2 + WaveRNN

@mrgloom
Copy link

mrgloom commented Jun 7, 2019

Seems about 300k steps is sufficient:
wavernn_samples.zip

@erogol Can you elaborate on do_trim_silence parameter?

@erogol
Copy link
Contributor Author

erogol commented Jun 7, 2019

trims silences at the beginning and the end by thresholding.

@chynphh
Copy link

chynphh commented Jan 14, 2020

When training WaveRNN with mels from Tacotron2, which wavs do you use, the ground truth wavs file or the wavs generated by Tacotron2?

When I use the ground truth wavs, there are some bugs in MyDataset.collate, coarse = np.stack(coarse).astype(np.float32). The items in coarse have different shape. I think this is due to the wrong sig_offsets. The slicing operation x[1][sig_offsets[i] : sig_offsets[i] + seq_len + 1] is outside the range of x[1] sometimes.

But I don't know how to fix it.

@OswaldoBornemann
Copy link

May i ask how much time that cost you to train 300k steps using wavernn?

@OswaldoBornemann
Copy link

@chynphh I faced that problem too. Have you solved it yet ?

@Ivona221
Copy link

When training WaveRNN with mels from Tacotron2, which wavs do you use, the ground truth wavs file or the wavs generated by Tacotron2?

When I use the ground truth wavs, there are some bugs in MyDataset.collate, coarse = np.stack(coarse).astype(np.float32). The items in coarse have different shape. I think this is due to the wrong sig_offsets. The slicing operation x[1][sig_offsets[i] : sig_offsets[i] + seq_len + 1] is outside the range of x[1] sometimes.

But I don't know how to fix it

When training WaveRNN with mels from Tacotron2, which wavs do you use, the ground truth wavs file or the wavs generated by Tacotron2?

When I use the ground truth wavs, there are some bugs in MyDataset.collate, coarse = np.stack(coarse).astype(np.float32). The items in coarse have different shape. I think this is due to the wrong sig_offsets. The slicing operation x[1][sig_offsets[i] : sig_offsets[i] + seq_len + 1] is outside the range of x[1] sometimes.

But I don't know how to fix it.

I am facing the same issue. Is it solved yet?

@erogol
Copy link
Contributor Author

erogol commented Mar 27, 2020

it might be about triming the noise. Try diabling it if it is enabled or the otherway around.

@ethanstan
Copy link

@Ivona221 @chynphh I am having this same issue. I tried disabling and enabling the noise trimming as @erogol recommended. Still coming up with this same error.

Did any of you all find a fix?

coarse = np.stack(coarse).astype(np.float32) ValueError: all input arrays must have the same shape

@Ivona221
Copy link

Hi @ethanstan what branches are you using? I found out that you need to use Tacotron 2 from the branch https://github.com/mozilla/TTS/tree/Tacotron2-iter-260K-824c091 and I think I used the WaveRNN from the Master branch.

@ethanstan
Copy link

@Ivona221 I'm using the current TTS tacotron 2 master branch. Does this mean I need to retrain from that branch? I'm pretty happy with the output from Tacotron2 right now. Do you know if it was something in the config or what changed? Thank you!

@Ivona221
Copy link

@Ivona221 I'm using the current TTS tacotron 2 master branch. Does this mean I need to retrain from that branch? I'm pretty happy with the output from Tacotron2 right now. Do you know if it was something in the config or what changed? Thank you!

Yes when i was training I needed to change to that branch and train another model. The results are a bit worst, at least for my dataset (Macedonian language) but it is the only way I made it work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
experiment experimental things improvement a new feature
Projects
None yet
Development

No branches or pull requests