-
Notifications
You must be signed in to change notification settings - Fork 0
/
radtts.py
752 lines (647 loc) · 33.4 KB
/
radtts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import torch
from torch import nn
from common import Encoder, LengthRegulator, ConvAttention
from common import Invertible1x1ConvLUS, Invertible1x1Conv
from common import AffineTransformationLayer, LinearNorm, ExponentialClass
from common import get_mask_from_lengths
from attribute_prediction_model import get_attribute_prediction_model
from alignment import mas_width1 as mas
class FlowStep(nn.Module):
def __init__(self, n_mel_channels, n_context_dim, n_layers,
affine_model='simple_conv', scaling_fn='exp',
matrix_decomposition='', affine_activation='softplus',
use_partial_padding=False, cache_inverse=False):
super(FlowStep, self).__init__()
if matrix_decomposition == 'LUS':
self.invtbl_conv = Invertible1x1ConvLUS(n_mel_channels, cache_inverse=cache_inverse)
else:
self.invtbl_conv = Invertible1x1Conv(n_mel_channels, cache_inverse=cache_inverse)
self.affine_tfn = AffineTransformationLayer(
n_mel_channels, n_context_dim, n_layers, affine_model=affine_model,
scaling_fn=scaling_fn, affine_activation=affine_activation,
use_partial_padding=use_partial_padding)
def enable_inverse_cache(self):
self.invtbl_conv.cache_inverse=True
def forward(self, z, context, inverse=False, seq_lens=None):
if inverse: # for inference z-> mel
z = self.affine_tfn(z, context, inverse, seq_lens=seq_lens)
z = self.invtbl_conv(z, inverse)
return z
else: # training mel->z
z, log_det_W = self.invtbl_conv(z)
z, log_s = self.affine_tfn(z, context, seq_lens=seq_lens)
return z, log_det_W, log_s
class RADTTS(torch.nn.Module):
def __init__(self, n_speakers, n_speaker_dim, n_text, n_text_dim, n_flows,
n_conv_layers_per_step, n_mel_channels, n_hidden,
mel_encoder_n_hidden, dummy_speaker_embedding, n_early_size,
n_early_every, n_group_size, affine_model, dur_model_config,
f0_model_config, energy_model_config, v_model_config=None,
include_modules='dec', scaling_fn='exp',
matrix_decomposition='', learn_alignments=False,
affine_activation='softplus', attn_use_CTC=True,
use_speaker_emb_for_alignment=False, use_context_lstm=False,
context_lstm_norm=None, text_encoder_lstm_norm=None,
n_f0_dims=0, n_energy_avg_dims=0,
context_lstm_w_f0_and_energy=True,
use_first_order_features=False, unvoiced_bias_activation='',
ap_pred_log_f0=False, **kwargs):
super(RADTTS, self).__init__()
assert(n_early_size % 2 == 0)
self.do_mel_descaling = kwargs.get('do_mel_descaling', True)
self.n_mel_channels = n_mel_channels
self.n_f0_dims = n_f0_dims # >= 1 to trains with f0
self.n_energy_avg_dims = n_energy_avg_dims # >= 1 trains with energy
self.decoder_use_partial_padding = kwargs.get(
'decoder_use_partial_padding', True)
self.n_speaker_dim = n_speaker_dim
assert(self.n_speaker_dim % 2 == 0)
self.speaker_embedding = torch.nn.Embedding(
n_speakers, self.n_speaker_dim)
self.embedding = torch.nn.Embedding(n_text, n_text_dim)
self.flows = torch.nn.ModuleList()
self.encoder = Encoder(encoder_embedding_dim=n_text_dim,
norm_fn=nn.InstanceNorm1d,
lstm_norm_fn=text_encoder_lstm_norm)
self.dummy_speaker_embedding = dummy_speaker_embedding
self.learn_alignments = learn_alignments
self.affine_activation = affine_activation
self.include_modules = include_modules
self.attn_use_CTC = bool(attn_use_CTC)
self.use_speaker_emb_for_alignment = use_speaker_emb_for_alignment
self.use_context_lstm = bool(use_context_lstm)
self.context_lstm_norm = context_lstm_norm
self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy
self.length_regulator = LengthRegulator()
self.use_first_order_features = bool(use_first_order_features)
self.decoder_use_unvoiced_bias = kwargs.get(
'decoder_use_unvoiced_bias', True)
self.ap_pred_log_f0 = ap_pred_log_f0
self.ap_use_unvoiced_bias = kwargs.get('ap_use_unvoiced_bias', True)
self.attn_straight_through_estimator = kwargs.get(
'attn_straight_through_estimator', False)
if 'atn' in include_modules or 'dec' in include_modules:
if self.learn_alignments:
if self.use_speaker_emb_for_alignment:
self.attention = ConvAttention(
n_mel_channels, n_text_dim + self.n_speaker_dim)
else:
self.attention = ConvAttention(n_mel_channels, n_text_dim)
self.n_flows = n_flows
self.n_group_size = n_group_size
n_flowstep_cond_dims = (
self.n_speaker_dim +
(n_text_dim + n_f0_dims + n_energy_avg_dims) * n_group_size)
if self.use_context_lstm:
n_in_context_lstm = (
self.n_speaker_dim + n_text_dim * n_group_size)
n_context_lstm_hidden = int(
(self.n_speaker_dim + n_text_dim * n_group_size) / 2)
if self.context_lstm_w_f0_and_energy:
n_in_context_lstm = (
n_f0_dims + n_energy_avg_dims + n_text_dim)
n_in_context_lstm *= n_group_size
n_in_context_lstm += self.n_speaker_dim
n_context_hidden = (
n_f0_dims + n_energy_avg_dims + n_text_dim)
n_context_hidden = n_context_hidden * n_group_size / 2
n_context_hidden = self.n_speaker_dim + n_context_hidden
n_context_hidden = int(n_context_hidden)
n_flowstep_cond_dims = (
self.n_speaker_dim + n_text_dim * n_group_size)
self.context_lstm = torch.nn.LSTM(
input_size=n_in_context_lstm,
hidden_size=n_context_lstm_hidden, num_layers=1,
batch_first=True, bidirectional=True)
if context_lstm_norm is not None:
if 'spectral' in context_lstm_norm:
print("Applying spectral norm to context encoder LSTM")
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
elif 'weight' in context_lstm_norm:
print("Applying weight norm to context encoder LSTM")
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
self.context_lstm = lstm_norm_fn_pntr(
self.context_lstm, 'weight_hh_l0')
self.context_lstm = lstm_norm_fn_pntr(
self.context_lstm, 'weight_hh_l0_reverse')
if self.n_group_size > 1:
self.unfold_params = {'kernel_size': (n_group_size, 1),
'stride': n_group_size,
'padding': 0, 'dilation': 1}
self.unfold = nn.Unfold(**self.unfold_params)
self.exit_steps = []
self.n_early_size = n_early_size
n_mel_channels = n_mel_channels*n_group_size
for i in range(self.n_flows):
if i > 0 and i % n_early_every == 0: # early exitting
n_mel_channels -= self.n_early_size
self.exit_steps.append(i)
self.flows.append(FlowStep(
n_mel_channels, n_flowstep_cond_dims,
n_conv_layers_per_step, affine_model, scaling_fn,
matrix_decomposition, affine_activation=affine_activation,
use_partial_padding=self.decoder_use_partial_padding))
if 'dpm' in include_modules:
dur_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
self.dur_pred_layer = get_attribute_prediction_model(
dur_model_config)
self.use_unvoiced_bias = False
self.use_vpred_module = False
self.ap_use_voiced_embeddings = kwargs.get(
'ap_use_voiced_embeddings', True)
if self.decoder_use_unvoiced_bias or self.ap_use_unvoiced_bias:
assert (unvoiced_bias_activation in {'relu', 'exp'})
self.use_unvoiced_bias = True
if unvoiced_bias_activation == 'relu':
unvbias_nonlin = nn.ReLU()
elif unvoiced_bias_activation == 'exp':
unvbias_nonlin = ExponentialClass()
else:
exit(1) # we won't reach here anyway due to the assertion
self.unvoiced_bias_module = nn.Sequential(
LinearNorm(n_text_dim, 1), unvbias_nonlin)
# all situations in which the vpred module is necessary
if self.ap_use_voiced_embeddings or self.use_unvoiced_bias or 'vpred' in include_modules:
self.use_vpred_module = True
if self.use_vpred_module:
v_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
self.v_pred_module = get_attribute_prediction_model(v_model_config)
# 4 embeddings, first two are scales, second two are biases
if self.ap_use_voiced_embeddings:
self.v_embeddings = torch.nn.Embedding(4, n_text_dim)
if 'apm' in include_modules:
f0_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
energy_model_config['hparams']['n_speaker_dim'] = n_speaker_dim
if self.use_first_order_features:
f0_model_config['hparams']['n_in_dim'] = 2
energy_model_config['hparams']['n_in_dim'] = 2
if 'spline_flow_params' in f0_model_config['hparams'] and f0_model_config['hparams']['spline_flow_params'] is not None:
f0_model_config['hparams']['spline_flow_params']['n_in_channels'] = 2
if 'spline_flow_params' in energy_model_config['hparams'] and energy_model_config['hparams']['spline_flow_params'] is not None:
energy_model_config['hparams']['spline_flow_params']['n_in_channels'] = 2
else:
if 'spline_flow_params' in f0_model_config['hparams'] and f0_model_config['hparams']['spline_flow_params'] is not None:
f0_model_config['hparams']['spline_flow_params']['n_in_channels'] = f0_model_config['hparams']['n_in_dim']
if 'spline_flow_params' in energy_model_config['hparams'] and energy_model_config['hparams']['spline_flow_params'] is not None:
energy_model_config['hparams']['spline_flow_params']['n_in_channels'] = energy_model_config['hparams']['n_in_dim']
self.f0_pred_module = get_attribute_prediction_model(
f0_model_config)
self.energy_pred_module = get_attribute_prediction_model(
energy_model_config)
def is_attribute_unconditional(self):
"""
returns true if the decoder is conditioned on neither energy nor F0
"""
return self.n_f0_dims == 0 and self.n_energy_avg_dims == 0
def encode_speaker(self, spk_ids):
spk_ids = spk_ids * 0 if self.dummy_speaker_embedding else spk_ids
spk_vecs = self.speaker_embedding(spk_ids)
return spk_vecs
def encode_text(self, text, in_lens):
# text_embeddings: b x len_text x n_text_dim
text_embeddings = self.embedding(text).transpose(1, 2)
# text_enc: b x n_text_dim x encoder_dim (512)
if in_lens is None:
text_enc = self.encoder.infer(text_embeddings).transpose(1, 2)
else:
text_enc = self.encoder(text_embeddings, in_lens).transpose(1, 2)
return text_enc, text_embeddings
def preprocess_context(self, context, speaker_vecs, out_lens=None, f0=None,
energy_avg=None):
if self.n_group_size > 1:
# unfolding zero-padded values
context = self.unfold(context.unsqueeze(-1))
if f0 is not None:
f0 = self.unfold(f0[:, None, :, None])
if energy_avg is not None:
energy_avg = self.unfold(energy_avg[:, None, :, None])
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, context.shape[2])
context_w_spkvec = torch.cat((context, speaker_vecs), 1)
if self.use_context_lstm:
if self.context_lstm_w_f0_and_energy:
if f0 is not None:
context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
if energy_avg is not None:
context_w_spkvec = torch.cat(
(context_w_spkvec, energy_avg), 1)
unfolded_out_lens = (out_lens // self.n_group_size).long().cpu()
unfolded_out_lens_packed = nn.utils.rnn.pack_padded_sequence(
context_w_spkvec.transpose(1, 2), unfolded_out_lens,
batch_first=True, enforce_sorted=False)
self.context_lstm.flatten_parameters()
context_lstm_packed_output, _ = self.context_lstm(
unfolded_out_lens_packed)
context_lstm_padded_output, _ = nn.utils.rnn.pad_packed_sequence(
context_lstm_packed_output, batch_first=True)
context_w_spkvec = context_lstm_padded_output.transpose(1, 2)
if not self.context_lstm_w_f0_and_energy:
if f0 is not None:
context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
if energy_avg is not None:
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
return context_w_spkvec
def enable_inverse_cache(self):
for flow_step in self.flows:
flow_step.enable_inverse_cache()
def fold(self, mel):
"""Inverse of the self.unfold(mel.unsqueeze(-1)) operation used for the
grouping or "squeeze" operation on input
Args:
mel: B x C x T tensor of temporal data
"""
mel = nn.functional.fold(
mel, output_size=(mel.shape[2]*self.n_group_size, 1),
**self.unfold_params).squeeze(-1)
return mel
def binarize_attention(self, attn, in_lens, out_lens):
"""For training purposes only. Binarizes attention with MAS. These will
no longer recieve a gradient
Args:
attn: B x 1 x max_mel_len x max_text_len
"""
b_size = attn.shape[0]
with torch.no_grad():
attn_cpu = attn.data.cpu().numpy()
attn_out = torch.zeros_like(attn)
for ind in range(b_size):
hard_attn = mas(attn_cpu[ind, 0, :out_lens[ind], :in_lens[ind]])
attn_out[ind, 0, :out_lens[ind], :in_lens[ind]] = torch.tensor(
hard_attn, device=attn.get_device())
return attn_out
def get_first_order_features(self, feats, out_lens, dilation=1):
"""
feats: b x max_length
out_lens: b-dim
"""
# add an extra column
feats_extended_R = torch.cat(
(feats, torch.zeros_like(feats[:, 0:dilation])), dim=1)
feats_extended_L = torch.cat(
(torch.zeros_like(feats[:, 0:dilation]), feats), dim=1)
dfeats_R = feats_extended_R[:, dilation:] - feats
dfeats_L = feats - feats_extended_L[:, 0:-dilation]
return (dfeats_R + dfeats_L) * 0.5
def apply_voice_mask_to_text(self, text_enc, voiced_mask):
"""
text_enc: b x C x N
voiced_mask: b x N
"""
voiced_mask = voiced_mask.unsqueeze(1)
voiced_embedding_s = self.v_embeddings.weight[0:1, :, None]
unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None]
voiced_embedding_b = self.v_embeddings.weight[2:3, :, None]
unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None]
scale = torch.sigmoid(voiced_embedding_s*voiced_mask + unvoiced_embedding_s*(1-voiced_mask))
bias = 0.1*torch.tanh(voiced_embedding_b*voiced_mask + unvoiced_embedding_b*(1-voiced_mask))
return text_enc*scale+bias
def forward(self, mel, speaker_ids, text, in_lens, out_lens,
binarize_attention=False, attn_prior=None,
f0=None, energy_avg=None, voiced_mask=None, p_voiced=None):
speaker_vecs = self.encode_speaker(speaker_ids)
text_enc, text_embeddings = self.encode_text(text, in_lens)
log_s_list, log_det_W_list, z_mel = [], [], []
attn = None
attn_soft = None
attn_hard = None
if 'atn' in self.include_modules or 'dec' in self.include_modules:
# make sure to do the alignments before folding
attn_mask = get_mask_from_lengths(in_lens)[..., None] == 0
text_embeddings_for_attn = text_embeddings
if self.use_speaker_emb_for_alignment:
speaker_vecs_expd = speaker_vecs[:, :, None].expand(
-1, -1, text_embeddings.shape[2])
text_embeddings_for_attn = torch.cat(
(text_embeddings_for_attn, speaker_vecs_expd.detach()), 1)
# attn_mask shld be 1 for unsd t-steps in text_enc_w_spkvec tensor
attn_soft, attn_logprob = self.attention(
mel, text_embeddings_for_attn, out_lens, attn_mask,
key_lens=in_lens, attn_prior=attn_prior)
if binarize_attention:
attn = self.binarize_attention(attn_soft, in_lens, out_lens)
attn_hard = attn
if self.attn_straight_through_estimator:
attn_hard = attn_soft + (attn_hard - attn_soft).detach()
else:
attn = attn_soft
context = torch.bmm(text_enc, attn.squeeze(1).transpose(1, 2))
f0_bias = 0
# unvoiced bias forward pass
if self.use_unvoiced_bias:
f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1))
f0_bias = -f0_bias[..., 0]
f0_bias = f0_bias * (~voiced_mask.bool()).float()
# mel decoder forward pass
if 'dec' in self.include_modules:
if self.n_group_size > 1:
# might truncate some frames at the end, but that's ok
# sometimes referred to as the "squeeeze" operation
# invert this by calling self.fold(mel_or_z)
mel = self.unfold(mel.unsqueeze(-1))
z_out = []
# where context is folded
# mask f0 in case values are interpolated
if f0 is None:
f0_aug = None
else:
if self.decoder_use_unvoiced_bias:
f0_aug = f0 * voiced_mask + f0_bias
else:
f0_aug = f0 * voiced_mask
context_w_spkvec = self.preprocess_context(
context, speaker_vecs, out_lens, f0_aug,
energy_avg)
log_s_list, log_det_W_list, z_out = [], [], []
unfolded_seq_lens = out_lens//self.n_group_size
for i, flow_step in enumerate(self.flows):
if i in self.exit_steps:
z = mel[:, :self.n_early_size]
z_out.append(z)
mel = mel[:, self.n_early_size:]
mel, log_det_W, log_s = flow_step(
mel, context_w_spkvec, seq_lens=unfolded_seq_lens)
log_s_list.append(log_s)
log_det_W_list.append(log_det_W)
z_out.append(mel)
z_mel = torch.cat(z_out, 1)
# duration predictor forward pass
duration_model_outputs = None
if 'dpm' in self.include_modules:
if attn_hard is None:
attn_hard = self.binarize_attention(
attn_soft, in_lens, out_lens)
# convert hard attention to durations
attn_hard_reduced = attn_hard.sum(2)[:, 0, :]
duration_model_outputs = self.dur_pred_layer(
torch.detach(text_enc),
torch.detach(speaker_vecs),
torch.detach(attn_hard_reduced.float()), in_lens)
# f0, energy, vpred predictors forward pass
f0_model_outputs = None
energy_model_outputs = None
vpred_model_outputs = None
if 'apm' in self.include_modules:
if attn_hard is None:
attn_hard = self.binarize_attention(
attn_soft, in_lens, out_lens)
# convert hard attention to durations
if binarize_attention:
text_enc_time_expanded = context.clone()
else:
text_enc_time_expanded = torch.bmm(
text_enc, attn_hard.squeeze(1).transpose(1, 2))
if self.use_vpred_module:
# unvoiced bias requires voiced mask prediction
vpred_model_outputs = self.v_pred_module(
torch.detach(text_enc_time_expanded),
torch.detach(speaker_vecs),
torch.detach(voiced_mask), out_lens)
# affine transform context using voiced mask
if self.ap_use_voiced_embeddings:
text_enc_time_expanded = self.apply_voice_mask_to_text(
text_enc_time_expanded, voiced_mask)
# whether to use the unvoiced bias in the attribute predictor
# circumvent in-place modification
f0_target = f0.clone()
if self.ap_use_unvoiced_bias:
f0_target = torch.detach(f0_target * voiced_mask + f0_bias)
else:
f0_target = torch.detach(f0_target)
# fit to log f0 in f0 predictor
f0_target[voiced_mask.bool()] = torch.log(
f0_target[voiced_mask.bool()])
f0_target = f0_target / 6 # scale to ~ [0, 1] in log space
energy_avg = energy_avg * 2 - 1 # scale to ~ [-1, 1]
if self.use_first_order_features:
df0 = self.get_first_order_features(f0_target, out_lens)
denergy_avg = self.get_first_order_features(
energy_avg, out_lens)
f0_voiced = torch.cat(
(f0_target[:, None], df0[:, None]), dim=1)
energy_avg = torch.cat(
(energy_avg[:, None], denergy_avg[:, None]), dim=1)
f0_voiced = f0_voiced * 3 # scale to ~ 1 std
energy_avg = energy_avg * 3 # scale to ~ 1 std
else:
f0_voiced = f0_target * 2 # scale to ~ 1 std
energy_avg = energy_avg * 1.4 # scale to ~ 1 std
f0_model_outputs = self.f0_pred_module(
text_enc_time_expanded, torch.detach(speaker_vecs),
f0_voiced, out_lens)
energy_model_outputs = self.energy_pred_module(
text_enc_time_expanded, torch.detach(speaker_vecs),
energy_avg, out_lens)
outputs = {'z_mel': z_mel,
'log_det_W_list': log_det_W_list,
'log_s_list': log_s_list,
'duration_model_outputs': duration_model_outputs,
'f0_model_outputs': f0_model_outputs,
'energy_model_outputs': energy_model_outputs,
'vpred_model_outputs': vpred_model_outputs,
'attn_soft': attn_soft,
'attn': attn,
'text_embeddings': text_embeddings,
'attn_logprob': attn_logprob
}
return outputs
def infer(self, speaker_id, text, sigma, sigma_dur=0.8, sigma_f0=0.8,
sigma_energy=0.8, token_dur_scaling=1.0, token_duration_max=100,
speaker_id_text=None, speaker_id_attributes=None, dur=None,
f0=None, energy_avg=None, voiced_mask=None, f0_mean=0.0,
f0_std=0.0, energy_mean=0.0, energy_std=0.0):
batch_size = text.shape[0]
n_tokens = text.shape[1]
spk_vec = self.encode_speaker(speaker_id)
spk_vec_text, spk_vec_attributes = spk_vec, spk_vec
if speaker_id_text is not None:
spk_vec_text = self.encode_speaker(speaker_id_text)
if speaker_id_attributes is not None:
spk_vec_attributes = self.encode_speaker(speaker_id_attributes)
txt_enc, txt_emb = self.encode_text(text, None)
if dur is None:
# get token durations
z_dur = torch.FloatTensor(batch_size, 1, n_tokens)
z_dur = z_dur.normal_() * sigma_dur
dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
if dur.shape[-1] < txt_enc.shape[-1]:
to_pad = txt_enc.shape[-1] - dur.shape[2]
pad_fn = nn.ReplicationPad1d((0, to_pad))
dur = pad_fn(dur)
dur = dur[:, 0]
dur = dur.clamp(0, token_duration_max)
dur = dur * token_dur_scaling if token_dur_scaling > 0 else dur
dur = (dur + 0.5).floor().int()
out_lens = dur.sum(1).long().cpu() if dur.shape[0] != 1 else [dur.sum(1)]
max_n_frames = max(out_lens)
out_lens = torch.LongTensor(out_lens).to(txt_enc.device)
# get attributes f0, energy, vpred, etc)
txt_enc_time_expanded = self.length_regulator(
txt_enc.transpose(1, 2), dur).transpose(1, 2)
if not self.is_attribute_unconditional():
# if explicitly modeling attributes
if voiced_mask is None:
if self.use_vpred_module:
# get logits
voiced_mask = self.v_pred_module.infer(
None, txt_enc_time_expanded, spk_vec_attributes)
voiced_mask = (torch.sigmoid(voiced_mask[:, 0]) > 0.5)
voiced_mask = voiced_mask.float()
ap_txt_enc_time_expanded = txt_enc_time_expanded
# voice mask augmentation only used for attribute prediction
if self.ap_use_voiced_embeddings:
ap_txt_enc_time_expanded = self.apply_voice_mask_to_text(
txt_enc_time_expanded, voiced_mask)
f0_bias = 0
# unvoiced bias forward pass
if self.use_unvoiced_bias:
f0_bias = self.unvoiced_bias_module(
txt_enc_time_expanded.permute(0, 2, 1))
f0_bias = -f0_bias[..., 0]
f0_bias = f0_bias * (~voiced_mask.bool()).float()
if f0 is None:
n_f0_feature_channels = 2 if self.use_first_order_features else 1
z_f0 = torch.FloatTensor(
batch_size, n_f0_feature_channels, max_n_frames).normal_() * sigma_f0
f0 = self.infer_f0(
z_f0, ap_txt_enc_time_expanded, spk_vec_attributes,
voiced_mask, out_lens)[:, 0]
if f0_mean > 0.0:
vmask_bool = voiced_mask.bool()
f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std()
f0[vmask_bool] = (f0[vmask_bool] - f0_mu) / f0_sigma
f0_std = f0_std if f0_std > 0 else f0_sigma
f0[vmask_bool] = f0[vmask_bool] * f0_std + f0_mean
if energy_avg is None:
n_energy_feature_channels = 2 if self.use_first_order_features else 1
z_energy_avg = torch.FloatTensor(
batch_size, n_energy_feature_channels, max_n_frames).normal_() * sigma_energy
energy_avg = self.infer_energy(
z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens)[:, 0]
# replication pad, because ungrouping with different group sizes
# may lead to mismatched lengths
if energy_avg.shape[1] < out_lens[0]:
to_pad = out_lens[0] - energy_avg.shape[1]
pad_fn = nn.ReplicationPad1d((0, to_pad))
f0 = pad_fn(f0[None])[0]
energy_avg = pad_fn(energy_avg[None])[0]
if f0.shape[1] < out_lens[0]:
to_pad = out_lens[0] - f0.shape[1]
pad_fn = nn.ReplicationPad1d((0, to_pad))
f0 = pad_fn(f0[None])[0]
if self.decoder_use_unvoiced_bias:
context_w_spkvec = self.preprocess_context(
txt_enc_time_expanded, spk_vec, out_lens,
f0 * voiced_mask + f0_bias, energy_avg)
else:
context_w_spkvec = self.preprocess_context(
txt_enc_time_expanded, spk_vec, out_lens, f0*voiced_mask,
energy_avg)
else:
context_w_spkvec = self.preprocess_context(
txt_enc_time_expanded, spk_vec, out_lens, None,
None)
residual = torch.FloatTensor(
batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size)
residual = residual.normal_() * sigma
# map from z sample to data
exit_steps_stack = self.exit_steps.copy()
mel = residual[:, len(exit_steps_stack) * self.n_early_size:]
remaining_residual = residual[:, :len(exit_steps_stack)*self.n_early_size]
unfolded_seq_lens = out_lens//self.n_group_size
for i, flow_step in enumerate(reversed(self.flows)):
curr_step = len(self.flows) - i - 1
mel = flow_step(mel, context_w_spkvec, inverse=True, seq_lens=unfolded_seq_lens)
if len(exit_steps_stack) > 0 and curr_step == exit_steps_stack[-1]:
# concatenate the next chunk of z
exit_steps_stack.pop()
residual_to_add = remaining_residual[
:, len(exit_steps_stack)*self.n_early_size:]
remaining_residual = remaining_residual[
:, :len(exit_steps_stack)*self.n_early_size]
mel = torch.cat((residual_to_add, mel), 1)
if self.n_group_size > 1:
mel = self.fold(mel)
if self.do_mel_descaling:
mel = mel * 2 - 5.5
return {'mel': mel,
'dur': dur,
'f0': f0,
'energy_avg': energy_avg,
'voiced_mask': voiced_mask
}
def infer_f0(self, residual, txt_enc_time_expanded, spk_vec,
voiced_mask=None, lens=None):
f0 = self.f0_pred_module.infer(
residual, txt_enc_time_expanded, spk_vec, lens)
if voiced_mask is not None and len(voiced_mask.shape) == 2:
voiced_mask = voiced_mask[:, None]
# constants
if self.ap_pred_log_f0:
if self.use_first_order_features:
f0 = f0[:,0:1,:] / 3
else:
f0 = f0 / 2
f0 = f0 * 6
else:
f0 = f0 / 6
f0 = f0 / 640
if voiced_mask is None:
voiced_mask = f0 > 0.0
else:
voiced_mask = voiced_mask.bool()
# due to grouping, f0 might be 1 frame short
voiced_mask = voiced_mask[:,:,:f0.shape[-1]]
if self.ap_pred_log_f0:
# if variable is set, decoder sees linear f0
# mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool()
f0[voiced_mask] = torch.exp(f0[voiced_mask])
f0[~voiced_mask] = 0.0
return f0
def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens):
energy = self.energy_pred_module.infer(
residual, txt_enc_time_expanded, spk_vec, lens)
# magic constants
if self.use_first_order_features:
energy = energy / 3
else:
energy = energy / 1.4
energy = (energy + 1) / 2
return energy
def remove_norms(self):
"""Removes spectral and weightnorms from model. Call before inference
"""
for name, module in self.named_modules():
try:
nn.utils.remove_spectral_norm(module, name='weight_hh_l0')
print("Removed spectral norm from {}".format(name))
except:
pass
try:
nn.utils.remove_spectral_norm(module, name='weight_hh_l0_reverse')
print("Removed spectral norm from {}".format(name))
except:
pass
try:
nn.utils.remove_weight_norm(module)
print("Removed wnorm from {}".format(name))
except:
pass