forked from ibab/tensorflow-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwavenet.py
134 lines (108 loc) · 5.62 KB
/
wavenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import tensorflow as tf
class WaveNet(object):
'''Implements the WaveNet network for generative audio.
Usage (with the architecture as in the DeepMind paper):
dilations = [2**i for i in range(N)] * M
channels = 2**8 # Quantize to 256 possible amplitude values.
net = WaveNet(batch_size, channels, dilations)
loss = net.loss(input_batch)
'''
def __init__(self, batch_size, channels, dilations, filter_width=2):
self.batch_size = batch_size
self.channels = channels
self.dilations = dilations
self.filter_width = filter_width
def _create_dilation_layer(self, input_batch, layer_index, dilation):
'''Adds a single causal dilated convolution layer.'''
# The filter widths can be configured as a hyperparameter.
wf = tf.Variable(tf.truncated_normal(
[1, self.filter_width, 256, 256],
stddev=0.2,
name="filter"))
wg = tf.Variable(tf.truncated_normal(
[1, self.filter_width, 256, 256],
stddev=0.2, name="gate"))
# TensorFlow has an operator for convolution with holes.
tmp1 = tf.nn.atrous_conv2d(input_batch, wf,
rate=dilation,
padding="SAME",
name="conv_f")
tmp2 = tf.nn.atrous_conv2d(input_batch, wg,
rate=dilation,
padding="SAME",
name="conv_g")
out = tf.tanh(tmp1) * tf.sigmoid(tmp2)
# Shift output to the right by dilation count so that only current/past
# values can influence the prediction.
out = tf.slice(out, [0] * 4, [-1, -1, tf.shape(out)[2] - dilation, -1])
out = tf.pad(out, [[0, 0], [0, 0], [dilation, 0], [0, 0]])
w = tf.Variable(tf.truncated_normal([1, 1, 256, 256], stddev=0.20,
name="dense"))
transformed = tf.nn.conv2d(out, w, strides=[1] * 4,
padding="SAME", name="dense")
tf.histogram_summary('layer{}_filter'.format(layer_index), wf)
tf.histogram_summary('layer{}_guard'.format(layer_index), wg)
tf.histogram_summary('layer{}_weights'.format(layer_index), w)
return transformed, input_batch + transformed
def _preprocess(self, audio):
'''Quantizes waveform amplitudes.'''
with tf.name_scope('preprocessing'):
mu = self.channels - 1
# Perform mu-law companding transformation (ITU-T, 1988).
magnitude = tf.log(1 + mu * tf.abs(audio)) / tf.log(1. + mu)
signal = tf.sign(audio) * magnitude
quantized = tf.cast((signal + 1) / 2 * mu, tf.int32)
return quantized
def _create_network(self, input_batch):
outputs = []
current_layer = input_batch
# Add all defined dilation layers.
with tf.name_scope('dilated_stack'):
for layer_index, dilation in enumerate(self.dilations):
with tf.name_scope('layer{}'.format(layer_index)):
output, current_layer = self._create_dilation_layer(
current_layer,
layer_index,
dilation=dilation)
outputs.append(output)
with tf.name_scope('postprocessing'):
# Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to
# postprocess the output.
w1 = tf.Variable(tf.truncated_normal([1, 1, 256, 256], stddev=0.3,
name="postprocess1"))
w2 = tf.Variable(tf.truncated_normal([1, 1, 256, 256], stddev=0.3,
name="postprocess2"))
tf.histogram_summary('postprocess1_weights', w1)
tf.histogram_summary('postprocess2_weights', w2)
# We skip connections from the outputs of each layer, adding them
# all up here.
total = sum(outputs)
transformed1 = tf.nn.relu(total)
conv1 = tf.nn.conv2d(transformed1, w1, [1] * 4, padding="SAME")
transformed2 = tf.nn.relu(conv1)
conv2 = tf.nn.conv2d(transformed2, w2, [1] * 4, padding="SAME")
return conv2
def loss(self, input_batch, name='wavenet'):
with tf.variable_scope(name):
input_batch = self._preprocess(input_batch)
# One-hot encode waveform amplitudes, so we can define the network
# as a categorical distribution over possible amplitudes.
with tf.name_scope('one_hot_encode'):
encoded = tf.one_hot(input_batch, depth=self.channels,
dtype=tf.float32)
encoded = tf.reshape(encoded,
[self.batch_size, 1, -1, self.channels])
raw_output = self._create_network(encoded)
with tf.name_scope('loss'):
# Shift original input left by one sample, which means that
# each output pixel has to predict the next input pixel.
shifted = tf.slice(encoded, [0, 0, 1, 0],
[-1, -1, tf.shape(encoded)[2] - 1, -1])
shifted = tf.pad(shifted, [[0, 0], [0, 0], [0, 1], [0, 0]])
prediction = tf.reshape(raw_output, [-1, self.channels])
loss = tf.nn.softmax_cross_entropy_with_logits(
prediction,
tf.reshape(shifted, [-1, self.channels]))
reduced_loss = tf.reduce_mean(loss)
tf.scalar_summary('loss', reduced_loss)
return reduced_loss