-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
111 lines (74 loc) · 4.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# coding: utf-8
import numpy as np
from nltk.tokenize import word_tokenize
from lstm_vae import create_lstm_vae, inference
def get_text_data(data_path, num_samples=1000):
# vectorize the data
input_texts = []
input_characters = set(["\t"])
with open(data_path, "r", encoding="utf-8") as f:
lines = f.read().lower().split("\n")
for line in lines[: min(num_samples, len(lines) - 1)]:
input_text, _ = line.split("\t")
input_text = word_tokenize(input_text)
input_text.append("<end>")
input_texts.append(input_text)
for char in input_text:
if char not in input_characters:
input_characters.add(char)
input_characters = sorted(list(input_characters))
num_encoder_tokens = len(input_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts]) + 1
print("Number of samples:", len(input_texts))
print("Number of unique input tokens:", num_encoder_tokens)
print("Max sequence length for inputs:", max_encoder_seq_length)
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype="float32")
decoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype="float32")
for i, input_text in enumerate(input_texts):
decoder_input_data[i, 0, input_token_index["\t"]] = 1.0
for t, char in enumerate(input_text):
encoder_input_data[i, t, input_token_index[char]] = 1.0
decoder_input_data[i, t + 1, input_token_index[char]] = 1.0
return max_encoder_seq_length, num_encoder_tokens, input_characters, input_token_index, reverse_input_char_index, \
encoder_input_data, decoder_input_data
if __name__ == "__main__":
from argparse import ArgumentParser
p = ArgumentParser()
p.add_argument("--input", default="data/fra.txt", type=str)
p.add_argument("--num_samples", default=3000, type=int)
p.add_argument("--batch_size", default=1, type=int)
p.add_argument("--epochs", default=40, type=int)
p.add_argument("--latent_dim", default=191, type=int)
p.add_argument("--inter_dim", default=353, type=int)
p.add_argument("--samples", default=5, type=int)
args = p.parse_args()
timesteps_max, enc_tokens, characters, char2id, id2char, x, x_decoder = get_text_data(num_samples=args.num_samples,
data_path=args.input)
print(x.shape, "Creating model...")
input_dim, timesteps = x.shape[-1], x.shape[-2]
batch_size, latent_dim = args.batch_size, args.latent_dim
intermediate_dim, epochs = args.inter_dim, args.epochs
vae, enc, gen, stepper = create_lstm_vae(input_dim,
batch_size=batch_size,
intermediate_dim=intermediate_dim,
latent_dim=latent_dim)
print("Training model...")
vae.fit([x, x_decoder], x, epochs=epochs, verbose=1)
print("Fitted, predicting...")
def decode(s):
return inference.decode_sequence(s, gen, stepper, input_dim, char2id, id2char, timesteps_max)
for _ in range(args.samples):
id_from = np.random.randint(0, x.shape[0] - 1)
id_to = np.random.randint(0, x.shape[0] - 1)
m_from, std_from = enc.predict([[x[id_from]]])
m_to, std_to = enc.predict([[x[id_to]]])
seq_from = np.random.normal(size=(latent_dim,))
seq_from = m_from + std_from * seq_from
seq_to = np.random.normal(size=(latent_dim,))
seq_to = m_to + std_to * seq_to
print("== \t", " ".join([id2char[j] for j in np.argmax(x[id_from], axis=1)]), "==")
for v in np.linspace(0, 1, 7):
print("%.2f\t" % (1 - v), decode(v * seq_to + (1 - v) * seq_from))
print("== \t", " ".join([id2char[j] for j in np.argmax(x[id_to], axis=1)]), "==")