-
Notifications
You must be signed in to change notification settings - Fork 24
/
evaluate.py
52 lines (42 loc) · 2.05 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from utils import CharacterTable, transform
from utils import restore_model, decode_sequences
from utils import read_text, tokenize
error_rate = 0.6
reverse = True
model_path = './models/seq2seq.h5'
hidden_size = 512
sample_mode = 'argmax'
data_path = './data'
books = ['nietzsche.txt', 'pride_and_prejudice.txt', 'shakespeare.txt', 'war_and_peace.txt']
test_sentence = 'The rabbit-hole went straight on like a tunnel for some way, and then dipped suddenly down, so suddenly that Alice had not a moment to think about stopping herself before she found herself falling down a very deep well.'
if __name__ == '__main__':
text = read_text(data_path, books)
vocab = tokenize(text)
vocab = list(filter(None, set(vocab)))
# `maxlen` is the length of the longest word in the vocabulary
# plus two SOS and EOS characters.
maxlen = max([len(token) for token in vocab]) + 2
train_encoder, train_decoder, train_target = transform(
vocab, maxlen, error_rate=error_rate, shuffle=False)
tokens = tokenize(test_sentence)
tokens = list(filter(None, tokens))
nb_tokens = len(tokens)
misspelled_tokens, _, target_tokens = transform(
tokens, maxlen, error_rate=error_rate, shuffle=False)
input_chars = set(' '.join(train_encoder))
target_chars = set(' '.join(train_decoder))
input_ctable = CharacterTable(input_chars)
target_ctable = CharacterTable(target_chars)
encoder_model, decoder_model = restore_model(model_path, hidden_size)
input_tokens, target_tokens, decoded_tokens = decode_sequences(
misspelled_tokens, target_tokens, input_ctable, target_ctable,
maxlen, reverse, encoder_model, decoder_model, nb_tokens,
sample_mode=sample_mode, random=False)
print('-')
print('Input sentence: ', ' '.join([token for token in input_tokens]))
print('-')
print('Decoded sentence:', ' '.join([token for token in decoded_tokens]))
print('-')
print('Target sentence: ', ' '.join([token for token in target_tokens]))