-
Notifications
You must be signed in to change notification settings - Fork 275
/
masked_conv.py
57 lines (44 loc) · 1.52 KB
/
masked_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Implement masked 1d convolution with max-pooling"""
__author__ = "Guillaume Genthial"
from six.moves import reduce
import tensorflow as tf
def masked_conv1d_and_max(t, weights, filters, kernel_size):
"""Applies 1d convolution and a masked max-pooling
Parameters
----------
t : tf.Tensor
A tensor with at least 3 dimensions [d1, d2, ..., dn-1, dn]
weights : tf.Tensor of tf.bool
A Tensor of shape [d1, d2, dn-1]
filters : int
number of filters
kernel_size : int
kernel size for the temporal convolution
Returns
-------
tf.Tensor
A tensor of shape [d1, d2, dn-1, filters]
"""
# Get shape and parameters
shape = tf.shape(t)
ndims = t.shape.ndims
dim1 = reduce(lambda x, y: x*y, [shape[i] for i in range(ndims - 2)])
dim2 = shape[-2]
dim3 = t.shape[-1]
# Reshape weights
weights = tf.reshape(weights, shape=[dim1, dim2, 1])
weights = tf.to_float(weights)
# Reshape input and apply weights
flat_shape = [dim1, dim2, dim3]
t = tf.reshape(t, shape=flat_shape)
t *= weights
# Apply convolution
t_conv = tf.layers.conv1d(t, filters, kernel_size, padding='same')
t_conv *= weights
# Reduce max -- set to zero if all padded
t_conv += (1. - weights) * tf.reduce_min(t_conv, axis=-2, keepdims=True)
t_max = tf.reduce_max(t_conv, axis=-2)
# Reshape the output
final_shape = [shape[i] for i in range(ndims-2)] + [filters]
t_max = tf.reshape(t_max, shape=final_shape)
return t_max