-
Notifications
You must be signed in to change notification settings - Fork 794
/
model.py
executable file
·50 lines (35 loc) · 1.45 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sugartensor as tf
num_blocks = 3 # dilated blocks
num_dim = 128 # latent dimension
#
# logit calculating graph using atrous convolution
#
def get_logit(x, voca_size):
# residual block
def res_block(tensor, size, rate, block, dim=num_dim):
with tf.sg_context(name='block_%d_%d' % (block, rate)):
# filter convolution
conv_filter = tensor.sg_aconv1d(size=size, rate=rate, act='tanh', bn=True, name='conv_filter')
# gate convolution
conv_gate = tensor.sg_aconv1d(size=size, rate=rate, act='sigmoid', bn=True, name='conv_gate')
# output by gate multiplying
out = conv_filter * conv_gate
# final output
out = out.sg_conv1d(size=1, dim=dim, act='tanh', bn=True, name='conv_out')
# residual and skip output
return out + tensor, out
# expand dimension
with tf.sg_context(name='front'):
z = x.sg_conv1d(size=1, dim=num_dim, act='tanh', bn=True, name='conv_in')
# dilated conv block loop
skip = 0 # skip connections
for i in range(num_blocks):
for r in [1, 2, 4, 8, 16]:
z, s = res_block(z, size=7, rate=r, block=i)
skip += s
# final logit layers
with tf.sg_context(name='logit'):
logit = (skip
.sg_conv1d(size=1, act='tanh', bn=True, name='conv_1')
.sg_conv1d(size=1, dim=voca_size, name='conv_2'))
return logit