-
Notifications
You must be signed in to change notification settings - Fork 1
/
tests.py
86 lines (69 loc) · 3.33 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import test
import attention_cell
from sample_cell import SampleCell
from ind_cat_cell import IndCatCell
import time
TIME_STEPS = 100000 # TODO: adjust this
recurrent_max = pow(2, 1 / TIME_STEPS)
class ModelTest(test.TestCase):
def sample_cell_test(self):
with self.test_session() as sess:
x = tf.placeholder(tf.float32, shape=(1, 2))
logits = tf.placeholder(tf.float32, shape=(1, 256))
init_state = tf.placeholder(tf.float32, shape=(1, 256))
cell = SampleCell(recurrent_max)
out, _ = cell(tf.concat([logits, x], 1), init_state)
sess.run([tf.global_variables_initializer()])
cat, coarse = sess.run([cell.cat, cell.coarse_norm],
{x.name: np.array([[0.5, 0.1]]),
logits.name: np.random.normal(size=(1, 256)),
init_state.name: np.random.normal(size=(1, 256))})
def batchwise_conv_test(self):
with self.test_session() as sess:
filt_shape = (32, 10, 80, 32)
seq_shape = (32, 100, 80)
filt = tf.placeholder(tf.float32, shape=filt_shape)
seq = tf.placeholder(tf.float32, shape=seq_shape)
conv = attention_cell.batchwise_conv(seq, filt)
conv2 = attention_cell.batchwise_conv_2(seq, filt)
filt_d = np.random.normal(size=filt_shape)
seq_d = np.random.normal(size=seq_shape)
start_time = time.time()
result1 = sess.run([conv],
{filt.name: filt_d, seq.name: seq_d})
print(time.time() - start_time)
start_time = time.time()
result2 = sess.run([conv2],
{filt.name: filt_d, seq.name: seq_d})
print(time.time() - start_time)
self.assertAllClose(result1, result2, atol=1e-04)
def indcat_cell_test(self):
with self.test_session() as sess:
prev = tf.placeholder(tf.float32, shape=(1, 32))
x = tf.placeholder(tf.float32, shape=(1, 2))
init_state = tf.placeholder(tf.float32, shape=(1, 256))
cell = IndCatCell(256, recurrent_max)
out, _ = cell(tf.concat([prev, x], 1), init_state)
sess.run([tf.global_variables_initializer()])
result = sess.run([out],
{x.name: np.array([[0.5, 0.1]]),
prev.name: np.random.normal(size=(1, 32)),
init_state.name: np.random.normal(size=(1, 256))})
print(result)
def attention_cell_test(self):
with self.test_session() as sess:
x = tf.placeholder(tf.float32, shape=(1, 2))
features = tf.placeholder(tf.float32, shape=(1, 7, 5)) # batch, length, dict size
init_state = tf.placeholder(tf.float32, shape=(1, 800))
cell = attention_cell.AttentionCell(32, features, recurrent_max)
out, _ = cell(x, init_state)
sess.run([tf.global_variables_initializer()])
result = sess.run([out],
{x.name: np.array([[0.5, 0.1]]),
features.name: np.random.normal(size=(1, 7, 5)),
init_state.name: np.random.normal(size=(1, 800))})
print(result)
tests = ModelTest()