-
Notifications
You must be signed in to change notification settings - Fork 19
/
models.py
105 lines (95 loc) · 4.05 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import chainer
import chainer.functions as F
import chainer.links as L
from modules import VQ
from modules import WaveNet
from utils import ExponentialMovingAverage
class Encoder(chainer.Chain):
def __init__(self, d):
super(Encoder, self).__init__()
with self.init_scope():
self.conv1 = L.Convolution2D(1, d, (4, 1), (2, 1), (1, 0))
self.conv2 = L.Convolution2D(d, d, (4, 1), (2, 1), (1, 0))
self.conv3 = L.Convolution2D(d, d, (4, 1), (2, 1), (1, 0))
self.conv4 = L.Convolution2D(d, d, (4, 1), (2, 1), (1, 0))
self.conv5 = L.Convolution2D(d, d, (4, 1), (2, 1), (1, 0))
self.conv6 = L.Convolution2D(d, d, (4, 1), (2, 1), (1, 0))
def __call__(self, x):
h = F.relu(self.conv1(x))
h = F.relu(self.conv2(h))
h = F.relu(self.conv3(h))
h = F.relu(self.conv4(h))
h = F.relu(self.conv5(h))
z = self.conv6(h)
return z
class VAE(chainer.Chain):
def __init__(self, d, k, n_loop, n_layer, filter_size, quantize,
residual_channels, dilated_channels, skip_channels,
use_logistic, n_mixture, log_scale_min,
n_speaker, embed_channels, dropout_zero_rate, ema_mu, beta):
super(VAE, self).__init__()
self.beta = beta
self.ema_mu = ema_mu
self.quantize = quantize
with self.init_scope():
self.enc = Encoder(d)
self.vq = VQ(k)
dec = WaveNet(
n_loop, n_layer, filter_size, quantize, residual_channels,
dilated_channels, skip_channels, use_logistic,
global_conditioned=True, local_conditioned=True,
n_mixture=n_mixture, log_scale_min=log_scale_min,
n_speaker=n_speaker, embed_dim=embed_channels,
local_condition_dim=d, upsample_factor=64, use_deconv=False,
dropout_zero_rate=dropout_zero_rate)
if ema_mu < 1:
self.dec = ExponentialMovingAverage(dec, ema_mu)
else:
self.dec = dec
def __call__(self, raw, one_hot, speaker, quantized):
# forward
z = self.enc(raw)
e = self.vq(z)
e_ = self.vq(chainer.Variable(z.data))
global_cond = speaker
local_cond = e
y = self.dec(one_hot, global_cond, local_cond)
# calculate loss
loss1 = F.softmax_cross_entropy(y, quantized)
loss2 = F.mean((chainer.Variable(z.data) - e_) ** 2)
loss3 = self.beta * F.mean((z - chainer.Variable(e.data)) ** 2)
loss = loss1 + loss2 + loss3
chainer.reporter.report(
{'loss1': loss1, 'loss2': loss2, 'loss3': loss3, 'loss': loss},
self)
return loss1, loss2, loss3
def generate(self, raw, speaker, use_ema):
# initialize and encode
output = self.xp.zeros(raw.shape[2])
if self.ema_mu < 1:
if use_ema:
dec = self.dec.ema
else:
dec = self.dec.target
else:
dec = self.dec
with chainer.using_config('enable_backprop', False):
z = self.enc(raw)
e = self.vq(z)
global_cond = dec.embed_global_cond(speaker)
local_cond = dec.upsample_local_cond(e)
one_hot = chainer.Variable(self.xp.zeros(
self.quantize, dtype=self.xp.float32).reshape((1, -1, 1, 1)))
dec.initialize(1, global_cond)
length = local_cond.shape[2]
# generate
for i in range(length-1):
with chainer.using_config('enable_backprop', False):
out = dec.generate(one_hot, local_cond[:, :, i:i+1])
zeros = self.xp.zeros_like(one_hot.array)
value = self.xp.random.choice(
self.quantize, size=1, p=F.softmax(out).array[0, :, 0, 0])
output[i:i+1] = value
zeros[:, value, :, :] = 1
one_hot = chainer.Variable(zeros)
return output