-
Notifications
You must be signed in to change notification settings - Fork 500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Planned TODOs #1
Comments
At the moment, I think I finished to implement basic features (batch/incremental inference, local/global conditioning) and confirmed that unconditioned WaveNet trained on CMU Arctic (~1200 utterances, 16kHz) can generate sounds like speech. Audio samples are attached. top: real speech, bottom: generated speech. The only first one sample of real-speech was fed to the WaveNet decoder as an initial input. |
For reference, these are other wavenet projects I know of: |
Still not quite high quality, but vocoder conditioned on mel-spectrogram started to work. Audio samples from a model trained 10 hours are attached. |
Finished transposed convolution support at 8c0b5a9. Started training again. |
Hi, I've already tried to use linguistic features as local features, but I found there might be a problem that linguistic features are based on phoneme class, mel-specs are based on frame class, but the local features of wavenet inputs are based on sample point class. Here is a case, if a phoneme's duration is 0.25s, and its sample rate is 16k, in order to create the wavenet inputs, I have to duplicate the single phoneme's linguistic feature to int(0.25 * 16000) times as their samples' local features. Do you think my practice is right or not? How do you process the mel-spec features while they are frame class? Thanks for answering me. |
Wavenet can capture the differences even if many samples' local features are same as long as its receptive field is wide? |
@jamestang0219 I think you are right. In the paper http://www.isca-speech.org/archive/Interspeech_2017/pdfs/0314.PDF, they use log-f0 and mel-cepstrum as conditional features and duplicate them to adjust time resolution. I also tried this idea and got reasonable result. |
Latest audio sample attached. Mel-spectrogram are repeated to adjust time resolution. See Lines 39 to 40 in b8ee2ce
|
@r9y9 In your source code, you use transposed convolution to implement the upsample process? Have you ever checked which method is better for upsampling? |
@jamestang0219 I implemented transposed convolution but haven't got success yet. I wonder 256x upsampling is hard to train, especially for small dataset which I'm experimenting with now. WaveNet authors reported transposed convolution is better, though. |
Lines 43 to 47 in 3c9deb1
For now I am not using transposed convolution. |
@r9y9 May I know your hyper parameters for extracting mel spectrogram? Frame shift is 0.0125s and frame width is 0.05s? If this is your parameters, but why you use 256 as the upsample factor instead of sr(16000) * frame_shift(0.0125) = 200? Any tricks here? Forgive me for many questions:( because I also wanna reproduce tacotron2 result |
@jamestang0219 Hyper parameters for audio parameter extraction: Lines 19 to 28 in 3c9deb1
I use frame shift 256 samples / 16 ms. |
@r9y9 Thanks:) |
@r9y9 I notice that in Tacotron2, two upsampling layers with transposed convolution are used. But in my WaveNet implementation, it still can't work. |
@npuichigo Could you share what parameters (padding, kernel_size, etc) you are using? I tried 1d transposed covolution with stride=16, kernel_size=16, padding=0 two times to upsample inputs to 256x. wavenet_vocoder/wavenet_vocoder/wavenet.py Lines 105 to 112 in 8c0b5a9
|
@r9y9 Parameters of mine are listed below. Because I use frame shift which is 12.5ms, upsampling factor is 200.
|
https://r9y9.github.io/wavenet_vocoder/ Created a simple project page and uploaded audio samples for speaker-dependent WaveNet vocoder. I'm working on global conditioning (speaker embedding) now. |
@r9y9 Regarding upsampling network, I found that 2D transposed convolution works well, while 1D version will generate speech with unnatural prosody, maybe because 2D transpose convolution only consider local information in frequency domain. height_width = 3 # kernel width along frequency axis
up_lc_batch = tf.expand_dims(lc_batch, 3)
up_lc_batch = tf.layers.conv2d_transpose(
up_lc_batch, 1, (10, height_width),
strides=(10, 1), padding='SAME',
kernel_initializer=tf.constant_initializer(1.0 / height_width))
up_lc_batch = tf.layers.conv2d_transpose(
up_lc_batch, 1, (20, height_width),
strides=(20, 1), padding='SAME',
kernel_initializer=tf.constant_initializer(1.0 / height_width))
up_lc_batch = tf.squeeze(up_lc_batch, 3) |
@npuichigo Thank you for sharing that! Did you check the output of the upsampling network? Could upsampling network actually learn upsampling? I mean, did you get high-resolution mel-spectrogram? I was wondering if I need to add loss term regarding upsampling (e.g., MSE between coarse mel-spectrogram and 1-shift high resolution mel-spectrogram) and I'm curious whether it could be learned without upsampling specific loss. |
@r9y9 I think transposed convolution with same stride and kernel size is similar to duplicating. Like the following picture, if the kernel is one everywhere, then it's just duplicating. So maybe I need to check the values of kernel after training. |
https://r9y9.github.io/wavenet_vocoder/ Added audio samples for multi-speaker version of WaveNet vocoder. |
Hello @r9y9 , great work and awesome samples, would you mind sharing the weights of the network for the wavenet_vocoder trained on mel_spectrograms with CMU artic dataset without speaker embedding ? I would like to use and compare them with griffin-lim reconstruction to see which works better. |
@rishabh135 Not at all. Here it is: https://www.dropbox.com/sh/b1p32sxywo6xdnb/AAB2TU2DGhPDJgUzNc38Cz75a?dl=0 Note that you have to use exactly same mel-spectrogram extraction Lines 66 to 69 in f05e520
Lines 20 to 28 in f05e520
|
Using the transposed convolution below, I can get good initialization for the upsampling network. Very nice, thanks @npuichigo ! kernel_size = 3
padding = (kernel_size - 1) // 2
upsample_factor = 16
conv = nn.ConvTranspose2d(1,1,kernel_size=(kernel_size,upsample_factor),
stride=(1,upsample_factor), padding=(padding,0))
conv.bias.data.zero_()
conv.weight.data.fill_(1/kernel_size); Mel-spectrogram (hop_size = 256) 16x upsampled mel-spectrogram |
https://r9y9.github.io/wavenet_vocoder/ Update samples of multi-speaker WN. Used mixture of logistic distributions. It was quite costly to train.. Also added ground truth audio samples for ease of comparison. |
@r9y9 what do you mean by costly do train? what are the biggest challenges? |
I meant it's much time consuming. It took a week or more to get sufficient good quality for LJSpeech and CMU ARCTIC. |
Can you share the loss curve? |
I’m in a short business trip and do not have access to my GPU PC right now. I can share when I come back home after a week. |
That's great, Ryuchi! Thank you@ |
In the original Salimans pixel-cnn++ code the loss is converted to bits per output dimension which is actually quite handy for comparison with other implementations and experiments. For this just divide the loss by the dimensionality of the output * ln(2). How many bits is the model able to predict? |
This is unclear to me, probably because I haven't read the paper. You're
saying that n_bits = loss / C, that is the higher the loss the more bits the model can
output?
On Feb 19, 2018 1:36 PM, "bliep" <notifications@github.com> wrote:
In the original Salimans pixel-cnn++ code the loss is converted to bits per
output dimension which is actually quite handy for comparison with other
implementations and experiments. For this just divide the loss by the
dimensionality of the output * ln(2). How many bits is the model able to
predict?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ACMij7ep9ln_wukDD2lFUKJr_ihgIZOXks5tWelNgaJpZM4RPxW4>
.
|
The loss is the negative log probability, and averaged over the output dimension it is an estimate of the entropy in a sample. In the original paper (predicting pixels in an image) the residual entropy was around 3 bits (out of 8, so predicting 5 bits). Since it is not easy for me to figure out the output dimension of this wavenet implementation, a loss of 56-57 doesn't tell me much. |
I see now, it just the loss but normalized to bits, thus facilitating comparison as you mentioned! From what I understand the model has 10 mixture of logistics with 3 params each (pi, mean, log-scale), producing a total of 30 channels. This is what I understand from what @r9y9 has on the |
@r9y9 |
Yes, current master is the latest one and this is what I locally have. Maybe training procedure I described in #1 (comment) is important for quality. |
@r9y9 would you mind re-sharing your weights for the mel-conditioned wavenet? The link you shared earlier is broken. Thanks! |
@r9y9 for multi gpu training, i test that we only need to fix Line 639 in 26e4305
y_hat = torch.nn.parallel.data_parallel(model, (x, c, g, False)) and increase the num_workers , batch_size Also, we can set the device_ids and output_device for different cmd args
|
Efficient Neural Audio Synthesis https://arxiv.org/abs/1802.08435 |
https://github.com/r9y9/wavenet_vocoder#pre-trained-models Added link to pre-trained models. |
Hi @r9y9, Thank you so much for sharing your work. We have followed yours and got some results in Tensorflow. While we have not many tested yet, It works in the same parameters as yours except without Dropout, WeightNorm techniques. You can find some results in here. If I get another information during testing, I'll let you know about it. Thanks! |
@twidddj Nice! I'm looking forward to your results. |
I think I can close this now. Discussion on remained issues (e.g, DeepVoice + WaveNet) can continue on specific issue. |
This is an umbrella issue to track progress for my planned TODOs. Comments and requests are welcome.
Goal
Model
Training script
Experiments
Misc
[ ] Time sliced data generator?Sampling frequency
Advanced (lower priority)
The text was updated successfully, but these errors were encountered: