-
Notifications
You must be signed in to change notification settings - Fork 7
/
predict.py
executable file
·39 lines (30 loc) · 1.43 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import json
import tensorflow as tf
import coref_model as cm
import util
if __name__ == "__main__":
config = util.initialize_from_env()
# Input file in .jsonlines format.
input_filename = sys.argv[2]
# Predictions will be written to this file in .jsonlines format.
output_filename = sys.argv[3]
model = cm.CorefModel(config)
with tf.Session() as session:
model.restore(session)
with open(output_filename, "w") as output_file:
with open(input_filename) as input_file:
for example_num, line in enumerate(input_file.readlines()):
example = json.loads(line)
tensorized_example = model.tensorize_example(example, is_training=False)
feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)}
_, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict)
predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores)
example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents)
output_file.write(json.dumps(example))
output_file.write("\n")
if example_num % 100 == 0:
print("Decoded {} examples.".format(example_num + 1))