forked from ofirnachum/sequence_gan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple_demo.py
110 lines (84 loc) · 2.8 KB
/
simple_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import print_function
__doc__ = """Simple demo on toy data.
In this demo, our toy data is the set of 'valley' sequences. These are
sequences of numbers which are first decreasing and then increasing.
We train our generator via curriculum learning. We start by only training
it to predict the next token given a groundtruth sequence via cross-entropy
loss. As our generator learns, we start to train it directly against the
discriminator.
The discriminator always learns at the standard rate (half real examples
and half artificial examples generated by the generator).
"""
import model
import train
import numpy as np
import tensorflow as tf
import random
NUM_EMB = 4
EMB_DIM = 5
HIDDEN_DIM = 10
SEQ_LENGTH = 5
START_TOKEN = 0
EPOCH_ITER = 1000
CURRICULUM_RATE = 0.03 # how quickly to move from supervised training to unsupervised
TRAIN_ITER = 100000 # generator/discriminator alternating
D_STEPS = 2 # how many times to train the discriminator per generator step
LEARNING_RATE = 0.01 * SEQ_LENGTH
SEED = 88
def get_trainable_model():
return model.GRU(
NUM_EMB, EMB_DIM, HIDDEN_DIM,
SEQ_LENGTH, START_TOKEN,
learning_rate=LEARNING_RATE)
def verify_sequence(seq):
downhill = True
prev = NUM_EMB
for tok in seq:
if tok == START_TOKEN:
return False
if downhill:
if tok > prev:
downhill = False
elif tok < prev:
return False
prev = tok
return True
def get_random_sequence():
"""Returns random valley sequence."""
tokens = set(range(NUM_EMB))
tokens.discard(START_TOKEN)
tokens = list(tokens)
pivot = int(random.random() * SEQ_LENGTH)
left_of_pivot = []
right_of_pivot = []
for i in range(SEQ_LENGTH):
tok = random.choice(tokens)
if i <= pivot:
left_of_pivot.append(tok)
else:
right_of_pivot.append(tok)
left_of_pivot.sort(reverse=True)
right_of_pivot.sort(reverse=False)
return left_of_pivot + right_of_pivot
def test_sequence_definition():
for _ in range(1000):
assert verify_sequence(get_random_sequence())
def main():
random.seed(SEED)
np.random.seed(SEED)
trainable_model = get_trainable_model()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print('training')
for epoch in range(TRAIN_ITER // EPOCH_ITER):
print('epoch', epoch)
proportion_supervised = max(0.0, 1.0 - CURRICULUM_RATE * epoch)
train.train_epoch(
sess, trainable_model, EPOCH_ITER,
proportion_supervised=proportion_supervised,
g_steps=1, d_steps=D_STEPS,
next_sequence=get_random_sequence,
verify_sequence=verify_sequence)
if __name__ == '__main__':
test_sequence_definition()
main()